summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py88
1 files changed, 83 insertions, 5 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 046e5a6..0e8ce15 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -298,7 +298,9 @@ class TestOptimizer(unittest.TestCase):
self.check_file(
"qualify_columns", qualify_columns, execute=True, schema=self.schema, set_dialect=True
)
- self.check_file("qualify_columns_ddl", qualify_columns, schema=self.schema)
+ self.check_file(
+ "qualify_columns_ddl", qualify_columns, schema=self.schema, set_dialect=True
+ )
def test_qualify_columns__with_invisible(self):
schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}})
@@ -340,6 +342,9 @@ class TestOptimizer(unittest.TestCase):
def test_simplify(self):
self.check_file("simplify", simplify, set_dialect=True)
+ expression = parse_one("SELECT a, c, b FROM table1 WHERE 1 = 1")
+ self.assertEqual(simplify(simplify(expression.find(exp.Where))).sql(), "WHERE TRUE")
+
expression = parse_one("TRUE AND TRUE AND TRUE")
self.assertEqual(exp.true(), optimizer.simplify.simplify(expression))
self.assertEqual(exp.true(), optimizer.simplify.simplify(expression.this))
@@ -359,15 +364,18 @@ class TestOptimizer(unittest.TestCase):
self.assertEqual("CONCAT('a', x, 'bc')", simplified_safe_concat.sql())
anon_unquoted_str = parse_one("anonymous(x, y)")
- self.assertEqual(optimizer.simplify.gen(anon_unquoted_str), "ANONYMOUS x,y")
+ self.assertEqual(optimizer.simplify.gen(anon_unquoted_str), "ANONYMOUS(x,y)")
+
+ query = parse_one("SELECT x FROM t")
+ self.assertEqual(optimizer.simplify.gen(query), optimizer.simplify.gen(query.copy()))
anon_unquoted_identifier = exp.Anonymous(
this=exp.to_identifier("anonymous"), expressions=[exp.column("x"), exp.column("y")]
)
- self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS x,y")
+ self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS(x,y)")
anon_quoted = parse_one('"anonymous"(x, y)')
- self.assertEqual(optimizer.simplify.gen(anon_quoted), '"anonymous" x,y')
+ self.assertEqual(optimizer.simplify.gen(anon_quoted), '"anonymous"(x,y)')
with self.assertRaises(ValueError) as e:
anon_invalid = exp.Anonymous(this=5)
@@ -375,6 +383,28 @@ class TestOptimizer(unittest.TestCase):
self.assertIn("Anonymous.this expects a str or an Identifier, got 'int'.", str(e.exception))
+ sql = parse_one(
+ """
+ WITH cte AS (select 1 union select 2), cte2 AS (
+ SELECT ROW() OVER (PARTITION BY y) FROM (
+ (select 1) limit 10
+ )
+ )
+ SELECT
+ *,
+ a + 1,
+ a div 1,
+ filter("B", (x, y) -> x + y)
+ FROM (z AS z CROSS JOIN z) AS f(a) LEFT JOIN a.b.c.d.e.f.g USING(n) ORDER BY 1
+ """
+ )
+ self.assertEqual(
+ optimizer.simplify.gen(sql),
+ """
+SELECT :with,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expression,SELECT :expressions,2,:distinct,True,:alias, AS cte,CTE :this,SELECT :expressions,WINDOW :this,ROW(),:partition_by,y,:over,OVER,:from,FROM ((SELECT :expressions,1):limit,LIMIT :expression,10),:alias, AS cte2,:expressions,STAR,a + 1,a DIV 1,FILTER("B",LAMBDA :this,x + y,:expressions,x,y),:from,FROM (z AS z:joins,JOIN :this,z,:kind,CROSS) AS f(a),:joins,JOIN :this,a.b.c.d.e.f.g,:side,LEFT,:using,n,:order,ORDER :expressions,ORDERED :this,1,:nulls_first,True
+""".strip(),
+ )
+
def test_unnest_subqueries(self):
self.check_file(
"unnest_subqueries",
@@ -475,6 +505,18 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
)
def test_scope(self):
+ ast = parse_one("SELECT IF(a IN UNNEST(b), 1, 0) AS c FROM t", dialect="bigquery")
+ self.assertEqual(build_scope(ast).columns, [exp.column("a"), exp.column("b")])
+
+ many_unions = parse_one(" UNION ALL ".join(["SELECT x FROM t"] * 10000))
+ scopes_using_traverse = list(build_scope(many_unions).traverse())
+ scopes_using_traverse_scope = traverse_scope(many_unions)
+ self.assertEqual(len(scopes_using_traverse), len(scopes_using_traverse_scope))
+ assert all(
+ x.expression is y.expression
+ for x, y in zip(scopes_using_traverse, scopes_using_traverse_scope)
+ )
+
sql = """
WITH q AS (
SELECT x.b FROM x
@@ -522,7 +564,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(
{
node.sql()
- for node, *_ in walk_in_scope(expression.find(exp.Where))
+ for node in walk_in_scope(expression.find(exp.Where))
if isinstance(node, exp.Column)
},
{"s.b"},
@@ -667,6 +709,14 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expressions[0].type.this, exp.DataType.Type.BIGINT)
self.assertEqual(expressions[1].type.this, exp.DataType.Type.DOUBLE)
+ expressions = annotate_types(
+ parse_one("SELECT SUM(2 / 3), CAST(2 AS DECIMAL) / 3", dialect="mysql")
+ ).expressions
+
+ self.assertEqual(expressions[0].type.this, exp.DataType.Type.DOUBLE)
+ self.assertEqual(expressions[0].this.type.this, exp.DataType.Type.DOUBLE)
+ self.assertEqual(expressions[1].type.this, exp.DataType.Type.DECIMAL)
+
def test_bracket_annotation(self):
expression = annotate_types(parse_one("SELECT A[:]")).expressions[0]
@@ -1056,6 +1106,34 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.selects[1].type, exp.DataType.build("STRUCT<c int>"))
self.assertEqual(expression.selects[2].type, exp.DataType.build("int"))
+ self.assertEqual(
+ annotate_types(
+ optimizer.qualify.qualify(
+ parse_one(
+ "SELECT x FROM UNNEST(GENERATE_DATE_ARRAY('2021-01-01', current_date(), interval 1 day)) AS x"
+ )
+ )
+ )
+ .selects[0]
+ .type,
+ exp.DataType.build("date"),
+ )
+
+ def test_map_annotation(self):
+ # ToMap annotation
+ expression = annotate_types(parse_one("SELECT MAP {'x': 1}", read="duckdb"))
+ self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, INT)"))
+
+ # Map annotation
+ expression = annotate_types(
+ parse_one("SELECT MAP(['key1', 'key2', 'key3'], [10, 20, 30])", read="duckdb")
+ )
+ self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, INT)"))
+
+ # VarMap annotation
+ expression = annotate_types(parse_one("SELECT MAP('a', 'b')", read="spark"))
+ self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, VARCHAR)"))
+
def test_recursive_cte(self):
query = parse_one(
"""