diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dialects/test_bigquery.py | 217 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 23 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 104 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 97 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 13 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 19 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 20 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 302 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 112 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 3 | ||||
-rw-r--r-- | tests/fixtures/optimizer/tpc-ds/tpc-ds.sql | 30 | ||||
-rw-r--r-- | tests/fixtures/optimizer/unnest_subqueries.sql | 312 | ||||
-rw-r--r-- | tests/test_executor.py | 10 | ||||
-rw-r--r-- | tests/test_expressions.py | 19 | ||||
-rw-r--r-- | tests/test_lineage.py | 8 | ||||
-rw-r--r-- | tests/test_optimizer.py | 6 | ||||
-rw-r--r-- | tests/test_parser.py | 35 | ||||
-rw-r--r-- | tests/test_schema.py | 28 | ||||
-rw-r--r-- | tests/test_transpile.py | 7 |
19 files changed, 907 insertions, 458 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 301cd57..728785c 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -21,6 +21,13 @@ class TestBigQuery(Validator): def test_bigquery(self): self.validate_identity( + """CREATE TEMPORARY FUNCTION FOO() +RETURNS STRING +LANGUAGE js AS +'return "Hello world!"'""", + pretty=True, + ) + self.validate_identity( "[a, a(1, 2,3,4444444444444444, tttttaoeunthaoentuhaoentuheoantu, toheuntaoheutnahoeunteoahuntaoeh), b(3, 4,5), c, d, tttttttttttttttteeeeeeeeeeeeeett, 12312312312]", """[ a, @@ -279,6 +286,13 @@ class TestBigQuery(Validator): ) self.validate_all( + "SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS h, UNNEST(h.t3) AS s", + write={ + "bigquery": "SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS h, UNNEST(h.t3) AS s", + "duckdb": "SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS _t0(h), UNNEST(h.t3) AS _t1(s)", + }, + ) + self.validate_all( "PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)", write={ "bigquery": "PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)", @@ -289,7 +303,7 @@ class TestBigQuery(Validator): "SELECT results FROM Coordinates, Coordinates.position AS results", write={ "bigquery": "SELECT results FROM Coordinates, UNNEST(Coordinates.position) AS results", - "presto": "SELECT results FROM Coordinates, UNNEST(Coordinates.position) AS _t(results)", + "presto": "SELECT results FROM Coordinates, UNNEST(Coordinates.position) AS _t0(results)", }, ) self.validate_all( @@ -307,7 +321,7 @@ class TestBigQuery(Validator): }, write={ "bigquery": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS results", - "presto": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS _t(results)", + "presto": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS _t0(results)", "redshift": "SELECT results FROM Coordinates AS c, c.position AS results", }, ) @@ -525,7 +539,7 @@ class TestBigQuery(Validator): "SELECT * FROM t WHERE EXISTS(SELECT * FROM unnest(nums) AS x WHERE x > 1)", write={ "bigquery": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS x WHERE x > 1)", - "duckdb": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS _t(x) WHERE x > 1)", + "duckdb": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS _t0(x) WHERE x > 1)", }, ) self.validate_all( @@ -618,12 +632,87 @@ class TestBigQuery(Validator): }, ) self.validate_all( + "LOWER(TO_HEX(x))", + write={ + "": "LOWER(HEX(x))", + "bigquery": "TO_HEX(x)", + "clickhouse": "LOWER(HEX(x))", + "duckdb": "LOWER(HEX(x))", + "hive": "LOWER(HEX(x))", + "mysql": "LOWER(HEX(x))", + "spark": "LOWER(HEX(x))", + "sqlite": "LOWER(HEX(x))", + "presto": "LOWER(TO_HEX(x))", + "trino": "LOWER(TO_HEX(x))", + }, + ) + self.validate_all( + "TO_HEX(x)", + read={ + "": "LOWER(HEX(x))", + "clickhouse": "LOWER(HEX(x))", + "duckdb": "LOWER(HEX(x))", + "hive": "LOWER(HEX(x))", + "mysql": "LOWER(HEX(x))", + "spark": "LOWER(HEX(x))", + "sqlite": "LOWER(HEX(x))", + "presto": "LOWER(TO_HEX(x))", + "trino": "LOWER(TO_HEX(x))", + }, + write={ + "": "LOWER(HEX(x))", + "bigquery": "TO_HEX(x)", + "clickhouse": "LOWER(HEX(x))", + "duckdb": "LOWER(HEX(x))", + "hive": "LOWER(HEX(x))", + "mysql": "LOWER(HEX(x))", + "presto": "LOWER(TO_HEX(x))", + "spark": "LOWER(HEX(x))", + "sqlite": "LOWER(HEX(x))", + "trino": "LOWER(TO_HEX(x))", + }, + ) + self.validate_all( + "UPPER(TO_HEX(x))", + read={ + "": "HEX(x)", + "clickhouse": "HEX(x)", + "duckdb": "HEX(x)", + "hive": "HEX(x)", + "mysql": "HEX(x)", + "presto": "TO_HEX(x)", + "spark": "HEX(x)", + "sqlite": "HEX(x)", + "trino": "TO_HEX(x)", + }, + write={ + "": "HEX(x)", + "bigquery": "UPPER(TO_HEX(x))", + "clickhouse": "HEX(x)", + "duckdb": "HEX(x)", + "hive": "HEX(x)", + "mysql": "HEX(x)", + "presto": "TO_HEX(x)", + "spark": "HEX(x)", + "sqlite": "HEX(x)", + "trino": "TO_HEX(x)", + }, + ) + self.validate_all( "MD5(x)", + read={ + "clickhouse": "MD5(x)", + "presto": "MD5(x)", + "trino": "MD5(x)", + }, write={ "": "MD5_DIGEST(x)", "bigquery": "MD5(x)", + "clickhouse": "MD5(x)", "hive": "UNHEX(MD5(x))", + "presto": "MD5(x)", "spark": "UNHEX(MD5(x))", + "trino": "MD5(x)", }, ) self.validate_all( @@ -631,25 +720,69 @@ class TestBigQuery(Validator): read={ "duckdb": "SELECT MD5(some_string)", "spark": "SELECT MD5(some_string)", + "clickhouse": "SELECT LOWER(HEX(MD5(some_string)))", + "presto": "SELECT LOWER(TO_HEX(MD5(some_string)))", + "trino": "SELECT LOWER(TO_HEX(MD5(some_string)))", }, write={ "": "SELECT MD5(some_string)", "bigquery": "SELECT TO_HEX(MD5(some_string))", "duckdb": "SELECT MD5(some_string)", + "clickhouse": "SELECT LOWER(HEX(MD5(some_string)))", + "presto": "SELECT LOWER(TO_HEX(MD5(some_string)))", + "trino": "SELECT LOWER(TO_HEX(MD5(some_string)))", + }, + ) + self.validate_all( + "SHA1(x)", + read={ + "clickhouse": "SHA1(x)", + "presto": "SHA1(x)", + "trino": "SHA1(x)", + }, + write={ + "clickhouse": "SHA1(x)", + "bigquery": "SHA1(x)", + "": "SHA(x)", + "presto": "SHA1(x)", + "trino": "SHA1(x)", + }, + ) + self.validate_all( + "SHA1(x)", + write={ + "bigquery": "SHA1(x)", + "": "SHA(x)", }, ) self.validate_all( "SHA256(x)", + read={ + "clickhouse": "SHA256(x)", + "presto": "SHA256(x)", + "trino": "SHA256(x)", + }, write={ "bigquery": "SHA256(x)", "spark2": "SHA2(x, 256)", + "clickhouse": "SHA256(x)", + "presto": "SHA256(x)", + "trino": "SHA256(x)", }, ) self.validate_all( "SHA512(x)", + read={ + "clickhouse": "SHA512(x)", + "presto": "SHA512(x)", + "trino": "SHA512(x)", + }, write={ + "clickhouse": "SHA512(x)", "bigquery": "SHA512(x)", "spark2": "SHA2(x, 512)", + "presto": "SHA512(x)", + "trino": "SHA512(x)", }, ) self.validate_all( @@ -860,8 +993,8 @@ class TestBigQuery(Validator): }, write={ "bigquery": "SELECT * FROM UNNEST(['7', '14']) AS x", - "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS _t(x)", - "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS _t(x)", + "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS _t0(x)", + "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS _t0(x)", }, ) self.validate_all( @@ -982,6 +1115,69 @@ class TestBigQuery(Validator): }, ) self.validate_all( + "DELETE db.example_table WHERE x = 1", + write={ + "bigquery": "DELETE db.example_table WHERE x = 1", + "presto": "DELETE FROM db.example_table WHERE x = 1", + }, + ) + self.validate_all( + "DELETE db.example_table tb WHERE tb.x = 1", + write={ + "bigquery": "DELETE db.example_table AS tb WHERE tb.x = 1", + "presto": "DELETE FROM db.example_table WHERE x = 1", + }, + ) + self.validate_all( + "DELETE db.example_table AS tb WHERE tb.x = 1", + write={ + "bigquery": "DELETE db.example_table AS tb WHERE tb.x = 1", + "presto": "DELETE FROM db.example_table WHERE x = 1", + }, + ) + self.validate_all( + "DELETE FROM db.example_table WHERE x = 1", + write={ + "bigquery": "DELETE FROM db.example_table WHERE x = 1", + "presto": "DELETE FROM db.example_table WHERE x = 1", + }, + ) + self.validate_all( + "DELETE FROM db.example_table tb WHERE tb.x = 1", + write={ + "bigquery": "DELETE FROM db.example_table AS tb WHERE tb.x = 1", + "presto": "DELETE FROM db.example_table WHERE x = 1", + }, + ) + self.validate_all( + "DELETE FROM db.example_table AS tb WHERE tb.x = 1", + write={ + "bigquery": "DELETE FROM db.example_table AS tb WHERE tb.x = 1", + "presto": "DELETE FROM db.example_table WHERE x = 1", + }, + ) + self.validate_all( + "DELETE FROM db.example_table AS tb WHERE example_table.x = 1", + write={ + "bigquery": "DELETE FROM db.example_table AS tb WHERE example_table.x = 1", + "presto": "DELETE FROM db.example_table WHERE x = 1", + }, + ) + self.validate_all( + "DELETE FROM db.example_table WHERE example_table.x = 1", + write={ + "bigquery": "DELETE FROM db.example_table WHERE example_table.x = 1", + "presto": "DELETE FROM db.example_table WHERE example_table.x = 1", + }, + ) + self.validate_all( + "DELETE FROM db.t1 AS t1 WHERE NOT t1.c IN (SELECT db.t2.c FROM db.t2)", + write={ + "bigquery": "DELETE FROM db.t1 AS t1 WHERE NOT t1.c IN (SELECT db.t2.c FROM db.t2)", + "presto": "DELETE FROM db.t1 WHERE NOT c IN (SELECT c FROM db.t2)", + }, + ) + self.validate_all( "SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])", write={ "bigquery": "SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])", @@ -1464,3 +1660,14 @@ OPTIONS ( with self.assertRaises(ParseError): transpile("SELECT JSON_OBJECT('a', 1, 'b') AS json_data", read="bigquery") + + def test_mod(self): + for sql in ("MOD(a, b)", "MOD('a', b)", "MOD(5, 2)", "MOD((a + 1) * 8, 5 - 1)"): + with self.subTest(f"Testing BigQuery roundtrip of modulo operation: {sql}"): + self.validate_identity(sql) + + self.validate_identity("SELECT MOD((SELECT 1), 2)") + self.validate_identity( + "MOD((a + 1), b)", + "MOD(a + 1, b)", + ) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index af552d1..15adda8 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -425,6 +425,21 @@ class TestClickhouse(Validator): }, ) + self.validate_identity("ALTER TABLE visits DROP PARTITION 201901") + self.validate_identity("ALTER TABLE visits DROP PARTITION ALL") + self.validate_identity( + "ALTER TABLE visits DROP PARTITION tuple(toYYYYMM(toDate('2019-01-25')))" + ) + self.validate_identity("ALTER TABLE visits DROP PARTITION ID '201901'") + + self.validate_identity("ALTER TABLE visits REPLACE PARTITION 201901 FROM visits_tmp") + self.validate_identity("ALTER TABLE visits REPLACE PARTITION ALL FROM visits_tmp") + self.validate_identity( + "ALTER TABLE visits REPLACE PARTITION tuple(toYYYYMM(toDate('2019-01-25'))) FROM visits_tmp" + ) + self.validate_identity("ALTER TABLE visits REPLACE PARTITION ID '201901' FROM visits_tmp") + self.validate_identity("ALTER TABLE visits ON CLUSTER test_cluster DROP COLUMN col1") + def test_cte(self): self.validate_identity("WITH 'x' AS foo SELECT foo") self.validate_identity("WITH ['c'] AS field_names SELECT field_names") @@ -829,6 +844,9 @@ LIFETIME(MIN 0 MAX 0)""", self.validate_identity( "CREATE TABLE t1 (a String EPHEMERAL, b String EPHEMERAL func(), c String MATERIALIZED func(), d String ALIAS func()) ENGINE=TinyLog()" ) + self.validate_identity( + "CREATE TABLE t (a String, b String, c UInt64, PROJECTION p1 (SELECT a, sum(c) GROUP BY a, b), PROJECTION p2 (SELECT b, sum(c) GROUP BY b)) ENGINE=MergeTree()" + ) def test_agg_functions(self): def extract_agg_func(query): @@ -856,3 +874,8 @@ LIFETIME(MIN 0 MAX 0)""", ) parse_one("foobar(x)").assert_is(exp.Anonymous) + + def test_drop_on_cluster(self): + for creatable in ("DATABASE", "TABLE", "VIEW", "DICTIONARY", "FUNCTION"): + with self.subTest(f"Test DROP {creatable} ON CLUSTER"): + self.validate_identity(f"DROP {creatable} test ON CLUSTER test_cluster") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index dda0eb2..77306dc 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1019,6 +1019,19 @@ class TestDialect(Validator): }, ) + self.validate_all( + "TIMESTAMP_TRUNC(x, DAY, 'UTC')", + write={ + "": "TIMESTAMP_TRUNC(x, DAY, 'UTC')", + "duckdb": "DATE_TRUNC('DAY', x)", + "presto": "DATE_TRUNC('DAY', x)", + "postgres": "DATE_TRUNC('DAY', x, 'UTC')", + "snowflake": "DATE_TRUNC('DAY', x)", + "databricks": "DATE_TRUNC('DAY', x)", + "clickhouse": "DATE_TRUNC('DAY', x, 'UTC')", + }, + ) + for unit in ("DAY", "MONTH", "YEAR"): self.validate_all( f"{unit}(x)", @@ -1681,6 +1694,26 @@ class TestDialect(Validator): "tsql": "CAST(a AS FLOAT) / b", }, ) + self.validate_all( + "MOD(8 - 1 + 7, 7)", + write={ + "": "(8 - 1 + 7) % 7", + "hive": "(8 - 1 + 7) % 7", + "presto": "(8 - 1 + 7) % 7", + "snowflake": "(8 - 1 + 7) % 7", + "bigquery": "MOD(8 - 1 + 7, 7)", + }, + ) + self.validate_all( + "MOD(a, b + 1)", + write={ + "": "a % (b + 1)", + "hive": "a % (b + 1)", + "presto": "a % (b + 1)", + "snowflake": "a % (b + 1)", + "bigquery": "MOD(a, b + 1)", + }, + ) def test_typeddiv(self): typed_div = exp.Div(this=exp.column("a"), expression=exp.column("b"), typed=True) @@ -2186,6 +2219,8 @@ SELECT ) def test_cast_to_user_defined_type(self): + self.validate_identity("CAST(x AS some_udt(1234))") + self.validate_all( "CAST(x AS some_udt)", write={ @@ -2214,6 +2249,18 @@ SELECT "tsql": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", }, ) + self.validate_all( + 'SELECT "user id", some_id, 1 as other_id, 2 as "2 nd id" FROM t QUALIFY COUNT(*) OVER () > 1', + write={ + "duckdb": 'SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id" FROM t QUALIFY COUNT(*) OVER () > 1', + "snowflake": 'SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id" FROM t QUALIFY COUNT(*) OVER () > 1', + "clickhouse": 'SELECT "user id", some_id, other_id, "2 nd id" FROM (SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id", COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1', + "mysql": "SELECT `user id`, some_id, other_id, `2 nd id` FROM (SELECT `user id`, some_id, 1 AS other_id, 2 AS `2 nd id`, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", + "oracle": 'SELECT "user id", some_id, other_id, "2 nd id" FROM (SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id", COUNT(*) OVER () AS _w FROM t) _t WHERE _w > 1', + "postgres": 'SELECT "user id", some_id, other_id, "2 nd id" FROM (SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id", COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1', + "tsql": "SELECT [user id], some_id, other_id, [2 nd id] FROM (SELECT [user id] AS [user id], some_id AS some_id, 1 AS other_id, 2 AS [2 nd id], COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", + }, + ) def test_nested_ctes(self): self.validate_all( @@ -2249,7 +2296,7 @@ SELECT "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq", write={ "duckdb": "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq", - "tsql": "WITH t2(y) AS (SELECT 2), t1(x) AS (SELECT 1) SELECT * FROM (SELECT y AS y FROM t2) AS subq", + "tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y AS y FROM t2) AS subq", }, ) self.validate_all( @@ -2273,6 +2320,59 @@ FROM c""", "hive": "WITH a1 AS (SELECT 1), a2 AS (SELECT 2), b AS (SELECT * FROM a1, a2), c AS (SELECT * FROM b) SELECT * FROM c", }, ) + self.validate_all( + """ +WITH subquery1 AS ( + WITH tmp AS ( + SELECT + * + FROM table0 + ) + SELECT + * + FROM tmp +), subquery2 AS ( + WITH tmp2 AS ( + SELECT + * + FROM table1 + WHERE + a IN subquery1 + ) + SELECT + * + FROM tmp2 +) +SELECT + * +FROM subquery2 +""", + write={ + "hive": """WITH tmp AS ( + SELECT + * + FROM table0 +), subquery1 AS ( + SELECT + * + FROM tmp +), tmp2 AS ( + SELECT + * + FROM table1 + WHERE + a IN subquery1 +), subquery2 AS ( + SELECT + * + FROM tmp2 +) +SELECT + * +FROM subquery2""", + }, + pretty=True, + ) def test_unsupported_null_ordering(self): # We'll transpile a portable query from the following dialects to MySQL / T-SQL, which @@ -2372,7 +2472,7 @@ FROM c""", "hive": UnsupportedError, "mysql": UnsupportedError, "oracle": UnsupportedError, - "postgres": "(ARRAY_LENGTH(arr, 1) = 0 OR ARRAY_LENGTH(ARRAY(SELECT x FROM UNNEST(arr) AS _t(x) WHERE pred), 1) <> 0)", + "postgres": "(ARRAY_LENGTH(arr, 1) = 0 OR ARRAY_LENGTH(ARRAY(SELECT x FROM UNNEST(arr) AS _t0(x) WHERE pred), 1) <> 0)", "presto": "ANY_MATCH(arr, x -> pred)", "redshift": UnsupportedError, "snowflake": UnsupportedError, diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 2d0af13..03dea93 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -272,6 +272,14 @@ class TestDuckDB(Validator): "SELECT * FROM x LEFT JOIN UNNEST(y)", "SELECT * FROM x LEFT JOIN UNNEST(y) ON TRUE" ) self.validate_identity( + "SELECT JSON_EXTRACT_STRING(c, '$.k1') = 'v1'", + "SELECT (c ->> '$.k1') = 'v1'", + ) + self.validate_identity( + "SELECT JSON_EXTRACT(c, '$.k1') = 'v1'", + "SELECT (c -> '$.k1') = 'v1'", + ) + self.validate_identity( """SELECT '{"foo": [1, 2, 3]}' -> 'foo' -> 0""", """SELECT '{"foo": [1, 2, 3]}' -> '$.foo' -> '$[0]'""", ) @@ -734,6 +742,28 @@ class TestDuckDB(Validator): ) self.validate_identity("COPY lineitem (l_orderkey) TO 'orderkey.tbl' WITH (DELIMITER '|')") + self.validate_all( + "VARIANCE(a)", + write={ + "duckdb": "VARIANCE(a)", + "clickhouse": "varSamp(a)", + }, + ) + self.validate_all( + "STDDEV(a)", + write={ + "duckdb": "STDDEV(a)", + "clickhouse": "stddevSamp(a)", + }, + ) + self.validate_all( + "DATE_TRUNC('DAY', x)", + write={ + "duckdb": "DATE_TRUNC('DAY', x)", + "clickhouse": "DATE_TRUNC('DAY', x)", + }, + ) + def test_array_index(self): with self.assertLogs(helper_logger) as cm: self.validate_all( @@ -803,7 +833,7 @@ class TestDuckDB(Validator): write={"duckdb": "SELECT (90 * INTERVAL '1' DAY)"}, ) self.validate_all( - "SELECT ((DATE_TRUNC('DAY', CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP) + INTERVAL (0 - (DAYOFWEEK(CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7 % 7) DAY) + (7 * INTERVAL (-5) DAY))) AS t1", + "SELECT ((DATE_TRUNC('DAY', CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP) + INTERVAL (0 - ((DAYOFWEEK(CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7) % 7) DAY) + (7 * INTERVAL (-5) DAY))) AS t1", read={ "presto": "SELECT ((DATE_ADD('week', -5, DATE_TRUNC('DAY', DATE_ADD('day', (0 - MOD((DAY_OF_WEEK(CAST(CAST(DATE_TRUNC('DAY', NOW()) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7, 7)), CAST(CAST(DATE_TRUNC('DAY', NOW()) AS DATE) AS TIMESTAMP)))))) AS t1", }, @@ -827,6 +857,9 @@ class TestDuckDB(Validator): "duckdb": "EPOCH_MS(x)", "presto": "FROM_UNIXTIME(CAST(x AS DOUBLE) / POW(10, 3))", "spark": "TIMESTAMP_MILLIS(x)", + "clickhouse": "fromUnixTimestamp64Milli(CAST(x AS Int64))", + "postgres": "TO_TIMESTAMP(CAST(x AS DOUBLE PRECISION) / 10 ^ 3)", + "mysql": "FROM_UNIXTIME(x / POWER(10, 3))", }, ) self.validate_all( @@ -892,11 +925,11 @@ class TestDuckDB(Validator): def test_sample(self): self.validate_identity( "SELECT * FROM tbl USING SAMPLE 5", - "SELECT * FROM tbl USING SAMPLE (5 ROWS)", + "SELECT * FROM tbl USING SAMPLE RESERVOIR (5 ROWS)", ) self.validate_identity( "SELECT * FROM tbl USING SAMPLE 10%", - "SELECT * FROM tbl USING SAMPLE (10 PERCENT)", + "SELECT * FROM tbl USING SAMPLE SYSTEM (10 PERCENT)", ) self.validate_identity( "SELECT * FROM tbl USING SAMPLE 10 PERCENT (bernoulli)", @@ -920,14 +953,13 @@ class TestDuckDB(Validator): ) self.validate_all( - "SELECT * FROM example TABLESAMPLE (3 ROWS) REPEATABLE (82)", + "SELECT * FROM example TABLESAMPLE RESERVOIR (3 ROWS) REPEATABLE (82)", read={ "duckdb": "SELECT * FROM example TABLESAMPLE (3) REPEATABLE (82)", "snowflake": "SELECT * FROM example SAMPLE (3 ROWS) SEED (82)", }, write={ - "duckdb": "SELECT * FROM example TABLESAMPLE (3 ROWS) REPEATABLE (82)", - "snowflake": "SELECT * FROM example TABLESAMPLE (3 ROWS) SEED (82)", + "duckdb": "SELECT * FROM example TABLESAMPLE RESERVOIR (3 ROWS) REPEATABLE (82)", }, ) @@ -946,10 +978,6 @@ class TestDuckDB(Validator): self.validate_identity("CAST(x AS DOUBLE)") self.validate_identity("CAST(x AS DECIMAL(15, 4))") self.validate_identity("CAST(x AS STRUCT(number BIGINT))") - self.validate_identity( - "CAST(ROW(1, ROW(1)) AS STRUCT(number BIGINT, row STRUCT(number BIGINT)))" - ) - self.validate_identity("CAST(x AS INT64)", "CAST(x AS BIGINT)") self.validate_identity("CAST(x AS INT32)", "CAST(x AS INT)") self.validate_identity("CAST(x AS INT16)", "CAST(x AS SMALLINT)") @@ -969,6 +997,32 @@ class TestDuckDB(Validator): self.validate_identity("CAST(x AS BINARY)", "CAST(x AS BLOB)") self.validate_identity("CAST(x AS VARBINARY)", "CAST(x AS BLOB)") self.validate_identity("CAST(x AS LOGICAL)", "CAST(x AS BOOLEAN)") + self.validate_identity( + "CAST(ROW(1, ROW(1)) AS STRUCT(number BIGINT, row STRUCT(number BIGINT)))" + ) + self.validate_identity( + "123::CHARACTER VARYING", + "CAST(123 AS TEXT)", + ) + self.validate_identity( + "CAST([[STRUCT_PACK(a := 1)]] AS STRUCT(a BIGINT)[][])", + "CAST([[{'a': 1}]] AS STRUCT(a BIGINT)[][])", + ) + self.validate_identity( + "CAST([STRUCT_PACK(a := 1)] AS STRUCT(a BIGINT)[])", + "CAST([{'a': 1}] AS STRUCT(a BIGINT)[])", + ) + + self.validate_all( + "CAST(x AS DECIMAL(38, 0))", + read={ + "snowflake": "CAST(x AS NUMBER)", + "duckdb": "CAST(x AS DECIMAL(38, 0))", + }, + write={ + "snowflake": "CAST(x AS DECIMAL(38, 0))", + }, + ) self.validate_all( "CAST(x AS NUMERIC)", write={ @@ -994,12 +1048,6 @@ class TestDuckDB(Validator): }, ) self.validate_all( - "123::CHARACTER VARYING", - write={ - "duckdb": "CAST(123 AS TEXT)", - }, - ) - self.validate_all( "cast([[1]] as int[][])", write={ "duckdb": "CAST([[1]] AS INT[][])", @@ -1007,7 +1055,10 @@ class TestDuckDB(Validator): }, ) self.validate_all( - "CAST(x AS DATE) + INTERVAL (7 * -1) DAY", read={"spark": "DATE_SUB(x, 7)"} + "CAST(x AS DATE) + INTERVAL (7 * -1) DAY", + read={ + "spark": "DATE_SUB(x, 7)", + }, ) self.validate_all( "TRY_CAST(1 AS DOUBLE)", @@ -1034,18 +1085,6 @@ class TestDuckDB(Validator): "snowflake": "CAST(COL AS ARRAY(BIGINT))", }, ) - self.validate_all( - "CAST([STRUCT_PACK(a := 1)] AS STRUCT(a BIGINT)[])", - write={ - "duckdb": "CAST([{'a': 1}] AS STRUCT(a BIGINT)[])", - }, - ) - self.validate_all( - "CAST([[STRUCT_PACK(a := 1)]] AS STRUCT(a BIGINT)[][])", - write={ - "duckdb": "CAST([[{'a': 1}]] AS STRUCT(a BIGINT)[][])", - }, - ) def test_bool_or(self): self.validate_all( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 53e2dab..84fb3c2 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -149,6 +149,9 @@ class TestMySQL(Validator): "SELECT * FROM t1, t2, t3 FOR SHARE OF t1 NOWAIT FOR UPDATE OF t2, t3 SKIP LOCKED" ) self.validate_identity( + "REPLACE INTO table SELECT id FROM table2 WHERE cnt > 100", check_command_warning=True + ) + self.validate_identity( """SELECT * FROM foo WHERE 3 MEMBER OF(info->'$.value')""", """SELECT * FROM foo WHERE 3 MEMBER OF(JSON_EXTRACT(info, '$.value'))""", ) @@ -608,6 +611,16 @@ class TestMySQL(Validator): def test_mysql(self): self.validate_all( + "SELECT CONCAT('11', '22')", + read={ + "postgres": "SELECT '11' || '22'", + }, + write={ + "mysql": "SELECT CONCAT('11', '22')", + "postgres": "SELECT CONCAT('11', '22')", + }, + ) + self.validate_all( "SELECT department, GROUP_CONCAT(name) AS employee_names FROM data GROUP BY department", read={ "postgres": "SELECT department, array_agg(name) AS employee_names FROM data GROUP BY department", diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 6b6117e..a8a6c12 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -733,6 +733,13 @@ class TestPostgres(Validator): self.validate_identity("TRUNCATE TABLE t1 RESTRICT") self.validate_identity("TRUNCATE TABLE t1 CONTINUE IDENTITY CASCADE") self.validate_identity("TRUNCATE TABLE t1 RESTART IDENTITY RESTRICT") + self.validate_identity("ALTER TABLE t1 SET LOGGED") + self.validate_identity("ALTER TABLE t1 SET UNLOGGED") + self.validate_identity("ALTER TABLE t1 SET WITHOUT CLUSTER") + self.validate_identity("ALTER TABLE t1 SET WITHOUT OIDS") + self.validate_identity("ALTER TABLE t1 SET ACCESS METHOD method") + self.validate_identity("ALTER TABLE t1 SET TABLESPACE tablespace") + self.validate_identity("ALTER TABLE t1 SET (fillfactor = 5, autovacuum_enabled = TRUE)") self.validate_identity( "CREATE TABLE t (vid INT NOT NULL, CONSTRAINT ht_vid_nid_fid_idx EXCLUDE (INT4RANGE(vid, nid) WITH &&, INT4RANGE(fid, fid, '[]') WITH &&))" ) @@ -798,23 +805,21 @@ class TestPostgres(Validator): "CREATE TABLE test (x TIMESTAMP[][])", ) self.validate_identity( - "CREATE FUNCTION add(INT, INT) RETURNS INT SET search_path TO 'public' AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE", - check_command_warning=True, + "CREATE FUNCTION add(integer, integer) RETURNS INT LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT AS 'select $1 + $2;'", ) self.validate_identity( - "CREATE FUNCTION x(INT) RETURNS INT SET foo FROM CURRENT", - check_command_warning=True, + "CREATE FUNCTION add(integer, integer) RETURNS INT LANGUAGE SQL IMMUTABLE STRICT AS 'select $1 + $2;'" ) self.validate_identity( - "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT", + "CREATE FUNCTION add(INT, INT) RETURNS INT SET search_path TO 'public' AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE", check_command_warning=True, ) self.validate_identity( - "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE CALLED ON NULL INPUT", + "CREATE FUNCTION x(INT) RETURNS INT SET foo FROM CURRENT", check_command_warning=True, ) self.validate_identity( - "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE STRICT", + "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE CALLED ON NULL INPUT", check_command_warning=True, ) self.validate_identity( diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index e227ea9..3925e32 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -498,7 +498,25 @@ FROM ( }, ) - def test_rename_table(self): + def test_alter_table(self): + self.validate_identity("ALTER TABLE s.t ALTER SORTKEY (c)") + self.validate_identity("ALTER TABLE t ALTER SORTKEY AUTO") + self.validate_identity("ALTER TABLE t ALTER SORTKEY NONE") + self.validate_identity("ALTER TABLE t ALTER SORTKEY (c1, c2)") + self.validate_identity("ALTER TABLE t ALTER SORTKEY (c1, c2)") + self.validate_identity("ALTER TABLE t ALTER COMPOUND SORTKEY (c1, c2)") + self.validate_identity("ALTER TABLE t ALTER DISTSTYLE ALL") + self.validate_identity("ALTER TABLE t ALTER DISTSTYLE EVEN") + self.validate_identity("ALTER TABLE t ALTER DISTSTYLE AUTO") + self.validate_identity("ALTER TABLE t ALTER DISTSTYLE KEY DISTKEY c") + self.validate_identity("ALTER TABLE t SET TABLE PROPERTIES ('a' = '5', 'b' = 'c')") + self.validate_identity("ALTER TABLE t SET LOCATION 's3://bucket/folder/'") + self.validate_identity("ALTER TABLE t SET FILE FORMAT AVRO") + self.validate_identity( + "ALTER TABLE t ALTER DISTKEY c", + "ALTER TABLE t ALTER DISTSTYLE KEY DISTKEY c", + ) + self.validate_all( "ALTER TABLE db.t1 RENAME TO db.t2", write={ diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index dae8355..d3c47af 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -10,14 +10,14 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): - self.validate_identity( - "MERGE INTO my_db AS ids USING (SELECT new_id FROM my_model WHERE NOT col IS NULL) AS new_ids ON ids.type = new_ids.type AND ids.source = new_ids.source WHEN NOT MATCHED THEN INSERT VALUES (new_ids.new_id)" - ) - self.validate_identity("ALTER TABLE table1 CLUSTER BY (name DESC)") - self.validate_identity( - "INSERT OVERWRITE TABLE t SELECT 1", "INSERT OVERWRITE INTO t SELECT 1" + self.validate_all( + "ARRAY_CONSTRUCT_COMPACT(1, null, 2)", + write={ + "spark": "ARRAY_COMPACT(ARRAY(1, NULL, 2))", + "snowflake": "ARRAY_CONSTRUCT_COMPACT(1, NULL, 2)", + }, ) - self.validate_identity("SELECT rename, replace") + expr = parse_one("SELECT APPROX_TOP_K(C4, 3, 5) FROM t") expr.selects[0].assert_is(exp.AggFunc) self.assertEqual(expr.sql(dialect="snowflake"), "SELECT APPROX_TOP_K(C4, 3, 5) FROM t") @@ -43,6 +43,8 @@ WHERE )""", ) + self.validate_identity("ALTER TABLE table1 CLUSTER BY (name DESC)") + self.validate_identity("SELECT rename, replace") self.validate_identity("SELECT TIMEADD(HOUR, 2, CAST('09:05:03' AS TIME))") self.validate_identity("SELECT CAST(OBJECT_CONSTRUCT('a', 1) AS MAP(VARCHAR, INT))") self.validate_identity("SELECT CAST(OBJECT_CONSTRUCT('a', 1) AS OBJECT(a CHAR NOT NULL))") @@ -86,20 +88,19 @@ WHERE self.validate_identity("a$b") # valid snowflake identifier self.validate_identity("SELECT REGEXP_LIKE(a, b, c)") self.validate_identity("CREATE TABLE foo (bar FLOAT AUTOINCREMENT START 0 INCREMENT 1)") - self.validate_identity("ALTER TABLE IF EXISTS foo SET TAG a = 'a', b = 'b', c = 'c'") - self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c") - self.validate_identity("ALTER TABLE foo SET COMMENT = 'bar'") - self.validate_identity("ALTER TABLE foo SET CHANGE_TRACKING = FALSE") - self.validate_identity("ALTER TABLE foo UNSET DATA_RETENTION_TIME_IN_DAYS, CHANGE_TRACKING") self.validate_identity("COMMENT IF EXISTS ON TABLE foo IS 'bar'") self.validate_identity("SELECT CONVERT_TIMEZONE('UTC', 'America/Los_Angeles', col)") self.validate_identity("ALTER TABLE a SWAP WITH b") self.validate_identity("SELECT MATCH_CONDITION") + self.validate_identity("SELECT * REPLACE (CAST(col AS TEXT) AS scol) FROM t") self.validate_identity( - 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage' + "MERGE INTO my_db AS ids USING (SELECT new_id FROM my_model WHERE NOT col IS NULL) AS new_ids ON ids.type = new_ids.type AND ids.source = new_ids.source WHEN NOT MATCHED THEN INSERT VALUES (new_ids.new_id)" + ) + self.validate_identity( + "INSERT OVERWRITE TABLE t SELECT 1", "INSERT OVERWRITE INTO t SELECT 1" ) self.validate_identity( - "SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE (0.1)" + 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage' ) self.validate_identity( "SELECT * FROM DATA AS DATA_L ASOF JOIN DATA AS DATA_R MATCH_CONDITION (DATA_L.VAL > DATA_R.VAL) ON DATA_L.ID = DATA_R.ID" @@ -146,10 +147,6 @@ WHERE "SELECT TIMESTAMP_FROM_PARTS(d, t)", ) self.validate_identity( - "SELECT user_id, value FROM table_name SAMPLE ($s) SEED (0)", - "SELECT user_id, value FROM table_name TABLESAMPLE ($s) SEED (0)", - ) - self.validate_identity( "SELECT v:attr[0].name FROM vartab", "SELECT GET_PATH(v, 'attr[0].name') FROM vartab", ) @@ -236,6 +233,38 @@ WHERE "CAST(x AS NCHAR VARYING)", "CAST(x AS VARCHAR)", ) + self.validate_identity( + "CREATE OR REPLACE TEMPORARY TABLE x (y NUMBER IDENTITY(0, 1))", + "CREATE OR REPLACE TEMPORARY TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)", + ) + self.validate_identity( + "CREATE TEMPORARY TABLE x (y NUMBER AUTOINCREMENT(0, 1))", + "CREATE TEMPORARY TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)", + ) + self.validate_identity( + "CREATE TABLE x (y NUMBER IDENTITY START 0 INCREMENT 1)", + "CREATE TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)", + ) + self.validate_identity( + "ALTER TABLE foo ADD COLUMN id INT identity(1, 1)", + "ALTER TABLE foo ADD COLUMN id INT AUTOINCREMENT START 1 INCREMENT 1", + ) + self.validate_identity( + "SELECT DAYOFWEEK('2016-01-02T23:39:20.123-07:00'::TIMESTAMP)", + "SELECT DAYOFWEEK(CAST('2016-01-02T23:39:20.123-07:00' AS TIMESTAMP))", + ) + self.validate_identity( + "SELECT * FROM xxx WHERE col ilike '%Don''t%'", + "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'", + ) + self.validate_identity( + "SELECT * EXCLUDE a, b FROM xxx", + "SELECT * EXCLUDE (a), b FROM xxx", + ) + self.validate_identity( + "SELECT * RENAME a AS b, c AS d FROM xxx", + "SELECT * RENAME (a AS b), c AS d FROM xxx", + ) self.validate_all( "OBJECT_CONSTRUCT_KEEP_NULL('key_1', 'one', 'key_2', NULL)", @@ -250,18 +279,6 @@ WHERE }, ) self.validate_all( - "SELECT * FROM example TABLESAMPLE (3) SEED (82)", - read={ - "databricks": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)", - "duckdb": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)", - }, - write={ - "databricks": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)", - "duckdb": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)", - "snowflake": "SELECT * FROM example TABLESAMPLE (3) SEED (82)", - }, - ) - self.validate_all( "SELECT TIME_FROM_PARTS(12, 34, 56, 987654321)", write={ "duckdb": "SELECT MAKE_TIME(12, 34, 56 + (987654321 / 1000000000.0))", @@ -568,60 +585,12 @@ WHERE }, ) self.validate_all( - "CREATE OR REPLACE TEMPORARY TABLE x (y NUMBER IDENTITY(0, 1))", - write={ - "snowflake": "CREATE OR REPLACE TEMPORARY TABLE x (y DECIMAL AUTOINCREMENT START 0 INCREMENT 1)", - }, - ) - self.validate_all( - "CREATE TEMPORARY TABLE x (y NUMBER AUTOINCREMENT(0, 1))", - write={ - "snowflake": "CREATE TEMPORARY TABLE x (y DECIMAL AUTOINCREMENT START 0 INCREMENT 1)", - }, - ) - self.validate_all( - "CREATE TABLE x (y NUMBER IDENTITY START 0 INCREMENT 1)", - write={ - "snowflake": "CREATE TABLE x (y DECIMAL AUTOINCREMENT START 0 INCREMENT 1)", - }, - ) - self.validate_all( - "ALTER TABLE foo ADD COLUMN id INT identity(1, 1)", - write={ - "snowflake": "ALTER TABLE foo ADD COLUMN id INT AUTOINCREMENT START 1 INCREMENT 1", - }, - ) - self.validate_all( - "SELECT DAYOFWEEK('2016-01-02T23:39:20.123-07:00'::TIMESTAMP)", - write={ - "snowflake": "SELECT DAYOFWEEK(CAST('2016-01-02T23:39:20.123-07:00' AS TIMESTAMP))", - }, - ) - self.validate_all( - "SELECT * FROM xxx WHERE col ilike '%Don''t%'", - write={ - "snowflake": "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'", - }, - ) - self.validate_all( - "SELECT * EXCLUDE a, b FROM xxx", - write={ - "snowflake": "SELECT * EXCLUDE (a), b FROM xxx", - }, - ) - self.validate_all( - "SELECT * RENAME a AS b, c AS d FROM xxx", - write={ - "snowflake": "SELECT * RENAME (a AS b), c AS d FROM xxx", - }, - ) - self.validate_all( - "SELECT * EXCLUDE (a, b) RENAME (c AS d, E AS F) FROM xxx", + "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx", read={ "duckdb": "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx", }, write={ - "snowflake": "SELECT * EXCLUDE (a, b) RENAME (c AS d, E AS F) FROM xxx", + "snowflake": "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx", "duckdb": "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx", }, ) @@ -831,6 +800,18 @@ WHERE "snowflake": "SELECT LISTAGG(col1, ', ') WITHIN GROUP (ORDER BY col2) FROM t", }, ) + self.validate_all( + "SELECT APPROX_PERCENTILE(a, 0.5) FROM t", + read={ + "trino": "SELECT APPROX_PERCENTILE(a, 1, 0.5, 0.001) FROM t", + "presto": "SELECT APPROX_PERCENTILE(a, 1, 0.5, 0.001) FROM t", + }, + write={ + "trino": "SELECT APPROX_PERCENTILE(a, 0.5) FROM t", + "presto": "SELECT APPROX_PERCENTILE(a, 0.5) FROM t", + "snowflake": "SELECT APPROX_PERCENTILE(a, 0.5) FROM t", + }, + ) def test_null_treatment(self): self.validate_all( @@ -881,6 +862,10 @@ WHERE self.validate_identity("SELECT * FROM @namespace.%table_name/path/to/file.json.gz") self.validate_identity("SELECT * FROM '@external/location' (FILE_FORMAT => 'path.to.csv')") self.validate_identity("PUT file:///dir/tmp.csv @%table", check_command_warning=True) + self.validate_identity("SELECT * FROM (SELECT a FROM @foo)") + self.validate_identity( + "SELECT * FROM (SELECT * FROM '@external/location' (FILE_FORMAT => 'path.to.csv'))" + ) self.validate_identity( "SELECT * FROM @foo/bar (FILE_FORMAT => ds_sandbox.test.my_csv_format, PATTERN => 'test') AS bla" ) @@ -902,18 +887,27 @@ WHERE def test_sample(self): self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)") - self.validate_identity("SELECT * FROM testtable TABLESAMPLE (100)") self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)") - self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)") self.validate_identity( - "SELECT i, j FROM table1 AS t1 INNER JOIN table2 AS t2 TABLESAMPLE (50) WHERE t2.j = t1.i" + "SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE BERNOULLI (0.1)" + ) + self.validate_identity( + "SELECT i, j FROM table1 AS t1 INNER JOIN table2 AS t2 TABLESAMPLE BERNOULLI (50) WHERE t2.j = t1.i" + ) + self.validate_identity( + "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE BERNOULLI (1)" + ) + self.validate_identity( + "SELECT * FROM testtable TABLESAMPLE (10 ROWS)", + "SELECT * FROM testtable TABLESAMPLE BERNOULLI (10 ROWS)", ) self.validate_identity( - "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1)" + "SELECT * FROM testtable TABLESAMPLE (100)", + "SELECT * FROM testtable TABLESAMPLE BERNOULLI (100)", ) self.validate_identity( "SELECT * FROM testtable SAMPLE (10)", - "SELECT * FROM testtable TABLESAMPLE (10)", + "SELECT * FROM testtable TABLESAMPLE BERNOULLI (10)", ) self.validate_identity( "SELECT * FROM testtable SAMPLE ROW (0)", @@ -923,8 +917,30 @@ WHERE "SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)", "SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)", ) + self.validate_identity( + "SELECT user_id, value FROM table_name SAMPLE BERNOULLI ($s) SEED (0)", + "SELECT user_id, value FROM table_name TABLESAMPLE BERNOULLI ($s) SEED (0)", + ) self.validate_all( + "SELECT * FROM example TABLESAMPLE BERNOULLI (3) SEED (82)", + read={ + "duckdb": "SELECT * FROM example TABLESAMPLE BERNOULLI (3 PERCENT) REPEATABLE (82)", + }, + write={ + "databricks": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)", + "duckdb": "SELECT * FROM example TABLESAMPLE BERNOULLI (3 PERCENT) REPEATABLE (82)", + "snowflake": "SELECT * FROM example TABLESAMPLE BERNOULLI (3) SEED (82)", + }, + ) + self.validate_all( + "SELECT * FROM test AS _tmp TABLESAMPLE (5)", + write={ + "postgres": "SELECT * FROM test AS _tmp TABLESAMPLE BERNOULLI (5)", + "snowflake": "SELECT * FROM test AS _tmp TABLESAMPLE BERNOULLI (5)", + }, + ) + self.validate_all( """ SELECT i, j FROM @@ -933,7 +949,7 @@ WHERE table2 AS t2 SAMPLE (50) -- 50% of rows in table2 WHERE t2.j = t1.i""", write={ - "snowflake": "SELECT i, j FROM table1 AS t1 TABLESAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 TABLESAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i", + "snowflake": "SELECT i, j FROM table1 AS t1 TABLESAMPLE BERNOULLI (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 TABLESAMPLE BERNOULLI (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i", }, ) self.validate_all( @@ -945,7 +961,7 @@ WHERE self.validate_all( "SELECT * FROM (SELECT * FROM t1 join t2 on t1.a = t2.c) SAMPLE (1)", write={ - "snowflake": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1)", + "snowflake": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE BERNOULLI (1)", "spark": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1 PERCENT)", }, ) @@ -1172,7 +1188,7 @@ WHERE location=@s2/logs/ partition_type = user_specified file_format = (type = parquet)""", - "CREATE EXTERNAL TABLE et2 (col1 DATE AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL1') AS DATE)), col2 VARCHAR AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL2') AS VARCHAR)), col3 DECIMAL AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL3') AS DECIMAL))) LOCATION @s2/logs/ PARTITION BY (col1, col2, col3) partition_type=user_specified file_format=(type = parquet)", + "CREATE EXTERNAL TABLE et2 (col1 DATE AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL1') AS DATE)), col2 VARCHAR AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL2') AS VARCHAR)), col3 DECIMAL(38, 0) AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL3') AS DECIMAL(38, 0)))) LOCATION @s2/logs/ PARTITION BY (col1, col2, col3) partition_type=user_specified file_format=(type = parquet)", ) self.validate_identity("CREATE OR REPLACE VIEW foo (uid) COPY GRANTS AS (SELECT 1)") self.validate_identity("CREATE TABLE geospatial_table (id INT, g GEOGRAPHY)") @@ -1181,8 +1197,12 @@ WHERE self.validate_identity("CREATE SCHEMA mytestschema_clone CLONE testschema") self.validate_identity("CREATE TABLE IDENTIFIER('foo') (COLUMN1 VARCHAR, COLUMN2 VARCHAR)") self.validate_identity("CREATE TABLE IDENTIFIER($foo) (col1 VARCHAR, col2 VARCHAR)") + self.validate_identity("CREATE TAG cost_center ALLOWED_VALUES 'a', 'b'") self.validate_identity( - "DROP function my_udf (OBJECT(city VARCHAR, zipcode DECIMAL, val ARRAY(BOOLEAN)))" + "CREATE OR REPLACE TAG IF NOT EXISTS cost_center COMMENT='cost_center tag'" + ).this.assert_is(exp.Identifier) + self.validate_identity( + "DROP FUNCTION my_udf (OBJECT(city VARCHAR, zipcode DECIMAL(38, 0), val ARRAY(BOOLEAN)))" ) self.validate_identity( "CREATE TABLE orders_clone_restore CLONE orders AT (TIMESTAMP => TO_TIMESTAMP_TZ('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss'))" @@ -1203,13 +1223,27 @@ WHERE "CREATE ICEBERG TABLE my_iceberg_table (amount ARRAY(INT)) CATALOG='SNOWFLAKE' EXTERNAL_VOLUME='my_external_volume' BASE_LOCATION='my/relative/path/from/extvol'" ) self.validate_identity( - "CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL, val ARRAY(BOOLEAN))) RETURNS VARCHAR AS $$ SELECT 'foo' $$", - "CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL, val ARRAY(BOOLEAN))) RETURNS VARCHAR AS ' SELECT \\'foo\\' '", + """CREATE OR REPLACE FUNCTION ibis_udfs.public.object_values("obj" OBJECT) RETURNS ARRAY LANGUAGE JAVASCRIPT RETURNS NULL ON NULL INPUT AS ' return Object.values(obj) '""" + ) + self.validate_identity( + """CREATE OR REPLACE FUNCTION ibis_udfs.public.object_values("obj" OBJECT) RETURNS ARRAY LANGUAGE JAVASCRIPT STRICT AS ' return Object.values(obj) '""" + ) + self.validate_identity( + "CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL(38, 0), val ARRAY(BOOLEAN))) RETURNS VARCHAR AS $$ SELECT 'foo' $$", + "CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL(38, 0), val ARRAY(BOOLEAN))) RETURNS VARCHAR AS ' SELECT \\'foo\\' '", ) self.validate_identity( "CREATE OR REPLACE FUNCTION my_udtf(foo BOOLEAN) RETURNS TABLE(col1 ARRAY(INT)) AS $$ WITH t AS (SELECT CAST([1, 2, 3] AS ARRAY(INT)) AS c) SELECT c FROM t $$", "CREATE OR REPLACE FUNCTION my_udtf(foo BOOLEAN) RETURNS TABLE (col1 ARRAY(INT)) AS ' WITH t AS (SELECT CAST([1, 2, 3] AS ARRAY(INT)) AS c) SELECT c FROM t '", ) + self.validate_identity( + "CREATE SEQUENCE seq1 WITH START=1, INCREMENT=1 ORDER", + "CREATE SEQUENCE seq1 START=1 INCREMENT BY 1 ORDER", + ) + self.validate_identity( + "CREATE SEQUENCE seq1 WITH START=1 INCREMENT=1 ORDER", + "CREATE SEQUENCE seq1 START=1 INCREMENT=1 ORDER", + ) self.validate_all( "CREATE TABLE orders_clone CLONE orders", @@ -1435,6 +1469,9 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene ) def test_values(self): + select = exp.select("*").from_("values (map(['a'], [1]))") + self.assertEqual(select.sql("snowflake"), "SELECT * FROM (SELECT OBJECT_CONSTRUCT('a', 1))") + self.validate_all( 'SELECT "c0", "c1" FROM (VALUES (1, 2), (3, 4)) AS "t0"("c0", "c1")', read={ @@ -1832,10 +1869,7 @@ STORAGE_ALLOWED_LOCATIONS=('s3://mybucket1/path1/', 's3://mybucket2/path2/')""", self.assertEqual(expression.sql(dialect="snowflake"), "SELECT TRY_CAST(FOO() AS TEXT)") def test_copy(self): - self.validate_identity( - """COPY INTO mytable (col1, col2) FROM 's3://mybucket/data/files' FILES = ('file1', 'file2') PATTERN = 'pattern' file_format = (FORMAT_NAME = my_csv_format NULL_IF = ('str1', 'str2')) PARSE_HEADER = TRUE""", - """COPY INTO mytable (col1, col2) FROM 's3://mybucket/data/files' FILES = ('file1', 'file2') PATTERN = 'pattern' FILE_FORMAT = (FORMAT_NAME = my_csv_format NULL_IF = ('str1', 'str2')) PARSE_HEADER = TRUE""", - ) + self.validate_identity("COPY INTO test (c1) FROM (SELECT $1.c1 FROM @mystage)") self.validate_identity( """COPY INTO temp FROM @random_stage/path/ FILE_FORMAT = (TYPE = CSV FIELD_DELIMITER = '|' NULL_IF = () FIELD_OPTIONALLY_ENCLOSED_BY = '"' TIMESTAMP_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' DATE_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' BINARY_FORMAT = BASE64) VALIDATION_MODE = 'RETURN_3_ROWS'""" ) @@ -1845,3 +1879,81 @@ STORAGE_ALLOWED_LOCATIONS=('s3://mybucket1/path1/', 's3://mybucket2/path2/')""", self.validate_identity( """COPY INTO mytable FROM 'azure://myaccount.blob.core.windows.net/mycontainer/data/files' CREDENTIALS = (AZURE_SAS_TOKEN = 'token') ENCRYPTION = (TYPE = 'AZURE_CSE' MASTER_KEY = 'kPx...') FILE_FORMAT = (FORMAT_NAME = my_csv_format)""" ) + self.validate_identity( + """COPY INTO mytable (col1, col2) FROM 's3://mybucket/data/files' FILES = ('file1', 'file2') PATTERN = 'pattern' FILE_FORMAT = (FORMAT_NAME = my_csv_format NULL_IF = ('str1', 'str2')) PARSE_HEADER = TRUE""" + ) + self.validate_all( + """COPY INTO 's3://example/data.csv' + FROM EXTRA.EXAMPLE.TABLE + credentials = (x) + STORAGE_INTEGRATION = S3_INTEGRATION + FILE_FORMAT = (TYPE = CSV COMPRESSION = NONE NULL_IF = ('') FIELD_OPTIONALLY_ENCLOSED_BY = '"') + HEADER = TRUE + OVERWRITE = TRUE + SINGLE = TRUE + """, + write={ + "": """COPY INTO 's3://example/data.csv' +FROM EXTRA.EXAMPLE.TABLE +CREDENTIALS = (x) WITH ( + STORAGE_INTEGRATION S3_INTEGRATION, + FILE_FORMAT = (TYPE = CSV COMPRESSION = NONE NULL_IF = ( + '' + ) FIELD_OPTIONALLY_ENCLOSED_BY = '"'), + HEADER TRUE, + OVERWRITE TRUE, + SINGLE TRUE +)""", + "snowflake": """COPY INTO 's3://example/data.csv' +FROM EXTRA.EXAMPLE.TABLE +CREDENTIALS = (x) +STORAGE_INTEGRATION = S3_INTEGRATION +FILE_FORMAT = (TYPE = CSV COMPRESSION = NONE NULL_IF = ( + '' +) FIELD_OPTIONALLY_ENCLOSED_BY = '"') +HEADER = TRUE +OVERWRITE = TRUE +SINGLE = TRUE""", + }, + pretty=True, + ) + self.validate_all( + """COPY INTO 's3://example/data.csv' + FROM EXTRA.EXAMPLE.TABLE + credentials = (x) + STORAGE_INTEGRATION = S3_INTEGRATION + FILE_FORMAT = (TYPE = CSV COMPRESSION = NONE NULL_IF = ('') FIELD_OPTIONALLY_ENCLOSED_BY = '"') + HEADER = TRUE + OVERWRITE = TRUE + SINGLE = TRUE + """, + write={ + "": """COPY INTO 's3://example/data.csv' FROM EXTRA.EXAMPLE.TABLE CREDENTIALS = (x) WITH (STORAGE_INTEGRATION S3_INTEGRATION, FILE_FORMAT = (TYPE = CSV COMPRESSION = NONE NULL_IF = ('') FIELD_OPTIONALLY_ENCLOSED_BY = '"'), HEADER TRUE, OVERWRITE TRUE, SINGLE TRUE)""", + "snowflake": """COPY INTO 's3://example/data.csv' FROM EXTRA.EXAMPLE.TABLE CREDENTIALS = (x) STORAGE_INTEGRATION = S3_INTEGRATION FILE_FORMAT = (TYPE = CSV COMPRESSION = NONE NULL_IF = ('') FIELD_OPTIONALLY_ENCLOSED_BY = '"') HEADER = TRUE OVERWRITE = TRUE SINGLE = TRUE""", + }, + ) + + def test_querying_semi_structured_data(self): + self.validate_identity("SELECT $1") + self.validate_identity("SELECT $1.elem") + + self.validate_identity("SELECT $1:a.b", "SELECT GET_PATH($1, 'a.b')") + self.validate_identity("SELECT t.$23:a.b", "SELECT GET_PATH(t.$23, 'a.b')") + self.validate_identity("SELECT t.$17:a[0].b[0].c", "SELECT GET_PATH(t.$17, 'a[0].b[0].c')") + + def test_alter_set_unset(self): + self.validate_identity("ALTER TABLE tbl SET DATA_RETENTION_TIME_IN_DAYS=1") + self.validate_identity("ALTER TABLE tbl SET DEFAULT_DDL_COLLATION='test'") + self.validate_identity("ALTER TABLE foo SET COMMENT='bar'") + self.validate_identity("ALTER TABLE foo SET CHANGE_TRACKING=FALSE") + self.validate_identity("ALTER TABLE table1 SET TAG foo.bar = 'baz'") + self.validate_identity("ALTER TABLE IF EXISTS foo SET TAG a = 'a', b = 'b', c = 'c'") + self.validate_identity( + """ALTER TABLE tbl SET STAGE_FILE_FORMAT = (TYPE = CSV FIELD_DELIMITER = '|' NULL_IF = () FIELD_OPTIONALLY_ENCLOSED_BY = '"' TIMESTAMP_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' DATE_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' BINARY_FORMAT = BASE64)""", + ) + self.validate_identity( + """ALTER TABLE tbl SET STAGE_COPY_OPTIONS = (ON_ERROR = SKIP_FILE SIZE_LIMIT = 5 PURGE = TRUE MATCH_BY_COLUMN_NAME = CASE_SENSITIVE)""" + ) + + self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c") + self.validate_identity("ALTER TABLE foo UNSET DATA_RETENTION_TIME_IN_DAYS, CHANGE_TRACKING") diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 1538d47..45a4657 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1,5 +1,4 @@ from sqlglot import exp, parse, parse_one -from sqlglot.parser import logger as parser_logger from tests.dialects.test_dialect import Validator from sqlglot.errors import ParseError @@ -8,6 +7,8 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): + self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c") + self.validate_identity("DROP view a.b.c", "DROP VIEW b.c") self.validate_identity("ROUND(x, 1, 0)") self.validate_identity("EXEC MyProc @id=7, @name='Lochristi'", check_command_warning=True) # https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN @@ -191,16 +192,9 @@ class TestTSQL(Validator): ) self.validate_all( - """ - CREATE TABLE x( - [zip_cd] [varchar](5) NULL NOT FOR REPLICATION, - [zip_cd_mkey] [varchar](5) NOT NULL, - CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC) - WITH (PAD_INDEX = ON, STATISTICS_NORECOMPUTE = OFF) ON [INDEX] - ) ON [SECONDARY] - """, + """CREATE TABLE x ([zip_cd] VARCHAR(5) NULL NOT FOR REPLICATION, [zip_cd_mkey] VARCHAR(5) NOT NULL, CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC) WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON [INDEX]) ON [SECONDARY]""", write={ - "tsql": "CREATE TABLE x ([zip_cd] VARCHAR(5) NULL NOT FOR REPLICATION, [zip_cd_mkey] VARCHAR(5) NOT NULL, CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC) WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON [INDEX]) ON [SECONDARY]", + "tsql": "CREATE TABLE x ([zip_cd] VARCHAR(5) NULL NOT FOR REPLICATION, [zip_cd_mkey] VARCHAR(5) NOT NULL, CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC) WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON [INDEX]) ON [SECONDARY]", "spark2": "CREATE TABLE x (`zip_cd` VARCHAR(5), `zip_cd_mkey` VARCHAR(5) NOT NULL, CONSTRAINT `pk_mytable` PRIMARY KEY (`zip_cd_mkey`))", }, ) @@ -259,7 +253,7 @@ class TestTSQL(Validator): self.validate_identity("SELECT * FROM ##foo") self.validate_identity("SELECT a = 1", "SELECT 1 AS a") self.validate_identity( - "DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'", check_command_warning=True + "DECLARE @TestVariable AS VARCHAR(100) = 'Save Our Planet'", ) self.validate_identity( "SELECT a = 1 UNION ALL SELECT a = b", "SELECT 1 AS a UNION ALL SELECT b AS a" @@ -461,6 +455,7 @@ class TestTSQL(Validator): self.validate_identity("CAST(x AS IMAGE)") self.validate_identity("CAST(x AS SQL_VARIANT)") self.validate_identity("CAST(x AS BIT)") + self.validate_all( "CAST(x AS DATETIME2)", read={ @@ -488,7 +483,7 @@ class TestTSQL(Validator): }, ) - def test__types_ints(self): + def test_types_ints(self): self.validate_all( "CAST(X AS INT)", write={ @@ -521,10 +516,14 @@ class TestTSQL(Validator): self.validate_all( "CAST(X AS TINYINT)", + read={ + "duckdb": "CAST(X AS UTINYINT)", + }, write={ - "hive": "CAST(X AS TINYINT)", - "spark2": "CAST(X AS TINYINT)", - "spark": "CAST(X AS TINYINT)", + "duckdb": "CAST(X AS UTINYINT)", + "hive": "CAST(X AS SMALLINT)", + "spark2": "CAST(X AS SMALLINT)", + "spark": "CAST(X AS SMALLINT)", "tsql": "CAST(X AS TINYINT)", }, ) @@ -764,19 +763,33 @@ class TestTSQL(Validator): expression.sql(dialect="tsql"), "ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B" ) - for clusterd_keyword in ("CLUSTERED", "NONCLUSTERED"): + for clustered_keyword in ("CLUSTERED", "NONCLUSTERED"): self.validate_identity( 'CREATE TABLE "dbo"."benchmark" (' '"name" CHAR(7) NOT NULL, ' '"internal_id" VARCHAR(10) NOT NULL, ' - f'UNIQUE {clusterd_keyword} ("internal_id" ASC))', + f'UNIQUE {clustered_keyword} ("internal_id" ASC))', "CREATE TABLE [dbo].[benchmark] (" "[name] CHAR(7) NOT NULL, " "[internal_id] VARCHAR(10) NOT NULL, " - f"UNIQUE {clusterd_keyword} ([internal_id] ASC))", + f"UNIQUE {clustered_keyword} ([internal_id] ASC))", ) self.validate_identity( + "ALTER TABLE tbl SET SYSTEM_VERSIONING=ON(HISTORY_TABLE=db.tbl, DATA_CONSISTENCY_CHECK=OFF, HISTORY_RETENTION_PERIOD=5 DAYS)" + ) + self.validate_identity( + "ALTER TABLE tbl SET SYSTEM_VERSIONING=ON(HISTORY_TABLE=db.tbl, HISTORY_RETENTION_PERIOD=INFINITE)" + ) + self.validate_identity("ALTER TABLE tbl SET SYSTEM_VERSIONING=OFF") + self.validate_identity("ALTER TABLE tbl SET FILESTREAM_ON = 'test'") + self.validate_identity( + "ALTER TABLE tbl SET DATA_DELETION=ON(FILTER_COLUMN=col, RETENTION_PERIOD=5 MONTHS)" + ) + self.validate_identity("ALTER TABLE tbl SET DATA_DELETION=ON") + self.validate_identity("ALTER TABLE tbl SET DATA_DELETION=OFF") + + self.validate_identity( "CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < CURRENT_TIMESTAMP - 7 END", "CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < GETDATE() - 7 END", ) @@ -900,8 +913,7 @@ class TestTSQL(Validator): def test_udf(self): self.validate_identity( - "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)", - check_command_warning=True, + "DECLARE @DWH_DateCreated AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)", ) self.validate_identity( "CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar" @@ -973,9 +985,9 @@ WHERE BEGIN SET XACT_ABORT ON; - DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104); - DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104); - DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER); + DECLARE @DWH_DateCreated AS DATETIME = CONVERT(DATETIME, getdate(), 104); + DECLARE @DWH_DateModified DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104); + DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (CURRENT_USER()); DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER); DECLARE @SalesAmountBefore float; @@ -985,18 +997,17 @@ WHERE expected_sqls = [ "CREATE PROCEDURE [TRANSF].[SP_Merge_Sales_Real] @Loadid INTEGER, @NumberOfRows INTEGER AS BEGIN SET XACT_ABORT ON", - "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)", - "DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104)", - "DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER)", - "DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER)", - "DECLARE @SalesAmountBefore float", + "DECLARE @DWH_DateCreated AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)", + "DECLARE @DWH_DateModified AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)", + "DECLARE @DWH_IdUserCreated AS INTEGER = SUSER_ID(CURRENT_USER())", + "DECLARE @DWH_IdUserModified AS INTEGER = SUSER_ID(CURRENT_USER())", + "DECLARE @SalesAmountBefore AS FLOAT", "SELECT @SalesAmountBefore = SUM(SalesAmount) FROM TRANSF.[Pre_Merge_Sales_Real] AS S", "END", ] - with self.assertLogs(parser_logger): - for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): - self.assertEqual(expr.sql(dialect="tsql"), expected_sql) + for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): + self.assertEqual(expr.sql(dialect="tsql"), expected_sql) sql = """ CREATE PROC [dbo].[transform_proc] AS @@ -1010,14 +1021,13 @@ WHERE """ expected_sqls = [ - "CREATE PROC [dbo].[transform_proc] AS DECLARE @CurrentDate VARCHAR(20)", + "CREATE PROC [dbo].[transform_proc] AS DECLARE @CurrentDate AS VARCHAR(20)", "SET @CurrentDate = CONVERT(VARCHAR(20), GETDATE(), 120)", "CREATE TABLE [target_schema].[target_table] (a INTEGER) WITH (DISTRIBUTION=REPLICATE, HEAP)", ] - with self.assertLogs(parser_logger): - for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): - self.assertEqual(expr.sql(dialect="tsql"), expected_sql) + for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): + self.assertEqual(expr.sql(dialect="tsql"), expected_sql) def test_charindex(self): self.validate_identity( @@ -1823,3 +1833,35 @@ FROM OPENJSON(@json) WITH ( "duckdb": "WITH t1(c) AS (SELECT 1), t2 AS (SELECT CAST(c AS INTEGER) FROM t1) SELECT * FROM t2", }, ) + + def test_declare(self): + # supported cases + self.validate_identity("DECLARE @X INT", "DECLARE @X AS INTEGER") + self.validate_identity("DECLARE @X INT = 1", "DECLARE @X AS INTEGER = 1") + self.validate_identity( + "DECLARE @X INT, @Y VARCHAR(10)", "DECLARE @X AS INTEGER, @Y AS VARCHAR(10)" + ) + self.validate_identity( + "declare @X int = (select col from table where id = 1)", + "DECLARE @X AS INTEGER = (SELECT col FROM table WHERE id = 1)", + ) + self.validate_identity( + "declare @X TABLE (Id INT NOT NULL, Name VARCHAR(100) NOT NULL)", + "DECLARE @X AS TABLE (Id INTEGER NOT NULL, Name VARCHAR(100) NOT NULL)", + ) + self.validate_identity( + "declare @X TABLE (Id INT NOT NULL, constraint PK_Id primary key (Id))", + "DECLARE @X AS TABLE (Id INTEGER NOT NULL, CONSTRAINT PK_Id PRIMARY KEY (Id))", + ) + self.validate_identity( + "declare @X UserDefinedTableType", + "DECLARE @X AS UserDefinedTableType", + ) + self.validate_identity( + "DECLARE @MyTableVar TABLE (EmpID INT NOT NULL, PRIMARY KEY CLUSTERED (EmpID), UNIQUE NONCLUSTERED (EmpID), INDEX CustomNonClusteredIndex NONCLUSTERED (EmpID))", + check_command_warning=True, + ) + self.validate_identity( + "DECLARE vendor_cursor CURSOR FOR SELECT VendorID, Name FROM Purchasing.Vendor WHERE PreferredVendorStatus = 1 ORDER BY VendorID", + check_command_warning=True, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 6b742c3..13a6153 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -869,4 +869,5 @@ TRUNCATE(a, b) SELECT enum SELECT unlogged SELECT name -SELECT copy
\ No newline at end of file +SELECT copy +SELECT rollup
\ No newline at end of file diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index a357b07..5b004fa 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -1409,31 +1409,11 @@ WITH "_u_0" AS ( "store_sales"."ss_quantity" <= 80 AND "store_sales"."ss_quantity" >= 61 ) SELECT - CASE - WHEN MAX("_u_0"."_col_0") > 3672 - THEN MAX("_u_1"."_col_0") - ELSE MAX("_u_2"."_col_0") - END AS "bucket1", - CASE - WHEN MAX("_u_3"."_col_0") > 3392 - THEN MAX("_u_4"."_col_0") - ELSE MAX("_u_5"."_col_0") - END AS "bucket2", - CASE - WHEN MAX("_u_6"."_col_0") > 32784 - THEN MAX("_u_7"."_col_0") - ELSE MAX("_u_8"."_col_0") - END AS "bucket3", - CASE - WHEN MAX("_u_9"."_col_0") > 26032 - THEN MAX("_u_10"."_col_0") - ELSE MAX("_u_11"."_col_0") - END AS "bucket4", - CASE - WHEN MAX("_u_12"."_col_0") > 23982 - THEN MAX("_u_13"."_col_0") - ELSE MAX("_u_14"."_col_0") - END AS "bucket5" + CASE WHEN "_u_0"."_col_0" > 3672 THEN "_u_1"."_col_0" ELSE "_u_2"."_col_0" END AS "bucket1", + CASE WHEN "_u_3"."_col_0" > 3392 THEN "_u_4"."_col_0" ELSE "_u_5"."_col_0" END AS "bucket2", + CASE WHEN "_u_6"."_col_0" > 32784 THEN "_u_7"."_col_0" ELSE "_u_8"."_col_0" END AS "bucket3", + CASE WHEN "_u_9"."_col_0" > 26032 THEN "_u_10"."_col_0" ELSE "_u_11"."_col_0" END AS "bucket4", + CASE WHEN "_u_12"."_col_0" > 23982 THEN "_u_13"."_col_0" ELSE "_u_14"."_col_0" END AS "bucket5" FROM "reason" AS "reason" CROSS JOIN "_u_0" AS "_u_0" CROSS JOIN "_u_1" AS "_u_1" diff --git a/tests/fixtures/optimizer/unnest_subqueries.sql b/tests/fixtures/optimizer/unnest_subqueries.sql index 45e462b..a5a35b1 100644 --- a/tests/fixtures/optimizer/unnest_subqueries.sql +++ b/tests/fixtures/optimizer/unnest_subqueries.sql @@ -1,243 +1,69 @@ ---SELECT x.a > (SELECT SUM(y.a) AS b FROM y) FROM x; --------------------------------------- --- Unnest Subqueries --------------------------------------- -SELECT * -FROM x AS x -WHERE - x.a = (SELECT SUM(y.a) AS a FROM y) - AND x.a IN (SELECT y.a AS a FROM y) - AND x.a IN (SELECT y.b AS b FROM y) - AND x.a = ANY (SELECT y.a AS a FROM y) - AND x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a) - AND x.a > (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a) - AND x.a <> ANY (SELECT y.a AS a FROM y WHERE y.a = x.a) - AND x.a NOT IN (SELECT y.a AS a FROM y WHERE y.a = x.a) - AND x.a IN (SELECT y.a AS a FROM y WHERE y.b = x.a) - AND x.a < (SELECT SUM(y.a) AS a FROM y WHERE y.a = x.a and y.a = x.b and y.b <> x.d) - AND EXISTS (SELECT y.a AS a, y.b AS b FROM y WHERE x.a = y.a) - AND x.a IN (SELECT y.a AS a FROM y LIMIT 10) - AND x.a IN (SELECT y.a AS a FROM y OFFSET 10) - AND x.a IN (SELECT y.a AS a, y.b AS b FROM y) - AND x.a > ANY (SELECT y.a FROM y) - AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10) - AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10) - AND x.a > ALL (SELECT y.c FROM y WHERE y.a = x.a) - AND x.a > (SELECT COUNT(*) as d FROM y WHERE y.a = x.a) - AND x.a = SUM(SELECT 1) -- invalid statement left alone - AND x.a IN (SELECT max(y.b) AS b FROM y GROUP BY y.a) -; -SELECT - * -FROM x AS x -CROSS JOIN ( - SELECT - SUM(y.a) AS a - FROM y -) AS _u_0 -LEFT JOIN ( - SELECT - y.a AS a - FROM y - GROUP BY - y.a -) AS _u_1 - ON x.a = _u_1.a -LEFT JOIN ( - SELECT - y.b AS b - FROM y - GROUP BY - y.b -) AS _u_2 - ON x.a = _u_2.b -LEFT JOIN ( - SELECT - y.a AS a - FROM y - GROUP BY - y.a -) AS _u_3 - ON x.a = _u_3.a -LEFT JOIN ( - SELECT - SUM(y.b) AS b, - y.a AS _u_5 - FROM y - WHERE - TRUE - GROUP BY - y.a -) AS _u_4 - ON x.a = _u_4._u_5 -LEFT JOIN ( - SELECT - SUM(y.b) AS b, - y.a AS _u_7 - FROM y - WHERE - TRUE - GROUP BY - y.a -) AS _u_6 - ON x.a = _u_6._u_7 -LEFT JOIN ( - SELECT - y.a AS a - FROM y - WHERE - TRUE - GROUP BY - y.a -) AS _u_8 - ON _u_8.a = x.a -LEFT JOIN ( - SELECT - y.a AS a - FROM y - WHERE - TRUE - GROUP BY - y.a -) AS _u_9 - ON _u_9.a = x.a -LEFT JOIN ( - SELECT - ARRAY_AGG(y.a) AS a, - y.b AS _u_11 - FROM y - WHERE - TRUE - GROUP BY - y.b -) AS _u_10 - ON _u_10._u_11 = x.a -LEFT JOIN ( - SELECT - SUM(y.a) AS a, - y.a AS _u_13, - ARRAY_AGG(y.b) AS _u_14 - FROM y - WHERE - TRUE AND TRUE AND TRUE - GROUP BY - y.a -) AS _u_12 - ON _u_12._u_13 = x.a AND _u_12._u_13 = x.b -LEFT JOIN ( - SELECT - y.a AS a - FROM y - WHERE - TRUE - GROUP BY - y.a -) AS _u_15 - ON x.a = _u_15.a -LEFT JOIN ( - SELECT - ARRAY_AGG(c), - y.a AS _u_20 - FROM y - WHERE - TRUE - GROUP BY - y.a -) AS _u_19 - ON _u_19._u_20 = x.a -LEFT JOIN ( - SELECT - COUNT(*) AS d, - y.a AS _u_22 - FROM y - WHERE - TRUE - GROUP BY - y.a -) AS _u_21 - ON _u_21._u_22 = x.a -LEFT JOIN ( - SELECT - _q.b - FROM ( - SELECT - MAX(y.b) AS b - FROM y - GROUP BY - y.a - ) AS _q - GROUP BY - _q.b -) AS _u_24 - ON x.a = _u_24.b -WHERE - x.a = _u_0.a - AND NOT _u_1.a IS NULL - AND NOT _u_2.b IS NULL - AND NOT _u_3.a IS NULL - AND x.a = _u_4.b - AND x.a > _u_6.b - AND x.a = _u_8.a - AND NOT x.a = _u_9.a - AND ARRAY_ANY(_u_10.a, _x -> _x = x.a) - AND ( - x.a < _u_12.a AND ARRAY_ANY(_u_12._u_14, _x -> _x <> x.d) - ) - AND NOT _u_15.a IS NULL - AND x.a IN ( - SELECT - y.a AS a - FROM y - LIMIT 10 - ) - AND x.a IN ( - SELECT - y.a AS a - FROM y - OFFSET 10 - ) - AND x.a IN ( - SELECT - y.a AS a, - y.b AS b - FROM y - ) - AND x.a > ANY ( - SELECT - y.a - FROM y - ) - AND x.a = ( - SELECT - SUM(y.c) AS c - FROM y - WHERE - y.a = x.a - LIMIT 10 - ) - AND x.a = ( - SELECT - SUM(y.c) AS c - FROM y - WHERE - y.a = x.a - OFFSET 10 - ) - AND ARRAY_ALL(_u_19."", _x -> _x = x.a) - AND x.a > COALESCE(_u_21.d, 0) - AND x.a = SUM(SELECT - 1) /* invalid statement left alone */ - AND NOT _u_24.b IS NULL -; -SELECT - CAST(( - SELECT - x.a AS a - FROM x - ) AS TEXT) AS a; -SELECT - CAST(( - SELECT - x.a AS a - FROM x - ) AS TEXT) AS a; +SELECT * FROM x WHERE x.a = (SELECT SUM(y.a) AS a FROM y); +SELECT * FROM x CROSS JOIN (SELECT SUM(y.a) AS a FROM y) AS _u_0 WHERE x.a = _u_0.a; + +SELECT * FROM x WHERE x.a IN (SELECT y.a AS a FROM y); +SELECT * FROM x LEFT JOIN (SELECT y.a AS a FROM y GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE NOT _u_0.a IS NULL; + +SELECT * FROM x WHERE x.a IN (SELECT y.b AS b FROM y); +SELECT * FROM x LEFT JOIN (SELECT y.b AS b FROM y GROUP BY y.b) AS _u_0 ON x.a = _u_0.b WHERE NOT _u_0.b IS NULL; + +SELECT * FROM x WHERE x.a = ANY (SELECT y.a AS a FROM y); +SELECT * FROM x LEFT JOIN (SELECT y.a AS a FROM y GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE NOT _u_0.a IS NULL; + +SELECT * FROM x WHERE x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a); +SELECT * FROM x LEFT JOIN (SELECT SUM(y.b) AS b, y.a AS _u_1 FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0._u_1 WHERE x.a = _u_0.b; + +SELECT * FROM x WHERE x.a > (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a); +SELECT * FROM x LEFT JOIN (SELECT SUM(y.b) AS b, y.a AS _u_1 FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0._u_1 WHERE x.a > _u_0.b; + +SELECT * FROM x WHERE x.a <> ANY (SELECT y.a AS a FROM y WHERE y.a = x.a); +SELECT * FROM x LEFT JOIN (SELECT y.a AS a FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON _u_0.a = x.a WHERE x.a <> _u_0.a; + +SELECT * FROM x WHERE x.a NOT IN (SELECT y.a AS a FROM y WHERE y.a = x.a); +SELECT * FROM x LEFT JOIN (SELECT y.a AS a FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON _u_0.a = x.a WHERE NOT x.a = _u_0.a; + +SELECT * FROM x WHERE x.a IN (SELECT y.a AS a FROM y WHERE y.b = x.a); +SELECT * FROM x LEFT JOIN (SELECT ARRAY_AGG(y.a) AS a, y.b AS _u_1 FROM y WHERE TRUE GROUP BY y.b) AS _u_0 ON _u_0._u_1 = x.a WHERE ARRAY_ANY(_u_0.a, _x -> _x = x.a); + +SELECT * FROM x WHERE x.a < (SELECT SUM(y.a) AS a FROM y WHERE y.a = x.a and y.a = x.b and y.b <> x.d); +SELECT * FROM x LEFT JOIN (SELECT SUM(y.a) AS a, y.a AS _u_1, ARRAY_AGG(y.b) AS _u_2 FROM y WHERE TRUE AND TRUE AND TRUE GROUP BY y.a) AS _u_0 ON _u_0._u_1 = x.a AND _u_0._u_1 = x.b WHERE (x.a < _u_0.a AND ARRAY_ANY(_u_0._u_2, _x -> _x <> x.d)); + +SELECT * FROM x WHERE EXISTS (SELECT y.a AS a, y.b AS b FROM y WHERE x.a = y.a); +SELECT * FROM x LEFT JOIN (SELECT y.a AS a FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE NOT _u_0.a IS NULL; + +SELECT * FROM x WHERE x.a IN (SELECT y.a AS a FROM y LIMIT 10); +SELECT * FROM x WHERE x.a IN (SELECT y.a AS a FROM y LIMIT 10); + +SELECT * FROM x.a WHERE x.a IN (SELECT y.a AS a FROM y OFFSET 10); +SELECT * FROM x.a WHERE x.a IN (SELECT y.a AS a FROM y OFFSET 10); + +SELECT * FROM x.a WHERE x.a IN (SELECT y.a AS a, y.b AS b FROM y); +SELECT * FROM x.a WHERE x.a IN (SELECT y.a AS a, y.b AS b FROM y); + +SELECT * FROM x.a WHERE x.a > ANY (SELECT y.a FROM y); +SELECT * FROM x.a WHERE x.a > ANY (SELECT y.a FROM y); + +SELECT * FROM x WHERE x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10); +SELECT * FROM x WHERE x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10); + +SELECT * FROM x WHERE x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10); +SELECT * FROM x WHERE x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10); + +SELECT * FROM x WHERE x.a > ALL (SELECT y.c AS c FROM y WHERE y.a = x.a); +SELECT * FROM x LEFT JOIN (SELECT ARRAY_AGG(y.c) AS c, y.a AS _u_1 FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON _u_0._u_1 = x.a WHERE ARRAY_ALL(_u_0.c, _x -> x.a > _x); + +SELECT * FROM x WHERE x.a > (SELECT COUNT(*) as d FROM y WHERE y.a = x.a); +SELECT * FROM x LEFT JOIN (SELECT COUNT(*) AS d, y.a AS _u_1 FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON _u_0._u_1 = x.a WHERE x.a > COALESCE(_u_0.d, 0); + +# title: invalid statement left alone +SELECT * FROM x WHERE x.a = SUM(SELECT 1); +SELECT * FROM x WHERE x.a = SUM(SELECT 1); + +SELECT * FROM x WHERE x.a IN (SELECT max(y.b) AS b FROM y GROUP BY y.a); +SELECT * FROM x LEFT JOIN (SELECT _q.b AS b FROM (SELECT MAX(y.b) AS b FROM y GROUP BY y.a) AS _q GROUP BY _q.b) AS _u_0 ON x.a = _u_0.b WHERE NOT _u_0.b IS NULL; + +SELECT x.a > (SELECT SUM(y.a) AS b FROM y) FROM x; +SELECT x.a > _u_0.b FROM x CROSS JOIN (SELECT SUM(y.a) AS b FROM y) AS _u_0; + +SELECT (SELECT MAX(t2.c1) AS c1 FROM t2 WHERE t2.c2 = t1.c2 AND t2.c3 <= TRUNC(t1.c3)) AS c FROM t1; +SELECT _u_0.c1 AS c FROM t1 LEFT JOIN (SELECT MAX(t2.c1) AS c1, t2.c2 AS _u_1, MAX(t2.c3) AS _u_2 FROM t2 WHERE TRUE AND TRUE GROUP BY t2.c2) AS _u_0 ON _u_0._u_1 = t1.c2 WHERE _u_0._u_2 <= TRUNC(t1.c3); diff --git a/tests/test_executor.py b/tests/test_executor.py index 1eaca14..317b930 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -707,9 +707,15 @@ class TestExecutor(unittest.TestCase): ("ROUND(1.2)", 1), ("ROUND(1.2345, 2)", 1.23), ("ROUND(NULL)", None), - ("UNIXTOTIME(1659981729)", datetime.datetime(2022, 8, 8, 18, 2, 9)), + ( + "UNIXTOTIME(1659981729)", + datetime.datetime(2022, 8, 8, 18, 2, 9, tzinfo=datetime.timezone.utc), + ), ("TIMESTRTOTIME('2013-04-05 01:02:03')", datetime.datetime(2013, 4, 5, 1, 2, 3)), - ("UNIXTOTIME(40 * 365 * 86400)", datetime.datetime(2009, 12, 22, 00, 00, 00)), + ( + "UNIXTOTIME(40 * 365 * 86400)", + datetime.datetime(2009, 12, 22, 00, 00, 00, tzinfo=datetime.timezone.utc), + ), ( "STRTOTIME('08/03/2024 12:34:56', '%d/%m/%Y %H:%M:%S')", datetime.datetime(2024, 3, 8, 12, 34, 56), diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 81cfb86..1395b24 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -674,7 +674,9 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("STANDARD_HASH('hello', 'sha256')"), exp.StandardHash) self.assertIsInstance(parse_one("DATE(foo)"), exp.Date) self.assertIsInstance(parse_one("HEX(foo)"), exp.Hex) - self.assertIsInstance(parse_one("TO_HEX(foo)", read="bigquery"), exp.Hex) + self.assertIsInstance(parse_one("LOWER(HEX(foo))"), exp.LowerHex) + self.assertIsInstance(parse_one("TO_HEX(foo)", read="bigquery"), exp.LowerHex) + self.assertIsInstance(parse_one("UPPER(TO_HEX(foo))", read="bigquery"), exp.Hex) self.assertIsInstance(parse_one("TO_HEX(MD5(foo))", read="bigquery"), exp.MD5) self.assertIsInstance(parse_one("TRANSFORM(a, b)", read="spark"), exp.Transform) self.assertIsInstance(parse_one("ADD_MONTHS(a, b)"), exp.AddMonths) @@ -834,21 +836,22 @@ class TestExpressions(unittest.TestCase): b AS B, c, /*comment*/ d AS D, -- another comment - CAST(x AS INT) -- final comment + CAST(x AS INT), -- yet another comment + y AND /* foo */ w AS E -- final comment FROM foo """ expression = parse_one(sql) self.assertEqual( [e.alias_or_name for e in expression.expressions], - ["a", "B", "c", "D", "x"], + ["a", "B", "c", "D", "x", "E"], ) self.assertEqual( expression.sql(), - "SELECT a, b AS B, c /* comment */, d AS D /* another comment */, CAST(x AS INT) /* final comment */ FROM foo", + "SELECT a, b AS B, c /* comment */, d AS D /* another comment */, CAST(x AS INT) /* yet another comment */, y AND /* foo */ w AS E /* final comment */ FROM foo", ) self.assertEqual( expression.sql(comments=False), - "SELECT a, b AS B, c, d AS D, CAST(x AS INT) FROM foo", + "SELECT a, b AS B, c, d AS D, CAST(x AS INT), y AND w AS E FROM foo", ) self.assertEqual( expression.sql(pretty=True, comments=False), @@ -857,7 +860,8 @@ class TestExpressions(unittest.TestCase): b AS B, c, d AS D, - CAST(x AS INT) + CAST(x AS INT), + y AND w AS E FROM foo""", ) self.assertEqual( @@ -867,7 +871,8 @@ FROM foo""", b AS B, c, /* comment */ d AS D, /* another comment */ - CAST(x AS INT) /* final comment */ + CAST(x AS INT), /* yet another comment */ + y AND /* foo */ w AS E /* final comment */ FROM foo""", ) diff --git a/tests/test_lineage.py b/tests/test_lineage.py index 3e17f95..036f146 100644 --- a/tests/test_lineage.py +++ b/tests/test_lineage.py @@ -487,3 +487,11 @@ class TestLineage(unittest.TestCase): downstream = node.downstream[0] self.assertEqual(downstream.name, "z.a") self.assertEqual(downstream.source.sql(), "SELECT y.a AS a, y.b AS b, y.c AS c FROM y AS y") + + def test_node_name_doesnt_contain_comment(self) -> None: + sql = "SELECT * FROM (SELECT x /* c */ FROM t1) AS t2" + node = lineage("x", sql) + + self.assertEqual(len(node.downstream), 1) + self.assertEqual(len(node.downstream[0].downstream), 1) + self.assertEqual(node.downstream[0].downstream[0].name, "t1.x") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 758b60c..36768f8 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -428,11 +428,7 @@ SELECT :with,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expr ) def test_unnest_subqueries(self): - self.check_file( - "unnest_subqueries", - optimizer.unnest_subqueries.unnest_subqueries, - pretty=True, - ) + self.check_file("unnest_subqueries", optimizer.unnest_subqueries.unnest_subqueries) def test_pushdown_predicates(self): self.check_file("pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates) diff --git a/tests/test_parser.py b/tests/test_parser.py index 6bcdb64..2cefc07 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -503,7 +503,7 @@ class TestParser(unittest.TestCase): self.assertIsInstance(set_item, exp.SetItem) self.assertIsInstance(set_item.this, exp.EQ) - self.assertIsInstance(set_item.this.this, exp.Identifier) + self.assertIsInstance(set_item.this.this, exp.Column) self.assertIsInstance(set_item.this.expression, exp.Literal) self.assertEqual(set_item.args.get("kind"), "SESSION") @@ -856,5 +856,38 @@ class TestParser(unittest.TestCase): with self.subTest(dialect): self.assertEqual(parse_one(sql, dialect=dialect).sql(dialect=dialect), sql) + def test_alter_set(self): + sqls = [ + "ALTER TABLE tbl SET TBLPROPERTIES ('x'='1', 'Z'='2')", + "ALTER TABLE tbl SET SERDE 'test' WITH SERDEPROPERTIES ('k'='v', 'kay'='vee')", + "ALTER TABLE tbl SET SERDEPROPERTIES ('k'='v', 'kay'='vee')", + "ALTER TABLE tbl SET LOCATION 'new_location'", + "ALTER TABLE tbl SET FILEFORMAT file_format", + "ALTER TABLE tbl SET TAGS ('tag1' = 't1', 'tag2' = 't2')", + ] + + for dialect in ( + "hive", + "spark2", + "spark", + "databricks", + ): + for sql in sqls: + with self.subTest(f"Testing query '{sql}' for dialect {dialect}"): + self.assertEqual(parse_one(sql, dialect=dialect).sql(dialect=dialect), sql) + def test_distinct_from(self): self.assertIsInstance(parse_one("a IS DISTINCT FROM b OR c IS DISTINCT FROM d"), exp.Or) + + def test_trailing_comments(self): + expressions = parse(""" + select * from x; + -- my comment + """) + + self.assertEqual( + ";\n".join(e.sql() for e in expressions), "SELECT * FROM x;\n/* my comment */" + ) + + def test_parse_prop_eq(self): + self.assertIsInstance(parse_one("x(a := b and c)").expressions[0], exp.PropertyEQ) diff --git a/tests/test_schema.py b/tests/test_schema.py index 32686d7..5b50867 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -271,6 +271,34 @@ class TestSchema(unittest.TestCase): "Table z must match the schema's nesting level: 2.", ) + with self.assertRaises(SchemaError) as ctx: + MappingSchema( + { + "catalog": { + "db": {"tbl": {"col": "a"}}, + }, + "tbl2": {"col": "b"}, + }, + ) + self.assertEqual( + str(ctx.exception), + "Table tbl2 must match the schema's nesting level: 3.", + ) + + with self.assertRaises(SchemaError) as ctx: + MappingSchema( + { + "tbl2": {"col": "b"}, + "catalog": { + "db": {"tbl": {"col": "a"}}, + }, + }, + ) + self.assertEqual( + str(ctx.exception), + "Table catalog.db.tbl must match the schema's nesting level: 1.", + ) + def test_has_column(self): schema = MappingSchema({"x": {"c": "int"}}) self.assertTrue(schema.has_column("x", exp.column("c"))) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 22085b3..dea9985 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -551,6 +551,13 @@ FROM x""", pretty=True, ) + self.validate( + """SELECT X FROM catalog.db.table WHERE Y + -- + AND Z""", + """SELECT X FROM catalog.db.table WHERE Y AND Z""", + ) + def test_types(self): self.validate("INT 1", "CAST(1 AS INT)") self.validate("VARCHAR 'x' y", "CAST('x' AS VARCHAR) AS y") |