summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-08-26 08:12:52 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-08-26 08:12:52 +0000
commita1f10f8d39404d9bae42a64efaf505fa12f34c1a (patch)
tree9eb894268f2a145aa9d42b1726a555ab1359810f /tests/test_optimizer.py
parentAdding upstream version 25.8.1. (diff)
downloadsqlglot-upstream/25.16.1.tar.xz
sqlglot-upstream/25.16.1.zip
Adding upstream version 25.16.1.upstream/25.16.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py95
1 files changed, 71 insertions, 24 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 604a364..c746a78 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -360,6 +360,37 @@ class TestOptimizer(unittest.TestCase):
"SELECT _q_0.id AS id, _q_0.dt AS dt, _q_0.v AS v FROM (SELECT t1.id AS id, t1.dt AS dt, sum(coalesce(t2.v, 0)) AS v FROM t1 AS t1 LEFT JOIN lkp AS lkp ON t1.id = lkp.id LEFT JOIN t2 AS t2 ON lkp.other_id = t2.other_id AND t1.dt = t2.dt AND COALESCE(t1.common, lkp.common) = t2.common WHERE t1.id > 10 GROUP BY t1.id, t1.dt) AS _q_0",
)
+ # Detection of correlation where columns are referenced in derived tables nested within subqueries
+ self.assertEqual(
+ optimizer.qualify.qualify(
+ parse_one(
+ "SELECT a.g FROM a WHERE a.e < (SELECT MAX(u) FROM (SELECT SUM(c.b) AS u FROM c WHERE c.d = f GROUP BY c.e) w)"
+ ),
+ schema={
+ "a": {"g": "INT", "e": "INT", "f": "INT"},
+ "c": {"d": "INT", "e": "INT", "b": "INT"},
+ },
+ quote_identifiers=False,
+ ).sql(),
+ "SELECT a.g AS g FROM a AS a WHERE a.e < (SELECT MAX(w.u) AS _col_0 FROM (SELECT SUM(c.b) AS u FROM c AS c WHERE c.d = a.f GROUP BY c.e) AS w)",
+ )
+
+ # Detection of correlation where columns are referenced in derived tables nested within lateral joins
+ self.assertEqual(
+ optimizer.qualify.qualify(
+ parse_one(
+ "SELECT u.user_id, l.log_date FROM users AS u CROSS JOIN LATERAL (SELECT l1.log_date FROM (SELECT l.log_date FROM logs AS l WHERE l.user_id = u.user_id AND l.log_date <= 100 ORDER BY l.log_date LIMIT 1) AS l1) AS l",
+ dialect="postgres",
+ ),
+ schema={
+ "users": {"user_id": "text", "log_date": "date"},
+ "logs": {"user_id": "text", "log_date": "date"},
+ },
+ quote_identifiers=False,
+ ).sql("postgres"),
+ "SELECT u.user_id AS user_id, l.log_date AS log_date FROM users AS u CROSS JOIN LATERAL (SELECT l1.log_date AS log_date FROM (SELECT l.log_date AS log_date FROM logs AS l WHERE l.user_id = u.user_id AND l.log_date <= 100 ORDER BY l.log_date LIMIT 1) AS l1) AS l",
+ )
+
self.check_file(
"qualify_columns",
qualify_columns,
@@ -591,7 +622,7 @@ SELECT
"_q_0"."n_comment" AS "n_comment"
FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') AS "_q_0"
""".strip(),
- optimizer.optimize(expression).sql(pretty=True),
+ optimizer.optimize(expression, infer_csv_schemas=True).sql(pretty=True),
)
def test_scope(self):
@@ -1028,31 +1059,14 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
concat_expr.right.expressions[0].type.this, exp.DataType.Type.VARCHAR
) # x.cola (arg)
+ # Ensures we don't raise if there are unqualified columns
annotate_types(parse_one("select x from y lateral view explode(y) as x")).expressions[0]
- def test_null_annotation(self):
- expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
- self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
- self.assertEqual(expression.right.type.this, exp.DataType.Type.INT)
-
- # NULL <op> UNKNOWN should yield NULL
- sql = "SELECT NULL + SOME_ANONYMOUS_FUNC() AS result"
-
- concat_expr_alias = annotate_types(parse_one(sql)).expressions[0]
- self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.NULL)
-
- concat_expr = concat_expr_alias.this
- self.assertEqual(concat_expr.type.this, exp.DataType.Type.NULL)
- self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.NULL)
- self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.UNKNOWN)
-
- def test_nullable_annotation(self):
- nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
- expression = annotate_types(parse_one("NULL AND FALSE"))
-
- self.assertEqual(expression.type, nullable)
- self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
- self.assertEqual(expression.right.type.this, exp.DataType.Type.BOOLEAN)
+ # NULL <op> UNKNOWN should yield UNKNOWN
+ self.assertEqual(
+ annotate_types(parse_one("SELECT NULL + ANONYMOUS_FUNC()")).expressions[0].type.this,
+ exp.DataType.Type.UNKNOWN,
+ )
def test_predicate_annotation(self):
expression = annotate_types(parse_one("x BETWEEN a AND b"))
@@ -1181,6 +1195,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
exp.DataType.build("date"),
)
+ self.assertEqual(
+ annotate_types(
+ optimizer.qualify.qualify(
+ parse_one(
+ "SELECT x FROM UNNEST(GENERATE_TIMESTAMP_ARRAY('2016-10-05 00:00:00', '2016-10-06 02:00:00', interval 1 day)) AS x"
+ )
+ )
+ )
+ .selects[0]
+ .type,
+ exp.DataType.build("timestamp"),
+ )
+
def test_map_annotation(self):
# ToMap annotation
expression = annotate_types(parse_one("SELECT MAP {'x': 1}", read="duckdb"))
@@ -1196,6 +1223,26 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
expression = annotate_types(parse_one("SELECT MAP('a', 'b')", read="spark"))
self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, VARCHAR)"))
+ def test_union_annotation(self):
+ for left, right, expected_type in (
+ ("SELECT 1::INT AS c", "SELECT 2::BIGINT AS c", "BIGINT"),
+ ("SELECT 1 AS c", "SELECT NULL AS c", "INT"),
+ ("SELECT FOO() AS c", "SELECT 1 AS c", "UNKNOWN"),
+ ("SELECT FOO() AS c", "SELECT BAR() AS c", "UNKNOWN"),
+ ):
+ with self.subTest(f"left: {left}, right: {right}, expected: {expected_type}"):
+ lr = annotate_types(parse_one(f"SELECT t.c FROM ({left} UNION ALL {right}) t(c)"))
+ rl = annotate_types(parse_one(f"SELECT t.c FROM ({right} UNION ALL {left}) t(c)"))
+ assert lr.selects[0].type == rl.selects[0].type == exp.DataType.build(expected_type)
+
+ union_by_name = annotate_types(
+ parse_one(
+ "SELECT t.a, t.d FROM (SELECT 1 a, 3 d, UNION ALL BY NAME SELECT 7.0 d, 8::BIGINT a) AS t(a, d)"
+ )
+ )
+ self.assertEqual(union_by_name.selects[0].type.this, exp.DataType.Type.BIGINT)
+ self.assertEqual(union_by_name.selects[1].type.this, exp.DataType.Type.DOUBLE)
+
def test_recursive_cte(self):
query = parse_one(
"""