diff options
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r-- | tests/test_optimizer.py | 88 |
1 files changed, 83 insertions, 5 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 046e5a6..0e8ce15 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -298,7 +298,9 @@ class TestOptimizer(unittest.TestCase): self.check_file( "qualify_columns", qualify_columns, execute=True, schema=self.schema, set_dialect=True ) - self.check_file("qualify_columns_ddl", qualify_columns, schema=self.schema) + self.check_file( + "qualify_columns_ddl", qualify_columns, schema=self.schema, set_dialect=True + ) def test_qualify_columns__with_invisible(self): schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}}) @@ -340,6 +342,9 @@ class TestOptimizer(unittest.TestCase): def test_simplify(self): self.check_file("simplify", simplify, set_dialect=True) + expression = parse_one("SELECT a, c, b FROM table1 WHERE 1 = 1") + self.assertEqual(simplify(simplify(expression.find(exp.Where))).sql(), "WHERE TRUE") + expression = parse_one("TRUE AND TRUE AND TRUE") self.assertEqual(exp.true(), optimizer.simplify.simplify(expression)) self.assertEqual(exp.true(), optimizer.simplify.simplify(expression.this)) @@ -359,15 +364,18 @@ class TestOptimizer(unittest.TestCase): self.assertEqual("CONCAT('a', x, 'bc')", simplified_safe_concat.sql()) anon_unquoted_str = parse_one("anonymous(x, y)") - self.assertEqual(optimizer.simplify.gen(anon_unquoted_str), "ANONYMOUS x,y") + self.assertEqual(optimizer.simplify.gen(anon_unquoted_str), "ANONYMOUS(x,y)") + + query = parse_one("SELECT x FROM t") + self.assertEqual(optimizer.simplify.gen(query), optimizer.simplify.gen(query.copy())) anon_unquoted_identifier = exp.Anonymous( this=exp.to_identifier("anonymous"), expressions=[exp.column("x"), exp.column("y")] ) - self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS x,y") + self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS(x,y)") anon_quoted = parse_one('"anonymous"(x, y)') - self.assertEqual(optimizer.simplify.gen(anon_quoted), '"anonymous" x,y') + self.assertEqual(optimizer.simplify.gen(anon_quoted), '"anonymous"(x,y)') with self.assertRaises(ValueError) as e: anon_invalid = exp.Anonymous(this=5) @@ -375,6 +383,28 @@ class TestOptimizer(unittest.TestCase): self.assertIn("Anonymous.this expects a str or an Identifier, got 'int'.", str(e.exception)) + sql = parse_one( + """ + WITH cte AS (select 1 union select 2), cte2 AS ( + SELECT ROW() OVER (PARTITION BY y) FROM ( + (select 1) limit 10 + ) + ) + SELECT + *, + a + 1, + a div 1, + filter("B", (x, y) -> x + y) + FROM (z AS z CROSS JOIN z) AS f(a) LEFT JOIN a.b.c.d.e.f.g USING(n) ORDER BY 1 + """ + ) + self.assertEqual( + optimizer.simplify.gen(sql), + """ +SELECT :with,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expression,SELECT :expressions,2,:distinct,True,:alias, AS cte,CTE :this,SELECT :expressions,WINDOW :this,ROW(),:partition_by,y,:over,OVER,:from,FROM ((SELECT :expressions,1):limit,LIMIT :expression,10),:alias, AS cte2,:expressions,STAR,a + 1,a DIV 1,FILTER("B",LAMBDA :this,x + y,:expressions,x,y),:from,FROM (z AS z:joins,JOIN :this,z,:kind,CROSS) AS f(a),:joins,JOIN :this,a.b.c.d.e.f.g,:side,LEFT,:using,n,:order,ORDER :expressions,ORDERED :this,1,:nulls_first,True +""".strip(), + ) + def test_unnest_subqueries(self): self.check_file( "unnest_subqueries", @@ -475,6 +505,18 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') ) def test_scope(self): + ast = parse_one("SELECT IF(a IN UNNEST(b), 1, 0) AS c FROM t", dialect="bigquery") + self.assertEqual(build_scope(ast).columns, [exp.column("a"), exp.column("b")]) + + many_unions = parse_one(" UNION ALL ".join(["SELECT x FROM t"] * 10000)) + scopes_using_traverse = list(build_scope(many_unions).traverse()) + scopes_using_traverse_scope = traverse_scope(many_unions) + self.assertEqual(len(scopes_using_traverse), len(scopes_using_traverse_scope)) + assert all( + x.expression is y.expression + for x, y in zip(scopes_using_traverse, scopes_using_traverse_scope) + ) + sql = """ WITH q AS ( SELECT x.b FROM x @@ -522,7 +564,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual( { node.sql() - for node, *_ in walk_in_scope(expression.find(exp.Where)) + for node in walk_in_scope(expression.find(exp.Where)) if isinstance(node, exp.Column) }, {"s.b"}, @@ -667,6 +709,14 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expressions[0].type.this, exp.DataType.Type.BIGINT) self.assertEqual(expressions[1].type.this, exp.DataType.Type.DOUBLE) + expressions = annotate_types( + parse_one("SELECT SUM(2 / 3), CAST(2 AS DECIMAL) / 3", dialect="mysql") + ).expressions + + self.assertEqual(expressions[0].type.this, exp.DataType.Type.DOUBLE) + self.assertEqual(expressions[0].this.type.this, exp.DataType.Type.DOUBLE) + self.assertEqual(expressions[1].type.this, exp.DataType.Type.DECIMAL) + def test_bracket_annotation(self): expression = annotate_types(parse_one("SELECT A[:]")).expressions[0] @@ -1056,6 +1106,34 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.selects[1].type, exp.DataType.build("STRUCT<c int>")) self.assertEqual(expression.selects[2].type, exp.DataType.build("int")) + self.assertEqual( + annotate_types( + optimizer.qualify.qualify( + parse_one( + "SELECT x FROM UNNEST(GENERATE_DATE_ARRAY('2021-01-01', current_date(), interval 1 day)) AS x" + ) + ) + ) + .selects[0] + .type, + exp.DataType.build("date"), + ) + + def test_map_annotation(self): + # ToMap annotation + expression = annotate_types(parse_one("SELECT MAP {'x': 1}", read="duckdb")) + self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, INT)")) + + # Map annotation + expression = annotate_types( + parse_one("SELECT MAP(['key1', 'key2', 'key3'], [10, 20, 30])", read="duckdb") + ) + self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, INT)")) + + # VarMap annotation + expression = annotate_types(parse_one("SELECT MAP('a', 'b')", read="spark")) + self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, VARCHAR)")) + def test_recursive_cte(self): query = parse_one( """ |