summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/test_optimizer.py201
1 files changed, 174 insertions, 27 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 81b9731..857ba1a 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -27,11 +27,11 @@ def parse_and_optimize(func, sql, read_dialect, **kwargs):
return func(parse_one(sql, read=read_dialect), **kwargs)
-def qualify_columns(expression, **kwargs):
+def qualify_columns(expression, validate_qualify_columns=True, **kwargs):
expression = optimizer.qualify.qualify(
expression,
infer_schema=True,
- validate_qualify_columns=False,
+ validate_qualify_columns=validate_qualify_columns,
identify=False,
**kwargs,
)
@@ -135,11 +135,17 @@ class TestOptimizer(unittest.TestCase):
continue
dialect = meta.get("dialect")
leave_tables_isolated = meta.get("leave_tables_isolated")
+ validate_qualify_columns = meta.get("validate_qualify_columns")
func_kwargs = {**kwargs}
if leave_tables_isolated is not None:
func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)
+ if validate_qualify_columns is not None:
+ func_kwargs["validate_qualify_columns"] = string_to_bool(
+ validate_qualify_columns
+ )
+
if set_dialect and dialect:
func_kwargs["dialect"] = dialect
@@ -341,6 +347,88 @@ class TestOptimizer(unittest.TestCase):
"WITH tbl1 AS (SELECT STRUCT(1 AS `f0`, 2 AS f1) AS col) SELECT tbl1.col.`f0` AS `f0`, tbl1.col.f1 AS f1 FROM tbl1",
)
+ # can't coalesce USING columns because they don't exist in every already-joined table
+ self.assertEqual(
+ optimizer.qualify_columns.qualify_columns(
+ parse_one(
+ "SELECT id, dt, v FROM (SELECT t1.id, t1.dt, sum(coalesce(t2.v, 0)) AS v FROM t1 AS t1 LEFT JOIN lkp AS lkp USING (id) LEFT JOIN t2 AS t2 USING (other_id, dt, common) WHERE t1.id > 10 GROUP BY 1, 2) AS _q_0",
+ dialect="bigquery",
+ ),
+ schema=MappingSchema(
+ schema={
+ "t1": {"id": "int64", "dt": "date", "common": "int64"},
+ "lkp": {"id": "int64", "other_id": "int64", "common": "int64"},
+ "t2": {"other_id": "int64", "dt": "date", "v": "int64", "common": "int64"},
+ },
+ dialect="bigquery",
+ ),
+ ).sql(dialect="bigquery"),
+ "SELECT _q_0.id AS id, _q_0.dt AS dt, _q_0.v AS v FROM (SELECT t1.id AS id, t1.dt AS dt, sum(coalesce(t2.v, 0)) AS v FROM t1 AS t1 LEFT JOIN lkp AS lkp ON t1.id = lkp.id LEFT JOIN t2 AS t2 ON lkp.other_id = t2.other_id AND t1.dt = t2.dt AND COALESCE(t1.common, lkp.common) = t2.common WHERE t1.id > 10 GROUP BY t1.id, t1.dt) AS _q_0",
+ )
+
+ # Detection of correlation where columns are referenced in derived tables nested within subqueries
+ self.assertEqual(
+ optimizer.qualify.qualify(
+ parse_one(
+ "SELECT a.g FROM a WHERE a.e < (SELECT MAX(u) FROM (SELECT SUM(c.b) AS u FROM c WHERE c.d = f GROUP BY c.e) w)"
+ ),
+ schema={
+ "a": {"g": "INT", "e": "INT", "f": "INT"},
+ "c": {"d": "INT", "e": "INT", "b": "INT"},
+ },
+ quote_identifiers=False,
+ ).sql(),
+ "SELECT a.g AS g FROM a AS a WHERE a.e < (SELECT MAX(w.u) AS _col_0 FROM (SELECT SUM(c.b) AS u FROM c AS c WHERE c.d = a.f GROUP BY c.e) AS w)",
+ )
+
+ # Detection of correlation where columns are referenced in derived tables nested within lateral joins
+ self.assertEqual(
+ optimizer.qualify.qualify(
+ parse_one(
+ "SELECT u.user_id, l.log_date FROM users AS u CROSS JOIN LATERAL (SELECT l1.log_date FROM (SELECT l.log_date FROM logs AS l WHERE l.user_id = u.user_id AND l.log_date <= 100 ORDER BY l.log_date LIMIT 1) AS l1) AS l",
+ dialect="postgres",
+ ),
+ schema={
+ "users": {"user_id": "text", "log_date": "date"},
+ "logs": {"user_id": "text", "log_date": "date"},
+ },
+ quote_identifiers=False,
+ ).sql("postgres"),
+ "SELECT u.user_id AS user_id, l.log_date AS log_date FROM users AS u CROSS JOIN LATERAL (SELECT l1.log_date AS log_date FROM (SELECT l.log_date AS log_date FROM logs AS l WHERE l.user_id = u.user_id AND l.log_date <= 100 ORDER BY l.log_date LIMIT 1) AS l1) AS l",
+ )
+
+ self.assertEqual(
+ optimizer.qualify.qualify(
+ parse_one(
+ "SELECT A.b_id FROM A JOIN B ON A.b_id=B.b_id JOIN C USING(c_id)",
+ dialect="postgres",
+ ),
+ schema={
+ "A": {"b_id": "int"},
+ "B": {"b_id": "int", "c_id": "int"},
+ "C": {"c_id": "int"},
+ },
+ quote_identifiers=False,
+ ).sql("postgres"),
+ "SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.c_id = c.c_id",
+ )
+ self.assertEqual(
+ optimizer.qualify.qualify(
+ parse_one(
+ "SELECT A.b_id FROM A JOIN B ON A.b_id=B.b_id JOIN C ON B.b_id = C.b_id JOIN D USING(d_id)",
+ dialect="postgres",
+ ),
+ schema={
+ "A": {"b_id": "int"},
+ "B": {"b_id": "int", "d_id": "int"},
+ "C": {"b_id": "int"},
+ "D": {"d_id": "int"},
+ },
+ quote_identifiers=False,
+ ).sql("postgres"),
+ "SELECT a.b_id AS b_id FROM a AS a JOIN b AS b ON a.b_id = b.b_id JOIN c AS c ON b.b_id = c.b_id JOIN d AS d ON b.d_id = d.d_id",
+ )
+
self.check_file(
"qualify_columns",
qualify_columns,
@@ -473,15 +561,35 @@ SELECT :with,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expr
'SELECT "x"."a" + 1 AS "d", "x"."a" + 1 + 1 AS "e" FROM "x" AS "x" WHERE ("x"."a" + 2) > 1 GROUP BY "x"."a" + 1 + 1',
)
+ unused_schema = {"l": {"c": "int"}}
self.assertEqual(
optimizer.qualify_columns.qualify_columns(
parse_one("SELECT CAST(x AS INT) AS y FROM z AS z"),
- schema={"l": {"c": "int"}},
+ schema=unused_schema,
infer_schema=False,
).sql(),
"SELECT CAST(x AS INT) AS y FROM z AS z",
)
+ # BigQuery expands overlapping alias only for GROUP BY + HAVING
+ sql = "WITH data AS (SELECT 1 AS id, 2 AS my_id, 'a' AS name, 'b' AS full_name) SELECT id AS my_id, CONCAT(id, name) AS full_name FROM data WHERE my_id = 1 GROUP BY my_id, full_name HAVING my_id = 1"
+ self.assertEqual(
+ optimizer.qualify_columns.qualify_columns(
+ parse_one(sql, dialect="bigquery"),
+ schema=MappingSchema(schema=unused_schema, dialect="bigquery"),
+ ).sql(),
+ "WITH data AS (SELECT 1 AS id, 2 AS my_id, 'a' AS name, 'b' AS full_name) SELECT data.id AS my_id, CONCAT(data.id, data.name) AS full_name FROM data WHERE data.my_id = 1 GROUP BY data.id, CONCAT(data.id, data.name) HAVING data.id = 1",
+ )
+
+ # Clickhouse expands overlapping alias across the entire query
+ self.assertEqual(
+ optimizer.qualify_columns.qualify_columns(
+ parse_one(sql, dialect="clickhouse"),
+ schema=MappingSchema(schema=unused_schema, dialect="clickhouse"),
+ ).sql(),
+ "WITH data AS (SELECT 1 AS id, 2 AS my_id, 'a' AS name, 'b' AS full_name) SELECT data.id AS my_id, CONCAT(data.id, data.name) AS full_name FROM data WHERE data.id = 1 GROUP BY data.id, CONCAT(data.id, data.name) HAVING data.id = 1",
+ )
+
def test_optimize_joins(self):
self.check_file(
"optimize_joins",
@@ -552,7 +660,7 @@ SELECT
"_q_0"."n_comment" AS "n_comment"
FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') AS "_q_0"
""".strip(),
- optimizer.optimize(expression).sql(pretty=True),
+ optimizer.optimize(expression, infer_csv_schemas=True).sql(pretty=True),
)
def test_scope(self):
@@ -989,31 +1097,14 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
concat_expr.right.expressions[0].type.this, exp.DataType.Type.VARCHAR
) # x.cola (arg)
+ # Ensures we don't raise if there are unqualified columns
annotate_types(parse_one("select x from y lateral view explode(y) as x")).expressions[0]
- def test_null_annotation(self):
- expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
- self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
- self.assertEqual(expression.right.type.this, exp.DataType.Type.INT)
-
- # NULL <op> UNKNOWN should yield NULL
- sql = "SELECT NULL + SOME_ANONYMOUS_FUNC() AS result"
-
- concat_expr_alias = annotate_types(parse_one(sql)).expressions[0]
- self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.NULL)
-
- concat_expr = concat_expr_alias.this
- self.assertEqual(concat_expr.type.this, exp.DataType.Type.NULL)
- self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.NULL)
- self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.UNKNOWN)
-
- def test_nullable_annotation(self):
- nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
- expression = annotate_types(parse_one("NULL AND FALSE"))
-
- self.assertEqual(expression.type, nullable)
- self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
- self.assertEqual(expression.right.type.this, exp.DataType.Type.BOOLEAN)
+ # NULL <op> UNKNOWN should yield UNKNOWN
+ self.assertEqual(
+ annotate_types(parse_one("SELECT NULL + ANONYMOUS_FUNC()")).expressions[0].type.this,
+ exp.DataType.Type.UNKNOWN,
+ )
def test_predicate_annotation(self):
expression = annotate_types(parse_one("x BETWEEN a AND b"))
@@ -1142,6 +1233,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
exp.DataType.build("date"),
)
+ self.assertEqual(
+ annotate_types(
+ optimizer.qualify.qualify(
+ parse_one(
+ "SELECT x FROM UNNEST(GENERATE_TIMESTAMP_ARRAY('2016-10-05 00:00:00', '2016-10-06 02:00:00', interval 1 day)) AS x"
+ )
+ )
+ )
+ .selects[0]
+ .type,
+ exp.DataType.build("timestamp"),
+ )
+
def test_map_annotation(self):
# ToMap annotation
expression = annotate_types(parse_one("SELECT MAP {'x': 1}", read="duckdb"))
@@ -1157,6 +1261,26 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
expression = annotate_types(parse_one("SELECT MAP('a', 'b')", read="spark"))
self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, VARCHAR)"))
+ def test_union_annotation(self):
+ for left, right, expected_type in (
+ ("SELECT 1::INT AS c", "SELECT 2::BIGINT AS c", "BIGINT"),
+ ("SELECT 1 AS c", "SELECT NULL AS c", "INT"),
+ ("SELECT FOO() AS c", "SELECT 1 AS c", "UNKNOWN"),
+ ("SELECT FOO() AS c", "SELECT BAR() AS c", "UNKNOWN"),
+ ):
+ with self.subTest(f"left: {left}, right: {right}, expected: {expected_type}"):
+ lr = annotate_types(parse_one(f"SELECT t.c FROM ({left} UNION ALL {right}) t(c)"))
+ rl = annotate_types(parse_one(f"SELECT t.c FROM ({right} UNION ALL {left}) t(c)"))
+ assert lr.selects[0].type == rl.selects[0].type == exp.DataType.build(expected_type)
+
+ union_by_name = annotate_types(
+ parse_one(
+ "SELECT t.a, t.d FROM (SELECT 1 a, 3 d, UNION ALL BY NAME SELECT 7.0 d, 8::BIGINT a) AS t(a, d)"
+ )
+ )
+ self.assertEqual(union_by_name.selects[0].type.this, exp.DataType.Type.BIGINT)
+ self.assertEqual(union_by_name.selects[1].type.this, exp.DataType.Type.DOUBLE)
+
def test_recursive_cte(self):
query = parse_one(
"""
@@ -1253,3 +1377,26 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(4, normalization_distance(gen_expr(2), max_=100))
self.assertEqual(18, normalization_distance(gen_expr(3), max_=100))
self.assertEqual(110, normalization_distance(gen_expr(10), max_=100))
+
+ def test_custom_annotators(self):
+ # In Spark hierarchy, SUBSTRING result type is dependent on input expr type
+ for dialect in ("spark2", "spark", "databricks"):
+ for expr_type_pair in (
+ ("col", "STRING"),
+ ("col", "BINARY"),
+ ("'str_literal'", "STRING"),
+ ("CAST('str_literal' AS BINARY)", "BINARY"),
+ ):
+ with self.subTest(
+ f"Testing {dialect}'s SUBSTRING() result type for {expr_type_pair}"
+ ):
+ expr, type = expr_type_pair
+ ast = parse_one(f"SELECT substring({expr}, 2, 3) AS x FROM tbl", read=dialect)
+
+ subst_type = (
+ optimizer.optimize(ast, schema={"tbl": {"col": type}}, dialect=dialect)
+ .expressions[0]
+ .type
+ )
+
+ self.assertEqual(subst_type.sql(dialect), exp.DataType.build(type).sql(dialect))