summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/dialects/test_bigquery.py18
-rw-r--r--tests/dialects/test_clickhouse.py10
-rw-r--r--tests/dialects/test_dialect.py45
-rw-r--r--tests/dialects/test_hive.py18
-rw-r--r--tests/dialects/test_postgres.py1
-rw-r--r--tests/dialects/test_snowflake.py32
-rw-r--r--tests/dialects/test_spark.py58
-rw-r--r--tests/fixtures/identity.sql6
-rw-r--r--tests/fixtures/optimizer/merge_subqueries.sql168
-rw-r--r--tests/fixtures/optimizer/optimizer.sql140
-rw-r--r--tests/fixtures/optimizer/qualify_columns.sql43
-rw-r--r--tests/fixtures/optimizer/qualify_columns__with_invisible.sql35
-rw-r--r--tests/fixtures/optimizer/simplify.sql6
-rw-r--r--tests/fixtures/optimizer/tpc-h/tpc-h.sql13
-rw-r--r--tests/helpers.py8
-rw-r--r--tests/test_build.py63
-rw-r--r--tests/test_expressions.py22
-rw-r--r--tests/test_optimizer.py225
-rw-r--r--tests/test_transpile.py2
19 files changed, 810 insertions, 103 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index 7110eac..8921924 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -153,6 +153,10 @@ class TestBigQuery(Validator):
)
self.validate_identity(
+ "SELECT item, purchases, LAST_VALUE(item) OVER (item_window ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS most_popular FROM Produce WINDOW item_window AS (ORDER BY purchases)"
+ )
+
+ self.validate_identity(
"SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)",
)
@@ -223,6 +227,20 @@ class TestBigQuery(Validator):
},
)
self.validate_all(
+ "DATE_DIFF(DATE '2010-07-07', DATE '2008-12-25', DAY)",
+ write={
+ "bigquery": "DATE_DIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE), DAY)",
+ "mysql": "DATEDIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE))",
+ },
+ )
+ self.validate_all(
+ "DATE_DIFF(DATE '2010-07-07', DATE '2008-12-25', MINUTE)",
+ write={
+ "bigquery": "DATE_DIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE), MINUTE)",
+ "mysql": "DATEDIFF(CAST('2010-07-07' AS DATE), CAST('2008-12-25' AS DATE))",
+ },
+ )
+ self.validate_all(
"CURRENT_DATE('UTC')",
write={
"mysql": "CURRENT_DATE AT TIME ZONE 'UTC'",
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index e5b1516..715bf10 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -8,6 +8,8 @@ class TestClickhouse(Validator):
self.validate_identity("dictGet(x, 'y')")
self.validate_identity("SELECT * FROM x FINAL")
self.validate_identity("SELECT * FROM x AS y FINAL")
+ self.validate_identity("'a' IN mapKeys(map('a', 1, 'b', 2))")
+ self.validate_identity("CAST((1, 2) AS Tuple(a Int8, b Int16))")
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
@@ -20,6 +22,12 @@ class TestClickhouse(Validator):
self.validate_all(
"CAST(1 AS NULLABLE(Int64))",
write={
- "clickhouse": "CAST(1 AS Nullable(BIGINT))",
+ "clickhouse": "CAST(1 AS Nullable(Int64))",
+ },
+ )
+ self.validate_all(
+ "CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
+ write={
+ "clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))",
},
)
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index a9a313c..53edb42 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -82,6 +82,24 @@ class TestDialect(Validator):
},
)
self.validate_all(
+ "CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
+ write={
+ "clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))",
+ },
+ )
+ self.validate_all(
+ "CAST(ARRAY(1, 2) AS ARRAY<TINYINT>)",
+ write={
+ "clickhouse": "CAST([1, 2] AS Array(Int8))",
+ },
+ )
+ self.validate_all(
+ "CAST((1, 2) AS STRUCT<a: TINYINT, b: SMALLINT, c: INT, d: BIGINT>)",
+ write={
+ "clickhouse": "CAST((1, 2) AS Tuple(a Int8, b Int16, c Int32, d Int64))",
+ },
+ )
+ self.validate_all(
"CAST(a AS DATETIME)",
write={
"postgres": "CAST(a AS TIMESTAMP)",
@@ -170,7 +188,7 @@ class TestDialect(Validator):
"CAST(a AS DOUBLE)",
write={
"bigquery": "CAST(a AS FLOAT64)",
- "clickhouse": "CAST(a AS DOUBLE)",
+ "clickhouse": "CAST(a AS Float64)",
"duckdb": "CAST(a AS DOUBLE)",
"mysql": "CAST(a AS DOUBLE)",
"hive": "CAST(a AS DOUBLE)",
@@ -234,6 +252,8 @@ class TestDialect(Validator):
write={
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
+ "oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
+ "postgres": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')",
"redshift": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
"spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
@@ -245,6 +265,8 @@ class TestDialect(Validator):
"duckdb": "STRPTIME(x, '%y')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%y')",
+ "oracle": "TO_TIMESTAMP(x, 'YY')",
+ "postgres": "TO_TIMESTAMP(x, 'YY')",
"redshift": "TO_TIMESTAMP(x, 'YY')",
"spark": "TO_TIMESTAMP(x, 'yy')",
},
@@ -288,6 +310,8 @@ class TestDialect(Validator):
write={
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
+ "oracle": "TO_CHAR(x, 'YYYY-MM-DD')",
+ "postgres": "TO_CHAR(x, 'YYYY-MM-DD')",
"presto": "DATE_FORMAT(x, '%Y-%m-%d')",
"redshift": "TO_CHAR(x, 'YYYY-MM-DD')",
},
@@ -348,6 +372,8 @@ class TestDialect(Validator):
write={
"duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))",
"hive": "FROM_UNIXTIME(x)",
+ "oracle": "TO_DATE('1970-01-01','YYYY-MM-DD') + (x / 86400)",
+ "postgres": "TO_TIMESTAMP(x)",
"presto": "FROM_UNIXTIME(x)",
"starrocks": "FROM_UNIXTIME(x)",
},
@@ -704,6 +730,7 @@ class TestDialect(Validator):
"SELECT * FROM a UNION SELECT * FROM b",
read={
"bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
+ "clickhouse": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a UNION SELECT * FROM b",
"presto": "SELECT * FROM a UNION SELECT * FROM b",
"spark": "SELECT * FROM a UNION SELECT * FROM b",
@@ -719,6 +746,7 @@ class TestDialect(Validator):
"SELECT * FROM a UNION ALL SELECT * FROM b",
read={
"bigquery": "SELECT * FROM a UNION ALL SELECT * FROM b",
+ "clickhouse": "SELECT * FROM a UNION ALL SELECT * FROM b",
"duckdb": "SELECT * FROM a UNION ALL SELECT * FROM b",
"presto": "SELECT * FROM a UNION ALL SELECT * FROM b",
"spark": "SELECT * FROM a UNION ALL SELECT * FROM b",
@@ -848,15 +876,28 @@ class TestDialect(Validator):
"postgres": "STRPOS(x, ' ')",
"presto": "STRPOS(x, ' ')",
"spark": "LOCATE(' ', x)",
+ "clickhouse": "position(x, ' ')",
+ "snowflake": "POSITION(' ', x)",
},
)
self.validate_all(
- "STR_POSITION(x, 'a')",
+ "STR_POSITION('a', x)",
write={
"duckdb": "STRPOS(x, 'a')",
"postgres": "STRPOS(x, 'a')",
"presto": "STRPOS(x, 'a')",
"spark": "LOCATE('a', x)",
+ "clickhouse": "position(x, 'a')",
+ "snowflake": "POSITION('a', x)",
+ },
+ )
+ self.validate_all(
+ "POSITION('a', x, 3)",
+ write={
+ "presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
+ "spark": "LOCATE('a', x, 3)",
+ "clickhouse": "position(x, 'a', 3)",
+ "snowflake": "POSITION('a', x, 3)",
},
)
self.validate_all(
diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py
index d335921..acb3be9 100644
--- a/tests/dialects/test_hive.py
+++ b/tests/dialects/test_hive.py
@@ -247,7 +247,7 @@ class TestHive(Validator):
"presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(b AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(a AS VARCHAR), 1, 10) AS DATE))",
"hive": "DATEDIFF(TO_DATE(a), TO_DATE(b))",
"spark": "DATEDIFF(TO_DATE(a), TO_DATE(b))",
- "": "DATE_DIFF(TS_OR_DS_TO_DATE(a), TS_OR_DS_TO_DATE(b))",
+ "": "DATEDIFF(TS_OR_DS_TO_DATE(a), TS_OR_DS_TO_DATE(b))",
},
)
self.validate_all(
@@ -295,7 +295,7 @@ class TestHive(Validator):
"presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(CAST(SUBSTR(CAST(y AS VARCHAR), 1, 10) AS DATE) AS VARCHAR), 1, 10) AS DATE))",
"hive": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))",
"spark": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))",
- "": "DATE_DIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))",
+ "": "DATEDIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))",
},
)
self.validate_all(
@@ -450,11 +450,21 @@ class TestHive(Validator):
)
self.validate_all(
"MAP(a, b, c, d)",
+ read={
+ "": "VAR_MAP(a, b, c, d)",
+ "clickhouse": "map(a, b, c, d)",
+ "duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))",
+ "hive": "MAP(a, b, c, d)",
+ "presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
+ "spark": "MAP(a, b, c, d)",
+ },
write={
+ "": "MAP(ARRAY(a, c), ARRAY(b, d))",
+ "clickhouse": "map(a, b, c, d)",
"duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))",
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
"hive": "MAP(a, b, c, d)",
- "spark": "MAP_FROM_ARRAYS(ARRAY(a, c), ARRAY(b, d))",
+ "spark": "MAP(a, b, c, d)",
},
)
self.validate_all(
@@ -463,7 +473,7 @@ class TestHive(Validator):
"duckdb": "MAP(LIST_VALUE(a), LIST_VALUE(b))",
"presto": "MAP(ARRAY[a], ARRAY[b])",
"hive": "MAP(a, b)",
- "spark": "MAP_FROM_ARRAYS(ARRAY(a), ARRAY(b))",
+ "spark": "MAP(a, b)",
},
)
self.validate_all(
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index e0934d7..dc93c3a 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -67,6 +67,7 @@ class TestPostgres(Validator):
self.validate_identity("SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))")
self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)")
+ self.validate_identity("SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')")
self.validate_all(
"CREATE TABLE x (a UUID, b BYTEA)",
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 2145966..8a33e2d 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -305,3 +305,35 @@ class TestSnowflake(Validator):
self.validate_identity(
"CREATE PROCEDURE a.b.c(x INT, y VARIANT) RETURNS OBJECT EXECUTE AS CALLER AS 'BEGIN SELECT 1; END;'"
)
+
+ def test_table_literal(self):
+ # All examples from https://docs.snowflake.com/en/sql-reference/literals-table.html
+ self.validate_all(
+ r"""SELECT * FROM TABLE('MYTABLE')""", write={"snowflake": r"""SELECT * FROM TABLE('MYTABLE')"""}
+ )
+
+ self.validate_all(
+ r"""SELECT * FROM TABLE('MYDB."MYSCHEMA"."MYTABLE"')""",
+ write={"snowflake": r"""SELECT * FROM TABLE('MYDB."MYSCHEMA"."MYTABLE"')"""},
+ )
+
+ # Per Snowflake documentation at https://docs.snowflake.com/en/sql-reference/literals-table.html
+ # one can use either a " ' " or " $$ " to enclose the object identifier.
+ # Capturing the single tokens seems like lot of work. Hence adjusting tests to use these interchangeably,
+ self.validate_all(
+ r"""SELECT * FROM TABLE($$MYDB. "MYSCHEMA"."MYTABLE"$$)""",
+ write={"snowflake": r"""SELECT * FROM TABLE('MYDB. "MYSCHEMA"."MYTABLE"')"""},
+ )
+
+ self.validate_all(r"""SELECT * FROM TABLE($MYVAR)""", write={"snowflake": r"""SELECT * FROM TABLE($MYVAR)"""})
+
+ self.validate_all(r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""})
+
+ self.validate_all(
+ r"""SELECT * FROM TABLE(:BINDING)""", write={"snowflake": r"""SELECT * FROM TABLE(:BINDING)"""}
+ )
+
+ self.validate_all(
+ r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10""",
+ write={"snowflake": r"""SELECT * FROM TABLE($MYVAR) WHERE COL1 = 10"""},
+ )
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index 8377e47..9a7e64c 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -111,12 +111,70 @@ TBLPROPERTIES (
"SELECT /*+ COALESCE(3) */ * FROM x",
write={
"spark": "SELECT /*+ COALESCE(3) */ * FROM x",
+ "bigquery": "SELECT * FROM x",
},
)
self.validate_all(
"SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x",
write={
"spark": "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x",
+ "bigquery": "SELECT * FROM x",
+ },
+ )
+ self.validate_all(
+ "SELECT /*+ BROADCAST(table) */ cola FROM table",
+ write={
+ "spark": "SELECT /*+ BROADCAST(table) */ cola FROM table",
+ "bigquery": "SELECT cola FROM table",
+ },
+ )
+ self.validate_all(
+ "SELECT /*+ BROADCASTJOIN(table) */ cola FROM table",
+ write={
+ "spark": "SELECT /*+ BROADCASTJOIN(table) */ cola FROM table",
+ "bigquery": "SELECT cola FROM table",
+ },
+ )
+ self.validate_all(
+ "SELECT /*+ MAPJOIN(table) */ cola FROM table",
+ write={
+ "spark": "SELECT /*+ MAPJOIN(table) */ cola FROM table",
+ "bigquery": "SELECT cola FROM table",
+ },
+ )
+ self.validate_all(
+ "SELECT /*+ MERGE(table) */ cola FROM table",
+ write={
+ "spark": "SELECT /*+ MERGE(table) */ cola FROM table",
+ "bigquery": "SELECT cola FROM table",
+ },
+ )
+ self.validate_all(
+ "SELECT /*+ SHUFFLEMERGE(table) */ cola FROM table",
+ write={
+ "spark": "SELECT /*+ SHUFFLEMERGE(table) */ cola FROM table",
+ "bigquery": "SELECT cola FROM table",
+ },
+ )
+ self.validate_all(
+ "SELECT /*+ MERGEJOIN(table) */ cola FROM table",
+ write={
+ "spark": "SELECT /*+ MERGEJOIN(table) */ cola FROM table",
+ "bigquery": "SELECT cola FROM table",
+ },
+ )
+ self.validate_all(
+ "SELECT /*+ SHUFFLE_HASH(table) */ cola FROM table",
+ write={
+ "spark": "SELECT /*+ SHUFFLE_HASH(table) */ cola FROM table",
+ "bigquery": "SELECT cola FROM table",
+ },
+ )
+ self.validate_all(
+ "SELECT /*+ SHUFFLE_REPLICATE_NL(table) */ cola FROM table",
+ write={
+ "spark": "SELECT /*+ SHUFFLE_REPLICATE_NL(table) */ cola FROM table",
+ "bigquery": "SELECT cola FROM table",
},
)
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index a0de281..40e7cc1 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -321,6 +321,10 @@ SELECT 1 FROM a INNER JOIN b ON a.x = b.x
SELECT 1 FROM a LEFT JOIN b ON a.x = b.x
SELECT 1 FROM a RIGHT JOIN b ON a.x = b.x
SELECT 1 FROM a CROSS JOIN b ON a.x = b.x
+SELECT 1 FROM a LEFT SEMI JOIN b ON a.x = b.x
+SELECT 1 FROM a LEFT ANTI JOIN b ON a.x = b.x
+SELECT 1 FROM a RIGHT SEMI JOIN b ON a.x = b.x
+SELECT 1 FROM a RIGHT ANTI JOIN b ON a.x = b.x
SELECT 1 FROM a JOIN b USING (x)
SELECT 1 FROM a JOIN b USING (x, y, z)
SELECT 1 FROM a JOIN (SELECT a FROM c) AS b ON a.x = b.x AND a.x < 2
@@ -529,12 +533,14 @@ UPDATE db.tbl_name SET foo = 123 WHERE tbl_name.bar = 234
UPDATE db.tbl_name SET foo = 123, foo_1 = 234 WHERE tbl_name.bar = 234
TRUNCATE TABLE x
OPTIMIZE TABLE y
+VACUUM FREEZE my_table
WITH a AS (SELECT 1) INSERT INTO b SELECT * FROM a
WITH a AS (SELECT * FROM b) UPDATE a SET col = 1
WITH a AS (SELECT * FROM b) CREATE TABLE b AS SELECT * FROM a
WITH a AS (SELECT * FROM b) DELETE FROM a
WITH a AS (SELECT * FROM b) CACHE TABLE a
SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ?
+SELECT :hello, ? FROM x LIMIT :my_limit
WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a
WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a
SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z
diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql
index e13d3b3..c8186cc 100644
--- a/tests/fixtures/optimizer/merge_subqueries.sql
+++ b/tests/fixtures/optimizer/merge_subqueries.sql
@@ -1,107 +1,189 @@
--- Simple
+# title: Simple
SELECT a, b FROM (SELECT a, b FROM x);
SELECT x.a AS a, x.b AS b FROM x AS x;
--- Inner table alias is merged
+# title: Inner table alias is merged
SELECT a, b FROM (SELECT a, b FROM x AS q) AS r;
SELECT q.a AS a, q.b AS b FROM x AS q;
--- Double nesting
+# title: Double nesting
SELECT a, b FROM (SELECT a, b FROM (SELECT a, b FROM x));
SELECT x.a AS a, x.b AS b FROM x AS x;
--- WHERE clause is merged
-SELECT a, SUM(b) FROM (SELECT a, b FROM x WHERE a > 1) GROUP BY a;
-SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a;
+# title: WHERE clause is merged
+SELECT a, SUM(b) AS b FROM (SELECT a, b FROM x WHERE a > 1) GROUP BY a;
+SELECT x.a AS a, SUM(x.b) AS b FROM x AS x WHERE x.a > 1 GROUP BY x.a;
--- Outer query has join
-SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
-SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
-
--- Outer query has join
+# title: Outer query has join
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
+# title: Leave tables isolated
# leave_tables_isolated: true
SELECT a, c FROM (SELECT a, b FROM x WHERE a > 1) AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM (SELECT x.a AS a, x.b AS b FROM x AS x WHERE x.a > 1) AS x JOIN y AS y ON x.b = y.b;
--- Join on derived table
+# title: Join on derived table
SELECT a, c FROM x JOIN (SELECT b, c FROM y) AS y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
--- Inner query has a join
+# title: Inner query has a join
SELECT a, c FROM (SELECT a, c FROM x JOIN y ON x.b = y.b);
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
--- Inner query has conflicting name in outer query
+# title: Inner query has conflicting name in outer query
SELECT a, c FROM (SELECT q.a, q.b FROM x AS q) AS x JOIN y AS q ON x.b = q.b;
SELECT q_2.a AS a, q.c AS c FROM x AS q_2 JOIN y AS q ON q_2.b = q.b;
--- Inner query has conflicting name in joined source
+# title: Inner query has conflicting name in joined source
SELECT x.a, q.c FROM (SELECT a, x.b FROM x JOIN y AS q ON x.b = q.b) AS x JOIN y AS q ON x.b = q.b;
SELECT x.a AS a, q.c AS c FROM x AS x JOIN y AS q_2 ON x.b = q_2.b JOIN y AS q ON x.b = q.b;
--- Inner query has multiple conflicting names
-SELECT x.a, q.c, r.c FROM (SELECT q.a, r.b FROM x AS q JOIN y AS r ON q.b = r.b) AS x JOIN y AS q ON x.b = q.b JOIN y AS r ON x.b = r.b;
-SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2.b JOIN y AS q ON r_2.b = q.b JOIN y AS r ON r_2.b = r.b;
+# title: Inner query has multiple conflicting names
+SELECT x.a, q.c, r.c FROM (SELECT q.a, r.b FROM x AS q JOIN y AS r ON q.b = r.b) AS x JOIN y AS q ON x.b = q.b JOIN y AS r ON x.b = r.b ORDER BY x.a, q.c, r.c;
+SELECT q_2.a AS a, q.c AS c, r.c AS c FROM x AS q_2 JOIN y AS r_2 ON q_2.b = r_2.b JOIN y AS q ON r_2.b = q.b JOIN y AS r ON r_2.b = r.b ORDER BY q_2.a, q.c, r.c;
--- Inner queries have conflicting names with each other
+# title: Inner queries have conflicting names with each other
SELECT r.b FROM (SELECT b FROM x AS x) AS q JOIN (SELECT b FROM x) AS r ON q.b = r.b;
SELECT x_2.b AS b FROM x AS x JOIN x AS x_2 ON x.b = x_2.b;
--- WHERE clause in joined derived table is merged to ON clause
-SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y;
-SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON y.c > 1;
+# title: WHERE clause in joined derived table is merged to ON clause
+SELECT x.a, y.c FROM x JOIN (SELECT b, c FROM y WHERE c > 1) AS y ON x.b = y.b;
+SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b AND y.c > 1;
--- Comma JOIN in outer query
+# title: Comma JOIN in outer query
SELECT x.a, y.c FROM (SELECT a FROM x) AS x, (SELECT c FROM y) AS y;
SELECT x.a AS a, y.c AS c FROM x AS x, y AS y;
--- Comma JOIN in inner query
+# title: Comma JOIN in inner query
SELECT x.a, x.c FROM (SELECT x.a, z.c FROM x, y AS z) AS x;
SELECT x.a AS a, z.c AS c FROM x AS x CROSS JOIN y AS z;
--- (Regression) Column in ORDER BY
+# title: (Regression) Column in ORDER BY
SELECT * FROM (SELECT * FROM (SELECT * FROM x)) ORDER BY a LIMIT 1;
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a LIMIT 1;
--- CTE
+# title: CTE
WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x;
SELECT x.a AS a, x.b AS b FROM x AS x;
--- CTE with outer table alias
+# title: CTE with outer table alias
WITH y AS (SELECT a, b FROM x) SELECT a, b FROM y AS z;
SELECT x.a AS a, x.b AS b FROM x AS x;
--- Nested CTE
-WITH x AS (SELECT a FROM x), x2 AS (SELECT a FROM x) SELECT a FROM x2;
+# title: Nested CTE
+WITH x2 AS (SELECT a FROM x), x3 AS (SELECT a FROM x2) SELECT a FROM x3;
SELECT x.a AS a FROM x AS x;
--- CTE WHERE clause is merged
-WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, SUM(b) FROM x GROUP BY a;
-SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 GROUP BY x.a;
+# title: CTE WHERE clause is merged
+WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, SUM(b) AS b FROM x GROUP BY a;
+SELECT x.a AS a, SUM(x.b) AS b FROM x AS x WHERE x.a > 1 GROUP BY x.a;
--- CTE Outer query has join
-WITH x AS (SELECT a, b FROM x WHERE a > 1) SELECT a, c FROM x AS x JOIN y ON x.b = y.b;
+# title: CTE Outer query has join
+WITH x2 AS (SELECT a, b FROM x WHERE a > 1) SELECT a, c FROM x2 AS x JOIN y ON x.b = y.b;
SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b WHERE x.a > 1;
--- CTE with inner table alias
+# title: CTE with inner table alias
WITH y AS (SELECT a, b FROM x AS q) SELECT a, b FROM y AS z;
SELECT q.a AS a, q.b AS b FROM x AS q;
--- Duplicate queries to CTE
-WITH x AS (SELECT a, b FROM x) SELECT x.a, y.b FROM x JOIN x AS y;
-WITH x AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT x.a AS a, y.b AS b FROM x JOIN x AS y;
-
--- Nested CTE
+# title: Nested CTE
SELECT * FROM (WITH x AS (SELECT a, b FROM x) SELECT a, b FROM x);
SELECT x.a AS a, x.b AS b FROM x AS x;
--- Inner select is an expression
+# title: Inner select is an expression
SELECT a FROM (SELECT a FROM (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) AS x) AS x;
SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;
--- CTE select is an expression
-WITH x AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x AS x) AS x;
+# title: CTE select is an expression
+WITH x2 AS (SELECT COALESCE(a) AS a FROM x LEFT JOIN y ON x.a = y.b) SELECT a FROM (SELECT a FROM x2 AS x) AS x;
SELECT COALESCE(x.a) AS a FROM x AS x LEFT JOIN y AS y ON x.a = y.b;
+
+# title: Full outer join
+SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x FULL OUTER JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
+SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x FULL OUTER JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
+
+# title: Full outer join, no predicates
+SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x FULL OUTER JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b;
+SELECT x.b AS b, y.b AS b2 FROM x AS x FULL OUTER JOIN y AS y ON x.b = y.b;
+
+# title: Left join
+SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x LEFT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
+SELECT x.b AS b, y.b AS b2 FROM x AS x LEFT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b WHERE x.b = 1;
+
+# title: Left join, no predicates
+SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x LEFT JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b;
+SELECT x.b AS b, y.b AS b2 FROM x AS x LEFT JOIN y AS y ON x.b = y.b;
+
+# title: Right join
+SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x RIGHT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
+SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x RIGHT JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
+
+# title: Right join, no predicates
+SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x RIGHT JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b;
+SELECT x.b AS b, y.b AS b2 FROM x AS x RIGHT JOIN y AS y ON x.b = y.b;
+
+# title: Inner join
+SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x INNER JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y ON x.b = y.b;
+SELECT x.b AS b, y.b AS b2 FROM x AS x INNER JOIN y AS y ON x.b = y.b AND y.b = 2 WHERE x.b = 1;
+
+# title: Inner join, no predicates
+SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x INNER JOIN (SELECT y.b AS b FROM y AS y) AS y ON x.b = y.b;
+SELECT x.b AS b, y.b AS b2 FROM x AS x INNER JOIN y AS y ON x.b = y.b;
+
+# title: Cross join
+SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x WHERE x.b = 1) AS x CROSS JOIN (SELECT y.b AS b FROM y AS y WHERE y.b = 2) AS y;
+SELECT x.b AS b, y.b AS b2 FROM x AS x JOIN y AS y ON y.b = 2 WHERE x.b = 1;
+
+# title: Cross join, no predicates
+SELECT x.b AS b, y.b AS b2 FROM (SELECT x.b AS b FROM x AS x) AS x CROSS JOIN (SELECT y.b AS b FROM y AS y) AS y;
+SELECT x.b AS b, y.b AS b2 FROM x AS x CROSS JOIN y AS y;
+
+# title: Broadcast hint
+# dialect: spark
+WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(k) */ m.a, k.c FROM m JOIN n AS k ON m.b = k.b) SELECT joined.a, joined.c FROM joined;
+SELECT /*+ BROADCAST(y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
+
+# title: Broadcast hint multiple tables
+# dialect: spark
+WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(m, n) */ m.a, n.c FROM m JOIN n ON m.b = n.b) SELECT joined.a, joined.c FROM joined;
+SELECT /*+ BROADCAST(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
+
+# title: Multiple Table Hints
+# dialect: spark
+WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(m), MERGE(m, n) */ m.a, n.c FROM m JOIN n ON m.b = n.b) SELECT joined.a, joined.c FROM joined;
+SELECT /*+ BROADCAST(x), MERGE(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
+
+# title: Mix Table and Column Hints
+# dialect: spark
+WITH m AS (SELECT x.a, x.b FROM x), n AS (SELECT y.b, y.c FROM y), joined as (SELECT /*+ BROADCAST(m), MERGE(m, n) */ m.a, n.c FROM m JOIN n ON m.b = n.b) SELECT /*+ COALESCE(3) */ joined.a, joined.c FROM joined;
+SELECT /*+ COALESCE(3), BROADCAST(x), MERGE(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
+
+# title: Hint Subquery
+# dialect: spark
+SELECT
+ subquery.a,
+ subquery.c
+FROM (
+ SELECT /*+ BROADCAST(m), MERGE(m, n) */ m.a, n.c FROM (SELECT x.a, x.b FROM x) AS m JOIN (SELECT y.b, y.c FROM y) AS n ON m.b = n.b
+) AS subquery;
+SELECT /*+ BROADCAST(x), MERGE(x, y) */ x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b;
+
+# title: Subquery Test
+# dialect: spark
+SELECT /*+ BROADCAST(x) */
+ x.a,
+ x.c
+FROM (
+ SELECT
+ x.a,
+ x.c
+ FROM (
+ SELECT
+ x.a,
+ COUNT(1) AS c
+ FROM x
+ GROUP BY x.a
+ ) AS x
+) AS x;
+SELECT /*+ BROADCAST(x) */ x.a AS a, x.c AS c FROM (SELECT x.a AS a, COUNT(1) AS c FROM x AS x GROUP BY x.a) AS x;
diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql
index eb6761a..ab4f769 100644
--- a/tests/fixtures/optimizer/optimizer.sql
+++ b/tests/fixtures/optimizer/optimizer.sql
@@ -1,3 +1,5 @@
+# title: lateral
+# execute: false
SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m;
SELECT
"z"."a" AS "a",
@@ -6,11 +8,13 @@ FROM "z" AS "z"
LATERAL VIEW
EXPLODE(ARRAY(1, 2)) q AS "m";
+# title: unnest
SELECT x FROM UNNEST([1, 2]) AS q(x, y);
SELECT
"q"."x" AS "x"
FROM UNNEST(ARRAY(1, 2)) AS "q"("x", "y");
+# title: Union in CTE
WITH cte AS (
(
SELECT
@@ -21,7 +25,7 @@ WITH cte AS (
UNION ALL
(
SELECT
- a
+ b AS a
FROM
y
)
@@ -39,7 +43,7 @@ WITH "cte" AS (
UNION ALL
(
SELECT
- "y"."a" AS "a"
+ "y"."b" AS "a"
FROM "y" AS "y"
)
)
@@ -47,6 +51,7 @@ SELECT
"cte"."a" AS "a"
FROM "cte";
+# title: Chained CTEs
WITH cte1 AS (
SELECT a
FROM x
@@ -74,30 +79,31 @@ SELECT
"cte1"."a" + 1 AS "a"
FROM "cte1";
-SELECT a, SUM(b)
+# title: Correlated subquery
+SELECT a, SUM(b) AS sum_b
FROM (
SELECT x.a, y.b
FROM x, y
- WHERE (SELECT max(b) FROM y WHERE x.a = y.a) >= 0 AND x.a = y.a
+ WHERE (SELECT max(b) FROM y WHERE x.b = y.b) >= 0 AND x.b = y.b
) d
WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1
GROUP BY a;
WITH "_u_0" AS (
SELECT
MAX("y"."b") AS "_col_0",
- "y"."a" AS "_u_1"
+ "y"."b" AS "_u_1"
FROM "y" AS "y"
GROUP BY
- "y"."a"
+ "y"."b"
)
SELECT
"x"."a" AS "a",
- SUM("y"."b") AS "_col_1"
+ SUM("y"."b") AS "sum_b"
FROM "x" AS "x"
LEFT JOIN "_u_0" AS "_u_0"
- ON "x"."a" = "_u_0"."_u_1"
+ ON "x"."b" = "_u_0"."_u_1"
JOIN "y" AS "y"
- ON "x"."a" = "y"."a"
+ ON "x"."b" = "y"."b"
WHERE
"_u_0"."_col_0" >= 0
AND "x"."a" > 1
@@ -105,6 +111,7 @@ WHERE
GROUP BY
"x"."a";
+# title: Root subquery
(SELECT a FROM x) LIMIT 1;
(
SELECT
@@ -113,6 +120,7 @@ GROUP BY
)
LIMIT 1;
+# title: Root subquery is union
(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1;
(
SELECT
@@ -125,6 +133,7 @@ LIMIT 1;
)
LIMIT 1;
+# title: broadcast
# dialect: spark
SELECT /*+ BROADCAST(y) */ x.b FROM x JOIN y ON x.b = y.b;
SELECT /*+ BROADCAST(`y`) */
@@ -133,11 +142,14 @@ FROM `x` AS `x`
JOIN `y` AS `y`
ON `x`.`b` = `y`.`b`;
+# title: aggregate
+# execute: false
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT
AGGREGATE(ARRAY("x"."a", "x"."b"), 0, ("x", "acc") -> "x" + "acc" + "x"."a") AS "sum_agg"
FROM "x" AS "x";
+# title: values
SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
SELECT
"tab"."cola" AS "cola",
@@ -146,6 +158,7 @@ FROM (VALUES
(1, 'test'),
(2, 'test2')) AS "tab"("cola", "colb");
+# title: spark values
# dialect: spark
SELECT cola, colb FROM (VALUES (1, 'test'), (2, 'test2')) AS tab(cola, colb);
SELECT
@@ -154,3 +167,112 @@ SELECT
FROM VALUES
(1, 'test'),
(2, 'test2') AS `tab`(`cola`, `colb`);
+
+# title: complex CTE dependencies
+WITH m AS (
+ SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)
+), n AS (
+ SELECT a, b FROM m WHERE m.a = 1
+), o AS (
+ SELECT a, b FROM m WHERE m.a = 2
+) SELECT
+ n.a,
+ n.b,
+ o.b
+FROM n
+FULL OUTER JOIN o ON n.a = o.a
+CROSS JOIN n AS n2
+WHERE o.b > 0 AND n.a = n2.a;
+WITH "m" AS (
+ SELECT
+ "a1"."a" AS "a",
+ "a1"."b" AS "b"
+ FROM (VALUES
+ (1, 2)) AS "a1"("a", "b")
+), "n" AS (
+ SELECT
+ "m"."a" AS "a",
+ "m"."b" AS "b"
+ FROM "m"
+ WHERE
+ "m"."a" = 1
+), "o" AS (
+ SELECT
+ "m"."a" AS "a",
+ "m"."b" AS "b"
+ FROM "m"
+ WHERE
+ "m"."a" = 2
+)
+SELECT
+ "n"."a" AS "a",
+ "n"."b" AS "b",
+ "o"."b" AS "b"
+FROM "n"
+FULL JOIN "o"
+ ON "n"."a" = "o"."a"
+JOIN "n" AS "n2"
+ ON "n"."a" = "n2"."a"
+WHERE
+ "o"."b" > 0;
+
+# title: Broadcast hint
+# dialect: spark
+WITH m AS (
+ SELECT
+ x.a,
+ x.b
+ FROM x
+), n AS (
+ SELECT
+ y.b,
+ y.c
+ FROM y
+), joined as (
+ SELECT /*+ BROADCAST(n) */
+ m.a,
+ n.c
+ FROM m JOIN n ON m.b = n.b
+)
+SELECT
+ joined.a,
+ joined.c
+FROM joined;
+SELECT /*+ BROADCAST(`y`) */
+ `x`.`a` AS `a`,
+ `y`.`c` AS `c`
+FROM `x` AS `x`
+JOIN `y` AS `y`
+ ON `x`.`b` = `y`.`b`;
+
+# title: Mix Table and Column Hints
+# dialect: spark
+WITH m AS (
+ SELECT
+ x.a,
+ x.b
+ FROM x
+), n AS (
+ SELECT
+ y.b,
+ y.c
+ FROM y
+), joined as (
+ SELECT /*+ BROADCAST(m), MERGE(m, n) */
+ m.a,
+ n.c
+ FROM m JOIN n ON m.b = n.b
+)
+SELECT
+ /*+ COALESCE(3) */
+ joined.a,
+ joined.c
+FROM joined;
+SELECT /*+ COALESCE(3),
+ BROADCAST(`x`),
+ MERGE(`x`, `y`) */
+ `x`.`a` AS `a`,
+ `y`.`c` AS `c`
+FROM `x` AS `x`
+JOIN `y` AS `y`
+ ON `x`.`b` = `y`.`b`;
diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql
index f848e7a..83a3bf8 100644
--- a/tests/fixtures/optimizer/qualify_columns.sql
+++ b/tests/fixtures/optimizer/qualify_columns.sql
@@ -19,38 +19,49 @@ SELECT x.a AS a FROM x AS x;
SELECT a AS b FROM x;
SELECT x.a AS b FROM x AS x;
+# execute: false
SELECT 1, 2 FROM x;
SELECT 1 AS "_col_0", 2 AS "_col_1" FROM x AS x;
+# execute: false
SELECT a + b FROM x;
SELECT x.a + x.b AS "_col_0" FROM x AS x;
-SELECT a + b FROM x;
-SELECT x.a + x.b AS "_col_0" FROM x AS x;
-
+# execute: false
SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a;
SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a;
SELECT a AS j, b FROM x ORDER BY j;
SELECT x.a AS j, x.b AS b FROM x AS x ORDER BY j;
-SELECT a AS j, b FROM x GROUP BY j;
-SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a;
+SELECT a AS j, b AS a FROM x ORDER BY 1;
+SELECT x.a AS j, x.b AS a FROM x AS x ORDER BY x.a;
+
+SELECT SUM(a) AS c, SUM(b) AS d FROM x ORDER BY 1, 2;
+SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
+
+# execute: false
+SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2;
+SELECT SUM(x.a) AS "_col_0", SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b);
+
+SELECT a AS j, b FROM x GROUP BY j, b;
+SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a, x.b;
SELECT a, b FROM x GROUP BY 1, 2;
SELECT x.a AS a, x.b AS b FROM x AS x GROUP BY x.a, x.b;
SELECT a, b FROM x ORDER BY 1, 2;
-SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a, b;
+SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a, x.b;
+# execute: false
SELECT DATE(a), DATE(b) AS c FROM x GROUP BY 1, 2;
SELECT DATE(x.a) AS "_col_0", DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b);
-SELECT x.a AS c FROM x JOIN y ON x.b = y.b GROUP BY c;
-SELECT x.a AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c;
+SELECT SUM(x.a) AS c FROM x JOIN y ON x.b = y.b GROUP BY c;
+SELECT SUM(x.a) AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c;
-SELECT DATE(x.a) AS d FROM x JOIN y ON x.b = y.b GROUP BY d;
-SELECT DATE(x.a) AS d FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY DATE(x.a);
+SELECT COALESCE(x.a) AS d FROM x JOIN y ON x.b = y.b GROUP BY d;
+SELECT COALESCE(x.a) AS d FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY COALESCE(x.a);
SELECT a AS a, b FROM x ORDER BY a;
SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a;
@@ -69,6 +80,7 @@ SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x
SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1;
SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1;
+# execute: false
SELECT AGGREGATE(ARRAY(a, x.b), 0, (x, acc) -> x + acc + a) AS sum_agg FROM x;
SELECT AGGREGATE(ARRAY(x.a, x.b), 0, (x, acc) -> x + acc + x.a) AS sum_agg FROM x AS x;
@@ -93,8 +105,8 @@ SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
SELECT a FROM (SELECT a FROM (SELECT a FROM x));
SELECT "_q_1".a AS a FROM (SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0") AS "_q_1";
-SELECT x.a FROM x AS x JOIN (SELECT * FROM x);
-SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0";
+SELECT x.a FROM x AS x JOIN (SELECT * FROM x) AS y ON x.a = y.a;
+SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS y ON x.a = y.a;
--------------------------------------
-- Joins
@@ -123,6 +135,7 @@ SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FRO
SELECT a FROM x WHERE b IN (SELECT c FROM y);
SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y);
+# execute: false
SELECT (SELECT c FROM y) FROM x;
SELECT (SELECT y.c AS c FROM y AS y) AS "_col_0" FROM x AS x;
@@ -144,10 +157,12 @@ SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT x.b AS b FROM y AS x);
SELECT a FROM x AS i WHERE b IN (SELECT b FROM y AS j WHERE j.b IN (SELECT c FROM y AS k WHERE k.b = j.b));
SELECT i.a AS a FROM x AS i WHERE i.b IN (SELECT j.b AS b FROM y AS j WHERE j.b IN (SELECT k.c AS c FROM y AS k WHERE k.b = j.b));
+# execute: false
# dialect: bigquery
SELECT aa FROM x, UNNEST(a) AS aa;
SELECT aa AS aa FROM x AS x, UNNEST(x.a) AS aa;
+# execute: false
SELECT aa FROM x, UNNEST(a) AS t(aa);
SELECT t.aa AS aa FROM x AS x, UNNEST(x.a) AS t(aa);
@@ -205,15 +220,19 @@ WITH z AS ((SELECT x.b AS b FROM x AS x UNION ALL SELECT y.b AS b FROM y AS y) O
--------------------------------------
-- Except and Replace
--------------------------------------
+# execute: false
SELECT * REPLACE(a AS d) FROM x;
SELECT x.a AS d, x.b AS b FROM x AS x;
+# execute: false
SELECT * EXCEPT(b) REPLACE(a AS d) FROM x;
SELECT x.a AS d FROM x AS x;
+# execute: false
SELECT x.* EXCEPT(a), y.* FROM x, y;
SELECT x.b AS b, y.b AS b, y.c AS c FROM x AS x, y AS y;
+# execute: false
SELECT * EXCEPT(a) FROM x;
SELECT x.b AS b FROM x AS x;
diff --git a/tests/fixtures/optimizer/qualify_columns__with_invisible.sql b/tests/fixtures/optimizer/qualify_columns__with_invisible.sql
new file mode 100644
index 0000000..ee46c23
--- /dev/null
+++ b/tests/fixtures/optimizer/qualify_columns__with_invisible.sql
@@ -0,0 +1,35 @@
+--------------------------------------
+-- Qualify columns
+--------------------------------------
+SELECT a FROM x;
+SELECT x.a AS a FROM x AS x;
+
+SELECT b FROM x;
+SELECT x.b AS b FROM x AS x;
+
+--------------------------------------
+-- Derived tables
+--------------------------------------
+SELECT x.a FROM x AS x JOIN (SELECT * FROM x);
+SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a FROM x AS x) AS "_q_0";
+
+SELECT x.b FROM x AS x JOIN (SELECT b FROM x);
+SELECT x.b AS b FROM x AS x JOIN (SELECT x.b AS b FROM x AS x) AS "_q_0";
+
+--------------------------------------
+-- Expand *
+--------------------------------------
+SELECT * FROM x;
+SELECT x.a AS a FROM x AS x;
+
+SELECT * FROM y JOIN z ON y.b = z.b;
+SELECT y.b AS b, z.b AS b FROM y AS y JOIN z AS z ON y.b = z.b;
+
+SELECT * FROM y JOIN z ON y.c = z.c;
+SELECT y.b AS b, z.b AS b FROM y AS y JOIN z AS z ON y.c = z.c;
+
+SELECT a FROM (SELECT * FROM x);
+SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
+
+SELECT * FROM (SELECT a FROM x);
+SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0";
diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql
index d7217cf..07e818f 100644
--- a/tests/fixtures/optimizer/simplify.sql
+++ b/tests/fixtures/optimizer/simplify.sql
@@ -52,6 +52,9 @@ TRUE;
NULL AND TRUE;
NULL;
+NULL AND FALSE;
+FALSE;
+
NULL AND NULL;
NULL;
@@ -70,6 +73,9 @@ FALSE;
NOT FALSE;
TRUE;
+NOT NULL;
+NULL;
+
NULL = NULL;
NULL;
diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql
index d2f10fc..936a0af 100644
--- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql
+++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql
@@ -769,13 +769,20 @@ group by
order by
custdist desc,
c_count desc;
-WITH "c_orders" AS (
+WITH "orders_2" AS (
+ SELECT
+ "orders"."o_orderkey" AS "o_orderkey",
+ "orders"."o_custkey" AS "o_custkey",
+ "orders"."o_comment" AS "o_comment"
+ FROM "orders" AS "orders"
+ WHERE
+ NOT "orders"."o_comment" LIKE '%special%requests%'
+), "c_orders" AS (
SELECT
COUNT("orders"."o_orderkey") AS "c_count"
FROM "customer" AS "customer"
- LEFT JOIN "orders" AS "orders"
+ LEFT JOIN "orders_2" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
- AND NOT "orders"."o_comment" LIKE '%special%requests%'
GROUP BY
"customer"."c_custkey"
)
diff --git a/tests/helpers.py b/tests/helpers.py
index ad50483..2d200f6 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -45,6 +45,14 @@ def load_sql_fixture_pairs(filename):
yield meta, sql, expected
+def string_to_bool(string):
+ if string is None:
+ return False
+ if string in (True, False):
+ return string
+ return string and string.lower() in ("true", "1")
+
+
TPCH_SCHEMA = {
"lineitem": {
"l_orderkey": "uint64",
diff --git a/tests/test_build.py b/tests/test_build.py
index b5d657c..fa9e7f8 100644
--- a/tests/test_build.py
+++ b/tests/test_build.py
@@ -1,6 +1,19 @@
import unittest
-from sqlglot import and_, condition, exp, from_, not_, or_, parse_one, select
+from sqlglot import (
+ alias,
+ and_,
+ condition,
+ except_,
+ exp,
+ from_,
+ intersect,
+ not_,
+ or_,
+ parse_one,
+ select,
+ union,
+)
class TestBuild(unittest.TestCase):
@@ -320,6 +333,54 @@ class TestBuild(unittest.TestCase):
lambda: exp.update("tbl", {"x": 1}, from_="tbl2"),
"UPDATE tbl SET x = 1 FROM tbl2",
),
+ (
+ lambda: union("SELECT * FROM foo", "SELECT * FROM bla"),
+ "SELECT * FROM foo UNION SELECT * FROM bla",
+ ),
+ (
+ lambda: parse_one("SELECT * FROM foo").union("SELECT * FROM bla"),
+ "SELECT * FROM foo UNION SELECT * FROM bla",
+ ),
+ (
+ lambda: intersect("SELECT * FROM foo", "SELECT * FROM bla"),
+ "SELECT * FROM foo INTERSECT SELECT * FROM bla",
+ ),
+ (
+ lambda: parse_one("SELECT * FROM foo").intersect("SELECT * FROM bla"),
+ "SELECT * FROM foo INTERSECT SELECT * FROM bla",
+ ),
+ (
+ lambda: except_("SELECT * FROM foo", "SELECT * FROM bla"),
+ "SELECT * FROM foo EXCEPT SELECT * FROM bla",
+ ),
+ (
+ lambda: parse_one("SELECT * FROM foo").except_("SELECT * FROM bla"),
+ "SELECT * FROM foo EXCEPT SELECT * FROM bla",
+ ),
+ (
+ lambda: parse_one("(SELECT * FROM foo)").union("SELECT * FROM bla"),
+ "(SELECT * FROM foo) UNION SELECT * FROM bla",
+ ),
+ (
+ lambda: parse_one("(SELECT * FROM foo)").union("SELECT * FROM bla", distinct=False),
+ "(SELECT * FROM foo) UNION ALL SELECT * FROM bla",
+ ),
+ (
+ lambda: alias(parse_one("LAG(x) OVER (PARTITION BY y)"), "a"),
+ "LAG(x) OVER (PARTITION BY y) AS a",
+ ),
+ (
+ lambda: alias(parse_one("LAG(x) OVER (ORDER BY z)"), "a"),
+ "LAG(x) OVER (ORDER BY z) AS a",
+ ),
+ (
+ lambda: alias(parse_one("LAG(x) OVER (PARTITION BY y ORDER BY z)"), "a"),
+ "LAG(x) OVER (PARTITION BY y ORDER BY z) AS a",
+ ),
+ (
+ lambda: alias(parse_one("LAG(x) OVER ()"), "a"),
+ "LAG(x) OVER () AS a",
+ ),
]:
with self.subTest(sql):
self.assertEqual(expression().sql(dialect[0] if dialect else None), sql)
diff --git a/tests/test_expressions.py b/tests/test_expressions.py
index cc41307..abc95cb 100644
--- a/tests/test_expressions.py
+++ b/tests/test_expressions.py
@@ -115,6 +115,21 @@ class TestExpressions(unittest.TestCase):
["first", "second", "third"],
)
+ def test_table_name(self):
+ self.assertEqual(exp.table_name(parse_one("a", into=exp.Table)), "a")
+ self.assertEqual(exp.table_name(parse_one("a.b", into=exp.Table)), "a.b")
+ self.assertEqual(exp.table_name(parse_one("a.b.c", into=exp.Table)), "a.b.c")
+ self.assertEqual(exp.table_name("a.b.c"), "a.b.c")
+
+ def test_replace_tables(self):
+ self.assertEqual(
+ exp.replace_tables(
+ parse_one("select * from a join b join c.a join d.a join e.a"),
+ {"a": "a1", "b": "b.a", "c.a": "c.a2", "d.a": "d2"},
+ ).sql(),
+ 'SELECT * FROM "a1" JOIN "b"."a" JOIN "c"."a2" JOIN "d2" JOIN e.a',
+ )
+
def test_named_selects(self):
expression = parse_one("SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz")
self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"])
@@ -474,3 +489,10 @@ class TestExpressions(unittest.TestCase):
]:
with self.subTest(value):
self.assertEqual(exp.convert(value).sql(), expected)
+
+ def test_annotation_alias(self):
+ expression = parse_one("SELECT a, b AS B, c #comment, d AS D #another_comment FROM foo")
+ self.assertEqual(
+ [e.alias_or_name for e in expression.expressions],
+ ["a", "B", "c", "D"],
+ )
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index aad84ed..36a7785 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -1,17 +1,55 @@
import unittest
from functools import partial
+import duckdb
+from pandas.testing import assert_frame_equal
+
+import sqlglot
from sqlglot import exp, optimizer, parse_one, table
from sqlglot.errors import OptimizeError
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.schema import MappingSchema, ensure_schema
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
-from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures
+from tests.helpers import (
+ TPCH_SCHEMA,
+ load_sql_fixture_pairs,
+ load_sql_fixtures,
+ string_to_bool,
+)
class TestOptimizer(unittest.TestCase):
maxDiff = None
+ @classmethod
+ def setUpClass(cls):
+ cls.conn = duckdb.connect()
+ cls.conn.execute(
+ """
+ CREATE TABLE x (a INT, b INT);
+ CREATE TABLE y (b INT, c INT);
+ CREATE TABLE z (b INT, c INT);
+
+ INSERT INTO x VALUES (1, 1);
+ INSERT INTO x VALUES (2, 2);
+ INSERT INTO x VALUES (2, 2);
+ INSERT INTO x VALUES (3, 3);
+ INSERT INTO x VALUES (null, null);
+
+ INSERT INTO y VALUES (2, 2);
+ INSERT INTO y VALUES (2, 2);
+ INSERT INTO y VALUES (3, 3);
+ INSERT INTO y VALUES (4, 4);
+ INSERT INTO y VALUES (null, null);
+
+ INSERT INTO y VALUES (3, 3);
+ INSERT INTO y VALUES (3, 3);
+ INSERT INTO y VALUES (4, 4);
+ INSERT INTO y VALUES (5, 5);
+ INSERT INTO y VALUES (null, null);
+ """
+ )
+
def setUp(self):
self.schema = {
"x": {
@@ -28,29 +66,42 @@ class TestOptimizer(unittest.TestCase):
},
}
- def check_file(self, file, func, pretty=False, **kwargs):
+ def check_file(self, file, func, pretty=False, execute=False, **kwargs):
for i, (meta, sql, expected) in enumerate(load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1):
+ title = meta.get("title") or f"{i}, {sql}"
dialect = meta.get("dialect")
leave_tables_isolated = meta.get("leave_tables_isolated")
func_kwargs = {**kwargs}
if leave_tables_isolated is not None:
- func_kwargs["leave_tables_isolated"] = leave_tables_isolated.lower() in ("true", "1")
+ func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)
+
+ optimized = func(parse_one(sql, read=dialect), **func_kwargs)
- with self.subTest(f"{i}, {sql}"):
+ with self.subTest(title):
self.assertEqual(
- func(parse_one(sql, read=dialect), **func_kwargs).sql(pretty=pretty, dialect=dialect),
+ optimized.sql(pretty=pretty, dialect=dialect),
expected,
)
+ should_execute = meta.get("execute")
+ if should_execute is None:
+ should_execute = execute
+
+ if string_to_bool(should_execute):
+ with self.subTest(f"(execute) {title}"):
+ df1 = self.conn.execute(sqlglot.transpile(sql, read=dialect, write="duckdb")[0]).df()
+ df2 = self.conn.execute(optimized.sql(pretty=pretty, dialect="duckdb")).df()
+ assert_frame_equal(df1, df2)
+
def test_optimize(self):
schema = {
"x": {"a": "INT", "b": "INT"},
- "y": {"a": "INT", "b": "INT"},
+ "y": {"b": "INT", "c": "INT"},
"z": {"a": "INT", "c": "INT"},
}
- self.check_file("optimizer", optimizer.optimize, pretty=True, schema=schema)
+ self.check_file("optimizer", optimizer.optimize, pretty=True, execute=True, schema=schema)
def test_isolate_table_selects(self):
self.check_file(
@@ -86,7 +137,16 @@ class TestOptimizer(unittest.TestCase):
expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
return expression
- self.check_file("qualify_columns", qualify_columns, schema=self.schema)
+ self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema)
+
+ def test_qualify_columns__with_invisible(self):
+ def qualify_columns(expression, **kwargs):
+ expression = optimizer.qualify_tables.qualify_tables(expression)
+ expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
+ return expression
+
+ schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}})
+ self.check_file("qualify_columns__with_invisible", qualify_columns, schema=schema)
def test_qualify_columns__invalid(self):
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
@@ -141,7 +201,7 @@ class TestOptimizer(unittest.TestCase):
],
)
- self.check_file("merge_subqueries", optimize, schema=self.schema)
+ self.check_file("merge_subqueries", optimize, execute=True, schema=self.schema)
def test_eliminate_subqueries(self):
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
@@ -301,10 +361,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
}
for sql, target_type in tests.items():
- expression = parse_one(sql)
- annotated_expression = annotate_types(expression)
-
- self.assertEqual(annotated_expression.find(exp.Literal).type, target_type)
+ expression = annotate_types(parse_one(sql))
+ self.assertEqual(expression.find(exp.Literal).type, target_type)
def test_boolean_type_annotation(self):
tests = {
@@ -313,14 +371,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
}
for sql, target_type in tests.items():
- expression = parse_one(sql)
- annotated_expression = annotate_types(expression)
-
- self.assertEqual(annotated_expression.find(exp.Boolean).type, target_type)
+ expression = annotate_types(parse_one(sql))
+ self.assertEqual(expression.find(exp.Boolean).type, target_type)
def test_cast_type_annotation(self):
- expression = parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")
- annotate_types(expression)
+ expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ)
self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR)
@@ -328,16 +383,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
def test_cache_annotation(self):
- expression = parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
- annotated_expression = annotate_types(expression)
-
- self.assertEqual(annotated_expression.expression.expressions[0].type, exp.DataType.Type.INT)
+ expression = annotate_types(parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1"))
+ self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
def test_binary_annotation(self):
- expression = parse_one("SELECT 0.0 + (2 + 3)")
- annotate_types(expression)
-
- expression = expression.expressions[0]
+ expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0]
self.assertEqual(expression.type, exp.DataType.Type.DOUBLE)
self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE)
@@ -345,3 +395,124 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.right.this.type, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT)
+
+ def test_derived_tables_column_annotation(self):
+ schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}}
+ sql = """
+ SELECT a.cola AS cola
+ FROM (
+ SELECT x.cola + y.cola AS cola
+ FROM (
+ SELECT x.cola AS cola
+ FROM x AS x
+ ) AS x
+ JOIN (
+ SELECT y.cola AS cola
+ FROM y AS y
+ ) AS y
+ ) AS a
+ """
+
+ expression = annotate_types(parse_one(sql), schema=schema)
+ self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola
+
+ addition_alias = expression.args["from"].expressions[0].this.expressions[0]
+ self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola
+
+ addition = addition_alias.this
+ self.assertEqual(addition.type, exp.DataType.Type.FLOAT)
+ self.assertEqual(addition.this.type, exp.DataType.Type.INT)
+ self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT)
+
+ def test_cte_column_annotation(self):
+ schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}}
+ sql = """
+ WITH tbl AS (
+ SELECT x.cola + 'bla' AS cola, y.colb AS colb
+ FROM (
+ SELECT x.cola AS cola
+ FROM x AS x
+ ) AS x
+ JOIN (
+ SELECT y.colb AS colb
+ FROM y AS y
+ ) AS y
+ )
+ SELECT tbl.cola + tbl.colb + 'foo' AS col
+ FROM tbl AS tbl
+ """
+
+ expression = annotate_types(parse_one(sql), schema=schema)
+ self.assertEqual(expression.expressions[0].type, exp.DataType.Type.TEXT) # tbl.cola + tbl.colb + 'foo' AS col
+
+ outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
+ self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
+ self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT)
+ self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR)
+
+ inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb
+ self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
+
+ cte_select = expression.args["with"].expressions[0].this
+ self.assertEqual(cte_select.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola + 'bla' AS cola
+ self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
+
+ cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
+ self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR)
+ self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
+
+ # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
+ for d, t in zip(cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]):
+ self.assertEqual(d.this.expressions[0].this.type, t)
+
+ def test_function_annotation(self):
+ schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}}
+ sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x"
+
+ concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
+ self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR)
+
+ concat_expr = concat_expr_alias.this
+ self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
+ self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
+ self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
+
+ def test_unknown_annotation(self):
+ schema = {"x": {"cola": "VARCHAR"}}
+ sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
+
+ concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
+ self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN)
+
+ concat_expr = concat_expr_alias.this
+ self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
+ self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
+ self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) # SOME_ANONYMOUS_FUNC(x.cola)
+ self.assertEqual(concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR) # x.cola (arg)
+
+ def test_null_annotation(self):
+ expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
+ self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
+ self.assertEqual(expression.right.type, exp.DataType.Type.INT)
+
+ # NULL <op> UNKNOWN should yield NULL
+ sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result"
+
+ concat_expr_alias = annotate_types(parse_one(sql)).expressions[0]
+ self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL)
+
+ concat_expr = concat_expr_alias.this
+ self.assertEqual(concat_expr.type, exp.DataType.Type.NULL)
+ self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL)
+ self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN)
+
+ def test_nullable_annotation(self):
+ nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
+ expression = annotate_types(parse_one("NULL AND FALSE"))
+
+ self.assertEqual(expression.type, nullable)
+ self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
+ self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN)
diff --git a/tests/test_transpile.py b/tests/test_transpile.py
index 4bec2ac..01b8205 100644
--- a/tests/test_transpile.py
+++ b/tests/test_transpile.py
@@ -338,7 +338,7 @@ class TestTranspile(unittest.TestCase):
unsupported_level=level,
)
- error = "Cannot convert array columns into map use SparkSQL instead."
+ error = "Cannot convert array columns into map."
unsupported(ErrorLevel.WARN)
assert_logger_contains("\n".join([error] * 4), logger, level="warning")