From 20739a12c39121a9e7ad3c9a2469ec5a6876199d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 3 Jun 2023 01:59:40 +0200 Subject: Merging upstream version 15.0.0. Signed-off-by: Daniel Baumann --- tests/test_optimizer.py | 105 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 78 insertions(+), 27 deletions(-) (limited to 'tests/test_optimizer.py') diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 423cb84..2ae6da9 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -20,19 +20,20 @@ from tests.helpers import ( ) -def parse_and_optimize(func, sql, dialect, **kwargs): - return func(parse_one(sql, read=dialect), **kwargs) +def parse_and_optimize(func, sql, read_dialect, **kwargs): + return func(parse_one(sql, read=read_dialect), **kwargs) def qualify_columns(expression, **kwargs): - expression = optimizer.qualify_tables.qualify_tables(expression) - expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) + expression = optimizer.qualify.qualify( + expression, infer_schema=True, validate_qualify_columns=False, identify=False, **kwargs + ) return expression def pushdown_projections(expression, **kwargs): expression = optimizer.qualify_tables.qualify_tables(expression) - expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) + expression = optimizer.qualify_columns.qualify_columns(expression, infer_schema=True, **kwargs) expression = optimizer.pushdown_projections.pushdown_projections(expression, **kwargs) return expression @@ -98,7 +99,7 @@ class TestOptimizer(unittest.TestCase): }, } - def check_file(self, file, func, pretty=False, execute=False, **kwargs): + def check_file(self, file, func, pretty=False, execute=False, set_dialect=False, **kwargs): with ProcessPoolExecutor() as pool: results = {} @@ -113,6 +114,9 @@ class TestOptimizer(unittest.TestCase): if leave_tables_isolated is not None: func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated) + if set_dialect and dialect: + func_kwargs["dialect"] = dialect + future = pool.submit(parse_and_optimize, func, sql, dialect, **func_kwargs) results[future] = ( sql, @@ -141,13 +145,24 @@ class TestOptimizer(unittest.TestCase): assert_frame_equal(df1, df2) def test_optimize(self): + self.assertEqual(optimizer.optimize("x = 1 + 1", identify=None).sql(), "x = 2") + schema = { "x": {"a": "INT", "b": "INT"}, "y": {"b": "INT", "c": "INT"}, "z": {"a": "INT", "c": "INT"}, + "u": {"f": "INT", "g": "INT", "h": "TEXT"}, } - self.check_file("optimizer", optimizer.optimize, pretty=True, execute=True, schema=schema) + self.check_file( + "optimizer", + optimizer.optimize, + infer_schema=True, + pretty=True, + execute=True, + schema=schema, + set_dialect=True, + ) def test_isolate_table_selects(self): self.check_file( @@ -183,6 +198,15 @@ class TestOptimizer(unittest.TestCase): self.check_file("normalize", normalize) def test_qualify_columns(self): + self.assertEqual( + optimizer.qualify_columns.qualify_columns( + parse_one("select y from x"), + schema={}, + infer_schema=False, + ).sql(), + "SELECT y AS y FROM x", + ) + self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema) def test_qualify_columns__with_invisible(self): @@ -198,8 +222,12 @@ class TestOptimizer(unittest.TestCase): ) optimizer.qualify_columns.validate_qualify_columns(expression) - def test_lower_identities(self): - self.check_file("lower_identities", optimizer.lower_identities.lower_identities) + def test_normalize_identifiers(self): + self.check_file( + "normalize_identifiers", + optimizer.normalize_identifiers.normalize_identifiers, + set_dialect=True, + ) def test_pushdown_projection(self): self.check_file("pushdown_projections", pushdown_projections, schema=self.schema) @@ -221,24 +249,20 @@ class TestOptimizer(unittest.TestCase): def test_pushdown_predicates(self): self.check_file("pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates) - def test_expand_laterals(self): + def test_expand_alias_refs(self): # check order of lateral expansion with no schema self.assertEqual( - optimizer.optimize("SELECT a + 1 AS d, d + 1 AS e FROM x " "").sql(), - 'SELECT "x"."a" + 1 AS "d", "x"."a" + 2 AS "e" FROM "x" AS "x"', - ) - - self.check_file( - "expand_laterals", - optimizer.expand_laterals.expand_laterals, - pretty=True, - execute=True, + optimizer.optimize("SELECT a + 1 AS d, d + 1 AS e FROM x WHERE e > 1 GROUP BY e").sql(), + 'SELECT "x"."a" + 1 AS "d", "x"."a" + 2 AS "e" FROM "x" AS "x" WHERE "x"."a" + 2 > 1 GROUP BY "x"."a" + 2', ) - def test_expand_multi_table_selects(self): - self.check_file( - "expand_multi_table_selects", - optimizer.expand_multi_table_selects.expand_multi_table_selects, + self.assertEqual( + optimizer.qualify_columns.qualify_columns( + parse_one("SELECT CAST(x AS INT) AS y FROM z AS z"), + schema={"l": {"c": "int"}}, + infer_schema=False, + ).sql(), + "SELECT CAST(x AS INT) AS y FROM z AS z", ) def test_optimize_joins(self): @@ -280,8 +304,8 @@ class TestOptimizer(unittest.TestCase): optimize = partial( optimizer.optimize, rules=[ - optimizer.qualify_tables.qualify_tables, - optimizer.qualify_columns.qualify_columns, + optimizer.qualify.qualify, + optimizer.qualify_columns.quote_identifiers, annotate_types, optimizer.canonicalize.canonicalize, ], @@ -396,7 +420,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ) self.assertEqual(expression.this.type.this, exp.DataType.Type.VARCHAR) self.assertEqual(expression.args["to"].type.this, exp.DataType.Type.TIMESTAMPTZ) - self.assertEqual(expression.args["to"].expressions[0].type.this, exp.DataType.Type.INT) + self.assertEqual(expression.args["to"].expressions[0].this.type.this, exp.DataType.Type.INT) expression = annotate_types(parse_one("ARRAY(1)::ARRAY")) self.assertEqual(expression.type, parse_one("ARRAY", into=exp.DataType)) @@ -450,7 +474,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') expression.expressions[0].type.this, exp.DataType.Type.FLOAT ) # a.cola AS cola - addition_alias = expression.args["from"].expressions[0].this.expressions[0] + addition_alias = expression.args["from"].this.this.expressions[0] self.assertEqual( addition_alias.type.this, exp.DataType.Type.FLOAT ) # x.cola + y.cola AS cola @@ -663,3 +687,30 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') optimizer.optimize(parse_one("SELECT * FROM a"), schema=schema), parse_one('SELECT "a"."b c" AS "b c", "a"."d e" AS "d e" FROM "a" AS "a"'), ) + + def test_quotes(self): + schema = { + "example": { + '"source"': { + "id": "text", + '"name"': "text", + '"payload"': "text", + } + } + } + + expected = parse_one( + """ + SELECT + "source"."ID" AS "ID", + "source"."name" AS "name", + "source"."payload" AS "payload" + FROM "EXAMPLE"."source" AS "source" + """, + read="snowflake", + ).sql(pretty=True, dialect="snowflake") + + for func in (optimizer.qualify.qualify, optimizer.optimize): + source_query = parse_one('SELECT * FROM example."source"', read="snowflake") + transformed = func(source_query, dialect="snowflake", schema=schema) + self.assertEqual(transformed.sql(pretty=True, dialect="snowflake"), expected) -- cgit v1.2.3