diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dialects/test_clickhouse.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_databricks.py | 5 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 11 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 7 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_prql.py | 61 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 41 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 36 | ||||
-rw-r--r-- | tests/dialects/test_teradata.py | 89 | ||||
-rw-r--r-- | tests/fixtures/optimizer/annotate_types.sql | 57 | ||||
-rw-r--r-- | tests/fixtures/optimizer/optimizer.sql | 42 | ||||
-rw-r--r-- | tests/fixtures/optimizer/qualify_columns.sql | 10 | ||||
-rw-r--r-- | tests/fixtures/optimizer/qualify_tables.sql | 7 | ||||
-rw-r--r-- | tests/test_build.py | 24 | ||||
-rw-r--r-- | tests/test_expressions.py | 23 | ||||
-rw-r--r-- | tests/test_optimizer.py | 66 | ||||
-rw-r--r-- | tests/test_transpile.py | 8 |
19 files changed, 403 insertions, 100 deletions
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index c5f9847..df3caaf 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -810,3 +810,6 @@ LIFETIME(MIN 0 MAX 0)""", }, pretty=True, ) + self.validate_identity( + "CREATE TABLE t1 (a String EPHEMERAL, b String EPHEMERAL func(), c String MATERIALIZED func(), d String ALIAS func()) ENGINE=TinyLog()" + ) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 94f2dc2..c15cf09 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -7,6 +7,8 @@ class TestDatabricks(Validator): dialect = "databricks" def test_databricks(self): + self.validate_identity("DESCRIBE HISTORY a.b") + self.validate_identity("DESCRIBE history.tbl") self.validate_identity("CREATE TABLE t (c STRUCT<interval: DOUBLE COMMENT 'aaa'>)") self.validate_identity("CREATE TABLE my_table TBLPROPERTIES (a.b=15)") self.validate_identity("CREATE TABLE my_table TBLPROPERTIES ('a.b'=15)") @@ -24,6 +26,9 @@ class TestDatabricks(Validator): self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1") self.validate_identity("CREATE TABLE foo (x DATE GENERATED ALWAYS AS (CAST(y AS DATE)))") self.validate_identity( + "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t" + ) + self.validate_identity( "SELECT * FROM sales UNPIVOT INCLUDE NULLS (sales FOR quarter IN (q1 AS `Jan-Mar`))" ) self.validate_identity( diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 76ab94b..691beb9 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1333,6 +1333,15 @@ class TestDialect(Validator): def test_set_operators(self): self.validate_all( + "SELECT * FROM a UNION SELECT * FROM b ORDER BY x LIMIT 1", + write={ + "": "SELECT * FROM a UNION SELECT * FROM b ORDER BY x LIMIT 1", + "clickhouse": "SELECT * FROM (SELECT * FROM a UNION DISTINCT SELECT * FROM b) AS _l_0 ORDER BY x NULLS FIRST LIMIT 1", + "tsql": "SELECT TOP 1 * FROM (SELECT * FROM a UNION SELECT * FROM b) AS _l_0 ORDER BY x", + }, + ) + + self.validate_all( "SELECT * FROM a UNION SELECT * FROM b", read={ "bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b", @@ -1667,7 +1676,7 @@ class TestDialect(Validator): "presto": "CAST(a AS DOUBLE) / b", "redshift": "CAST(a AS DOUBLE PRECISION) / b", "sqlite": "CAST(a AS REAL) / b", - "teradata": "CAST(a AS DOUBLE) / b", + "teradata": "CAST(a AS DOUBLE PRECISION) / b", "trino": "CAST(a AS DOUBLE) / b", "tsql": "CAST(a AS FLOAT) / b", }, diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 5a7e93e..0b13a70 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -240,6 +240,7 @@ class TestDuckDB(Validator): self.validate_identity("SELECT MAP(['key1', 'key2', 'key3'], [10, 20, 30])") self.validate_identity("SELECT MAP {'x': 1}") + self.validate_identity("SELECT (MAP {'x': 1})['x']") self.validate_identity("SELECT df1.*, df2.* FROM df1 POSITIONAL JOIN df2") self.validate_identity("MAKE_TIMESTAMP(1992, 9, 20, 13, 34, 27.123456)") self.validate_identity("MAKE_TIMESTAMP(1667810584123456)") diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 7a9d6bf..6558c97 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -1103,7 +1103,7 @@ COMMENT='客户账户表'""" "presto": "CAST(a AS DOUBLE) / NULLIF(b, 0)", "redshift": "CAST(a AS DOUBLE PRECISION) / NULLIF(b, 0)", "sqlite": "CAST(a AS REAL) / b", - "teradata": "CAST(a AS DOUBLE) / NULLIF(b, 0)", + "teradata": "CAST(a AS DOUBLE PRECISION) / NULLIF(b, 0)", "trino": "CAST(a AS DOUBLE) / NULLIF(b, 0)", "tsql": "CAST(a AS FLOAT) / NULLIF(b, 0)", }, diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 7a41cef..5a55a7d 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -50,6 +50,7 @@ class TestPostgres(Validator): self.validate_identity("STRING_AGG(DISTINCT x, ',' ORDER BY y DESC)") self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END") self.validate_identity("COMMENT ON TABLE mytable IS 'this'") + self.validate_identity("COMMENT ON MATERIALIZED VIEW my_view IS 'this'") self.validate_identity("SELECT e'\\xDEADBEEF'") self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')") @@ -467,6 +468,12 @@ class TestPostgres(Validator): }, ) self.validate_all( + "SELECT DATE_PART('epoch', CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", + read={ + "": "SELECT TIME_TO_UNIX(TIMESTAMP '2023-01-04 04:05:06.789')", + }, + ) + self.validate_all( "x ^ y", write={ "": "POWER(x, y)", diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 2162499..e1d8c06 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -7,6 +7,7 @@ class TestPresto(Validator): dialect = "presto" def test_cast(self): + self.validate_identity("SELECT * FROM x qualify", "SELECT * FROM x AS qualify") self.validate_identity("CAST(x AS IPADDRESS)") self.validate_identity("CAST(x AS IPPREFIX)") @@ -611,6 +612,15 @@ class TestPresto(Validator): self.validate_identity( "SELECT * FROM example.testdb.customer_orders FOR TIMESTAMP AS OF CAST('2022-03-23 09:59:29.803 Europe/Vienna' AS TIMESTAMP)" ) + self.validate_identity( + "SELECT origin_state, destination_state, origin_zip, SUM(package_weight) FROM shipping GROUP BY ALL CUBE (origin_state, destination_state), ROLLUP (origin_state, origin_zip)" + ) + self.validate_identity( + "SELECT origin_state, destination_state, origin_zip, SUM(package_weight) FROM shipping GROUP BY DISTINCT CUBE (origin_state, destination_state), ROLLUP (origin_state, origin_zip)" + ) + self.validate_identity( + "SELECT JSON_EXTRACT_SCALAR(CAST(extra AS JSON), '$.value_b'), COUNT(*) FROM table_a GROUP BY DISTINCT (JSON_EXTRACT_SCALAR(CAST(extra AS JSON), '$.value_b'))" + ) self.validate_all( "SELECT LAST_DAY_OF_MONTH(CAST('2008-11-25' AS DATE))", diff --git a/tests/dialects/test_prql.py b/tests/dialects/test_prql.py index 9a42d0c..69e2e28 100644 --- a/tests/dialects/test_prql.py +++ b/tests/dialects/test_prql.py @@ -5,13 +5,56 @@ class TestPRQL(Validator): dialect = "prql" def test_prql(self): - self.validate_identity("FROM x", "SELECT * FROM x") - self.validate_identity("FROM x DERIVE a + 1", "SELECT *, a + 1 FROM x") - self.validate_identity("FROM x DERIVE x = a + 1", "SELECT *, a + 1 AS x FROM x") - self.validate_identity("FROM x DERIVE {a + 1}", "SELECT *, a + 1 FROM x") - self.validate_identity("FROM x DERIVE {x = a + 1, b}", "SELECT *, a + 1 AS x, b FROM x") - self.validate_identity("FROM x TAKE 10", "SELECT * FROM x LIMIT 10") - self.validate_identity("FROM x TAKE 10 TAKE 5", "SELECT * FROM x LIMIT 5") - self.validate_identity( - "FROM x DERIVE {x = a + 1, b} SELECT {y = x, 2}", "SELECT a + 1 AS y, 2 FROM x" + self.validate_identity("from x", "SELECT * FROM x") + self.validate_identity("from x derive a + 1", "SELECT *, a + 1 FROM x") + self.validate_identity("from x derive x = a + 1", "SELECT *, a + 1 AS x FROM x") + self.validate_identity("from x derive {a + 1}", "SELECT *, a + 1 FROM x") + self.validate_identity("from x derive {x = a + 1, b}", "SELECT *, a + 1 AS x, b FROM x") + self.validate_identity( + "from x derive {x = a + 1, b} select {y = x, 2}", "SELECT a + 1 AS y, 2 FROM x" + ) + self.validate_identity("from x take 10", "SELECT * FROM x LIMIT 10") + self.validate_identity("from x take 10 take 5", "SELECT * FROM x LIMIT 5") + self.validate_identity("from x filter age > 25", "SELECT * FROM x WHERE age > 25") + self.validate_identity( + "from x derive {x = a + 1, b} filter age > 25", + "SELECT *, a + 1 AS x, b FROM x WHERE age > 25", + ) + self.validate_identity("from x filter dept != 'IT'", "SELECT * FROM x WHERE dept <> 'IT'") + self.validate_identity( + "from x filter p == 'product' select { a, b }", "SELECT a, b FROM x WHERE p = 'product'" + ) + self.validate_identity( + "from x filter age > 25 filter age < 27", "SELECT * FROM x WHERE age > 25 AND age < 27" + ) + self.validate_identity( + "from x filter (age > 25 && age < 27)", "SELECT * FROM x WHERE (age > 25 AND age < 27)" + ) + self.validate_identity( + "from x filter (age > 25 || age < 27)", "SELECT * FROM x WHERE (age > 25 OR age < 27)" + ) + self.validate_identity( + "from x filter (age > 25 || age < 22) filter age > 26 filter age < 27", + "SELECT * FROM x WHERE ((age > 25 OR age < 22) AND age > 26) AND age < 27", + ) + self.validate_identity( + "from x sort age", + "SELECT * FROM x ORDER BY age", + ) + self.validate_identity( + "from x sort {-age}", + "SELECT * FROM x ORDER BY age DESC", + ) + self.validate_identity( + "from x sort {age, name}", + "SELECT * FROM x ORDER BY age, name", + ) + self.validate_identity( + "from x sort {-age, +name}", + "SELECT * FROM x ORDER BY age DESC, name", + ) + self.validate_identity("from x append y", "SELECT * FROM x UNION ALL SELECT * FROM y") + self.validate_identity("from x remove y", "SELECT * FROM x EXCEPT ALL SELECT * FROM y") + self.validate_identity( + "from x intersect y", "SELECT * FROM x INTERSECT ALL SELECT * FROM y" ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index a41d35a..b652541 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -66,6 +66,7 @@ WHERE self.validate_identity("SELECT DAYOFYEAR(CURRENT_TIMESTAMP())") self.validate_identity("LISTAGG(data['some_field'], ',')") self.validate_identity("WEEKOFYEAR(tstamp)") + self.validate_identity("SELECT QUARTER(CURRENT_TIMESTAMP())") self.validate_identity("SELECT SUM(amount) FROM mytable GROUP BY ALL") self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT foo FROM IDENTIFIER('x')") self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT IDENTIFIER('foo') FROM x") @@ -1575,22 +1576,26 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene ) def test_match_recognize(self): - for row in ( - "ONE ROW PER MATCH", - "ALL ROWS PER MATCH", - "ALL ROWS PER MATCH SHOW EMPTY MATCHES", - "ALL ROWS PER MATCH OMIT EMPTY MATCHES", - "ALL ROWS PER MATCH WITH UNMATCHED ROWS", - ): - for after in ( - "AFTER MATCH SKIP", - "AFTER MATCH SKIP PAST LAST ROW", - "AFTER MATCH SKIP TO NEXT ROW", - "AFTER MATCH SKIP TO FIRST x", - "AFTER MATCH SKIP TO LAST x", + for window_frame in ("", "FINAL ", "RUNNING "): + for row in ( + "ONE ROW PER MATCH", + "ALL ROWS PER MATCH", + "ALL ROWS PER MATCH SHOW EMPTY MATCHES", + "ALL ROWS PER MATCH OMIT EMPTY MATCHES", + "ALL ROWS PER MATCH WITH UNMATCHED ROWS", ): - self.validate_identity( - f"""SELECT + for after in ( + "AFTER MATCH SKIP", + "AFTER MATCH SKIP PAST LAST ROW", + "AFTER MATCH SKIP TO NEXT ROW", + "AFTER MATCH SKIP TO FIRST x", + "AFTER MATCH SKIP TO LAST x", + ): + with self.subTest( + f"MATCH_RECOGNIZE with window frame {window_frame}, rows {row}, after {after}: " + ): + self.validate_identity( + f"""SELECT * FROM x MATCH_RECOGNIZE ( @@ -1598,15 +1603,15 @@ MATCH_RECOGNIZE ( ORDER BY x DESC MEASURES - y AS b + {window_frame}y AS b {row} {after} PATTERN (^ S1 S2*? ( {{- S3 -}} S4 )+ | PERMUTE(S1, S2){{1,2}} $) DEFINE x AS y )""", - pretty=True, - ) + pretty=True, + ) def test_show_users(self): self.validate_identity("SHOW USERS") diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 18f1fb7..d2285e0 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -2,6 +2,7 @@ from unittest import mock from sqlglot import exp, parse_one from sqlglot.dialects.dialect import Dialects +from sqlglot.helper import logger as helper_logger from tests.dialects.test_dialect import Validator @@ -223,17 +224,16 @@ TBLPROPERTIES ( ) def test_spark(self): - self.validate_identity("any_value(col, true)", "ANY_VALUE(col) IGNORE NULLS") - self.validate_identity("first(col, true)", "FIRST(col) IGNORE NULLS") - self.validate_identity("first_value(col, true)", "FIRST_VALUE(col) IGNORE NULLS") - self.validate_identity("last(col, true)", "LAST(col) IGNORE NULLS") - self.validate_identity("last_value(col, true)", "LAST_VALUE(col) IGNORE NULLS") - self.assertEqual( parse_one("REFRESH TABLE t", read="spark").assert_is(exp.Refresh).sql(dialect="spark"), "REFRESH TABLE t", ) + self.validate_identity("any_value(col, true)", "ANY_VALUE(col) IGNORE NULLS") + self.validate_identity("first(col, true)", "FIRST(col) IGNORE NULLS") + self.validate_identity("first_value(col, true)", "FIRST_VALUE(col) IGNORE NULLS") + self.validate_identity("last(col, true)", "LAST(col) IGNORE NULLS") + self.validate_identity("last_value(col, true)", "LAST_VALUE(col) IGNORE NULLS") self.validate_identity("DESCRIBE EXTENDED db.table") self.validate_identity("SELECT * FROM test TABLESAMPLE (50 PERCENT)") self.validate_identity("SELECT * FROM test TABLESAMPLE (5 ROWS)") @@ -284,6 +284,30 @@ TBLPROPERTIES ( "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", ) + with self.assertLogs(helper_logger): + self.validate_all( + "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + read={ + "databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + }, + write={ + "databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + "duckdb": "SELECT ([1, 2, 3])[3]", + "spark": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + }, + ) + + self.validate_all( + "SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)", + read={ + "databricks": "SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)", + }, + write={ + "databricks": "SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)", + "duckdb": "SELECT (MAP([1, 2], ['a', 'b'])[2])[1]", + "spark": "SELECT TRY_ELEMENT_AT(MAP(1, 'a', 2, 'b'), 2)", + }, + ) self.validate_all( "SELECT SPLIT('123|789', '\\\\|')", read={ diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index a85ca8c..010b683 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -210,3 +210,92 @@ class TestTeradata(Validator): "teradata": "TRYCAST('-2.5' AS DECIMAL(5, 2))", }, ) + + def test_time(self): + self.validate_all( + "CURRENT_TIMESTAMP", + read={ + "teradata": "CURRENT_TIMESTAMP", + "snowflake": "CURRENT_TIMESTAMP()", + }, + ) + + self.validate_all( + "SELECT '2023-01-01' + INTERVAL '5' YEAR", + read={ + "teradata": "SELECT '2023-01-01' + INTERVAL '5' YEAR", + "snowflake": "SELECT DATEADD(YEAR, 5, '2023-01-01')", + }, + ) + self.validate_all( + "SELECT '2023-01-01' - INTERVAL '5' YEAR", + read={ + "teradata": "SELECT '2023-01-01' - INTERVAL '5' YEAR", + "snowflake": "SELECT DATEADD(YEAR, -5, '2023-01-01')", + }, + ) + self.validate_all( + "SELECT '2023-01-01' - INTERVAL '5' YEAR", + read={ + "teradata": "SELECT '2023-01-01' - INTERVAL '5' YEAR", + "sqlite": "SELECT DATE_SUB('2023-01-01', 5, YEAR)", + }, + ) + self.validate_all( + "SELECT '2023-01-01' + INTERVAL '5' YEAR", + read={ + "teradata": "SELECT '2023-01-01' + INTERVAL '5' YEAR", + "sqlite": "SELECT DATE_SUB('2023-01-01', -5, YEAR)", + }, + ) + self.validate_all( + "SELECT (90 * INTERVAL '1' DAY)", + read={ + "teradata": "SELECT (90 * INTERVAL '1' DAY)", + "snowflake": "SELECT INTERVAL '1' QUARTER", + }, + ) + self.validate_all( + "SELECT (7 * INTERVAL '1' DAY)", + read={ + "teradata": "SELECT (7 * INTERVAL '1' DAY)", + "snowflake": "SELECT INTERVAL '1' WEEK", + }, + ) + self.validate_all( + "SELECT '2023-01-01' + (90 * INTERVAL '5' DAY)", + read={ + "teradata": "SELECT '2023-01-01' + (90 * INTERVAL '5' DAY)", + "snowflake": "SELECT DATEADD(QUARTER, 5, '2023-01-01')", + }, + ) + self.validate_all( + "SELECT '2023-01-01' + (7 * INTERVAL '5' DAY)", + read={ + "teradata": "SELECT '2023-01-01' + (7 * INTERVAL '5' DAY)", + "snowflake": "SELECT DATEADD(WEEK, 5, '2023-01-01')", + }, + ) + self.validate_all( + "CAST(TO_CHAR(x, 'Q') AS INT)", + read={ + "teradata": "CAST(TO_CHAR(x, 'Q') AS INT)", + "snowflake": "DATE_PART(QUARTER, x)", + "bigquery": "EXTRACT(QUARTER FROM x)", + }, + ) + self.validate_all( + "EXTRACT(MONTH FROM x)", + read={ + "teradata": "EXTRACT(MONTH FROM x)", + "snowflake": "DATE_PART(MONTH, x)", + "bigquery": "EXTRACT(MONTH FROM x)", + }, + ) + self.validate_all( + "CAST(TO_CHAR(x, 'Q') AS INT)", + read={ + "snowflake": "quarter(x)", + "teradata": "CAST(TO_CHAR(x, 'Q') AS INT)", + }, + ) diff --git a/tests/fixtures/optimizer/annotate_types.sql b/tests/fixtures/optimizer/annotate_types.sql new file mode 100644 index 0000000..e781765 --- /dev/null +++ b/tests/fixtures/optimizer/annotate_types.sql @@ -0,0 +1,57 @@ +5; +INT; + +5.3; +DOUBLE; + +'bla'; +VARCHAR; + +True; +bool; + +false; +bool; + +null; +null; +CASE WHEN x THEN NULL ELSE 1 END; +INT; + +CASE WHEN x THEN 1 ELSE NULL END; +INT; + +IF(true, 1, null); +INT; + +IF(true, null, 1); +INT; + +STRUCT(1 AS col); +STRUCT<col INT>; + +STRUCT(1 AS col, 2.5 AS row); +STRUCT<col INT, row DOUBLE>; + +STRUCT(1); +STRUCT<INT>; + +STRUCT(1 AS col, 2.5 AS row, struct(3.5 AS inner_col, 4 AS inner_row) AS nested_struct); +STRUCT<col INT, row DOUBLE, nested_struct STRUCT<inner_col DOUBLE, inner_row INT>>; + +STRUCT(1 AS col, 2.5, ARRAY[1, 2, 3] AS nested_array, 'foo'); +STRUCT<col INT, DOUBLE, nested_array ARRAY<INT>, VARCHAR>; + +STRUCT(1, 2.5, 'bar'); +STRUCT<INT, DOUBLE, VARCHAR>; + +STRUCT(1 AS "CaseSensitive"); +STRUCT<"CaseSensitive" INT>; + +# dialect: duckdb +STRUCT_PACK(a := 1, b := 2.5); +STRUCT<a INT, b DOUBLE>; + +# dialect: presto +ROW(1, 2.5, 'foo'); +STRUCT<INT, DOUBLE, VARCHAR>; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index cc72e6d..37ef4fd 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -1388,3 +1388,45 @@ WHERE ORDER BY COUNT(DISTINCT `cs1`.`cs_order_number`) LIMIT 100; + +# execute: false +SELECT + * +FROM event +WHERE priority = 'High' AND tagname IN ( + SELECT + tag_input AS tagname + FROM cascade + WHERE tag_input = 'XXX' OR tag_output = 'XXX' + UNION + SELECT + tag_output AS tagname + FROM cascade + WHERE tag_input = 'XXX' OR tag_output = 'XXX' +); +WITH "_u_0" AS ( + SELECT + "cascade"."tag_input" AS "tagname" + FROM "cascade" AS "cascade" + WHERE + "cascade"."tag_input" = 'XXX' OR "cascade"."tag_output" = 'XXX' + UNION + SELECT + "cascade"."tag_output" AS "tagname" + FROM "cascade" AS "cascade" + WHERE + "cascade"."tag_input" = 'XXX' OR "cascade"."tag_output" = 'XXX' +), "_u_1" AS ( + SELECT + "cascade"."tag_input" AS "tagname" + FROM "_u_0" AS "_u_0" + GROUP BY + "cascade"."tag_input" +) +SELECT + * +FROM "event" AS "event" +LEFT JOIN "_u_1" AS "_u_1" + ON "_u_1"."tagname" = "event"."tagname" +WHERE + "event"."priority" = 'High' AND NOT "_u_1"."tagname" IS NULL; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 289145b..8baf961 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -492,6 +492,11 @@ SELECT x AS x, offset AS offset FROM UNNEST([1, 2]) AS x WITH OFFSET AS offset; select * from unnest ([1, 2]) as x with offset as y; SELECT x AS x, y AS y FROM UNNEST([1, 2]) AS x WITH OFFSET AS y; +# dialect: bigquery +# execute: false +select x, a, x.a from unnest([STRUCT(1 AS a)]) as x; +SELECT x AS x, a AS a, x.a AS a FROM UNNEST([STRUCT(1 AS a)]) AS x; + # dialect: presto SELECT x.a, i.b FROM x CROSS JOIN UNNEST(SPLIT(CAST(b AS VARCHAR), ',')) AS i(b); SELECT x.a AS a, i.b AS b FROM x AS x CROSS JOIN UNNEST(SPLIT(CAST(x.b AS VARCHAR), ',')) AS i(b); @@ -508,6 +513,11 @@ SELECT t.c1 AS c1, t.c2 AS c2, t.c3 AS c3 FROM FOO(bar) AS t(c1, c2, c3); SELECT c1, c3 FROM foo(bar) AS t(c1, c2, c3); SELECT t.c1 AS c1, t.c3 AS c3 FROM FOO(bar) AS t(c1, c2, c3); +# dialect: redshift +# execute: false +SELECT c.f::VARCHAR(MAX) AS f, e AS e FROM a.b AS c, c.d AS e; +SELECT CAST(c.f AS VARCHAR(MAX)) AS f, e AS e FROM a.b AS c, c.d AS e; + -------------------------------------- -- Window functions -------------------------------------- diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql index f651a87..104400e 100644 --- a/tests/fixtures/optimizer/qualify_tables.sql +++ b/tests/fixtures/optimizer/qualify_tables.sql @@ -166,3 +166,10 @@ WITH cte AS (SELECT b FROM c.db.y AS y) INSERT INTO c.db.s SELECT * FROM cte AS # title: qualify wrapped query (SELECT x FROM t); (SELECT x FROM c.db.t AS t); + +# title: replace columns with db/catalog refs +SELECT db1.a.id, db2.a.id FROM db1.a JOIN db2.a ON db1.a.id = db2.a.id; +SELECT a.id, a_2.id FROM c.db1.a AS a JOIN c.db2.a AS a_2 ON a.id = a_2.id; + +SELECT cat.db1.a.id, db2.a.id FROM cat.db1.a JOIN db2.a ON cat.db1.a.id = db2.a.id; +SELECT a.id, a_2.id FROM cat.db1.a AS a JOIN c.db2.a AS a_2 ON a.id = a_2.id; diff --git a/tests/test_build.py b/tests/test_build.py index cdddd4f..ad0bb9a 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -301,6 +301,10 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl ORDER BY y", ), ( + lambda: parse_one("select * from x union select * from y").order_by("y"), + "SELECT * FROM x UNION SELECT * FROM y ORDER BY y", + ), + ( lambda: select("x").from_("tbl").cluster_by("y"), "SELECT x FROM tbl CLUSTER BY y", "hive", @@ -505,15 +509,19 @@ class TestBuild(unittest.TestCase): (lambda: parse_one("(SELECT 1)").select("2"), "(SELECT 1, 2)"), ( lambda: parse_one("(SELECT 1)").limit(1), - "SELECT * FROM ((SELECT 1)) AS _l_0 LIMIT 1", + "(SELECT 1) LIMIT 1", ), ( lambda: parse_one("WITH t AS (SELECT 1) (SELECT 1)").limit(1), - "SELECT * FROM (WITH t AS (SELECT 1) (SELECT 1)) AS _l_0 LIMIT 1", + "WITH t AS (SELECT 1) (SELECT 1) LIMIT 1", ), ( lambda: parse_one("(SELECT 1 LIMIT 2)").limit(1), - "SELECT * FROM ((SELECT 1 LIMIT 2)) AS _l_0 LIMIT 1", + "(SELECT 1 LIMIT 2) LIMIT 1", + ), + ( + lambda: parse_one("SELECT 1 UNION SELECT 2").limit(5).offset(2), + "SELECT 1 UNION SELECT 2 LIMIT 5 OFFSET 2", ), (lambda: parse_one("(SELECT 1)").subquery(), "((SELECT 1))"), (lambda: parse_one("(SELECT 1)").subquery("alias"), "((SELECT 1)) AS alias"), @@ -665,14 +673,8 @@ class TestBuild(unittest.TestCase): "(x, y) IN ((1, 2), (3, 4))", "postgres", ), - ( - lambda: exp.cast_unless("CAST(x AS INT)", "int", "int"), - "CAST(x AS INT)", - ), - ( - lambda: exp.cast_unless("CAST(x AS TEXT)", "int", "int"), - "CAST(CAST(x AS TEXT) AS INT)", - ), + (lambda: exp.cast("CAST(x AS INT)", "int"), "CAST(x AS INT)"), + (lambda: exp.cast("CAST(x AS TEXT)", "int"), "CAST(CAST(x AS TEXT) AS INT)"), ( lambda: exp.rename_column("table1", "c1", "c2", True), "ALTER TABLE table1 RENAME COLUMN IF EXISTS c1 TO c2", diff --git a/tests/test_expressions.py b/tests/test_expressions.py index ed19ac1..85560b8 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -634,6 +634,7 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("MAX(a)"), exp.Max) self.assertIsInstance(parse_one("MIN(a)"), exp.Min) self.assertIsInstance(parse_one("MONTH(a)"), exp.Month) + self.assertIsInstance(parse_one("QUARTER(a)"), exp.Quarter) self.assertIsInstance(parse_one("POSITION(' ' IN a)"), exp.StrPosition) self.assertIsInstance(parse_one("POW(a, 2)"), exp.Pow) self.assertIsInstance(parse_one("POWER(a, 2)"), exp.Pow) @@ -716,6 +717,9 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(dot, exp.Dot) self.assertEqual(dot.sql(), "a.b.c.d.e.f") + dot = exp.column("d", "c", "b", "a", fields=["e", "f"], quoted=True) + self.assertEqual(dot.sql(), '"a"."b"."c"."d"."e"."f"') + def test_text(self): column = parse_one("a.b.c.d.e") self.assertEqual(column.text("expression"), "e") @@ -893,8 +897,6 @@ FROM foo""", self.assertEqual(catalog_db_and_table.name, "table_name") self.assertEqual(catalog_db_and_table.args.get("db"), exp.to_identifier("db")) self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog")) - with self.assertRaises(ValueError): - exp.to_table(1) def test_to_column(self): column_only = exp.to_column("column_name") @@ -903,8 +905,14 @@ FROM foo""", table_and_column = exp.to_column("table_name.column_name") self.assertEqual(table_and_column.name, "column_name") self.assertEqual(table_and_column.args.get("table"), exp.to_identifier("table_name")) - with self.assertRaises(ValueError): - exp.to_column(1) + + self.assertEqual(exp.to_column("foo bar").sql(), '"foo bar"') + self.assertEqual(exp.to_column("`column_name`", dialect="spark").sql(), '"column_name"') + self.assertEqual(exp.to_column("column_name", quoted=True).sql(), '"column_name"') + self.assertEqual( + exp.to_column("column_name", table=exp.to_identifier("table_name")).sql(), + "table_name.column_name", + ) def test_union(self): expression = parse_one("SELECT cola, colb UNION SELECT colx, coly") @@ -996,6 +1004,13 @@ FROM foo""", "ALTER TABLE t1 RENAME TO t2", ) + def test_is_negative(self): + self.assertTrue(parse_one("-1").is_negative) + self.assertTrue(parse_one("- 1.0").is_negative) + self.assertTrue(exp.Literal.number("-1").is_negative) + self.assertFalse(parse_one("1").is_negative) + self.assertFalse(parse_one("x").is_negative) + def test_is_star(self): assert parse_one("*").is_star assert parse_one("foo.*").is_star diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 0e8ce15..c0b362c 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -230,6 +230,17 @@ class TestOptimizer(unittest.TestCase): def test_qualify_columns(self, logger): self.assertEqual( optimizer.qualify_columns.qualify_columns( + parse_one( + "WITH RECURSIVE t AS (SELECT 1 AS x UNION ALL SELECT x + 1 FROM t AS child WHERE x < 10) SELECT * FROM t" + ), + schema={}, + infer_schema=False, + ).sql(), + "WITH RECURSIVE t AS (SELECT 1 AS x UNION ALL SELECT child.x + 1 AS _col_0 FROM t AS child WHERE child.x < 10) SELECT t.x AS x FROM t", + ) + + self.assertEqual( + optimizer.qualify_columns.qualify_columns( parse_one("WITH x AS (SELECT a FROM db.y) SELECT * FROM db.x"), schema={"db": {"x": {"z": "int"}, "y": {"a": "int"}}}, expand_stars=False, @@ -617,53 +628,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') level="warning", ) - def test_struct_type_annotation(self): - tests = { - ("SELECT STRUCT(1 AS col)", "spark"): "STRUCT<col INT>", - ("SELECT STRUCT(1 AS col, 2.5 AS row)", "spark"): "STRUCT<col INT, row DOUBLE>", - ("SELECT STRUCT(1)", "bigquery"): "STRUCT<INT>", - ( - "SELECT STRUCT(1 AS col, 2.5 AS row, struct(3.5 AS inner_col, 4 AS inner_row) AS nested_struct)", - "spark", - ): "STRUCT<col INT, row DOUBLE, nested_struct STRUCT<inner_col DOUBLE, inner_row INT>>", - ( - "SELECT STRUCT(1 AS col, 2.5, ARRAY[1, 2, 3] AS nested_array, 'foo')", - "bigquery", - ): "STRUCT<col INT, DOUBLE, nested_array ARRAY<INT>, VARCHAR>", - ("SELECT STRUCT(1, 2.5, 'bar')", "spark"): "STRUCT<INT, DOUBLE, VARCHAR>", - ('SELECT STRUCT(1 AS "CaseSensitive")', "spark"): 'STRUCT<"CaseSensitive" INT>', - ("SELECT STRUCT_PACK(a := 1, b := 2.5)", "duckdb"): "STRUCT<a INT, b DOUBLE>", - ("SELECT ROW(1, 2.5, 'foo')", "presto"): "STRUCT<INT, DOUBLE, VARCHAR>", - } - - for (sql, dialect), target_type in tests.items(): - with self.subTest(sql): - expression = annotate_types(parse_one(sql, read=dialect)) - assert expression.expressions[0].is_type(target_type) - - def test_literal_type_annotation(self): - tests = { - "SELECT 5": exp.DataType.Type.INT, - "SELECT 5.3": exp.DataType.Type.DOUBLE, - "SELECT 'bla'": exp.DataType.Type.VARCHAR, - "5": exp.DataType.Type.INT, - "5.3": exp.DataType.Type.DOUBLE, - "'bla'": exp.DataType.Type.VARCHAR, - } - - for sql, target_type in tests.items(): - expression = annotate_types(parse_one(sql)) - self.assertEqual(expression.find(exp.Literal).type.this, target_type) - - def test_boolean_type_annotation(self): - tests = { - "SELECT TRUE": exp.DataType.Type.BOOLEAN, - "FALSE": exp.DataType.Type.BOOLEAN, - } + def test_annotate_types(self): + for i, (meta, sql, expected) in enumerate( + load_sql_fixture_pairs("optimizer/annotate_types.sql"), start=1 + ): + title = meta.get("title") or f"{i}, {sql}" + dialect = meta.get("dialect") + result = parse_and_optimize(annotate_types, sql, dialect) - for sql, target_type in tests.items(): - expression = annotate_types(parse_one(sql)) - self.assertEqual(expression.find(exp.Boolean).type.this, target_type) + with self.subTest(title): + self.assertEqual(result.type.sql(), exp.DataType.build(expected).sql()) def test_cast_type_annotation(self): expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 0c65da4..95fba30 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -69,12 +69,12 @@ class TestTranspile(unittest.TestCase): self.validate( "SELECT a, b, c FROM (SELECT a, b, c FROM t)", "SELECT\n" - " a\n" + " a\n" " , b\n" " , c\n" "FROM (\n" " SELECT\n" - " a\n" + " a\n" " , b\n" " , c\n" " FROM t\n" @@ -86,13 +86,13 @@ class TestTranspile(unittest.TestCase): ) self.validate( "SELECT FOO, BAR, BAZ", - "SELECT\n FOO\n , BAR\n , BAZ", + "SELECT\n FOO\n , BAR\n , BAZ", leading_comma=True, pretty=True, ) self.validate( "SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ", - "SELECT\n FOO /* x */\n , BAR /* y */\n , BAZ", + "SELECT\n FOO /* x */\n , BAR /* y */\n , BAZ", leading_comma=True, pretty=True, ) |