summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/dialects/test_bigquery.py217
-rw-r--r--tests/dialects/test_clickhouse.py23
-rw-r--r--tests/dialects/test_dialect.py104
-rw-r--r--tests/dialects/test_duckdb.py97
-rw-r--r--tests/dialects/test_mysql.py13
-rw-r--r--tests/dialects/test_postgres.py19
-rw-r--r--tests/dialects/test_redshift.py20
-rw-r--r--tests/dialects/test_snowflake.py302
-rw-r--r--tests/dialects/test_tsql.py112
-rw-r--r--tests/fixtures/identity.sql3
-rw-r--r--tests/fixtures/optimizer/tpc-ds/tpc-ds.sql30
-rw-r--r--tests/fixtures/optimizer/unnest_subqueries.sql312
-rw-r--r--tests/test_executor.py10
-rw-r--r--tests/test_expressions.py19
-rw-r--r--tests/test_lineage.py8
-rw-r--r--tests/test_optimizer.py6
-rw-r--r--tests/test_parser.py35
-rw-r--r--tests/test_schema.py28
-rw-r--r--tests/test_transpile.py7
19 files changed, 907 insertions, 458 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index 301cd57..728785c 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -21,6 +21,13 @@ class TestBigQuery(Validator):
def test_bigquery(self):
self.validate_identity(
+ """CREATE TEMPORARY FUNCTION FOO()
+RETURNS STRING
+LANGUAGE js AS
+'return "Hello world!"'""",
+ pretty=True,
+ )
+ self.validate_identity(
"[a, a(1, 2,3,4444444444444444, tttttaoeunthaoentuhaoentuheoantu, toheuntaoheutnahoeunteoahuntaoeh), b(3, 4,5), c, d, tttttttttttttttteeeeeeeeeeeeeett, 12312312312]",
"""[
a,
@@ -279,6 +286,13 @@ class TestBigQuery(Validator):
)
self.validate_all(
+ "SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS h, UNNEST(h.t3) AS s",
+ write={
+ "bigquery": "SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS h, UNNEST(h.t3) AS s",
+ "duckdb": "SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS _t0(h), UNNEST(h.t3) AS _t1(s)",
+ },
+ )
+ self.validate_all(
"PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)",
write={
"bigquery": "PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)",
@@ -289,7 +303,7 @@ class TestBigQuery(Validator):
"SELECT results FROM Coordinates, Coordinates.position AS results",
write={
"bigquery": "SELECT results FROM Coordinates, UNNEST(Coordinates.position) AS results",
- "presto": "SELECT results FROM Coordinates, UNNEST(Coordinates.position) AS _t(results)",
+ "presto": "SELECT results FROM Coordinates, UNNEST(Coordinates.position) AS _t0(results)",
},
)
self.validate_all(
@@ -307,7 +321,7 @@ class TestBigQuery(Validator):
},
write={
"bigquery": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS results",
- "presto": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS _t(results)",
+ "presto": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS _t0(results)",
"redshift": "SELECT results FROM Coordinates AS c, c.position AS results",
},
)
@@ -525,7 +539,7 @@ class TestBigQuery(Validator):
"SELECT * FROM t WHERE EXISTS(SELECT * FROM unnest(nums) AS x WHERE x > 1)",
write={
"bigquery": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS x WHERE x > 1)",
- "duckdb": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS _t(x) WHERE x > 1)",
+ "duckdb": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS _t0(x) WHERE x > 1)",
},
)
self.validate_all(
@@ -618,12 +632,87 @@ class TestBigQuery(Validator):
},
)
self.validate_all(
+ "LOWER(TO_HEX(x))",
+ write={
+ "": "LOWER(HEX(x))",
+ "bigquery": "TO_HEX(x)",
+ "clickhouse": "LOWER(HEX(x))",
+ "duckdb": "LOWER(HEX(x))",
+ "hive": "LOWER(HEX(x))",
+ "mysql": "LOWER(HEX(x))",
+ "spark": "LOWER(HEX(x))",
+ "sqlite": "LOWER(HEX(x))",
+ "presto": "LOWER(TO_HEX(x))",
+ "trino": "LOWER(TO_HEX(x))",
+ },
+ )
+ self.validate_all(
+ "TO_HEX(x)",
+ read={
+ "": "LOWER(HEX(x))",
+ "clickhouse": "LOWER(HEX(x))",
+ "duckdb": "LOWER(HEX(x))",
+ "hive": "LOWER(HEX(x))",
+ "mysql": "LOWER(HEX(x))",
+ "spark": "LOWER(HEX(x))",
+ "sqlite": "LOWER(HEX(x))",
+ "presto": "LOWER(TO_HEX(x))",
+ "trino": "LOWER(TO_HEX(x))",
+ },
+ write={
+ "": "LOWER(HEX(x))",
+ "bigquery": "TO_HEX(x)",
+ "clickhouse": "LOWER(HEX(x))",
+ "duckdb": "LOWER(HEX(x))",
+ "hive": "LOWER(HEX(x))",
+ "mysql": "LOWER(HEX(x))",
+ "presto": "LOWER(TO_HEX(x))",
+ "spark": "LOWER(HEX(x))",
+ "sqlite": "LOWER(HEX(x))",
+ "trino": "LOWER(TO_HEX(x))",
+ },
+ )
+ self.validate_all(
+ "UPPER(TO_HEX(x))",
+ read={
+ "": "HEX(x)",
+ "clickhouse": "HEX(x)",
+ "duckdb": "HEX(x)",
+ "hive": "HEX(x)",
+ "mysql": "HEX(x)",
+ "presto": "TO_HEX(x)",
+ "spark": "HEX(x)",
+ "sqlite": "HEX(x)",
+ "trino": "TO_HEX(x)",
+ },
+ write={
+ "": "HEX(x)",
+ "bigquery": "UPPER(TO_HEX(x))",
+ "clickhouse": "HEX(x)",
+ "duckdb": "HEX(x)",
+ "hive": "HEX(x)",
+ "mysql": "HEX(x)",
+ "presto": "TO_HEX(x)",
+ "spark": "HEX(x)",
+ "sqlite": "HEX(x)",
+ "trino": "TO_HEX(x)",
+ },
+ )
+ self.validate_all(
"MD5(x)",
+ read={
+ "clickhouse": "MD5(x)",
+ "presto": "MD5(x)",
+ "trino": "MD5(x)",
+ },
write={
"": "MD5_DIGEST(x)",
"bigquery": "MD5(x)",
+ "clickhouse": "MD5(x)",
"hive": "UNHEX(MD5(x))",
+ "presto": "MD5(x)",
"spark": "UNHEX(MD5(x))",
+ "trino": "MD5(x)",
},
)
self.validate_all(
@@ -631,25 +720,69 @@ class TestBigQuery(Validator):
read={
"duckdb": "SELECT MD5(some_string)",
"spark": "SELECT MD5(some_string)",
+ "clickhouse": "SELECT LOWER(HEX(MD5(some_string)))",
+ "presto": "SELECT LOWER(TO_HEX(MD5(some_string)))",
+ "trino": "SELECT LOWER(TO_HEX(MD5(some_string)))",
},
write={
"": "SELECT MD5(some_string)",
"bigquery": "SELECT TO_HEX(MD5(some_string))",
"duckdb": "SELECT MD5(some_string)",
+ "clickhouse": "SELECT LOWER(HEX(MD5(some_string)))",
+ "presto": "SELECT LOWER(TO_HEX(MD5(some_string)))",
+ "trino": "SELECT LOWER(TO_HEX(MD5(some_string)))",
+ },
+ )
+ self.validate_all(
+ "SHA1(x)",
+ read={
+ "clickhouse": "SHA1(x)",
+ "presto": "SHA1(x)",
+ "trino": "SHA1(x)",
+ },
+ write={
+ "clickhouse": "SHA1(x)",
+ "bigquery": "SHA1(x)",
+ "": "SHA(x)",
+ "presto": "SHA1(x)",
+ "trino": "SHA1(x)",
+ },
+ )
+ self.validate_all(
+ "SHA1(x)",
+ write={
+ "bigquery": "SHA1(x)",
+ "": "SHA(x)",
},
)
self.validate_all(
"SHA256(x)",
+ read={
+ "clickhouse": "SHA256(x)",
+ "presto": "SHA256(x)",
+ "trino": "SHA256(x)",
+ },
write={
"bigquery": "SHA256(x)",
"spark2": "SHA2(x, 256)",
+ "clickhouse": "SHA256(x)",
+ "presto": "SHA256(x)",
+ "trino": "SHA256(x)",
},
)
self.validate_all(
"SHA512(x)",
+ read={
+ "clickhouse": "SHA512(x)",
+ "presto": "SHA512(x)",
+ "trino": "SHA512(x)",
+ },
write={
+ "clickhouse": "SHA512(x)",
"bigquery": "SHA512(x)",
"spark2": "SHA2(x, 512)",
+ "presto": "SHA512(x)",
+ "trino": "SHA512(x)",
},
)
self.validate_all(
@@ -860,8 +993,8 @@ class TestBigQuery(Validator):
},
write={
"bigquery": "SELECT * FROM UNNEST(['7', '14']) AS x",
- "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS _t(x)",
- "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS _t(x)",
+ "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS _t0(x)",
+ "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS _t0(x)",
},
)
self.validate_all(
@@ -982,6 +1115,69 @@ class TestBigQuery(Validator):
},
)
self.validate_all(
+ "DELETE db.example_table WHERE x = 1",
+ write={
+ "bigquery": "DELETE db.example_table WHERE x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE db.example_table tb WHERE tb.x = 1",
+ write={
+ "bigquery": "DELETE db.example_table AS tb WHERE tb.x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE db.example_table AS tb WHERE tb.x = 1",
+ write={
+ "bigquery": "DELETE db.example_table AS tb WHERE tb.x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE FROM db.example_table WHERE x = 1",
+ write={
+ "bigquery": "DELETE FROM db.example_table WHERE x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE FROM db.example_table tb WHERE tb.x = 1",
+ write={
+ "bigquery": "DELETE FROM db.example_table AS tb WHERE tb.x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE FROM db.example_table AS tb WHERE tb.x = 1",
+ write={
+ "bigquery": "DELETE FROM db.example_table AS tb WHERE tb.x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE FROM db.example_table AS tb WHERE example_table.x = 1",
+ write={
+ "bigquery": "DELETE FROM db.example_table AS tb WHERE example_table.x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE FROM db.example_table WHERE example_table.x = 1",
+ write={
+ "bigquery": "DELETE FROM db.example_table WHERE example_table.x = 1",
+ "presto": "DELETE FROM db.example_table WHERE example_table.x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE FROM db.t1 AS t1 WHERE NOT t1.c IN (SELECT db.t2.c FROM db.t2)",
+ write={
+ "bigquery": "DELETE FROM db.t1 AS t1 WHERE NOT t1.c IN (SELECT db.t2.c FROM db.t2)",
+ "presto": "DELETE FROM db.t1 WHERE NOT c IN (SELECT c FROM db.t2)",
+ },
+ )
+ self.validate_all(
"SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])",
write={
"bigquery": "SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])",
@@ -1464,3 +1660,14 @@ OPTIONS (
with self.assertRaises(ParseError):
transpile("SELECT JSON_OBJECT('a', 1, 'b') AS json_data", read="bigquery")
+
+ def test_mod(self):
+ for sql in ("MOD(a, b)", "MOD('a', b)", "MOD(5, 2)", "MOD((a + 1) * 8, 5 - 1)"):
+ with self.subTest(f"Testing BigQuery roundtrip of modulo operation: {sql}"):
+ self.validate_identity(sql)
+
+ self.validate_identity("SELECT MOD((SELECT 1), 2)")
+ self.validate_identity(
+ "MOD((a + 1), b)",
+ "MOD(a + 1, b)",
+ )
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index af552d1..15adda8 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -425,6 +425,21 @@ class TestClickhouse(Validator):
},
)
+ self.validate_identity("ALTER TABLE visits DROP PARTITION 201901")
+ self.validate_identity("ALTER TABLE visits DROP PARTITION ALL")
+ self.validate_identity(
+ "ALTER TABLE visits DROP PARTITION tuple(toYYYYMM(toDate('2019-01-25')))"
+ )
+ self.validate_identity("ALTER TABLE visits DROP PARTITION ID '201901'")
+
+ self.validate_identity("ALTER TABLE visits REPLACE PARTITION 201901 FROM visits_tmp")
+ self.validate_identity("ALTER TABLE visits REPLACE PARTITION ALL FROM visits_tmp")
+ self.validate_identity(
+ "ALTER TABLE visits REPLACE PARTITION tuple(toYYYYMM(toDate('2019-01-25'))) FROM visits_tmp"
+ )
+ self.validate_identity("ALTER TABLE visits REPLACE PARTITION ID '201901' FROM visits_tmp")
+ self.validate_identity("ALTER TABLE visits ON CLUSTER test_cluster DROP COLUMN col1")
+
def test_cte(self):
self.validate_identity("WITH 'x' AS foo SELECT foo")
self.validate_identity("WITH ['c'] AS field_names SELECT field_names")
@@ -829,6 +844,9 @@ LIFETIME(MIN 0 MAX 0)""",
self.validate_identity(
"CREATE TABLE t1 (a String EPHEMERAL, b String EPHEMERAL func(), c String MATERIALIZED func(), d String ALIAS func()) ENGINE=TinyLog()"
)
+ self.validate_identity(
+ "CREATE TABLE t (a String, b String, c UInt64, PROJECTION p1 (SELECT a, sum(c) GROUP BY a, b), PROJECTION p2 (SELECT b, sum(c) GROUP BY b)) ENGINE=MergeTree()"
+ )
def test_agg_functions(self):
def extract_agg_func(query):
@@ -856,3 +874,8 @@ LIFETIME(MIN 0 MAX 0)""",
)
parse_one("foobar(x)").assert_is(exp.Anonymous)
+
+ def test_drop_on_cluster(self):
+ for creatable in ("DATABASE", "TABLE", "VIEW", "DICTIONARY", "FUNCTION"):
+ with self.subTest(f"Test DROP {creatable} ON CLUSTER"):
+ self.validate_identity(f"DROP {creatable} test ON CLUSTER test_cluster")
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index dda0eb2..77306dc 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -1019,6 +1019,19 @@ class TestDialect(Validator):
},
)
+ self.validate_all(
+ "TIMESTAMP_TRUNC(x, DAY, 'UTC')",
+ write={
+ "": "TIMESTAMP_TRUNC(x, DAY, 'UTC')",
+ "duckdb": "DATE_TRUNC('DAY', x)",
+ "presto": "DATE_TRUNC('DAY', x)",
+ "postgres": "DATE_TRUNC('DAY', x, 'UTC')",
+ "snowflake": "DATE_TRUNC('DAY', x)",
+ "databricks": "DATE_TRUNC('DAY', x)",
+ "clickhouse": "DATE_TRUNC('DAY', x, 'UTC')",
+ },
+ )
+
for unit in ("DAY", "MONTH", "YEAR"):
self.validate_all(
f"{unit}(x)",
@@ -1681,6 +1694,26 @@ class TestDialect(Validator):
"tsql": "CAST(a AS FLOAT) / b",
},
)
+ self.validate_all(
+ "MOD(8 - 1 + 7, 7)",
+ write={
+ "": "(8 - 1 + 7) % 7",
+ "hive": "(8 - 1 + 7) % 7",
+ "presto": "(8 - 1 + 7) % 7",
+ "snowflake": "(8 - 1 + 7) % 7",
+ "bigquery": "MOD(8 - 1 + 7, 7)",
+ },
+ )
+ self.validate_all(
+ "MOD(a, b + 1)",
+ write={
+ "": "a % (b + 1)",
+ "hive": "a % (b + 1)",
+ "presto": "a % (b + 1)",
+ "snowflake": "a % (b + 1)",
+ "bigquery": "MOD(a, b + 1)",
+ },
+ )
def test_typeddiv(self):
typed_div = exp.Div(this=exp.column("a"), expression=exp.column("b"), typed=True)
@@ -2186,6 +2219,8 @@ SELECT
)
def test_cast_to_user_defined_type(self):
+ self.validate_identity("CAST(x AS some_udt(1234))")
+
self.validate_all(
"CAST(x AS some_udt)",
write={
@@ -2214,6 +2249,18 @@ SELECT
"tsql": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
},
)
+ self.validate_all(
+ 'SELECT "user id", some_id, 1 as other_id, 2 as "2 nd id" FROM t QUALIFY COUNT(*) OVER () > 1',
+ write={
+ "duckdb": 'SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id" FROM t QUALIFY COUNT(*) OVER () > 1',
+ "snowflake": 'SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id" FROM t QUALIFY COUNT(*) OVER () > 1',
+ "clickhouse": 'SELECT "user id", some_id, other_id, "2 nd id" FROM (SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id", COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1',
+ "mysql": "SELECT `user id`, some_id, other_id, `2 nd id` FROM (SELECT `user id`, some_id, 1 AS other_id, 2 AS `2 nd id`, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ "oracle": 'SELECT "user id", some_id, other_id, "2 nd id" FROM (SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id", COUNT(*) OVER () AS _w FROM t) _t WHERE _w > 1',
+ "postgres": 'SELECT "user id", some_id, other_id, "2 nd id" FROM (SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id", COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1',
+ "tsql": "SELECT [user id], some_id, other_id, [2 nd id] FROM (SELECT [user id] AS [user id], some_id AS some_id, 1 AS other_id, 2 AS [2 nd id], COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ },
+ )
def test_nested_ctes(self):
self.validate_all(
@@ -2249,7 +2296,7 @@ SELECT
"WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq",
write={
"duckdb": "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq",
- "tsql": "WITH t2(y) AS (SELECT 2), t1(x) AS (SELECT 1) SELECT * FROM (SELECT y AS y FROM t2) AS subq",
+ "tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y AS y FROM t2) AS subq",
},
)
self.validate_all(
@@ -2273,6 +2320,59 @@ FROM c""",
"hive": "WITH a1 AS (SELECT 1), a2 AS (SELECT 2), b AS (SELECT * FROM a1, a2), c AS (SELECT * FROM b) SELECT * FROM c",
},
)
+ self.validate_all(
+ """
+WITH subquery1 AS (
+ WITH tmp AS (
+ SELECT
+ *
+ FROM table0
+ )
+ SELECT
+ *
+ FROM tmp
+), subquery2 AS (
+ WITH tmp2 AS (
+ SELECT
+ *
+ FROM table1
+ WHERE
+ a IN subquery1
+ )
+ SELECT
+ *
+ FROM tmp2
+)
+SELECT
+ *
+FROM subquery2
+""",
+ write={
+ "hive": """WITH tmp AS (
+ SELECT
+ *
+ FROM table0
+), subquery1 AS (
+ SELECT
+ *
+ FROM tmp
+), tmp2 AS (
+ SELECT
+ *
+ FROM table1
+ WHERE
+ a IN subquery1
+), subquery2 AS (
+ SELECT
+ *
+ FROM tmp2
+)
+SELECT
+ *
+FROM subquery2""",
+ },
+ pretty=True,
+ )
def test_unsupported_null_ordering(self):
# We'll transpile a portable query from the following dialects to MySQL / T-SQL, which
@@ -2372,7 +2472,7 @@ FROM c""",
"hive": UnsupportedError,
"mysql": UnsupportedError,
"oracle": UnsupportedError,
- "postgres": "(ARRAY_LENGTH(arr, 1) = 0 OR ARRAY_LENGTH(ARRAY(SELECT x FROM UNNEST(arr) AS _t(x) WHERE pred), 1) <> 0)",
+ "postgres": "(ARRAY_LENGTH(arr, 1) = 0 OR ARRAY_LENGTH(ARRAY(SELECT x FROM UNNEST(arr) AS _t0(x) WHERE pred), 1) <> 0)",
"presto": "ANY_MATCH(arr, x -> pred)",
"redshift": UnsupportedError,
"snowflake": UnsupportedError,
diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py
index 2d0af13..03dea93 100644
--- a/tests/dialects/test_duckdb.py
+++ b/tests/dialects/test_duckdb.py
@@ -272,6 +272,14 @@ class TestDuckDB(Validator):
"SELECT * FROM x LEFT JOIN UNNEST(y)", "SELECT * FROM x LEFT JOIN UNNEST(y) ON TRUE"
)
self.validate_identity(
+ "SELECT JSON_EXTRACT_STRING(c, '$.k1') = 'v1'",
+ "SELECT (c ->> '$.k1') = 'v1'",
+ )
+ self.validate_identity(
+ "SELECT JSON_EXTRACT(c, '$.k1') = 'v1'",
+ "SELECT (c -> '$.k1') = 'v1'",
+ )
+ self.validate_identity(
"""SELECT '{"foo": [1, 2, 3]}' -> 'foo' -> 0""",
"""SELECT '{"foo": [1, 2, 3]}' -> '$.foo' -> '$[0]'""",
)
@@ -734,6 +742,28 @@ class TestDuckDB(Validator):
)
self.validate_identity("COPY lineitem (l_orderkey) TO 'orderkey.tbl' WITH (DELIMITER '|')")
+ self.validate_all(
+ "VARIANCE(a)",
+ write={
+ "duckdb": "VARIANCE(a)",
+ "clickhouse": "varSamp(a)",
+ },
+ )
+ self.validate_all(
+ "STDDEV(a)",
+ write={
+ "duckdb": "STDDEV(a)",
+ "clickhouse": "stddevSamp(a)",
+ },
+ )
+ self.validate_all(
+ "DATE_TRUNC('DAY', x)",
+ write={
+ "duckdb": "DATE_TRUNC('DAY', x)",
+ "clickhouse": "DATE_TRUNC('DAY', x)",
+ },
+ )
+
def test_array_index(self):
with self.assertLogs(helper_logger) as cm:
self.validate_all(
@@ -803,7 +833,7 @@ class TestDuckDB(Validator):
write={"duckdb": "SELECT (90 * INTERVAL '1' DAY)"},
)
self.validate_all(
- "SELECT ((DATE_TRUNC('DAY', CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP) + INTERVAL (0 - (DAYOFWEEK(CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7 % 7) DAY) + (7 * INTERVAL (-5) DAY))) AS t1",
+ "SELECT ((DATE_TRUNC('DAY', CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP) + INTERVAL (0 - ((DAYOFWEEK(CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7) % 7) DAY) + (7 * INTERVAL (-5) DAY))) AS t1",
read={
"presto": "SELECT ((DATE_ADD('week', -5, DATE_TRUNC('DAY', DATE_ADD('day', (0 - MOD((DAY_OF_WEEK(CAST(CAST(DATE_TRUNC('DAY', NOW()) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7, 7)), CAST(CAST(DATE_TRUNC('DAY', NOW()) AS DATE) AS TIMESTAMP)))))) AS t1",
},
@@ -827,6 +857,9 @@ class TestDuckDB(Validator):
"duckdb": "EPOCH_MS(x)",
"presto": "FROM_UNIXTIME(CAST(x AS DOUBLE) / POW(10, 3))",
"spark": "TIMESTAMP_MILLIS(x)",
+ "clickhouse": "fromUnixTimestamp64Milli(CAST(x AS Int64))",
+ "postgres": "TO_TIMESTAMP(CAST(x AS DOUBLE PRECISION) / 10 ^ 3)",
+ "mysql": "FROM_UNIXTIME(x / POWER(10, 3))",
},
)
self.validate_all(
@@ -892,11 +925,11 @@ class TestDuckDB(Validator):
def test_sample(self):
self.validate_identity(
"SELECT * FROM tbl USING SAMPLE 5",
- "SELECT * FROM tbl USING SAMPLE (5 ROWS)",
+ "SELECT * FROM tbl USING SAMPLE RESERVOIR (5 ROWS)",
)
self.validate_identity(
"SELECT * FROM tbl USING SAMPLE 10%",
- "SELECT * FROM tbl USING SAMPLE (10 PERCENT)",
+ "SELECT * FROM tbl USING SAMPLE SYSTEM (10 PERCENT)",
)
self.validate_identity(
"SELECT * FROM tbl USING SAMPLE 10 PERCENT (bernoulli)",
@@ -920,14 +953,13 @@ class TestDuckDB(Validator):
)
self.validate_all(
- "SELECT * FROM example TABLESAMPLE (3 ROWS) REPEATABLE (82)",
+ "SELECT * FROM example TABLESAMPLE RESERVOIR (3 ROWS) REPEATABLE (82)",
read={
"duckdb": "SELECT * FROM example TABLESAMPLE (3) REPEATABLE (82)",
"snowflake": "SELECT * FROM example SAMPLE (3 ROWS) SEED (82)",
},
write={
- "duckdb": "SELECT * FROM example TABLESAMPLE (3 ROWS) REPEATABLE (82)",
- "snowflake": "SELECT * FROM example TABLESAMPLE (3 ROWS) SEED (82)",
+ "duckdb": "SELECT * FROM example TABLESAMPLE RESERVOIR (3 ROWS) REPEATABLE (82)",
},
)
@@ -946,10 +978,6 @@ class TestDuckDB(Validator):
self.validate_identity("CAST(x AS DOUBLE)")
self.validate_identity("CAST(x AS DECIMAL(15, 4))")
self.validate_identity("CAST(x AS STRUCT(number BIGINT))")
- self.validate_identity(
- "CAST(ROW(1, ROW(1)) AS STRUCT(number BIGINT, row STRUCT(number BIGINT)))"
- )
-
self.validate_identity("CAST(x AS INT64)", "CAST(x AS BIGINT)")
self.validate_identity("CAST(x AS INT32)", "CAST(x AS INT)")
self.validate_identity("CAST(x AS INT16)", "CAST(x AS SMALLINT)")
@@ -969,6 +997,32 @@ class TestDuckDB(Validator):
self.validate_identity("CAST(x AS BINARY)", "CAST(x AS BLOB)")
self.validate_identity("CAST(x AS VARBINARY)", "CAST(x AS BLOB)")
self.validate_identity("CAST(x AS LOGICAL)", "CAST(x AS BOOLEAN)")
+ self.validate_identity(
+ "CAST(ROW(1, ROW(1)) AS STRUCT(number BIGINT, row STRUCT(number BIGINT)))"
+ )
+ self.validate_identity(
+ "123::CHARACTER VARYING",
+ "CAST(123 AS TEXT)",
+ )
+ self.validate_identity(
+ "CAST([[STRUCT_PACK(a := 1)]] AS STRUCT(a BIGINT)[][])",
+ "CAST([[{'a': 1}]] AS STRUCT(a BIGINT)[][])",
+ )
+ self.validate_identity(
+ "CAST([STRUCT_PACK(a := 1)] AS STRUCT(a BIGINT)[])",
+ "CAST([{'a': 1}] AS STRUCT(a BIGINT)[])",
+ )
+
+ self.validate_all(
+ "CAST(x AS DECIMAL(38, 0))",
+ read={
+ "snowflake": "CAST(x AS NUMBER)",
+ "duckdb": "CAST(x AS DECIMAL(38, 0))",
+ },
+ write={
+ "snowflake": "CAST(x AS DECIMAL(38, 0))",
+ },
+ )
self.validate_all(
"CAST(x AS NUMERIC)",
write={
@@ -994,12 +1048,6 @@ class TestDuckDB(Validator):
},
)
self.validate_all(
- "123::CHARACTER VARYING",
- write={
- "duckdb": "CAST(123 AS TEXT)",
- },
- )
- self.validate_all(
"cast([[1]] as int[][])",
write={
"duckdb": "CAST([[1]] AS INT[][])",
@@ -1007,7 +1055,10 @@ class TestDuckDB(Validator):
},
)
self.validate_all(
- "CAST(x AS DATE) + INTERVAL (7 * -1) DAY", read={"spark": "DATE_SUB(x, 7)"}
+ "CAST(x AS DATE) + INTERVAL (7 * -1) DAY",
+ read={
+ "spark": "DATE_SUB(x, 7)",
+ },
)
self.validate_all(
"TRY_CAST(1 AS DOUBLE)",
@@ -1034,18 +1085,6 @@ class TestDuckDB(Validator):
"snowflake": "CAST(COL AS ARRAY(BIGINT))",
},
)
- self.validate_all(
- "CAST([STRUCT_PACK(a := 1)] AS STRUCT(a BIGINT)[])",
- write={
- "duckdb": "CAST([{'a': 1}] AS STRUCT(a BIGINT)[])",
- },
- )
- self.validate_all(
- "CAST([[STRUCT_PACK(a := 1)]] AS STRUCT(a BIGINT)[][])",
- write={
- "duckdb": "CAST([[{'a': 1}]] AS STRUCT(a BIGINT)[][])",
- },
- )
def test_bool_or(self):
self.validate_all(
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index 53e2dab..84fb3c2 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -149,6 +149,9 @@ class TestMySQL(Validator):
"SELECT * FROM t1, t2, t3 FOR SHARE OF t1 NOWAIT FOR UPDATE OF t2, t3 SKIP LOCKED"
)
self.validate_identity(
+ "REPLACE INTO table SELECT id FROM table2 WHERE cnt > 100", check_command_warning=True
+ )
+ self.validate_identity(
"""SELECT * FROM foo WHERE 3 MEMBER OF(info->'$.value')""",
"""SELECT * FROM foo WHERE 3 MEMBER OF(JSON_EXTRACT(info, '$.value'))""",
)
@@ -608,6 +611,16 @@ class TestMySQL(Validator):
def test_mysql(self):
self.validate_all(
+ "SELECT CONCAT('11', '22')",
+ read={
+ "postgres": "SELECT '11' || '22'",
+ },
+ write={
+ "mysql": "SELECT CONCAT('11', '22')",
+ "postgres": "SELECT CONCAT('11', '22')",
+ },
+ )
+ self.validate_all(
"SELECT department, GROUP_CONCAT(name) AS employee_names FROM data GROUP BY department",
read={
"postgres": "SELECT department, array_agg(name) AS employee_names FROM data GROUP BY department",
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 6b6117e..a8a6c12 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -733,6 +733,13 @@ class TestPostgres(Validator):
self.validate_identity("TRUNCATE TABLE t1 RESTRICT")
self.validate_identity("TRUNCATE TABLE t1 CONTINUE IDENTITY CASCADE")
self.validate_identity("TRUNCATE TABLE t1 RESTART IDENTITY RESTRICT")
+ self.validate_identity("ALTER TABLE t1 SET LOGGED")
+ self.validate_identity("ALTER TABLE t1 SET UNLOGGED")
+ self.validate_identity("ALTER TABLE t1 SET WITHOUT CLUSTER")
+ self.validate_identity("ALTER TABLE t1 SET WITHOUT OIDS")
+ self.validate_identity("ALTER TABLE t1 SET ACCESS METHOD method")
+ self.validate_identity("ALTER TABLE t1 SET TABLESPACE tablespace")
+ self.validate_identity("ALTER TABLE t1 SET (fillfactor = 5, autovacuum_enabled = TRUE)")
self.validate_identity(
"CREATE TABLE t (vid INT NOT NULL, CONSTRAINT ht_vid_nid_fid_idx EXCLUDE (INT4RANGE(vid, nid) WITH &&, INT4RANGE(fid, fid, '[]') WITH &&))"
)
@@ -798,23 +805,21 @@ class TestPostgres(Validator):
"CREATE TABLE test (x TIMESTAMP[][])",
)
self.validate_identity(
- "CREATE FUNCTION add(INT, INT) RETURNS INT SET search_path TO 'public' AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE",
- check_command_warning=True,
+ "CREATE FUNCTION add(integer, integer) RETURNS INT LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT AS 'select $1 + $2;'",
)
self.validate_identity(
- "CREATE FUNCTION x(INT) RETURNS INT SET foo FROM CURRENT",
- check_command_warning=True,
+ "CREATE FUNCTION add(integer, integer) RETURNS INT LANGUAGE SQL IMMUTABLE STRICT AS 'select $1 + $2;'"
)
self.validate_identity(
- "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT",
+ "CREATE FUNCTION add(INT, INT) RETURNS INT SET search_path TO 'public' AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE",
check_command_warning=True,
)
self.validate_identity(
- "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE CALLED ON NULL INPUT",
+ "CREATE FUNCTION x(INT) RETURNS INT SET foo FROM CURRENT",
check_command_warning=True,
)
self.validate_identity(
- "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE STRICT",
+ "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE CALLED ON NULL INPUT",
check_command_warning=True,
)
self.validate_identity(
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index e227ea9..3925e32 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -498,7 +498,25 @@ FROM (
},
)
- def test_rename_table(self):
+ def test_alter_table(self):
+ self.validate_identity("ALTER TABLE s.t ALTER SORTKEY (c)")
+ self.validate_identity("ALTER TABLE t ALTER SORTKEY AUTO")
+ self.validate_identity("ALTER TABLE t ALTER SORTKEY NONE")
+ self.validate_identity("ALTER TABLE t ALTER SORTKEY (c1, c2)")
+ self.validate_identity("ALTER TABLE t ALTER SORTKEY (c1, c2)")
+ self.validate_identity("ALTER TABLE t ALTER COMPOUND SORTKEY (c1, c2)")
+ self.validate_identity("ALTER TABLE t ALTER DISTSTYLE ALL")
+ self.validate_identity("ALTER TABLE t ALTER DISTSTYLE EVEN")
+ self.validate_identity("ALTER TABLE t ALTER DISTSTYLE AUTO")
+ self.validate_identity("ALTER TABLE t ALTER DISTSTYLE KEY DISTKEY c")
+ self.validate_identity("ALTER TABLE t SET TABLE PROPERTIES ('a' = '5', 'b' = 'c')")
+ self.validate_identity("ALTER TABLE t SET LOCATION 's3://bucket/folder/'")
+ self.validate_identity("ALTER TABLE t SET FILE FORMAT AVRO")
+ self.validate_identity(
+ "ALTER TABLE t ALTER DISTKEY c",
+ "ALTER TABLE t ALTER DISTSTYLE KEY DISTKEY c",
+ )
+
self.validate_all(
"ALTER TABLE db.t1 RENAME TO db.t2",
write={
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index dae8355..d3c47af 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -10,14 +10,14 @@ class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
- self.validate_identity(
- "MERGE INTO my_db AS ids USING (SELECT new_id FROM my_model WHERE NOT col IS NULL) AS new_ids ON ids.type = new_ids.type AND ids.source = new_ids.source WHEN NOT MATCHED THEN INSERT VALUES (new_ids.new_id)"
- )
- self.validate_identity("ALTER TABLE table1 CLUSTER BY (name DESC)")
- self.validate_identity(
- "INSERT OVERWRITE TABLE t SELECT 1", "INSERT OVERWRITE INTO t SELECT 1"
+ self.validate_all(
+ "ARRAY_CONSTRUCT_COMPACT(1, null, 2)",
+ write={
+ "spark": "ARRAY_COMPACT(ARRAY(1, NULL, 2))",
+ "snowflake": "ARRAY_CONSTRUCT_COMPACT(1, NULL, 2)",
+ },
)
- self.validate_identity("SELECT rename, replace")
+
expr = parse_one("SELECT APPROX_TOP_K(C4, 3, 5) FROM t")
expr.selects[0].assert_is(exp.AggFunc)
self.assertEqual(expr.sql(dialect="snowflake"), "SELECT APPROX_TOP_K(C4, 3, 5) FROM t")
@@ -43,6 +43,8 @@ WHERE
)""",
)
+ self.validate_identity("ALTER TABLE table1 CLUSTER BY (name DESC)")
+ self.validate_identity("SELECT rename, replace")
self.validate_identity("SELECT TIMEADD(HOUR, 2, CAST('09:05:03' AS TIME))")
self.validate_identity("SELECT CAST(OBJECT_CONSTRUCT('a', 1) AS MAP(VARCHAR, INT))")
self.validate_identity("SELECT CAST(OBJECT_CONSTRUCT('a', 1) AS OBJECT(a CHAR NOT NULL))")
@@ -86,20 +88,19 @@ WHERE
self.validate_identity("a$b") # valid snowflake identifier
self.validate_identity("SELECT REGEXP_LIKE(a, b, c)")
self.validate_identity("CREATE TABLE foo (bar FLOAT AUTOINCREMENT START 0 INCREMENT 1)")
- self.validate_identity("ALTER TABLE IF EXISTS foo SET TAG a = 'a', b = 'b', c = 'c'")
- self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c")
- self.validate_identity("ALTER TABLE foo SET COMMENT = 'bar'")
- self.validate_identity("ALTER TABLE foo SET CHANGE_TRACKING = FALSE")
- self.validate_identity("ALTER TABLE foo UNSET DATA_RETENTION_TIME_IN_DAYS, CHANGE_TRACKING")
self.validate_identity("COMMENT IF EXISTS ON TABLE foo IS 'bar'")
self.validate_identity("SELECT CONVERT_TIMEZONE('UTC', 'America/Los_Angeles', col)")
self.validate_identity("ALTER TABLE a SWAP WITH b")
self.validate_identity("SELECT MATCH_CONDITION")
+ self.validate_identity("SELECT * REPLACE (CAST(col AS TEXT) AS scol) FROM t")
self.validate_identity(
- 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage'
+ "MERGE INTO my_db AS ids USING (SELECT new_id FROM my_model WHERE NOT col IS NULL) AS new_ids ON ids.type = new_ids.type AND ids.source = new_ids.source WHEN NOT MATCHED THEN INSERT VALUES (new_ids.new_id)"
+ )
+ self.validate_identity(
+ "INSERT OVERWRITE TABLE t SELECT 1", "INSERT OVERWRITE INTO t SELECT 1"
)
self.validate_identity(
- "SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE (0.1)"
+ 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage'
)
self.validate_identity(
"SELECT * FROM DATA AS DATA_L ASOF JOIN DATA AS DATA_R MATCH_CONDITION (DATA_L.VAL > DATA_R.VAL) ON DATA_L.ID = DATA_R.ID"
@@ -146,10 +147,6 @@ WHERE
"SELECT TIMESTAMP_FROM_PARTS(d, t)",
)
self.validate_identity(
- "SELECT user_id, value FROM table_name SAMPLE ($s) SEED (0)",
- "SELECT user_id, value FROM table_name TABLESAMPLE ($s) SEED (0)",
- )
- self.validate_identity(
"SELECT v:attr[0].name FROM vartab",
"SELECT GET_PATH(v, 'attr[0].name') FROM vartab",
)
@@ -236,6 +233,38 @@ WHERE
"CAST(x AS NCHAR VARYING)",
"CAST(x AS VARCHAR)",
)
+ self.validate_identity(
+ "CREATE OR REPLACE TEMPORARY TABLE x (y NUMBER IDENTITY(0, 1))",
+ "CREATE OR REPLACE TEMPORARY TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)",
+ )
+ self.validate_identity(
+ "CREATE TEMPORARY TABLE x (y NUMBER AUTOINCREMENT(0, 1))",
+ "CREATE TEMPORARY TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)",
+ )
+ self.validate_identity(
+ "CREATE TABLE x (y NUMBER IDENTITY START 0 INCREMENT 1)",
+ "CREATE TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)",
+ )
+ self.validate_identity(
+ "ALTER TABLE foo ADD COLUMN id INT identity(1, 1)",
+ "ALTER TABLE foo ADD COLUMN id INT AUTOINCREMENT START 1 INCREMENT 1",
+ )
+ self.validate_identity(
+ "SELECT DAYOFWEEK('2016-01-02T23:39:20.123-07:00'::TIMESTAMP)",
+ "SELECT DAYOFWEEK(CAST('2016-01-02T23:39:20.123-07:00' AS TIMESTAMP))",
+ )
+ self.validate_identity(
+ "SELECT * FROM xxx WHERE col ilike '%Don''t%'",
+ "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'",
+ )
+ self.validate_identity(
+ "SELECT * EXCLUDE a, b FROM xxx",
+ "SELECT * EXCLUDE (a), b FROM xxx",
+ )
+ self.validate_identity(
+ "SELECT * RENAME a AS b, c AS d FROM xxx",
+ "SELECT * RENAME (a AS b), c AS d FROM xxx",
+ )
self.validate_all(
"OBJECT_CONSTRUCT_KEEP_NULL('key_1', 'one', 'key_2', NULL)",
@@ -250,18 +279,6 @@ WHERE
},
)
self.validate_all(
- "SELECT * FROM example TABLESAMPLE (3) SEED (82)",
- read={
- "databricks": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)",
- "duckdb": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)",
- },
- write={
- "databricks": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)",
- "duckdb": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)",
- "snowflake": "SELECT * FROM example TABLESAMPLE (3) SEED (82)",
- },
- )
- self.validate_all(
"SELECT TIME_FROM_PARTS(12, 34, 56, 987654321)",
write={
"duckdb": "SELECT MAKE_TIME(12, 34, 56 + (987654321 / 1000000000.0))",
@@ -568,60 +585,12 @@ WHERE
},
)
self.validate_all(
- "CREATE OR REPLACE TEMPORARY TABLE x (y NUMBER IDENTITY(0, 1))",
- write={
- "snowflake": "CREATE OR REPLACE TEMPORARY TABLE x (y DECIMAL AUTOINCREMENT START 0 INCREMENT 1)",
- },
- )
- self.validate_all(
- "CREATE TEMPORARY TABLE x (y NUMBER AUTOINCREMENT(0, 1))",
- write={
- "snowflake": "CREATE TEMPORARY TABLE x (y DECIMAL AUTOINCREMENT START 0 INCREMENT 1)",
- },
- )
- self.validate_all(
- "CREATE TABLE x (y NUMBER IDENTITY START 0 INCREMENT 1)",
- write={
- "snowflake": "CREATE TABLE x (y DECIMAL AUTOINCREMENT START 0 INCREMENT 1)",
- },
- )
- self.validate_all(
- "ALTER TABLE foo ADD COLUMN id INT identity(1, 1)",
- write={
- "snowflake": "ALTER TABLE foo ADD COLUMN id INT AUTOINCREMENT START 1 INCREMENT 1",
- },
- )
- self.validate_all(
- "SELECT DAYOFWEEK('2016-01-02T23:39:20.123-07:00'::TIMESTAMP)",
- write={
- "snowflake": "SELECT DAYOFWEEK(CAST('2016-01-02T23:39:20.123-07:00' AS TIMESTAMP))",
- },
- )
- self.validate_all(
- "SELECT * FROM xxx WHERE col ilike '%Don''t%'",
- write={
- "snowflake": "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'",
- },
- )
- self.validate_all(
- "SELECT * EXCLUDE a, b FROM xxx",
- write={
- "snowflake": "SELECT * EXCLUDE (a), b FROM xxx",
- },
- )
- self.validate_all(
- "SELECT * RENAME a AS b, c AS d FROM xxx",
- write={
- "snowflake": "SELECT * RENAME (a AS b), c AS d FROM xxx",
- },
- )
- self.validate_all(
- "SELECT * EXCLUDE (a, b) RENAME (c AS d, E AS F) FROM xxx",
+ "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx",
read={
"duckdb": "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx",
},
write={
- "snowflake": "SELECT * EXCLUDE (a, b) RENAME (c AS d, E AS F) FROM xxx",
+ "snowflake": "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx",
"duckdb": "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx",
},
)
@@ -831,6 +800,18 @@ WHERE
"snowflake": "SELECT LISTAGG(col1, ', ') WITHIN GROUP (ORDER BY col2) FROM t",
},
)
+ self.validate_all(
+ "SELECT APPROX_PERCENTILE(a, 0.5) FROM t",
+ read={
+ "trino": "SELECT APPROX_PERCENTILE(a, 1, 0.5, 0.001) FROM t",
+ "presto": "SELECT APPROX_PERCENTILE(a, 1, 0.5, 0.001) FROM t",
+ },
+ write={
+ "trino": "SELECT APPROX_PERCENTILE(a, 0.5) FROM t",
+ "presto": "SELECT APPROX_PERCENTILE(a, 0.5) FROM t",
+ "snowflake": "SELECT APPROX_PERCENTILE(a, 0.5) FROM t",
+ },
+ )
def test_null_treatment(self):
self.validate_all(
@@ -881,6 +862,10 @@ WHERE
self.validate_identity("SELECT * FROM @namespace.%table_name/path/to/file.json.gz")
self.validate_identity("SELECT * FROM '@external/location' (FILE_FORMAT => 'path.to.csv')")
self.validate_identity("PUT file:///dir/tmp.csv @%table", check_command_warning=True)
+ self.validate_identity("SELECT * FROM (SELECT a FROM @foo)")
+ self.validate_identity(
+ "SELECT * FROM (SELECT * FROM '@external/location' (FILE_FORMAT => 'path.to.csv'))"
+ )
self.validate_identity(
"SELECT * FROM @foo/bar (FILE_FORMAT => ds_sandbox.test.my_csv_format, PATTERN => 'test') AS bla"
)
@@ -902,18 +887,27 @@ WHERE
def test_sample(self):
self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)")
- self.validate_identity("SELECT * FROM testtable TABLESAMPLE (100)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)")
- self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)")
self.validate_identity(
- "SELECT i, j FROM table1 AS t1 INNER JOIN table2 AS t2 TABLESAMPLE (50) WHERE t2.j = t1.i"
+ "SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE BERNOULLI (0.1)"
+ )
+ self.validate_identity(
+ "SELECT i, j FROM table1 AS t1 INNER JOIN table2 AS t2 TABLESAMPLE BERNOULLI (50) WHERE t2.j = t1.i"
+ )
+ self.validate_identity(
+ "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE BERNOULLI (1)"
+ )
+ self.validate_identity(
+ "SELECT * FROM testtable TABLESAMPLE (10 ROWS)",
+ "SELECT * FROM testtable TABLESAMPLE BERNOULLI (10 ROWS)",
)
self.validate_identity(
- "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1)"
+ "SELECT * FROM testtable TABLESAMPLE (100)",
+ "SELECT * FROM testtable TABLESAMPLE BERNOULLI (100)",
)
self.validate_identity(
"SELECT * FROM testtable SAMPLE (10)",
- "SELECT * FROM testtable TABLESAMPLE (10)",
+ "SELECT * FROM testtable TABLESAMPLE BERNOULLI (10)",
)
self.validate_identity(
"SELECT * FROM testtable SAMPLE ROW (0)",
@@ -923,8 +917,30 @@ WHERE
"SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)",
"SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)",
)
+ self.validate_identity(
+ "SELECT user_id, value FROM table_name SAMPLE BERNOULLI ($s) SEED (0)",
+ "SELECT user_id, value FROM table_name TABLESAMPLE BERNOULLI ($s) SEED (0)",
+ )
self.validate_all(
+ "SELECT * FROM example TABLESAMPLE BERNOULLI (3) SEED (82)",
+ read={
+ "duckdb": "SELECT * FROM example TABLESAMPLE BERNOULLI (3 PERCENT) REPEATABLE (82)",
+ },
+ write={
+ "databricks": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)",
+ "duckdb": "SELECT * FROM example TABLESAMPLE BERNOULLI (3 PERCENT) REPEATABLE (82)",
+ "snowflake": "SELECT * FROM example TABLESAMPLE BERNOULLI (3) SEED (82)",
+ },
+ )
+ self.validate_all(
+ "SELECT * FROM test AS _tmp TABLESAMPLE (5)",
+ write={
+ "postgres": "SELECT * FROM test AS _tmp TABLESAMPLE BERNOULLI (5)",
+ "snowflake": "SELECT * FROM test AS _tmp TABLESAMPLE BERNOULLI (5)",
+ },
+ )
+ self.validate_all(
"""
SELECT i, j
FROM
@@ -933,7 +949,7 @@ WHERE
table2 AS t2 SAMPLE (50) -- 50% of rows in table2
WHERE t2.j = t1.i""",
write={
- "snowflake": "SELECT i, j FROM table1 AS t1 TABLESAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 TABLESAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i",
+ "snowflake": "SELECT i, j FROM table1 AS t1 TABLESAMPLE BERNOULLI (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 TABLESAMPLE BERNOULLI (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i",
},
)
self.validate_all(
@@ -945,7 +961,7 @@ WHERE
self.validate_all(
"SELECT * FROM (SELECT * FROM t1 join t2 on t1.a = t2.c) SAMPLE (1)",
write={
- "snowflake": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1)",
+ "snowflake": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE BERNOULLI (1)",
"spark": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1 PERCENT)",
},
)
@@ -1172,7 +1188,7 @@ WHERE
location=@s2/logs/
partition_type = user_specified
file_format = (type = parquet)""",
- "CREATE EXTERNAL TABLE et2 (col1 DATE AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL1') AS DATE)), col2 VARCHAR AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL2') AS VARCHAR)), col3 DECIMAL AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL3') AS DECIMAL))) LOCATION @s2/logs/ PARTITION BY (col1, col2, col3) partition_type=user_specified file_format=(type = parquet)",
+ "CREATE EXTERNAL TABLE et2 (col1 DATE AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL1') AS DATE)), col2 VARCHAR AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL2') AS VARCHAR)), col3 DECIMAL(38, 0) AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL3') AS DECIMAL(38, 0)))) LOCATION @s2/logs/ PARTITION BY (col1, col2, col3) partition_type=user_specified file_format=(type = parquet)",
)
self.validate_identity("CREATE OR REPLACE VIEW foo (uid) COPY GRANTS AS (SELECT 1)")
self.validate_identity("CREATE TABLE geospatial_table (id INT, g GEOGRAPHY)")
@@ -1181,8 +1197,12 @@ WHERE
self.validate_identity("CREATE SCHEMA mytestschema_clone CLONE testschema")
self.validate_identity("CREATE TABLE IDENTIFIER('foo') (COLUMN1 VARCHAR, COLUMN2 VARCHAR)")
self.validate_identity("CREATE TABLE IDENTIFIER($foo) (col1 VARCHAR, col2 VARCHAR)")
+ self.validate_identity("CREATE TAG cost_center ALLOWED_VALUES 'a', 'b'")
self.validate_identity(
- "DROP function my_udf (OBJECT(city VARCHAR, zipcode DECIMAL, val ARRAY(BOOLEAN)))"
+ "CREATE OR REPLACE TAG IF NOT EXISTS cost_center COMMENT='cost_center tag'"
+ ).this.assert_is(exp.Identifier)
+ self.validate_identity(
+ "DROP FUNCTION my_udf (OBJECT(city VARCHAR, zipcode DECIMAL(38, 0), val ARRAY(BOOLEAN)))"
)
self.validate_identity(
"CREATE TABLE orders_clone_restore CLONE orders AT (TIMESTAMP => TO_TIMESTAMP_TZ('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss'))"
@@ -1203,13 +1223,27 @@ WHERE
"CREATE ICEBERG TABLE my_iceberg_table (amount ARRAY(INT)) CATALOG='SNOWFLAKE' EXTERNAL_VOLUME='my_external_volume' BASE_LOCATION='my/relative/path/from/extvol'"
)
self.validate_identity(
- "CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL, val ARRAY(BOOLEAN))) RETURNS VARCHAR AS $$ SELECT 'foo' $$",
- "CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL, val ARRAY(BOOLEAN))) RETURNS VARCHAR AS ' SELECT \\'foo\\' '",
+ """CREATE OR REPLACE FUNCTION ibis_udfs.public.object_values("obj" OBJECT) RETURNS ARRAY LANGUAGE JAVASCRIPT RETURNS NULL ON NULL INPUT AS ' return Object.values(obj) '"""
+ )
+ self.validate_identity(
+ """CREATE OR REPLACE FUNCTION ibis_udfs.public.object_values("obj" OBJECT) RETURNS ARRAY LANGUAGE JAVASCRIPT STRICT AS ' return Object.values(obj) '"""
+ )
+ self.validate_identity(
+ "CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL(38, 0), val ARRAY(BOOLEAN))) RETURNS VARCHAR AS $$ SELECT 'foo' $$",
+ "CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL(38, 0), val ARRAY(BOOLEAN))) RETURNS VARCHAR AS ' SELECT \\'foo\\' '",
)
self.validate_identity(
"CREATE OR REPLACE FUNCTION my_udtf(foo BOOLEAN) RETURNS TABLE(col1 ARRAY(INT)) AS $$ WITH t AS (SELECT CAST([1, 2, 3] AS ARRAY(INT)) AS c) SELECT c FROM t $$",
"CREATE OR REPLACE FUNCTION my_udtf(foo BOOLEAN) RETURNS TABLE (col1 ARRAY(INT)) AS ' WITH t AS (SELECT CAST([1, 2, 3] AS ARRAY(INT)) AS c) SELECT c FROM t '",
)
+ self.validate_identity(
+ "CREATE SEQUENCE seq1 WITH START=1, INCREMENT=1 ORDER",
+ "CREATE SEQUENCE seq1 START=1 INCREMENT BY 1 ORDER",
+ )
+ self.validate_identity(
+ "CREATE SEQUENCE seq1 WITH START=1 INCREMENT=1 ORDER",
+ "CREATE SEQUENCE seq1 START=1 INCREMENT=1 ORDER",
+ )
self.validate_all(
"CREATE TABLE orders_clone CLONE orders",
@@ -1435,6 +1469,9 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene
)
def test_values(self):
+ select = exp.select("*").from_("values (map(['a'], [1]))")
+ self.assertEqual(select.sql("snowflake"), "SELECT * FROM (SELECT OBJECT_CONSTRUCT('a', 1))")
+
self.validate_all(
'SELECT "c0", "c1" FROM (VALUES (1, 2), (3, 4)) AS "t0"("c0", "c1")',
read={
@@ -1832,10 +1869,7 @@ STORAGE_ALLOWED_LOCATIONS=('s3://mybucket1/path1/', 's3://mybucket2/path2/')""",
self.assertEqual(expression.sql(dialect="snowflake"), "SELECT TRY_CAST(FOO() AS TEXT)")
def test_copy(self):
- self.validate_identity(
- """COPY INTO mytable (col1, col2) FROM 's3://mybucket/data/files' FILES = ('file1', 'file2') PATTERN = 'pattern' file_format = (FORMAT_NAME = my_csv_format NULL_IF = ('str1', 'str2')) PARSE_HEADER = TRUE""",
- """COPY INTO mytable (col1, col2) FROM 's3://mybucket/data/files' FILES = ('file1', 'file2') PATTERN = 'pattern' FILE_FORMAT = (FORMAT_NAME = my_csv_format NULL_IF = ('str1', 'str2')) PARSE_HEADER = TRUE""",
- )
+ self.validate_identity("COPY INTO test (c1) FROM (SELECT $1.c1 FROM @mystage)")
self.validate_identity(
"""COPY INTO temp FROM @random_stage/path/ FILE_FORMAT = (TYPE = CSV FIELD_DELIMITER = '|' NULL_IF = () FIELD_OPTIONALLY_ENCLOSED_BY = '"' TIMESTAMP_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' DATE_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' BINARY_FORMAT = BASE64) VALIDATION_MODE = 'RETURN_3_ROWS'"""
)
@@ -1845,3 +1879,81 @@ STORAGE_ALLOWED_LOCATIONS=('s3://mybucket1/path1/', 's3://mybucket2/path2/')""",
self.validate_identity(
"""COPY INTO mytable FROM 'azure://myaccount.blob.core.windows.net/mycontainer/data/files' CREDENTIALS = (AZURE_SAS_TOKEN = 'token') ENCRYPTION = (TYPE = 'AZURE_CSE' MASTER_KEY = 'kPx...') FILE_FORMAT = (FORMAT_NAME = my_csv_format)"""
)
+ self.validate_identity(
+ """COPY INTO mytable (col1, col2) FROM 's3://mybucket/data/files' FILES = ('file1', 'file2') PATTERN = 'pattern' FILE_FORMAT = (FORMAT_NAME = my_csv_format NULL_IF = ('str1', 'str2')) PARSE_HEADER = TRUE"""
+ )
+ self.validate_all(
+ """COPY INTO 's3://example/data.csv'
+ FROM EXTRA.EXAMPLE.TABLE
+ credentials = (x)
+ STORAGE_INTEGRATION = S3_INTEGRATION
+ FILE_FORMAT = (TYPE = CSV COMPRESSION = NONE NULL_IF = ('') FIELD_OPTIONALLY_ENCLOSED_BY = '"')
+ HEADER = TRUE
+ OVERWRITE = TRUE
+ SINGLE = TRUE
+ """,
+ write={
+ "": """COPY INTO 's3://example/data.csv'
+FROM EXTRA.EXAMPLE.TABLE
+CREDENTIALS = (x) WITH (
+ STORAGE_INTEGRATION S3_INTEGRATION,
+ FILE_FORMAT = (TYPE = CSV COMPRESSION = NONE NULL_IF = (
+ ''
+ ) FIELD_OPTIONALLY_ENCLOSED_BY = '"'),
+ HEADER TRUE,
+ OVERWRITE TRUE,
+ SINGLE TRUE
+)""",
+ "snowflake": """COPY INTO 's3://example/data.csv'
+FROM EXTRA.EXAMPLE.TABLE
+CREDENTIALS = (x)
+STORAGE_INTEGRATION = S3_INTEGRATION
+FILE_FORMAT = (TYPE = CSV COMPRESSION = NONE NULL_IF = (
+ ''
+) FIELD_OPTIONALLY_ENCLOSED_BY = '"')
+HEADER = TRUE
+OVERWRITE = TRUE
+SINGLE = TRUE""",
+ },
+ pretty=True,
+ )
+ self.validate_all(
+ """COPY INTO 's3://example/data.csv'
+ FROM EXTRA.EXAMPLE.TABLE
+ credentials = (x)
+ STORAGE_INTEGRATION = S3_INTEGRATION
+ FILE_FORMAT = (TYPE = CSV COMPRESSION = NONE NULL_IF = ('') FIELD_OPTIONALLY_ENCLOSED_BY = '"')
+ HEADER = TRUE
+ OVERWRITE = TRUE
+ SINGLE = TRUE
+ """,
+ write={
+ "": """COPY INTO 's3://example/data.csv' FROM EXTRA.EXAMPLE.TABLE CREDENTIALS = (x) WITH (STORAGE_INTEGRATION S3_INTEGRATION, FILE_FORMAT = (TYPE = CSV COMPRESSION = NONE NULL_IF = ('') FIELD_OPTIONALLY_ENCLOSED_BY = '"'), HEADER TRUE, OVERWRITE TRUE, SINGLE TRUE)""",
+ "snowflake": """COPY INTO 's3://example/data.csv' FROM EXTRA.EXAMPLE.TABLE CREDENTIALS = (x) STORAGE_INTEGRATION = S3_INTEGRATION FILE_FORMAT = (TYPE = CSV COMPRESSION = NONE NULL_IF = ('') FIELD_OPTIONALLY_ENCLOSED_BY = '"') HEADER = TRUE OVERWRITE = TRUE SINGLE = TRUE""",
+ },
+ )
+
+ def test_querying_semi_structured_data(self):
+ self.validate_identity("SELECT $1")
+ self.validate_identity("SELECT $1.elem")
+
+ self.validate_identity("SELECT $1:a.b", "SELECT GET_PATH($1, 'a.b')")
+ self.validate_identity("SELECT t.$23:a.b", "SELECT GET_PATH(t.$23, 'a.b')")
+ self.validate_identity("SELECT t.$17:a[0].b[0].c", "SELECT GET_PATH(t.$17, 'a[0].b[0].c')")
+
+ def test_alter_set_unset(self):
+ self.validate_identity("ALTER TABLE tbl SET DATA_RETENTION_TIME_IN_DAYS=1")
+ self.validate_identity("ALTER TABLE tbl SET DEFAULT_DDL_COLLATION='test'")
+ self.validate_identity("ALTER TABLE foo SET COMMENT='bar'")
+ self.validate_identity("ALTER TABLE foo SET CHANGE_TRACKING=FALSE")
+ self.validate_identity("ALTER TABLE table1 SET TAG foo.bar = 'baz'")
+ self.validate_identity("ALTER TABLE IF EXISTS foo SET TAG a = 'a', b = 'b', c = 'c'")
+ self.validate_identity(
+ """ALTER TABLE tbl SET STAGE_FILE_FORMAT = (TYPE = CSV FIELD_DELIMITER = '|' NULL_IF = () FIELD_OPTIONALLY_ENCLOSED_BY = '"' TIMESTAMP_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' DATE_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' BINARY_FORMAT = BASE64)""",
+ )
+ self.validate_identity(
+ """ALTER TABLE tbl SET STAGE_COPY_OPTIONS = (ON_ERROR = SKIP_FILE SIZE_LIMIT = 5 PURGE = TRUE MATCH_BY_COLUMN_NAME = CASE_SENSITIVE)"""
+ )
+
+ self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c")
+ self.validate_identity("ALTER TABLE foo UNSET DATA_RETENTION_TIME_IN_DAYS, CHANGE_TRACKING")
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index 1538d47..45a4657 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -1,5 +1,4 @@
from sqlglot import exp, parse, parse_one
-from sqlglot.parser import logger as parser_logger
from tests.dialects.test_dialect import Validator
from sqlglot.errors import ParseError
@@ -8,6 +7,8 @@ class TestTSQL(Validator):
dialect = "tsql"
def test_tsql(self):
+ self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c")
+ self.validate_identity("DROP view a.b.c", "DROP VIEW b.c")
self.validate_identity("ROUND(x, 1, 0)")
self.validate_identity("EXEC MyProc @id=7, @name='Lochristi'", check_command_warning=True)
# https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN
@@ -191,16 +192,9 @@ class TestTSQL(Validator):
)
self.validate_all(
- """
- CREATE TABLE x(
- [zip_cd] [varchar](5) NULL NOT FOR REPLICATION,
- [zip_cd_mkey] [varchar](5) NOT NULL,
- CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC)
- WITH (PAD_INDEX = ON, STATISTICS_NORECOMPUTE = OFF) ON [INDEX]
- ) ON [SECONDARY]
- """,
+ """CREATE TABLE x ([zip_cd] VARCHAR(5) NULL NOT FOR REPLICATION, [zip_cd_mkey] VARCHAR(5) NOT NULL, CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC) WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON [INDEX]) ON [SECONDARY]""",
write={
- "tsql": "CREATE TABLE x ([zip_cd] VARCHAR(5) NULL NOT FOR REPLICATION, [zip_cd_mkey] VARCHAR(5) NOT NULL, CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC) WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON [INDEX]) ON [SECONDARY]",
+ "tsql": "CREATE TABLE x ([zip_cd] VARCHAR(5) NULL NOT FOR REPLICATION, [zip_cd_mkey] VARCHAR(5) NOT NULL, CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC) WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON [INDEX]) ON [SECONDARY]",
"spark2": "CREATE TABLE x (`zip_cd` VARCHAR(5), `zip_cd_mkey` VARCHAR(5) NOT NULL, CONSTRAINT `pk_mytable` PRIMARY KEY (`zip_cd_mkey`))",
},
)
@@ -259,7 +253,7 @@ class TestTSQL(Validator):
self.validate_identity("SELECT * FROM ##foo")
self.validate_identity("SELECT a = 1", "SELECT 1 AS a")
self.validate_identity(
- "DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'", check_command_warning=True
+ "DECLARE @TestVariable AS VARCHAR(100) = 'Save Our Planet'",
)
self.validate_identity(
"SELECT a = 1 UNION ALL SELECT a = b", "SELECT 1 AS a UNION ALL SELECT b AS a"
@@ -461,6 +455,7 @@ class TestTSQL(Validator):
self.validate_identity("CAST(x AS IMAGE)")
self.validate_identity("CAST(x AS SQL_VARIANT)")
self.validate_identity("CAST(x AS BIT)")
+
self.validate_all(
"CAST(x AS DATETIME2)",
read={
@@ -488,7 +483,7 @@ class TestTSQL(Validator):
},
)
- def test__types_ints(self):
+ def test_types_ints(self):
self.validate_all(
"CAST(X AS INT)",
write={
@@ -521,10 +516,14 @@ class TestTSQL(Validator):
self.validate_all(
"CAST(X AS TINYINT)",
+ read={
+ "duckdb": "CAST(X AS UTINYINT)",
+ },
write={
- "hive": "CAST(X AS TINYINT)",
- "spark2": "CAST(X AS TINYINT)",
- "spark": "CAST(X AS TINYINT)",
+ "duckdb": "CAST(X AS UTINYINT)",
+ "hive": "CAST(X AS SMALLINT)",
+ "spark2": "CAST(X AS SMALLINT)",
+ "spark": "CAST(X AS SMALLINT)",
"tsql": "CAST(X AS TINYINT)",
},
)
@@ -764,19 +763,33 @@ class TestTSQL(Validator):
expression.sql(dialect="tsql"), "ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B"
)
- for clusterd_keyword in ("CLUSTERED", "NONCLUSTERED"):
+ for clustered_keyword in ("CLUSTERED", "NONCLUSTERED"):
self.validate_identity(
'CREATE TABLE "dbo"."benchmark" ('
'"name" CHAR(7) NOT NULL, '
'"internal_id" VARCHAR(10) NOT NULL, '
- f'UNIQUE {clusterd_keyword} ("internal_id" ASC))',
+ f'UNIQUE {clustered_keyword} ("internal_id" ASC))',
"CREATE TABLE [dbo].[benchmark] ("
"[name] CHAR(7) NOT NULL, "
"[internal_id] VARCHAR(10) NOT NULL, "
- f"UNIQUE {clusterd_keyword} ([internal_id] ASC))",
+ f"UNIQUE {clustered_keyword} ([internal_id] ASC))",
)
self.validate_identity(
+ "ALTER TABLE tbl SET SYSTEM_VERSIONING=ON(HISTORY_TABLE=db.tbl, DATA_CONSISTENCY_CHECK=OFF, HISTORY_RETENTION_PERIOD=5 DAYS)"
+ )
+ self.validate_identity(
+ "ALTER TABLE tbl SET SYSTEM_VERSIONING=ON(HISTORY_TABLE=db.tbl, HISTORY_RETENTION_PERIOD=INFINITE)"
+ )
+ self.validate_identity("ALTER TABLE tbl SET SYSTEM_VERSIONING=OFF")
+ self.validate_identity("ALTER TABLE tbl SET FILESTREAM_ON = 'test'")
+ self.validate_identity(
+ "ALTER TABLE tbl SET DATA_DELETION=ON(FILTER_COLUMN=col, RETENTION_PERIOD=5 MONTHS)"
+ )
+ self.validate_identity("ALTER TABLE tbl SET DATA_DELETION=ON")
+ self.validate_identity("ALTER TABLE tbl SET DATA_DELETION=OFF")
+
+ self.validate_identity(
"CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < CURRENT_TIMESTAMP - 7 END",
"CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < GETDATE() - 7 END",
)
@@ -900,8 +913,7 @@ class TestTSQL(Validator):
def test_udf(self):
self.validate_identity(
- "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)",
- check_command_warning=True,
+ "DECLARE @DWH_DateCreated AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)",
)
self.validate_identity(
"CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar"
@@ -973,9 +985,9 @@ WHERE
BEGIN
SET XACT_ABORT ON;
- DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104);
- DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104);
- DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER);
+ DECLARE @DWH_DateCreated AS DATETIME = CONVERT(DATETIME, getdate(), 104);
+ DECLARE @DWH_DateModified DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104);
+ DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (CURRENT_USER());
DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER);
DECLARE @SalesAmountBefore float;
@@ -985,18 +997,17 @@ WHERE
expected_sqls = [
"CREATE PROCEDURE [TRANSF].[SP_Merge_Sales_Real] @Loadid INTEGER, @NumberOfRows INTEGER AS BEGIN SET XACT_ABORT ON",
- "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)",
- "DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104)",
- "DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER)",
- "DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER)",
- "DECLARE @SalesAmountBefore float",
+ "DECLARE @DWH_DateCreated AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)",
+ "DECLARE @DWH_DateModified AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)",
+ "DECLARE @DWH_IdUserCreated AS INTEGER = SUSER_ID(CURRENT_USER())",
+ "DECLARE @DWH_IdUserModified AS INTEGER = SUSER_ID(CURRENT_USER())",
+ "DECLARE @SalesAmountBefore AS FLOAT",
"SELECT @SalesAmountBefore = SUM(SalesAmount) FROM TRANSF.[Pre_Merge_Sales_Real] AS S",
"END",
]
- with self.assertLogs(parser_logger):
- for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
- self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
+ for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
+ self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
sql = """
CREATE PROC [dbo].[transform_proc] AS
@@ -1010,14 +1021,13 @@ WHERE
"""
expected_sqls = [
- "CREATE PROC [dbo].[transform_proc] AS DECLARE @CurrentDate VARCHAR(20)",
+ "CREATE PROC [dbo].[transform_proc] AS DECLARE @CurrentDate AS VARCHAR(20)",
"SET @CurrentDate = CONVERT(VARCHAR(20), GETDATE(), 120)",
"CREATE TABLE [target_schema].[target_table] (a INTEGER) WITH (DISTRIBUTION=REPLICATE, HEAP)",
]
- with self.assertLogs(parser_logger):
- for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
- self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
+ for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
+ self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
def test_charindex(self):
self.validate_identity(
@@ -1823,3 +1833,35 @@ FROM OPENJSON(@json) WITH (
"duckdb": "WITH t1(c) AS (SELECT 1), t2 AS (SELECT CAST(c AS INTEGER) FROM t1) SELECT * FROM t2",
},
)
+
+ def test_declare(self):
+ # supported cases
+ self.validate_identity("DECLARE @X INT", "DECLARE @X AS INTEGER")
+ self.validate_identity("DECLARE @X INT = 1", "DECLARE @X AS INTEGER = 1")
+ self.validate_identity(
+ "DECLARE @X INT, @Y VARCHAR(10)", "DECLARE @X AS INTEGER, @Y AS VARCHAR(10)"
+ )
+ self.validate_identity(
+ "declare @X int = (select col from table where id = 1)",
+ "DECLARE @X AS INTEGER = (SELECT col FROM table WHERE id = 1)",
+ )
+ self.validate_identity(
+ "declare @X TABLE (Id INT NOT NULL, Name VARCHAR(100) NOT NULL)",
+ "DECLARE @X AS TABLE (Id INTEGER NOT NULL, Name VARCHAR(100) NOT NULL)",
+ )
+ self.validate_identity(
+ "declare @X TABLE (Id INT NOT NULL, constraint PK_Id primary key (Id))",
+ "DECLARE @X AS TABLE (Id INTEGER NOT NULL, CONSTRAINT PK_Id PRIMARY KEY (Id))",
+ )
+ self.validate_identity(
+ "declare @X UserDefinedTableType",
+ "DECLARE @X AS UserDefinedTableType",
+ )
+ self.validate_identity(
+ "DECLARE @MyTableVar TABLE (EmpID INT NOT NULL, PRIMARY KEY CLUSTERED (EmpID), UNIQUE NONCLUSTERED (EmpID), INDEX CustomNonClusteredIndex NONCLUSTERED (EmpID))",
+ check_command_warning=True,
+ )
+ self.validate_identity(
+ "DECLARE vendor_cursor CURSOR FOR SELECT VendorID, Name FROM Purchasing.Vendor WHERE PreferredVendorStatus = 1 ORDER BY VendorID",
+ check_command_warning=True,
+ )
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 6b742c3..13a6153 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -869,4 +869,5 @@ TRUNCATE(a, b)
SELECT enum
SELECT unlogged
SELECT name
-SELECT copy \ No newline at end of file
+SELECT copy
+SELECT rollup \ No newline at end of file
diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
index a357b07..5b004fa 100644
--- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
+++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
@@ -1409,31 +1409,11 @@ WITH "_u_0" AS (
"store_sales"."ss_quantity" <= 80 AND "store_sales"."ss_quantity" >= 61
)
SELECT
- CASE
- WHEN MAX("_u_0"."_col_0") > 3672
- THEN MAX("_u_1"."_col_0")
- ELSE MAX("_u_2"."_col_0")
- END AS "bucket1",
- CASE
- WHEN MAX("_u_3"."_col_0") > 3392
- THEN MAX("_u_4"."_col_0")
- ELSE MAX("_u_5"."_col_0")
- END AS "bucket2",
- CASE
- WHEN MAX("_u_6"."_col_0") > 32784
- THEN MAX("_u_7"."_col_0")
- ELSE MAX("_u_8"."_col_0")
- END AS "bucket3",
- CASE
- WHEN MAX("_u_9"."_col_0") > 26032
- THEN MAX("_u_10"."_col_0")
- ELSE MAX("_u_11"."_col_0")
- END AS "bucket4",
- CASE
- WHEN MAX("_u_12"."_col_0") > 23982
- THEN MAX("_u_13"."_col_0")
- ELSE MAX("_u_14"."_col_0")
- END AS "bucket5"
+ CASE WHEN "_u_0"."_col_0" > 3672 THEN "_u_1"."_col_0" ELSE "_u_2"."_col_0" END AS "bucket1",
+ CASE WHEN "_u_3"."_col_0" > 3392 THEN "_u_4"."_col_0" ELSE "_u_5"."_col_0" END AS "bucket2",
+ CASE WHEN "_u_6"."_col_0" > 32784 THEN "_u_7"."_col_0" ELSE "_u_8"."_col_0" END AS "bucket3",
+ CASE WHEN "_u_9"."_col_0" > 26032 THEN "_u_10"."_col_0" ELSE "_u_11"."_col_0" END AS "bucket4",
+ CASE WHEN "_u_12"."_col_0" > 23982 THEN "_u_13"."_col_0" ELSE "_u_14"."_col_0" END AS "bucket5"
FROM "reason" AS "reason"
CROSS JOIN "_u_0" AS "_u_0"
CROSS JOIN "_u_1" AS "_u_1"
diff --git a/tests/fixtures/optimizer/unnest_subqueries.sql b/tests/fixtures/optimizer/unnest_subqueries.sql
index 45e462b..a5a35b1 100644
--- a/tests/fixtures/optimizer/unnest_subqueries.sql
+++ b/tests/fixtures/optimizer/unnest_subqueries.sql
@@ -1,243 +1,69 @@
---SELECT x.a > (SELECT SUM(y.a) AS b FROM y) FROM x;
---------------------------------------
--- Unnest Subqueries
---------------------------------------
-SELECT *
-FROM x AS x
-WHERE
- x.a = (SELECT SUM(y.a) AS a FROM y)
- AND x.a IN (SELECT y.a AS a FROM y)
- AND x.a IN (SELECT y.b AS b FROM y)
- AND x.a = ANY (SELECT y.a AS a FROM y)
- AND x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a)
- AND x.a > (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a)
- AND x.a <> ANY (SELECT y.a AS a FROM y WHERE y.a = x.a)
- AND x.a NOT IN (SELECT y.a AS a FROM y WHERE y.a = x.a)
- AND x.a IN (SELECT y.a AS a FROM y WHERE y.b = x.a)
- AND x.a < (SELECT SUM(y.a) AS a FROM y WHERE y.a = x.a and y.a = x.b and y.b <> x.d)
- AND EXISTS (SELECT y.a AS a, y.b AS b FROM y WHERE x.a = y.a)
- AND x.a IN (SELECT y.a AS a FROM y LIMIT 10)
- AND x.a IN (SELECT y.a AS a FROM y OFFSET 10)
- AND x.a IN (SELECT y.a AS a, y.b AS b FROM y)
- AND x.a > ANY (SELECT y.a FROM y)
- AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10)
- AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10)
- AND x.a > ALL (SELECT y.c FROM y WHERE y.a = x.a)
- AND x.a > (SELECT COUNT(*) as d FROM y WHERE y.a = x.a)
- AND x.a = SUM(SELECT 1) -- invalid statement left alone
- AND x.a IN (SELECT max(y.b) AS b FROM y GROUP BY y.a)
-;
-SELECT
- *
-FROM x AS x
-CROSS JOIN (
- SELECT
- SUM(y.a) AS a
- FROM y
-) AS _u_0
-LEFT JOIN (
- SELECT
- y.a AS a
- FROM y
- GROUP BY
- y.a
-) AS _u_1
- ON x.a = _u_1.a
-LEFT JOIN (
- SELECT
- y.b AS b
- FROM y
- GROUP BY
- y.b
-) AS _u_2
- ON x.a = _u_2.b
-LEFT JOIN (
- SELECT
- y.a AS a
- FROM y
- GROUP BY
- y.a
-) AS _u_3
- ON x.a = _u_3.a
-LEFT JOIN (
- SELECT
- SUM(y.b) AS b,
- y.a AS _u_5
- FROM y
- WHERE
- TRUE
- GROUP BY
- y.a
-) AS _u_4
- ON x.a = _u_4._u_5
-LEFT JOIN (
- SELECT
- SUM(y.b) AS b,
- y.a AS _u_7
- FROM y
- WHERE
- TRUE
- GROUP BY
- y.a
-) AS _u_6
- ON x.a = _u_6._u_7
-LEFT JOIN (
- SELECT
- y.a AS a
- FROM y
- WHERE
- TRUE
- GROUP BY
- y.a
-) AS _u_8
- ON _u_8.a = x.a
-LEFT JOIN (
- SELECT
- y.a AS a
- FROM y
- WHERE
- TRUE
- GROUP BY
- y.a
-) AS _u_9
- ON _u_9.a = x.a
-LEFT JOIN (
- SELECT
- ARRAY_AGG(y.a) AS a,
- y.b AS _u_11
- FROM y
- WHERE
- TRUE
- GROUP BY
- y.b
-) AS _u_10
- ON _u_10._u_11 = x.a
-LEFT JOIN (
- SELECT
- SUM(y.a) AS a,
- y.a AS _u_13,
- ARRAY_AGG(y.b) AS _u_14
- FROM y
- WHERE
- TRUE AND TRUE AND TRUE
- GROUP BY
- y.a
-) AS _u_12
- ON _u_12._u_13 = x.a AND _u_12._u_13 = x.b
-LEFT JOIN (
- SELECT
- y.a AS a
- FROM y
- WHERE
- TRUE
- GROUP BY
- y.a
-) AS _u_15
- ON x.a = _u_15.a
-LEFT JOIN (
- SELECT
- ARRAY_AGG(c),
- y.a AS _u_20
- FROM y
- WHERE
- TRUE
- GROUP BY
- y.a
-) AS _u_19
- ON _u_19._u_20 = x.a
-LEFT JOIN (
- SELECT
- COUNT(*) AS d,
- y.a AS _u_22
- FROM y
- WHERE
- TRUE
- GROUP BY
- y.a
-) AS _u_21
- ON _u_21._u_22 = x.a
-LEFT JOIN (
- SELECT
- _q.b
- FROM (
- SELECT
- MAX(y.b) AS b
- FROM y
- GROUP BY
- y.a
- ) AS _q
- GROUP BY
- _q.b
-) AS _u_24
- ON x.a = _u_24.b
-WHERE
- x.a = _u_0.a
- AND NOT _u_1.a IS NULL
- AND NOT _u_2.b IS NULL
- AND NOT _u_3.a IS NULL
- AND x.a = _u_4.b
- AND x.a > _u_6.b
- AND x.a = _u_8.a
- AND NOT x.a = _u_9.a
- AND ARRAY_ANY(_u_10.a, _x -> _x = x.a)
- AND (
- x.a < _u_12.a AND ARRAY_ANY(_u_12._u_14, _x -> _x <> x.d)
- )
- AND NOT _u_15.a IS NULL
- AND x.a IN (
- SELECT
- y.a AS a
- FROM y
- LIMIT 10
- )
- AND x.a IN (
- SELECT
- y.a AS a
- FROM y
- OFFSET 10
- )
- AND x.a IN (
- SELECT
- y.a AS a,
- y.b AS b
- FROM y
- )
- AND x.a > ANY (
- SELECT
- y.a
- FROM y
- )
- AND x.a = (
- SELECT
- SUM(y.c) AS c
- FROM y
- WHERE
- y.a = x.a
- LIMIT 10
- )
- AND x.a = (
- SELECT
- SUM(y.c) AS c
- FROM y
- WHERE
- y.a = x.a
- OFFSET 10
- )
- AND ARRAY_ALL(_u_19."", _x -> _x = x.a)
- AND x.a > COALESCE(_u_21.d, 0)
- AND x.a = SUM(SELECT
- 1) /* invalid statement left alone */
- AND NOT _u_24.b IS NULL
-;
-SELECT
- CAST((
- SELECT
- x.a AS a
- FROM x
- ) AS TEXT) AS a;
-SELECT
- CAST((
- SELECT
- x.a AS a
- FROM x
- ) AS TEXT) AS a;
+SELECT * FROM x WHERE x.a = (SELECT SUM(y.a) AS a FROM y);
+SELECT * FROM x CROSS JOIN (SELECT SUM(y.a) AS a FROM y) AS _u_0 WHERE x.a = _u_0.a;
+
+SELECT * FROM x WHERE x.a IN (SELECT y.a AS a FROM y);
+SELECT * FROM x LEFT JOIN (SELECT y.a AS a FROM y GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE NOT _u_0.a IS NULL;
+
+SELECT * FROM x WHERE x.a IN (SELECT y.b AS b FROM y);
+SELECT * FROM x LEFT JOIN (SELECT y.b AS b FROM y GROUP BY y.b) AS _u_0 ON x.a = _u_0.b WHERE NOT _u_0.b IS NULL;
+
+SELECT * FROM x WHERE x.a = ANY (SELECT y.a AS a FROM y);
+SELECT * FROM x LEFT JOIN (SELECT y.a AS a FROM y GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE NOT _u_0.a IS NULL;
+
+SELECT * FROM x WHERE x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a);
+SELECT * FROM x LEFT JOIN (SELECT SUM(y.b) AS b, y.a AS _u_1 FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0._u_1 WHERE x.a = _u_0.b;
+
+SELECT * FROM x WHERE x.a > (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a);
+SELECT * FROM x LEFT JOIN (SELECT SUM(y.b) AS b, y.a AS _u_1 FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0._u_1 WHERE x.a > _u_0.b;
+
+SELECT * FROM x WHERE x.a <> ANY (SELECT y.a AS a FROM y WHERE y.a = x.a);
+SELECT * FROM x LEFT JOIN (SELECT y.a AS a FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON _u_0.a = x.a WHERE x.a <> _u_0.a;
+
+SELECT * FROM x WHERE x.a NOT IN (SELECT y.a AS a FROM y WHERE y.a = x.a);
+SELECT * FROM x LEFT JOIN (SELECT y.a AS a FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON _u_0.a = x.a WHERE NOT x.a = _u_0.a;
+
+SELECT * FROM x WHERE x.a IN (SELECT y.a AS a FROM y WHERE y.b = x.a);
+SELECT * FROM x LEFT JOIN (SELECT ARRAY_AGG(y.a) AS a, y.b AS _u_1 FROM y WHERE TRUE GROUP BY y.b) AS _u_0 ON _u_0._u_1 = x.a WHERE ARRAY_ANY(_u_0.a, _x -> _x = x.a);
+
+SELECT * FROM x WHERE x.a < (SELECT SUM(y.a) AS a FROM y WHERE y.a = x.a and y.a = x.b and y.b <> x.d);
+SELECT * FROM x LEFT JOIN (SELECT SUM(y.a) AS a, y.a AS _u_1, ARRAY_AGG(y.b) AS _u_2 FROM y WHERE TRUE AND TRUE AND TRUE GROUP BY y.a) AS _u_0 ON _u_0._u_1 = x.a AND _u_0._u_1 = x.b WHERE (x.a < _u_0.a AND ARRAY_ANY(_u_0._u_2, _x -> _x <> x.d));
+
+SELECT * FROM x WHERE EXISTS (SELECT y.a AS a, y.b AS b FROM y WHERE x.a = y.a);
+SELECT * FROM x LEFT JOIN (SELECT y.a AS a FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE NOT _u_0.a IS NULL;
+
+SELECT * FROM x WHERE x.a IN (SELECT y.a AS a FROM y LIMIT 10);
+SELECT * FROM x WHERE x.a IN (SELECT y.a AS a FROM y LIMIT 10);
+
+SELECT * FROM x.a WHERE x.a IN (SELECT y.a AS a FROM y OFFSET 10);
+SELECT * FROM x.a WHERE x.a IN (SELECT y.a AS a FROM y OFFSET 10);
+
+SELECT * FROM x.a WHERE x.a IN (SELECT y.a AS a, y.b AS b FROM y);
+SELECT * FROM x.a WHERE x.a IN (SELECT y.a AS a, y.b AS b FROM y);
+
+SELECT * FROM x.a WHERE x.a > ANY (SELECT y.a FROM y);
+SELECT * FROM x.a WHERE x.a > ANY (SELECT y.a FROM y);
+
+SELECT * FROM x WHERE x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10);
+SELECT * FROM x WHERE x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10);
+
+SELECT * FROM x WHERE x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10);
+SELECT * FROM x WHERE x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10);
+
+SELECT * FROM x WHERE x.a > ALL (SELECT y.c AS c FROM y WHERE y.a = x.a);
+SELECT * FROM x LEFT JOIN (SELECT ARRAY_AGG(y.c) AS c, y.a AS _u_1 FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON _u_0._u_1 = x.a WHERE ARRAY_ALL(_u_0.c, _x -> x.a > _x);
+
+SELECT * FROM x WHERE x.a > (SELECT COUNT(*) as d FROM y WHERE y.a = x.a);
+SELECT * FROM x LEFT JOIN (SELECT COUNT(*) AS d, y.a AS _u_1 FROM y WHERE TRUE GROUP BY y.a) AS _u_0 ON _u_0._u_1 = x.a WHERE x.a > COALESCE(_u_0.d, 0);
+
+# title: invalid statement left alone
+SELECT * FROM x WHERE x.a = SUM(SELECT 1);
+SELECT * FROM x WHERE x.a = SUM(SELECT 1);
+
+SELECT * FROM x WHERE x.a IN (SELECT max(y.b) AS b FROM y GROUP BY y.a);
+SELECT * FROM x LEFT JOIN (SELECT _q.b AS b FROM (SELECT MAX(y.b) AS b FROM y GROUP BY y.a) AS _q GROUP BY _q.b) AS _u_0 ON x.a = _u_0.b WHERE NOT _u_0.b IS NULL;
+
+SELECT x.a > (SELECT SUM(y.a) AS b FROM y) FROM x;
+SELECT x.a > _u_0.b FROM x CROSS JOIN (SELECT SUM(y.a) AS b FROM y) AS _u_0;
+
+SELECT (SELECT MAX(t2.c1) AS c1 FROM t2 WHERE t2.c2 = t1.c2 AND t2.c3 <= TRUNC(t1.c3)) AS c FROM t1;
+SELECT _u_0.c1 AS c FROM t1 LEFT JOIN (SELECT MAX(t2.c1) AS c1, t2.c2 AS _u_1, MAX(t2.c3) AS _u_2 FROM t2 WHERE TRUE AND TRUE GROUP BY t2.c2) AS _u_0 ON _u_0._u_1 = t1.c2 WHERE _u_0._u_2 <= TRUNC(t1.c3);
diff --git a/tests/test_executor.py b/tests/test_executor.py
index 1eaca14..317b930 100644
--- a/tests/test_executor.py
+++ b/tests/test_executor.py
@@ -707,9 +707,15 @@ class TestExecutor(unittest.TestCase):
("ROUND(1.2)", 1),
("ROUND(1.2345, 2)", 1.23),
("ROUND(NULL)", None),
- ("UNIXTOTIME(1659981729)", datetime.datetime(2022, 8, 8, 18, 2, 9)),
+ (
+ "UNIXTOTIME(1659981729)",
+ datetime.datetime(2022, 8, 8, 18, 2, 9, tzinfo=datetime.timezone.utc),
+ ),
("TIMESTRTOTIME('2013-04-05 01:02:03')", datetime.datetime(2013, 4, 5, 1, 2, 3)),
- ("UNIXTOTIME(40 * 365 * 86400)", datetime.datetime(2009, 12, 22, 00, 00, 00)),
+ (
+ "UNIXTOTIME(40 * 365 * 86400)",
+ datetime.datetime(2009, 12, 22, 00, 00, 00, tzinfo=datetime.timezone.utc),
+ ),
(
"STRTOTIME('08/03/2024 12:34:56', '%d/%m/%Y %H:%M:%S')",
datetime.datetime(2024, 3, 8, 12, 34, 56),
diff --git a/tests/test_expressions.py b/tests/test_expressions.py
index 81cfb86..1395b24 100644
--- a/tests/test_expressions.py
+++ b/tests/test_expressions.py
@@ -674,7 +674,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("STANDARD_HASH('hello', 'sha256')"), exp.StandardHash)
self.assertIsInstance(parse_one("DATE(foo)"), exp.Date)
self.assertIsInstance(parse_one("HEX(foo)"), exp.Hex)
- self.assertIsInstance(parse_one("TO_HEX(foo)", read="bigquery"), exp.Hex)
+ self.assertIsInstance(parse_one("LOWER(HEX(foo))"), exp.LowerHex)
+ self.assertIsInstance(parse_one("TO_HEX(foo)", read="bigquery"), exp.LowerHex)
+ self.assertIsInstance(parse_one("UPPER(TO_HEX(foo))", read="bigquery"), exp.Hex)
self.assertIsInstance(parse_one("TO_HEX(MD5(foo))", read="bigquery"), exp.MD5)
self.assertIsInstance(parse_one("TRANSFORM(a, b)", read="spark"), exp.Transform)
self.assertIsInstance(parse_one("ADD_MONTHS(a, b)"), exp.AddMonths)
@@ -834,21 +836,22 @@ class TestExpressions(unittest.TestCase):
b AS B,
c, /*comment*/
d AS D, -- another comment
- CAST(x AS INT) -- final comment
+ CAST(x AS INT), -- yet another comment
+ y AND /* foo */ w AS E -- final comment
FROM foo
"""
expression = parse_one(sql)
self.assertEqual(
[e.alias_or_name for e in expression.expressions],
- ["a", "B", "c", "D", "x"],
+ ["a", "B", "c", "D", "x", "E"],
)
self.assertEqual(
expression.sql(),
- "SELECT a, b AS B, c /* comment */, d AS D /* another comment */, CAST(x AS INT) /* final comment */ FROM foo",
+ "SELECT a, b AS B, c /* comment */, d AS D /* another comment */, CAST(x AS INT) /* yet another comment */, y AND /* foo */ w AS E /* final comment */ FROM foo",
)
self.assertEqual(
expression.sql(comments=False),
- "SELECT a, b AS B, c, d AS D, CAST(x AS INT) FROM foo",
+ "SELECT a, b AS B, c, d AS D, CAST(x AS INT), y AND w AS E FROM foo",
)
self.assertEqual(
expression.sql(pretty=True, comments=False),
@@ -857,7 +860,8 @@ class TestExpressions(unittest.TestCase):
b AS B,
c,
d AS D,
- CAST(x AS INT)
+ CAST(x AS INT),
+ y AND w AS E
FROM foo""",
)
self.assertEqual(
@@ -867,7 +871,8 @@ FROM foo""",
b AS B,
c, /* comment */
d AS D, /* another comment */
- CAST(x AS INT) /* final comment */
+ CAST(x AS INT), /* yet another comment */
+ y AND /* foo */ w AS E /* final comment */
FROM foo""",
)
diff --git a/tests/test_lineage.py b/tests/test_lineage.py
index 3e17f95..036f146 100644
--- a/tests/test_lineage.py
+++ b/tests/test_lineage.py
@@ -487,3 +487,11 @@ class TestLineage(unittest.TestCase):
downstream = node.downstream[0]
self.assertEqual(downstream.name, "z.a")
self.assertEqual(downstream.source.sql(), "SELECT y.a AS a, y.b AS b, y.c AS c FROM y AS y")
+
+ def test_node_name_doesnt_contain_comment(self) -> None:
+ sql = "SELECT * FROM (SELECT x /* c */ FROM t1) AS t2"
+ node = lineage("x", sql)
+
+ self.assertEqual(len(node.downstream), 1)
+ self.assertEqual(len(node.downstream[0].downstream), 1)
+ self.assertEqual(node.downstream[0].downstream[0].name, "t1.x")
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 758b60c..36768f8 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -428,11 +428,7 @@ SELECT :with,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expr
)
def test_unnest_subqueries(self):
- self.check_file(
- "unnest_subqueries",
- optimizer.unnest_subqueries.unnest_subqueries,
- pretty=True,
- )
+ self.check_file("unnest_subqueries", optimizer.unnest_subqueries.unnest_subqueries)
def test_pushdown_predicates(self):
self.check_file("pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates)
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 6bcdb64..2cefc07 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -503,7 +503,7 @@ class TestParser(unittest.TestCase):
self.assertIsInstance(set_item, exp.SetItem)
self.assertIsInstance(set_item.this, exp.EQ)
- self.assertIsInstance(set_item.this.this, exp.Identifier)
+ self.assertIsInstance(set_item.this.this, exp.Column)
self.assertIsInstance(set_item.this.expression, exp.Literal)
self.assertEqual(set_item.args.get("kind"), "SESSION")
@@ -856,5 +856,38 @@ class TestParser(unittest.TestCase):
with self.subTest(dialect):
self.assertEqual(parse_one(sql, dialect=dialect).sql(dialect=dialect), sql)
+ def test_alter_set(self):
+ sqls = [
+ "ALTER TABLE tbl SET TBLPROPERTIES ('x'='1', 'Z'='2')",
+ "ALTER TABLE tbl SET SERDE 'test' WITH SERDEPROPERTIES ('k'='v', 'kay'='vee')",
+ "ALTER TABLE tbl SET SERDEPROPERTIES ('k'='v', 'kay'='vee')",
+ "ALTER TABLE tbl SET LOCATION 'new_location'",
+ "ALTER TABLE tbl SET FILEFORMAT file_format",
+ "ALTER TABLE tbl SET TAGS ('tag1' = 't1', 'tag2' = 't2')",
+ ]
+
+ for dialect in (
+ "hive",
+ "spark2",
+ "spark",
+ "databricks",
+ ):
+ for sql in sqls:
+ with self.subTest(f"Testing query '{sql}' for dialect {dialect}"):
+ self.assertEqual(parse_one(sql, dialect=dialect).sql(dialect=dialect), sql)
+
def test_distinct_from(self):
self.assertIsInstance(parse_one("a IS DISTINCT FROM b OR c IS DISTINCT FROM d"), exp.Or)
+
+ def test_trailing_comments(self):
+ expressions = parse("""
+ select * from x;
+ -- my comment
+ """)
+
+ self.assertEqual(
+ ";\n".join(e.sql() for e in expressions), "SELECT * FROM x;\n/* my comment */"
+ )
+
+ def test_parse_prop_eq(self):
+ self.assertIsInstance(parse_one("x(a := b and c)").expressions[0], exp.PropertyEQ)
diff --git a/tests/test_schema.py b/tests/test_schema.py
index 32686d7..5b50867 100644
--- a/tests/test_schema.py
+++ b/tests/test_schema.py
@@ -271,6 +271,34 @@ class TestSchema(unittest.TestCase):
"Table z must match the schema's nesting level: 2.",
)
+ with self.assertRaises(SchemaError) as ctx:
+ MappingSchema(
+ {
+ "catalog": {
+ "db": {"tbl": {"col": "a"}},
+ },
+ "tbl2": {"col": "b"},
+ },
+ )
+ self.assertEqual(
+ str(ctx.exception),
+ "Table tbl2 must match the schema's nesting level: 3.",
+ )
+
+ with self.assertRaises(SchemaError) as ctx:
+ MappingSchema(
+ {
+ "tbl2": {"col": "b"},
+ "catalog": {
+ "db": {"tbl": {"col": "a"}},
+ },
+ },
+ )
+ self.assertEqual(
+ str(ctx.exception),
+ "Table catalog.db.tbl must match the schema's nesting level: 1.",
+ )
+
def test_has_column(self):
schema = MappingSchema({"x": {"c": "int"}})
self.assertTrue(schema.has_column("x", exp.column("c")))
diff --git a/tests/test_transpile.py b/tests/test_transpile.py
index 22085b3..dea9985 100644
--- a/tests/test_transpile.py
+++ b/tests/test_transpile.py
@@ -551,6 +551,13 @@ FROM x""",
pretty=True,
)
+ self.validate(
+ """SELECT X FROM catalog.db.table WHERE Y
+ --
+ AND Z""",
+ """SELECT X FROM catalog.db.table WHERE Y AND Z""",
+ )
+
def test_types(self):
self.validate("INT 1", "CAST(1 AS INT)")
self.validate("VARCHAR 'x' y", "CAST('x' AS VARCHAR) AS y")