diff options
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r-- | tests/test_optimizer.py | 95 |
1 files changed, 71 insertions, 24 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 604a364..c746a78 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -360,6 +360,37 @@ class TestOptimizer(unittest.TestCase): "SELECT _q_0.id AS id, _q_0.dt AS dt, _q_0.v AS v FROM (SELECT t1.id AS id, t1.dt AS dt, sum(coalesce(t2.v, 0)) AS v FROM t1 AS t1 LEFT JOIN lkp AS lkp ON t1.id = lkp.id LEFT JOIN t2 AS t2 ON lkp.other_id = t2.other_id AND t1.dt = t2.dt AND COALESCE(t1.common, lkp.common) = t2.common WHERE t1.id > 10 GROUP BY t1.id, t1.dt) AS _q_0", ) + # Detection of correlation where columns are referenced in derived tables nested within subqueries + self.assertEqual( + optimizer.qualify.qualify( + parse_one( + "SELECT a.g FROM a WHERE a.e < (SELECT MAX(u) FROM (SELECT SUM(c.b) AS u FROM c WHERE c.d = f GROUP BY c.e) w)" + ), + schema={ + "a": {"g": "INT", "e": "INT", "f": "INT"}, + "c": {"d": "INT", "e": "INT", "b": "INT"}, + }, + quote_identifiers=False, + ).sql(), + "SELECT a.g AS g FROM a AS a WHERE a.e < (SELECT MAX(w.u) AS _col_0 FROM (SELECT SUM(c.b) AS u FROM c AS c WHERE c.d = a.f GROUP BY c.e) AS w)", + ) + + # Detection of correlation where columns are referenced in derived tables nested within lateral joins + self.assertEqual( + optimizer.qualify.qualify( + parse_one( + "SELECT u.user_id, l.log_date FROM users AS u CROSS JOIN LATERAL (SELECT l1.log_date FROM (SELECT l.log_date FROM logs AS l WHERE l.user_id = u.user_id AND l.log_date <= 100 ORDER BY l.log_date LIMIT 1) AS l1) AS l", + dialect="postgres", + ), + schema={ + "users": {"user_id": "text", "log_date": "date"}, + "logs": {"user_id": "text", "log_date": "date"}, + }, + quote_identifiers=False, + ).sql("postgres"), + "SELECT u.user_id AS user_id, l.log_date AS log_date FROM users AS u CROSS JOIN LATERAL (SELECT l1.log_date AS log_date FROM (SELECT l.log_date AS log_date FROM logs AS l WHERE l.user_id = u.user_id AND l.log_date <= 100 ORDER BY l.log_date LIMIT 1) AS l1) AS l", + ) + self.check_file( "qualify_columns", qualify_columns, @@ -591,7 +622,7 @@ SELECT "_q_0"."n_comment" AS "n_comment" FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') AS "_q_0" """.strip(), - optimizer.optimize(expression).sql(pretty=True), + optimizer.optimize(expression, infer_csv_schemas=True).sql(pretty=True), ) def test_scope(self): @@ -1028,31 +1059,14 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') concat_expr.right.expressions[0].type.this, exp.DataType.Type.VARCHAR ) # x.cola (arg) + # Ensures we don't raise if there are unqualified columns annotate_types(parse_one("select x from y lateral view explode(y) as x")).expressions[0] - def test_null_annotation(self): - expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this - self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL) - self.assertEqual(expression.right.type.this, exp.DataType.Type.INT) - - # NULL <op> UNKNOWN should yield NULL - sql = "SELECT NULL + SOME_ANONYMOUS_FUNC() AS result" - - concat_expr_alias = annotate_types(parse_one(sql)).expressions[0] - self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.NULL) - - concat_expr = concat_expr_alias.this - self.assertEqual(concat_expr.type.this, exp.DataType.Type.NULL) - self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.NULL) - self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.UNKNOWN) - - def test_nullable_annotation(self): - nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) - expression = annotate_types(parse_one("NULL AND FALSE")) - - self.assertEqual(expression.type, nullable) - self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL) - self.assertEqual(expression.right.type.this, exp.DataType.Type.BOOLEAN) + # NULL <op> UNKNOWN should yield UNKNOWN + self.assertEqual( + annotate_types(parse_one("SELECT NULL + ANONYMOUS_FUNC()")).expressions[0].type.this, + exp.DataType.Type.UNKNOWN, + ) def test_predicate_annotation(self): expression = annotate_types(parse_one("x BETWEEN a AND b")) @@ -1181,6 +1195,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') exp.DataType.build("date"), ) + self.assertEqual( + annotate_types( + optimizer.qualify.qualify( + parse_one( + "SELECT x FROM UNNEST(GENERATE_TIMESTAMP_ARRAY('2016-10-05 00:00:00', '2016-10-06 02:00:00', interval 1 day)) AS x" + ) + ) + ) + .selects[0] + .type, + exp.DataType.build("timestamp"), + ) + def test_map_annotation(self): # ToMap annotation expression = annotate_types(parse_one("SELECT MAP {'x': 1}", read="duckdb")) @@ -1196,6 +1223,26 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') expression = annotate_types(parse_one("SELECT MAP('a', 'b')", read="spark")) self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, VARCHAR)")) + def test_union_annotation(self): + for left, right, expected_type in ( + ("SELECT 1::INT AS c", "SELECT 2::BIGINT AS c", "BIGINT"), + ("SELECT 1 AS c", "SELECT NULL AS c", "INT"), + ("SELECT FOO() AS c", "SELECT 1 AS c", "UNKNOWN"), + ("SELECT FOO() AS c", "SELECT BAR() AS c", "UNKNOWN"), + ): + with self.subTest(f"left: {left}, right: {right}, expected: {expected_type}"): + lr = annotate_types(parse_one(f"SELECT t.c FROM ({left} UNION ALL {right}) t(c)")) + rl = annotate_types(parse_one(f"SELECT t.c FROM ({right} UNION ALL {left}) t(c)")) + assert lr.selects[0].type == rl.selects[0].type == exp.DataType.build(expected_type) + + union_by_name = annotate_types( + parse_one( + "SELECT t.a, t.d FROM (SELECT 1 a, 3 d, UNION ALL BY NAME SELECT 7.0 d, 8::BIGINT a) AS t(a, d)" + ) + ) + self.assertEqual(union_by_name.selects[0].type.this, exp.DataType.Type.BIGINT) + self.assertEqual(union_by_name.selects[1].type.this, exp.DataType.Type.DOUBLE) + def test_recursive_cte(self): query = parse_one( """ |