summaryrefslogtreecommitdiffstats
path: root/tests/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects')
-rw-r--r--tests/dialects/test_bigquery.py254
-rw-r--r--tests/dialects/test_clickhouse.py24
-rw-r--r--tests/dialects/test_databricks.py62
-rw-r--r--tests/dialects/test_dialect.py152
-rw-r--r--tests/dialects/test_doris.py22
-rw-r--r--tests/dialects/test_duckdb.py155
-rw-r--r--tests/dialects/test_hive.py18
-rw-r--r--tests/dialects/test_materialize.py77
-rw-r--r--tests/dialects/test_mysql.py153
-rw-r--r--tests/dialects/test_oracle.py70
-rw-r--r--tests/dialects/test_postgres.py128
-rw-r--r--tests/dialects/test_presto.py16
-rw-r--r--tests/dialects/test_prql.py13
-rw-r--r--tests/dialects/test_redshift.py51
-rw-r--r--tests/dialects/test_risingwave.py14
-rw-r--r--tests/dialects/test_snowflake.py398
-rw-r--r--tests/dialects/test_spark.py34
-rw-r--r--tests/dialects/test_sqlite.py1
-rw-r--r--tests/dialects/test_teradata.py2
-rw-r--r--tests/dialects/test_trino.py18
-rw-r--r--tests/dialects/test_tsql.py192
21 files changed, 1524 insertions, 330 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index 301cd57..ae8ed16 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -20,6 +20,21 @@ class TestBigQuery(Validator):
maxDiff = None
def test_bigquery(self):
+ self.validate_all(
+ "EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
+ write={
+ "bigquery": "EXTRACT(HOUR FROM DATETIME(2008, 12, 25, 15, 30, 00))",
+ "duckdb": "EXTRACT(HOUR FROM MAKE_TIMESTAMP(2008, 12, 25, 15, 30, 00))",
+ "snowflake": "DATE_PART(HOUR, TIMESTAMP_FROM_PARTS(2008, 12, 25, 15, 30, 00))",
+ },
+ )
+ self.validate_identity(
+ """CREATE TEMPORARY FUNCTION FOO()
+RETURNS STRING
+LANGUAGE js AS
+'return "Hello world!"'""",
+ pretty=True,
+ )
self.validate_identity(
"[a, a(1, 2,3,4444444444444444, tttttaoeunthaoentuhaoentuheoantu, toheuntaoheutnahoeunteoahuntaoeh), b(3, 4,5), c, d, tttttttttttttttteeeeeeeeeeeeeett, 12312312312]",
"""[
@@ -279,6 +294,13 @@ class TestBigQuery(Validator):
)
self.validate_all(
+ "SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS h, UNNEST(h.t3) AS s",
+ write={
+ "bigquery": "SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS h, UNNEST(h.t3) AS s",
+ "duckdb": "SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS _t0(h), UNNEST(h.t3) AS _t1(s)",
+ },
+ )
+ self.validate_all(
"PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)",
write={
"bigquery": "PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)",
@@ -289,7 +311,7 @@ class TestBigQuery(Validator):
"SELECT results FROM Coordinates, Coordinates.position AS results",
write={
"bigquery": "SELECT results FROM Coordinates, UNNEST(Coordinates.position) AS results",
- "presto": "SELECT results FROM Coordinates, UNNEST(Coordinates.position) AS _t(results)",
+ "presto": "SELECT results FROM Coordinates, UNNEST(Coordinates.position) AS _t0(results)",
},
)
self.validate_all(
@@ -307,7 +329,7 @@ class TestBigQuery(Validator):
},
write={
"bigquery": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS results",
- "presto": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS _t(results)",
+ "presto": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS _t0(results)",
"redshift": "SELECT results FROM Coordinates AS c, c.position AS results",
},
)
@@ -525,7 +547,7 @@ class TestBigQuery(Validator):
"SELECT * FROM t WHERE EXISTS(SELECT * FROM unnest(nums) AS x WHERE x > 1)",
write={
"bigquery": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS x WHERE x > 1)",
- "duckdb": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS _t(x) WHERE x > 1)",
+ "duckdb": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS _t0(x) WHERE x > 1)",
},
)
self.validate_all(
@@ -605,9 +627,9 @@ class TestBigQuery(Validator):
'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)',
write={
"bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)",
- "databricks": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
+ "databricks": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
"mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)",
- "spark": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
+ "spark": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
},
)
self.validate_all(
@@ -618,12 +640,87 @@ class TestBigQuery(Validator):
},
)
self.validate_all(
+ "LOWER(TO_HEX(x))",
+ write={
+ "": "LOWER(HEX(x))",
+ "bigquery": "TO_HEX(x)",
+ "clickhouse": "LOWER(HEX(x))",
+ "duckdb": "LOWER(HEX(x))",
+ "hive": "LOWER(HEX(x))",
+ "mysql": "LOWER(HEX(x))",
+ "spark": "LOWER(HEX(x))",
+ "sqlite": "LOWER(HEX(x))",
+ "presto": "LOWER(TO_HEX(x))",
+ "trino": "LOWER(TO_HEX(x))",
+ },
+ )
+ self.validate_all(
+ "TO_HEX(x)",
+ read={
+ "": "LOWER(HEX(x))",
+ "clickhouse": "LOWER(HEX(x))",
+ "duckdb": "LOWER(HEX(x))",
+ "hive": "LOWER(HEX(x))",
+ "mysql": "LOWER(HEX(x))",
+ "spark": "LOWER(HEX(x))",
+ "sqlite": "LOWER(HEX(x))",
+ "presto": "LOWER(TO_HEX(x))",
+ "trino": "LOWER(TO_HEX(x))",
+ },
+ write={
+ "": "LOWER(HEX(x))",
+ "bigquery": "TO_HEX(x)",
+ "clickhouse": "LOWER(HEX(x))",
+ "duckdb": "LOWER(HEX(x))",
+ "hive": "LOWER(HEX(x))",
+ "mysql": "LOWER(HEX(x))",
+ "presto": "LOWER(TO_HEX(x))",
+ "spark": "LOWER(HEX(x))",
+ "sqlite": "LOWER(HEX(x))",
+ "trino": "LOWER(TO_HEX(x))",
+ },
+ )
+ self.validate_all(
+ "UPPER(TO_HEX(x))",
+ read={
+ "": "HEX(x)",
+ "clickhouse": "HEX(x)",
+ "duckdb": "HEX(x)",
+ "hive": "HEX(x)",
+ "mysql": "HEX(x)",
+ "presto": "TO_HEX(x)",
+ "spark": "HEX(x)",
+ "sqlite": "HEX(x)",
+ "trino": "TO_HEX(x)",
+ },
+ write={
+ "": "HEX(x)",
+ "bigquery": "UPPER(TO_HEX(x))",
+ "clickhouse": "HEX(x)",
+ "duckdb": "HEX(x)",
+ "hive": "HEX(x)",
+ "mysql": "HEX(x)",
+ "presto": "TO_HEX(x)",
+ "spark": "HEX(x)",
+ "sqlite": "HEX(x)",
+ "trino": "TO_HEX(x)",
+ },
+ )
+ self.validate_all(
"MD5(x)",
+ read={
+ "clickhouse": "MD5(x)",
+ "presto": "MD5(x)",
+ "trino": "MD5(x)",
+ },
write={
"": "MD5_DIGEST(x)",
"bigquery": "MD5(x)",
+ "clickhouse": "MD5(x)",
"hive": "UNHEX(MD5(x))",
+ "presto": "MD5(x)",
"spark": "UNHEX(MD5(x))",
+ "trino": "MD5(x)",
},
)
self.validate_all(
@@ -631,25 +728,72 @@ class TestBigQuery(Validator):
read={
"duckdb": "SELECT MD5(some_string)",
"spark": "SELECT MD5(some_string)",
+ "clickhouse": "SELECT LOWER(HEX(MD5(some_string)))",
+ "presto": "SELECT LOWER(TO_HEX(MD5(some_string)))",
+ "trino": "SELECT LOWER(TO_HEX(MD5(some_string)))",
},
write={
"": "SELECT MD5(some_string)",
"bigquery": "SELECT TO_HEX(MD5(some_string))",
"duckdb": "SELECT MD5(some_string)",
+ "clickhouse": "SELECT LOWER(HEX(MD5(some_string)))",
+ "presto": "SELECT LOWER(TO_HEX(MD5(some_string)))",
+ "trino": "SELECT LOWER(TO_HEX(MD5(some_string)))",
+ },
+ )
+ self.validate_all(
+ "SHA1(x)",
+ read={
+ "clickhouse": "SHA1(x)",
+ "presto": "SHA1(x)",
+ "trino": "SHA1(x)",
+ },
+ write={
+ "clickhouse": "SHA1(x)",
+ "bigquery": "SHA1(x)",
+ "": "SHA(x)",
+ "presto": "SHA1(x)",
+ "trino": "SHA1(x)",
+ },
+ )
+ self.validate_all(
+ "SHA1(x)",
+ write={
+ "bigquery": "SHA1(x)",
+ "": "SHA(x)",
},
)
self.validate_all(
"SHA256(x)",
+ read={
+ "clickhouse": "SHA256(x)",
+ "presto": "SHA256(x)",
+ "trino": "SHA256(x)",
+ "postgres": "SHA256(x)",
+ },
write={
"bigquery": "SHA256(x)",
"spark2": "SHA2(x, 256)",
+ "clickhouse": "SHA256(x)",
+ "postgres": "SHA256(x)",
+ "presto": "SHA256(x)",
+ "redshift": "SHA2(x, 256)",
+ "trino": "SHA256(x)",
},
)
self.validate_all(
"SHA512(x)",
+ read={
+ "clickhouse": "SHA512(x)",
+ "presto": "SHA512(x)",
+ "trino": "SHA512(x)",
+ },
write={
+ "clickhouse": "SHA512(x)",
"bigquery": "SHA512(x)",
"spark2": "SHA2(x, 512)",
+ "presto": "SHA512(x)",
+ "trino": "SHA512(x)",
},
)
self.validate_all(
@@ -860,8 +1004,8 @@ class TestBigQuery(Validator):
},
write={
"bigquery": "SELECT * FROM UNNEST(['7', '14']) AS x",
- "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS _t(x)",
- "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS _t(x)",
+ "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS _t0(x)",
+ "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS _t0(x)",
},
)
self.validate_all(
@@ -982,6 +1126,69 @@ class TestBigQuery(Validator):
},
)
self.validate_all(
+ "DELETE db.example_table WHERE x = 1",
+ write={
+ "bigquery": "DELETE db.example_table WHERE x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE db.example_table tb WHERE tb.x = 1",
+ write={
+ "bigquery": "DELETE db.example_table AS tb WHERE tb.x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE db.example_table AS tb WHERE tb.x = 1",
+ write={
+ "bigquery": "DELETE db.example_table AS tb WHERE tb.x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE FROM db.example_table WHERE x = 1",
+ write={
+ "bigquery": "DELETE FROM db.example_table WHERE x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE FROM db.example_table tb WHERE tb.x = 1",
+ write={
+ "bigquery": "DELETE FROM db.example_table AS tb WHERE tb.x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE FROM db.example_table AS tb WHERE tb.x = 1",
+ write={
+ "bigquery": "DELETE FROM db.example_table AS tb WHERE tb.x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE FROM db.example_table AS tb WHERE example_table.x = 1",
+ write={
+ "bigquery": "DELETE FROM db.example_table AS tb WHERE example_table.x = 1",
+ "presto": "DELETE FROM db.example_table WHERE x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE FROM db.example_table WHERE example_table.x = 1",
+ write={
+ "bigquery": "DELETE FROM db.example_table WHERE example_table.x = 1",
+ "presto": "DELETE FROM db.example_table WHERE example_table.x = 1",
+ },
+ )
+ self.validate_all(
+ "DELETE FROM db.t1 AS t1 WHERE NOT t1.c IN (SELECT db.t2.c FROM db.t2)",
+ write={
+ "bigquery": "DELETE FROM db.t1 AS t1 WHERE NOT t1.c IN (SELECT db.t2.c FROM db.t2)",
+ "presto": "DELETE FROM db.t1 WHERE NOT c IN (SELECT c FROM db.t2)",
+ },
+ )
+ self.validate_all(
"SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])",
write={
"bigquery": "SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])",
@@ -1073,7 +1280,7 @@ class TestBigQuery(Validator):
"SELECT REGEXP_EXTRACT(abc, 'pattern(group)') FROM table",
write={
"bigquery": "SELECT REGEXP_EXTRACT(abc, 'pattern(group)') FROM table",
- "duckdb": "SELECT REGEXP_EXTRACT(abc, 'pattern(group)', 1) FROM table",
+ "duckdb": '''SELECT REGEXP_EXTRACT(abc, 'pattern(group)', 1) FROM "table"''',
},
)
self.validate_all(
@@ -1328,6 +1535,26 @@ WHERE
"SELECT cola FROM (SELECT CAST('1' AS STRING) AS cola UNION ALL SELECT CAST('2' AS STRING) AS cola)",
)
+ def test_gap_fill(self):
+ self.validate_identity(
+ "SELECT * FROM GAP_FILL(TABLE device_data, ts_column => 'time', bucket_width => INTERVAL '1' MINUTE, value_columns => [('signal', 'locf')]) ORDER BY time"
+ )
+ self.validate_identity(
+ "SELECT a, b, c, d, e FROM GAP_FILL(TABLE foo, ts_column => 'b', partitioning_columns => ['a'], value_columns => [('c', 'bar'), ('d', 'baz'), ('e', 'bla')], bucket_width => INTERVAL '1' DAY)"
+ )
+ self.validate_identity(
+ "SELECT * FROM GAP_FILL(TABLE device_data, ts_column => 'time', bucket_width => INTERVAL '1' MINUTE, value_columns => [('signal', 'linear')], ignore_null_values => FALSE) ORDER BY time"
+ )
+ self.validate_identity(
+ "SELECT * FROM GAP_FILL(TABLE device_data, ts_column => 'time', bucket_width => INTERVAL '1' MINUTE) ORDER BY time"
+ )
+ self.validate_identity(
+ "SELECT * FROM GAP_FILL(TABLE device_data, ts_column => 'time', bucket_width => INTERVAL '1' MINUTE, value_columns => [('signal', 'null')], origin => CAST('2023-11-01 09:30:01' AS DATETIME)) ORDER BY time"
+ )
+ self.validate_identity(
+ "SELECT * FROM GAP_FILL(TABLE (SELECT * FROM UNNEST(ARRAY<STRUCT<device_id INT64, time DATETIME, signal INT64, state STRING>>[STRUCT(1, CAST('2023-11-01 09:34:01' AS DATETIME), 74, 'INACTIVE'), STRUCT(2, CAST('2023-11-01 09:36:00' AS DATETIME), 77, 'ACTIVE'), STRUCT(3, CAST('2023-11-01 09:37:00' AS DATETIME), 78, 'ACTIVE'), STRUCT(4, CAST('2023-11-01 09:38:01' AS DATETIME), 80, 'ACTIVE')])), ts_column => 'time', bucket_width => INTERVAL '1' MINUTE, value_columns => [('signal', 'linear')]) ORDER BY time"
+ )
+
def test_models(self):
self.validate_identity(
"SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT label, column1, column2 FROM mydataset.mytable))"
@@ -1464,3 +1691,14 @@ OPTIONS (
with self.assertRaises(ParseError):
transpile("SELECT JSON_OBJECT('a', 1, 'b') AS json_data", read="bigquery")
+
+ def test_mod(self):
+ for sql in ("MOD(a, b)", "MOD('a', b)", "MOD(5, 2)", "MOD((a + 1) * 8, 5 - 1)"):
+ with self.subTest(f"Testing BigQuery roundtrip of modulo operation: {sql}"):
+ self.validate_identity(sql)
+
+ self.validate_identity("SELECT MOD((SELECT 1), 2)")
+ self.validate_identity(
+ "MOD((a + 1), b)",
+ "MOD(a + 1, b)",
+ )
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index af552d1..72634a8 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -42,6 +42,7 @@ class TestClickhouse(Validator):
self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)")
self.assertIsNone(expr._meta)
+ self.validate_identity("SELECT CAST(x AS Tuple(String, Array(Nullable(Float64))))")
self.validate_identity("countIf(x, y)")
self.validate_identity("x = y")
self.validate_identity("x <> y")
@@ -425,6 +426,21 @@ class TestClickhouse(Validator):
},
)
+ self.validate_identity("ALTER TABLE visits DROP PARTITION 201901")
+ self.validate_identity("ALTER TABLE visits DROP PARTITION ALL")
+ self.validate_identity(
+ "ALTER TABLE visits DROP PARTITION tuple(toYYYYMM(toDate('2019-01-25')))"
+ )
+ self.validate_identity("ALTER TABLE visits DROP PARTITION ID '201901'")
+
+ self.validate_identity("ALTER TABLE visits REPLACE PARTITION 201901 FROM visits_tmp")
+ self.validate_identity("ALTER TABLE visits REPLACE PARTITION ALL FROM visits_tmp")
+ self.validate_identity(
+ "ALTER TABLE visits REPLACE PARTITION tuple(toYYYYMM(toDate('2019-01-25'))) FROM visits_tmp"
+ )
+ self.validate_identity("ALTER TABLE visits REPLACE PARTITION ID '201901' FROM visits_tmp")
+ self.validate_identity("ALTER TABLE visits ON CLUSTER test_cluster DROP COLUMN col1")
+
def test_cte(self):
self.validate_identity("WITH 'x' AS foo SELECT foo")
self.validate_identity("WITH ['c'] AS field_names SELECT field_names")
@@ -829,6 +845,9 @@ LIFETIME(MIN 0 MAX 0)""",
self.validate_identity(
"CREATE TABLE t1 (a String EPHEMERAL, b String EPHEMERAL func(), c String MATERIALIZED func(), d String ALIAS func()) ENGINE=TinyLog()"
)
+ self.validate_identity(
+ "CREATE TABLE t (a String, b String, c UInt64, PROJECTION p1 (SELECT a, sum(c) GROUP BY a, b), PROJECTION p2 (SELECT b, sum(c) GROUP BY b)) ENGINE=MergeTree()"
+ )
def test_agg_functions(self):
def extract_agg_func(query):
@@ -856,3 +875,8 @@ LIFETIME(MIN 0 MAX 0)""",
)
parse_one("foobar(x)").assert_is(exp.Anonymous)
+
+ def test_drop_on_cluster(self):
+ for creatable in ("DATABASE", "TABLE", "VIEW", "DICTIONARY", "FUNCTION"):
+ with self.subTest(f"Test DROP {creatable} ON CLUSTER"):
+ self.validate_identity(f"DROP {creatable} test ON CLUSTER test_cluster")
diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py
index c15cf09..9ef3b86 100644
--- a/tests/dialects/test_databricks.py
+++ b/tests/dialects/test_databricks.py
@@ -9,6 +9,7 @@ class TestDatabricks(Validator):
def test_databricks(self):
self.validate_identity("DESCRIBE HISTORY a.b")
self.validate_identity("DESCRIBE history.tbl")
+ self.validate_identity("CREATE TABLE t (a STRUCT<c: MAP<STRING, STRING>>)")
self.validate_identity("CREATE TABLE t (c STRUCT<interval: DOUBLE COMMENT 'aaa'>)")
self.validate_identity("CREATE TABLE my_table TBLPROPERTIES (a.b=15)")
self.validate_identity("CREATE TABLE my_table TBLPROPERTIES ('a.b'=15)")
@@ -20,12 +21,14 @@ class TestDatabricks(Validator):
self.validate_identity("SELECT CAST('23:00:00' AS INTERVAL MINUTE TO SECOND)")
self.validate_identity("CREATE TABLE target SHALLOW CLONE source")
self.validate_identity("INSERT INTO a REPLACE WHERE cond VALUES (1), (2)")
- self.validate_identity("SELECT c1 : price")
self.validate_identity("CREATE FUNCTION a.b(x INT) RETURNS INT RETURN x + 1")
self.validate_identity("CREATE FUNCTION a AS b")
self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1")
self.validate_identity("CREATE TABLE foo (x DATE GENERATED ALWAYS AS (CAST(y AS DATE)))")
self.validate_identity(
+ "CREATE TABLE IF NOT EXISTS db.table (a TIMESTAMP, b BOOLEAN GENERATED ALWAYS AS (NOT a IS NULL)) USING DELTA"
+ )
+ self.validate_identity(
"SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t"
)
self.validate_identity(
@@ -47,6 +50,9 @@ class TestDatabricks(Validator):
self.validate_identity(
"TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', city LIKE 'LA')"
)
+ self.validate_identity(
+ "COPY INTO target FROM `s3://link` FILEFORMAT = AVRO VALIDATE = ALL FILES = ('file1', 'file2') FORMAT_OPTIONS ('opt1'='true', 'opt2'='test') COPY_OPTIONS ('mergeSchema'='true')"
+ )
self.validate_all(
"CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))",
@@ -62,6 +68,20 @@ class TestDatabricks(Validator):
},
)
+ self.validate_all(
+ "SELECT X'1A2B'",
+ read={
+ "spark2": "SELECT X'1A2B'",
+ "spark": "SELECT X'1A2B'",
+ "databricks": "SELECT x'1A2B'",
+ },
+ write={
+ "spark2": "SELECT X'1A2B'",
+ "spark": "SELECT X'1A2B'",
+ "databricks": "SELECT X'1A2B'",
+ },
+ )
+
with self.assertRaises(ParseError):
transpile(
"CREATE FUNCTION add_one(x INT) RETURNS INT LANGUAGE PYTHON AS $foo$def add_one(x):\n return x+1$$",
@@ -76,37 +96,33 @@ class TestDatabricks(Validator):
# https://docs.databricks.com/sql/language-manual/functions/colonsign.html
def test_json(self):
- self.validate_identity("""SELECT c1 : price FROM VALUES ('{ "price": 5 }') AS T(c1)""")
-
- self.validate_all(
+ self.validate_identity(
+ """SELECT c1 : price FROM VALUES ('{ "price": 5 }') AS T(c1)""",
+ """SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""",
+ )
+ self.validate_identity(
"""SELECT c1:['price'] FROM VALUES('{ "price": 5 }') AS T(c1)""",
- write={
- "databricks": """SELECT c1 : ARRAY('price') FROM VALUES ('{ "price": 5 }') AS T(c1)""",
- },
+ """SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""",
)
- self.validate_all(
+ self.validate_identity(
"""SELECT c1:item[1].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- write={
- "databricks": """SELECT c1 : item[1].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- },
+ """SELECT GET_JSON_OBJECT(c1, '$.item[1].price') FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
)
- self.validate_all(
+ self.validate_identity(
"""SELECT c1:item[*].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- write={
- "databricks": """SELECT c1 : item[*].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- },
+ """SELECT GET_JSON_OBJECT(c1, '$.item[*].price') FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
)
- self.validate_all(
+ self.validate_identity(
"""SELECT from_json(c1:item[*].price, 'ARRAY<DOUBLE>')[0] FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- write={
- "databricks": """SELECT FROM_JSON(c1 : item[*].price, 'ARRAY<DOUBLE>')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- },
+ """SELECT FROM_JSON(GET_JSON_OBJECT(c1, '$.item[*].price'), 'ARRAY<DOUBLE>')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
)
- self.validate_all(
+ self.validate_identity(
"""SELECT inline(from_json(c1:item[*], 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- write={
- "databricks": """SELECT INLINE(FROM_JSON(c1 : item[*], 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- },
+ """SELECT INLINE(FROM_JSON(GET_JSON_OBJECT(c1, '$.item[*]'), 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
+ )
+ self.validate_identity(
+ "SELECT c1 : price",
+ "SELECT GET_JSON_OBJECT(c1, '$.price')",
)
def test_datediff(self):
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index ea38521..aaeb7b0 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -155,6 +155,7 @@ class TestDialect(Validator):
"clickhouse": "CAST(a AS String)",
"drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
+ "materialize": "CAST(a AS TEXT)",
"mysql": "CAST(a AS CHAR)",
"hive": "CAST(a AS STRING)",
"oracle": "CAST(a AS CLOB)",
@@ -175,6 +176,7 @@ class TestDialect(Validator):
"clickhouse": "CAST(a AS BINARY(4))",
"drill": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS BLOB(4))",
+ "materialize": "CAST(a AS BYTEA(4))",
"mysql": "CAST(a AS BINARY(4))",
"hive": "CAST(a AS BINARY(4))",
"oracle": "CAST(a AS BLOB(4))",
@@ -193,6 +195,7 @@ class TestDialect(Validator):
"bigquery": "CAST(a AS BYTES)",
"clickhouse": "CAST(a AS String)",
"duckdb": "CAST(a AS BLOB(4))",
+ "materialize": "CAST(a AS BYTEA(4))",
"mysql": "CAST(a AS VARBINARY(4))",
"hive": "CAST(a AS BINARY(4))",
"oracle": "CAST(a AS BLOB(4))",
@@ -236,6 +239,7 @@ class TestDialect(Validator):
"bigquery": "CAST(a AS STRING)",
"drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
+ "materialize": "CAST(a AS TEXT)",
"mysql": "CAST(a AS CHAR)",
"hive": "CAST(a AS STRING)",
"oracle": "CAST(a AS CLOB)",
@@ -255,6 +259,7 @@ class TestDialect(Validator):
"bigquery": "CAST(a AS STRING)",
"drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
+ "materialize": "CAST(a AS VARCHAR)",
"mysql": "CAST(a AS CHAR)",
"hive": "CAST(a AS STRING)",
"oracle": "CAST(a AS VARCHAR2)",
@@ -274,6 +279,7 @@ class TestDialect(Validator):
"bigquery": "CAST(a AS STRING)",
"drill": "CAST(a AS VARCHAR(3))",
"duckdb": "CAST(a AS TEXT(3))",
+ "materialize": "CAST(a AS VARCHAR(3))",
"mysql": "CAST(a AS CHAR(3))",
"hive": "CAST(a AS VARCHAR(3))",
"oracle": "CAST(a AS VARCHAR2(3))",
@@ -293,7 +299,8 @@ class TestDialect(Validator):
"bigquery": "CAST(a AS INT64)",
"drill": "CAST(a AS INTEGER)",
"duckdb": "CAST(a AS SMALLINT)",
- "mysql": "CAST(a AS SMALLINT)",
+ "materialize": "CAST(a AS SMALLINT)",
+ "mysql": "CAST(a AS SIGNED)",
"hive": "CAST(a AS SMALLINT)",
"oracle": "CAST(a AS NUMBER)",
"postgres": "CAST(a AS SMALLINT)",
@@ -328,6 +335,7 @@ class TestDialect(Validator):
"clickhouse": "CAST(a AS Float64)",
"drill": "CAST(a AS DOUBLE)",
"duckdb": "CAST(a AS DOUBLE)",
+ "materialize": "CAST(a AS DOUBLE PRECISION)",
"mysql": "CAST(a AS DOUBLE)",
"hive": "CAST(a AS DOUBLE)",
"oracle": "CAST(a AS DOUBLE PRECISION)",
@@ -599,6 +607,7 @@ class TestDialect(Validator):
"drill": "TO_TIMESTAMP(x, 'yy')",
"duckdb": "STRPTIME(x, '%y')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
+ "materialize": "TO_TIMESTAMP(x, 'YY')",
"presto": "DATE_PARSE(x, '%y')",
"oracle": "TO_TIMESTAMP(x, 'YY')",
"postgres": "TO_TIMESTAMP(x, 'YY')",
@@ -655,6 +664,7 @@ class TestDialect(Validator):
"drill": "TO_CHAR(x, 'yyyy-MM-dd')",
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
+ "materialize": "TO_CHAR(x, 'YYYY-MM-DD')",
"oracle": "TO_CHAR(x, 'YYYY-MM-DD')",
"postgres": "TO_CHAR(x, 'YYYY-MM-DD')",
"presto": "DATE_FORMAT(x, '%Y-%m-%d')",
@@ -698,6 +708,7 @@ class TestDialect(Validator):
"bigquery": "CAST(x AS DATE)",
"duckdb": "CAST(x AS DATE)",
"hive": "TO_DATE(x)",
+ "materialize": "CAST(x AS DATE)",
"postgres": "CAST(x AS DATE)",
"presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)",
"snowflake": "TO_DATE(x)",
@@ -730,6 +741,7 @@ class TestDialect(Validator):
"duckdb": "TO_TIMESTAMP(x)",
"hive": "FROM_UNIXTIME(x)",
"oracle": "TO_DATE('1970-01-01', 'YYYY-MM-DD') + (x / 86400)",
+ "materialize": "TO_TIMESTAMP(x)",
"postgres": "TO_TIMESTAMP(x)",
"presto": "FROM_UNIXTIME(x)",
"starrocks": "FROM_UNIXTIME(x)",
@@ -790,6 +802,7 @@ class TestDialect(Validator):
"drill": "DATE_ADD(x, INTERVAL 1 DAY)",
"duckdb": "x + INTERVAL 1 DAY",
"hive": "DATE_ADD(x, 1)",
+ "materialize": "x + INTERVAL '1 DAY'",
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
"postgres": "x + INTERVAL '1 DAY'",
"presto": "DATE_ADD('DAY', 1, x)",
@@ -826,6 +839,7 @@ class TestDialect(Validator):
"duckdb": "DATE_TRUNC('DAY', x)",
"mysql": "DATE(x)",
"presto": "DATE_TRUNC('DAY', x)",
+ "materialize": "DATE_TRUNC('DAY', x)",
"postgres": "DATE_TRUNC('DAY', x)",
"snowflake": "DATE_TRUNC('DAY', x)",
"starrocks": "DATE_TRUNC('DAY', x)",
@@ -838,6 +852,7 @@ class TestDialect(Validator):
read={
"bigquery": "TIMESTAMP_TRUNC(x, day)",
"duckdb": "DATE_TRUNC('day', x)",
+ "materialize": "DATE_TRUNC('day', x)",
"presto": "DATE_TRUNC('day', x)",
"postgres": "DATE_TRUNC('day', x)",
"snowflake": "DATE_TRUNC('day', x)",
@@ -899,6 +914,7 @@ class TestDialect(Validator):
},
write={
"bigquery": "DATE_TRUNC(x, YEAR)",
+ "materialize": "DATE_TRUNC('YEAR', x)",
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
"postgres": "DATE_TRUNC('YEAR', x)",
"snowflake": "DATE_TRUNC('YEAR', x)",
@@ -911,6 +927,7 @@ class TestDialect(Validator):
"TIMESTAMP_TRUNC(x, YEAR)",
read={
"bigquery": "TIMESTAMP_TRUNC(x, year)",
+ "materialize": "DATE_TRUNC('YEAR', x)",
"postgres": "DATE_TRUNC(year, x)",
"spark": "DATE_TRUNC('year', x)",
"snowflake": "DATE_TRUNC(year, x)",
@@ -1019,6 +1036,20 @@ class TestDialect(Validator):
},
)
+ self.validate_all(
+ "TIMESTAMP_TRUNC(x, DAY, 'UTC')",
+ write={
+ "": "TIMESTAMP_TRUNC(x, DAY, 'UTC')",
+ "duckdb": "DATE_TRUNC('DAY', x)",
+ "materialize": "DATE_TRUNC('DAY', x, 'UTC')",
+ "presto": "DATE_TRUNC('DAY', x)",
+ "postgres": "DATE_TRUNC('DAY', x, 'UTC')",
+ "snowflake": "DATE_TRUNC('DAY', x)",
+ "databricks": "DATE_TRUNC('DAY', x)",
+ "clickhouse": "DATE_TRUNC('DAY', x, 'UTC')",
+ },
+ )
+
for unit in ("DAY", "MONTH", "YEAR"):
self.validate_all(
f"{unit}(x)",
@@ -1150,7 +1181,7 @@ class TestDialect(Validator):
read={
"bigquery": "JSON_EXTRACT(x, '$.y')",
"duckdb": "x -> 'y'",
- "doris": "x -> '$.y'",
+ "doris": "JSON_EXTRACT(x, '$.y')",
"mysql": "JSON_EXTRACT(x, '$.y')",
"postgres": "x->'y'",
"presto": "JSON_EXTRACT(x, '$.y')",
@@ -1161,7 +1192,7 @@ class TestDialect(Validator):
write={
"bigquery": "JSON_EXTRACT(x, '$.y')",
"clickhouse": "JSONExtractString(x, 'y')",
- "doris": "x -> '$.y'",
+ "doris": "JSON_EXTRACT(x, '$.y')",
"duckdb": "x -> '$.y'",
"mysql": "JSON_EXTRACT(x, '$.y')",
"oracle": "JSON_EXTRACT(x, '$.y')",
@@ -1205,7 +1236,7 @@ class TestDialect(Validator):
read={
"bigquery": "JSON_EXTRACT(x, '$.y[0].z')",
"duckdb": "x -> '$.y[0].z'",
- "doris": "x -> '$.y[0].z'",
+ "doris": "JSON_EXTRACT(x, '$.y[0].z')",
"mysql": "JSON_EXTRACT(x, '$.y[0].z')",
"presto": "JSON_EXTRACT(x, '$.y[0].z')",
"snowflake": "GET_PATH(x, 'y[0].z')",
@@ -1215,7 +1246,7 @@ class TestDialect(Validator):
write={
"bigquery": "JSON_EXTRACT(x, '$.y[0].z')",
"clickhouse": "JSONExtractString(x, 'y', 1, 'z')",
- "doris": "x -> '$.y[0].z'",
+ "doris": "JSON_EXTRACT(x, '$.y[0].z')",
"duckdb": "x -> '$.y[0].z'",
"mysql": "JSON_EXTRACT(x, '$.y[0].z')",
"oracle": "JSON_EXTRACT(x, '$.y[0].z')",
@@ -1472,21 +1503,21 @@ class TestDialect(Validator):
"snowflake": "x ILIKE '%y'",
},
write={
- "bigquery": "LOWER(x) LIKE '%y'",
+ "bigquery": "LOWER(x) LIKE LOWER('%y')",
"clickhouse": "x ILIKE '%y'",
"drill": "x `ILIKE` '%y'",
"duckdb": "x ILIKE '%y'",
- "hive": "LOWER(x) LIKE '%y'",
- "mysql": "LOWER(x) LIKE '%y'",
- "oracle": "LOWER(x) LIKE '%y'",
+ "hive": "LOWER(x) LIKE LOWER('%y')",
+ "mysql": "LOWER(x) LIKE LOWER('%y')",
+ "oracle": "LOWER(x) LIKE LOWER('%y')",
"postgres": "x ILIKE '%y'",
- "presto": "LOWER(x) LIKE '%y'",
+ "presto": "LOWER(x) LIKE LOWER('%y')",
"snowflake": "x ILIKE '%y'",
"spark": "x ILIKE '%y'",
- "sqlite": "LOWER(x) LIKE '%y'",
- "starrocks": "LOWER(x) LIKE '%y'",
- "trino": "LOWER(x) LIKE '%y'",
- "doris": "LOWER(x) LIKE '%y'",
+ "sqlite": "LOWER(x) LIKE LOWER('%y')",
+ "starrocks": "LOWER(x) LIKE LOWER('%y')",
+ "trino": "LOWER(x) LIKE LOWER('%y')",
+ "doris": "LOWER(x) LIKE LOWER('%y')",
},
)
self.validate_all(
@@ -1681,6 +1712,26 @@ class TestDialect(Validator):
"tsql": "CAST(a AS FLOAT) / b",
},
)
+ self.validate_all(
+ "MOD(8 - 1 + 7, 7)",
+ write={
+ "": "(8 - 1 + 7) % 7",
+ "hive": "(8 - 1 + 7) % 7",
+ "presto": "(8 - 1 + 7) % 7",
+ "snowflake": "(8 - 1 + 7) % 7",
+ "bigquery": "MOD(8 - 1 + 7, 7)",
+ },
+ )
+ self.validate_all(
+ "MOD(a, b + 1)",
+ write={
+ "": "a % (b + 1)",
+ "hive": "a % (b + 1)",
+ "presto": "a % (b + 1)",
+ "snowflake": "a % (b + 1)",
+ "bigquery": "MOD(a, b + 1)",
+ },
+ )
def test_typeddiv(self):
typed_div = exp.Div(this=exp.column("a"), expression=exp.column("b"), typed=True)
@@ -2186,6 +2237,8 @@ SELECT
)
def test_cast_to_user_defined_type(self):
+ self.validate_identity("CAST(x AS some_udt(1234))")
+
self.validate_all(
"CAST(x AS some_udt)",
write={
@@ -2214,6 +2267,18 @@ SELECT
"tsql": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
},
)
+ self.validate_all(
+ 'SELECT "user id", some_id, 1 as other_id, 2 as "2 nd id" FROM t QUALIFY COUNT(*) OVER () > 1',
+ write={
+ "duckdb": 'SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id" FROM t QUALIFY COUNT(*) OVER () > 1',
+ "snowflake": 'SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id" FROM t QUALIFY COUNT(*) OVER () > 1',
+ "clickhouse": 'SELECT "user id", some_id, other_id, "2 nd id" FROM (SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id", COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1',
+ "mysql": "SELECT `user id`, some_id, other_id, `2 nd id` FROM (SELECT `user id`, some_id, 1 AS other_id, 2 AS `2 nd id`, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ "oracle": 'SELECT "user id", some_id, other_id, "2 nd id" FROM (SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id", COUNT(*) OVER () AS _w FROM t) _t WHERE _w > 1',
+ "postgres": 'SELECT "user id", some_id, other_id, "2 nd id" FROM (SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id", COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1',
+ "tsql": "SELECT [user id], some_id, other_id, [2 nd id] FROM (SELECT [user id] AS [user id], some_id AS some_id, 1 AS other_id, 2 AS [2 nd id], COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ },
+ )
def test_nested_ctes(self):
self.validate_all(
@@ -2249,7 +2314,7 @@ SELECT
"WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq",
write={
"duckdb": "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq",
- "tsql": "WITH t2(y) AS (SELECT 2), t1(x) AS (SELECT 1) SELECT * FROM (SELECT y AS y FROM t2) AS subq",
+ "tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y AS y FROM t2) AS subq",
},
)
self.validate_all(
@@ -2273,6 +2338,59 @@ FROM c""",
"hive": "WITH a1 AS (SELECT 1), a2 AS (SELECT 2), b AS (SELECT * FROM a1, a2), c AS (SELECT * FROM b) SELECT * FROM c",
},
)
+ self.validate_all(
+ """
+WITH subquery1 AS (
+ WITH tmp AS (
+ SELECT
+ *
+ FROM table0
+ )
+ SELECT
+ *
+ FROM tmp
+), subquery2 AS (
+ WITH tmp2 AS (
+ SELECT
+ *
+ FROM table1
+ WHERE
+ a IN subquery1
+ )
+ SELECT
+ *
+ FROM tmp2
+)
+SELECT
+ *
+FROM subquery2
+""",
+ write={
+ "hive": """WITH tmp AS (
+ SELECT
+ *
+ FROM table0
+), subquery1 AS (
+ SELECT
+ *
+ FROM tmp
+), tmp2 AS (
+ SELECT
+ *
+ FROM table1
+ WHERE
+ a IN subquery1
+), subquery2 AS (
+ SELECT
+ *
+ FROM tmp2
+)
+SELECT
+ *
+FROM subquery2""",
+ },
+ pretty=True,
+ )
def test_unsupported_null_ordering(self):
# We'll transpile a portable query from the following dialects to MySQL / T-SQL, which
@@ -2372,7 +2490,7 @@ FROM c""",
"hive": UnsupportedError,
"mysql": UnsupportedError,
"oracle": UnsupportedError,
- "postgres": "(ARRAY_LENGTH(arr, 1) = 0 OR ARRAY_LENGTH(ARRAY(SELECT x FROM UNNEST(arr) AS _t(x) WHERE pred), 1) <> 0)",
+ "postgres": "(ARRAY_LENGTH(arr, 1) = 0 OR ARRAY_LENGTH(ARRAY(SELECT x FROM UNNEST(arr) AS _t0(x) WHERE pred), 1) <> 0)",
"presto": "ANY_MATCH(arr, x -> pred)",
"redshift": UnsupportedError,
"snowflake": UnsupportedError,
@@ -2430,7 +2548,7 @@ FROM c""",
def test_reserved_keywords(self):
order = exp.select("*").from_("order")
- for dialect in ("presto", "redshift"):
+ for dialect in ("duckdb", "presto", "redshift"):
dialect = Dialect.get_or_raise(dialect)
self.assertEqual(
order.sql(dialect=dialect),
diff --git a/tests/dialects/test_doris.py b/tests/dialects/test_doris.py
index 035289b..8180d05 100644
--- a/tests/dialects/test_doris.py
+++ b/tests/dialects/test_doris.py
@@ -14,7 +14,9 @@ class TestDoris(Validator):
)
self.validate_all(
"SELECT MAX_BY(a, b), MIN_BY(c, d)",
- read={"clickhouse": "SELECT argMax(a, b), argMin(c, d)"},
+ read={
+ "clickhouse": "SELECT argMax(a, b), argMin(c, d)",
+ },
)
self.validate_all(
"SELECT ARRAY_SUM(x -> x * x, ARRAY(2, 3))",
@@ -36,6 +38,24 @@ class TestDoris(Validator):
"oracle": "ADD_MONTHS(d, n)",
},
)
+ self.validate_all(
+ """SELECT JSON_EXTRACT(CAST('{"key": 1}' AS JSONB), '$.key')""",
+ read={
+ "postgres": """SELECT '{"key": 1}'::jsonb ->> 'key'""",
+ },
+ write={
+ "doris": """SELECT JSON_EXTRACT(CAST('{"key": 1}' AS JSONB), '$.key')""",
+ "postgres": """SELECT JSON_EXTRACT_PATH(CAST('{"key": 1}' AS JSONB), 'key')""",
+ },
+ )
+ self.validate_all(
+ "SELECT GROUP_CONCAT('aa', ',')",
+ read={
+ "doris": "SELECT GROUP_CONCAT('aa', ',')",
+ "mysql": "SELECT GROUP_CONCAT('aa' SEPARATOR ',')",
+ "postgres": "SELECT STRING_AGG('aa', ',')",
+ },
+ )
def test_identity(self):
self.validate_identity("COALECSE(a, b, c, d)")
diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py
index 9105a49..2bde478 100644
--- a/tests/dialects/test_duckdb.py
+++ b/tests/dialects/test_duckdb.py
@@ -1,4 +1,4 @@
-from sqlglot import ErrorLevel, UnsupportedError, exp, parse_one, transpile
+from sqlglot import ErrorLevel, ParseError, UnsupportedError, exp, parse_one, transpile
from sqlglot.helper import logger as helper_logger
from sqlglot.optimizer.annotate_types import annotate_types
from tests.dialects.test_dialect import Validator
@@ -8,6 +8,9 @@ class TestDuckDB(Validator):
dialect = "duckdb"
def test_duckdb(self):
+ with self.assertRaises(ParseError):
+ parse_one("1 //", read="duckdb")
+
query = "WITH _data AS (SELECT [{'a': 1, 'b': 2}, {'a': 2, 'b': 3}] AS col) SELECT t.col['b'] FROM _data, UNNEST(_data.col) AS t(col) WHERE t.col['a'] = 1"
expr = annotate_types(self.validate_identity(query))
self.assertEqual(
@@ -16,6 +19,20 @@ class TestDuckDB(Validator):
)
self.validate_all(
+ "SELECT straight_join",
+ write={
+ "duckdb": "SELECT straight_join",
+ "mysql": "SELECT `straight_join`",
+ },
+ )
+ self.validate_all(
+ "SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMP)",
+ read={
+ "duckdb": "SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMP)",
+ "snowflake": "SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMPNTZ)",
+ },
+ )
+ self.validate_all(
"SELECT CAST('2020-01-01' AS DATE) + INTERVAL (day_offset) DAY FROM t",
read={
"duckdb": "SELECT CAST('2020-01-01' AS DATE) + INTERVAL (day_offset) DAY FROM t",
@@ -247,7 +264,7 @@ class TestDuckDB(Validator):
self.validate_identity("SELECT EPOCH_MS(10) AS t")
self.validate_identity("SELECT MAKE_TIMESTAMP(10) AS t")
self.validate_identity("SELECT TO_TIMESTAMP(10) AS t")
- self.validate_identity("SELECT UNNEST(column, recursive := TRUE) FROM table")
+ self.validate_identity("SELECT UNNEST(col, recursive := TRUE) FROM t")
self.validate_identity("VAR_POP(a)")
self.validate_identity("SELECT * FROM foo ASOF LEFT JOIN bar ON a = b")
self.validate_identity("PIVOT Cities ON Year USING SUM(Population)")
@@ -268,10 +285,23 @@ class TestDuckDB(Validator):
self.validate_identity("FROM tbl", "SELECT * FROM tbl")
self.validate_identity("x -> '$.family'")
self.validate_identity("CREATE TABLE color (name ENUM('RED', 'GREEN', 'BLUE'))")
+ self.validate_identity("SELECT * FROM foo WHERE bar > $baz AND bla = $bob")
self.validate_identity(
"SELECT * FROM x LEFT JOIN UNNEST(y)", "SELECT * FROM x LEFT JOIN UNNEST(y) ON TRUE"
)
self.validate_identity(
+ "SELECT a, LOGICAL_OR(b) FROM foo GROUP BY a",
+ "SELECT a, BOOL_OR(b) FROM foo GROUP BY a",
+ )
+ self.validate_identity(
+ "SELECT JSON_EXTRACT_STRING(c, '$.k1') = 'v1'",
+ "SELECT (c ->> '$.k1') = 'v1'",
+ )
+ self.validate_identity(
+ "SELECT JSON_EXTRACT(c, '$.k1') = 'v1'",
+ "SELECT (c -> '$.k1') = 'v1'",
+ )
+ self.validate_identity(
"""SELECT '{"foo": [1, 2, 3]}' -> 'foo' -> 0""",
"""SELECT '{"foo": [1, 2, 3]}' -> '$.foo' -> '$[0]'""",
)
@@ -324,6 +354,8 @@ class TestDuckDB(Validator):
self.validate_identity(
"SELECT * FROM (PIVOT Cities ON Year USING SUM(Population) GROUP BY Country) AS pivot_alias"
)
+ self.validate_identity("DATE_SUB('YEAR', col, '2020-01-01')").assert_is(exp.Anonymous)
+ self.validate_identity("DATESUB('YEAR', col, '2020-01-01')").assert_is(exp.Anonymous)
self.validate_all("0b1010", write={"": "0 AS b1010"})
self.validate_all("0x1010", write={"": "0 AS x1010"})
@@ -414,15 +446,15 @@ class TestDuckDB(Validator):
write={"duckdb": 'WITH "x" AS (SELECT 1) SELECT * FROM x'},
)
self.validate_all(
- "CREATE TABLE IF NOT EXISTS table (cola INT, colb STRING) USING ICEBERG PARTITIONED BY (colb)",
+ "CREATE TABLE IF NOT EXISTS t (cola INT, colb STRING) USING ICEBERG PARTITIONED BY (colb)",
write={
- "duckdb": "CREATE TABLE IF NOT EXISTS table (cola INT, colb TEXT)",
+ "duckdb": "CREATE TABLE IF NOT EXISTS t (cola INT, colb TEXT)",
},
)
self.validate_all(
- "CREATE TABLE IF NOT EXISTS table (cola INT COMMENT 'cola', colb STRING) USING ICEBERG PARTITIONED BY (colb)",
+ "CREATE TABLE IF NOT EXISTS t (cola INT COMMENT 'cola', colb STRING) USING ICEBERG PARTITIONED BY (colb)",
write={
- "duckdb": "CREATE TABLE IF NOT EXISTS table (cola INT, colb TEXT)",
+ "duckdb": "CREATE TABLE IF NOT EXISTS t (cola INT, colb TEXT)",
},
)
self.validate_all(
@@ -724,6 +756,36 @@ class TestDuckDB(Validator):
"""SELECT i FROM GENERATE_SERIES(0, 12) AS _(i) ORDER BY i ASC""",
)
+ self.validate_identity(
+ "COPY lineitem FROM 'lineitem.ndjson' WITH (FORMAT JSON, DELIMITER ',', AUTO_DETECT TRUE, COMPRESSION SNAPPY, CODEC ZSTD, FORCE_NOT_NULL (col1, col2))"
+ )
+ self.validate_identity(
+ "COPY (SELECT 42 AS a, 'hello' AS b) TO 'query.json' WITH (FORMAT JSON, ARRAY TRUE)"
+ )
+ self.validate_identity("COPY lineitem (l_orderkey) TO 'orderkey.tbl' WITH (DELIMITER '|')")
+
+ self.validate_all(
+ "VARIANCE(a)",
+ write={
+ "duckdb": "VARIANCE(a)",
+ "clickhouse": "varSamp(a)",
+ },
+ )
+ self.validate_all(
+ "STDDEV(a)",
+ write={
+ "duckdb": "STDDEV(a)",
+ "clickhouse": "stddevSamp(a)",
+ },
+ )
+ self.validate_all(
+ "DATE_TRUNC('DAY', x)",
+ write={
+ "duckdb": "DATE_TRUNC('DAY', x)",
+ "clickhouse": "DATE_TRUNC('DAY', x)",
+ },
+ )
+
def test_array_index(self):
with self.assertLogs(helper_logger) as cm:
self.validate_all(
@@ -793,7 +855,7 @@ class TestDuckDB(Validator):
write={"duckdb": "SELECT (90 * INTERVAL '1' DAY)"},
)
self.validate_all(
- "SELECT ((DATE_TRUNC('DAY', CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP) + INTERVAL (0 - (DAYOFWEEK(CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7 % 7) DAY) + (7 * INTERVAL (-5) DAY))) AS t1",
+ "SELECT ((DATE_TRUNC('DAY', CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP) + INTERVAL (0 - ((DAYOFWEEK(CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7) % 7) DAY) + (7 * INTERVAL (-5) DAY))) AS t1",
read={
"presto": "SELECT ((DATE_ADD('week', -5, DATE_TRUNC('DAY', DATE_ADD('day', (0 - MOD((DAY_OF_WEEK(CAST(CAST(DATE_TRUNC('DAY', NOW()) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7, 7)), CAST(CAST(DATE_TRUNC('DAY', NOW()) AS DATE) AS TIMESTAMP)))))) AS t1",
},
@@ -817,6 +879,9 @@ class TestDuckDB(Validator):
"duckdb": "EPOCH_MS(x)",
"presto": "FROM_UNIXTIME(CAST(x AS DOUBLE) / POW(10, 3))",
"spark": "TIMESTAMP_MILLIS(x)",
+ "clickhouse": "fromUnixTimestamp64Milli(CAST(x AS Int64))",
+ "postgres": "TO_TIMESTAMP(CAST(x AS DOUBLE PRECISION) / 10 ^ 3)",
+ "mysql": "FROM_UNIXTIME(x / POWER(10, 3))",
},
)
self.validate_all(
@@ -882,11 +947,11 @@ class TestDuckDB(Validator):
def test_sample(self):
self.validate_identity(
"SELECT * FROM tbl USING SAMPLE 5",
- "SELECT * FROM tbl USING SAMPLE (5 ROWS)",
+ "SELECT * FROM tbl USING SAMPLE RESERVOIR (5 ROWS)",
)
self.validate_identity(
"SELECT * FROM tbl USING SAMPLE 10%",
- "SELECT * FROM tbl USING SAMPLE (10 PERCENT)",
+ "SELECT * FROM tbl USING SAMPLE SYSTEM (10 PERCENT)",
)
self.validate_identity(
"SELECT * FROM tbl USING SAMPLE 10 PERCENT (bernoulli)",
@@ -910,14 +975,13 @@ class TestDuckDB(Validator):
)
self.validate_all(
- "SELECT * FROM example TABLESAMPLE (3 ROWS) REPEATABLE (82)",
+ "SELECT * FROM example TABLESAMPLE RESERVOIR (3 ROWS) REPEATABLE (82)",
read={
"duckdb": "SELECT * FROM example TABLESAMPLE (3) REPEATABLE (82)",
"snowflake": "SELECT * FROM example SAMPLE (3 ROWS) SEED (82)",
},
write={
- "duckdb": "SELECT * FROM example TABLESAMPLE (3 ROWS) REPEATABLE (82)",
- "snowflake": "SELECT * FROM example TABLESAMPLE (3 ROWS) SEED (82)",
+ "duckdb": "SELECT * FROM example TABLESAMPLE RESERVOIR (3 ROWS) REPEATABLE (82)",
},
)
@@ -936,10 +1000,6 @@ class TestDuckDB(Validator):
self.validate_identity("CAST(x AS DOUBLE)")
self.validate_identity("CAST(x AS DECIMAL(15, 4))")
self.validate_identity("CAST(x AS STRUCT(number BIGINT))")
- self.validate_identity(
- "CAST(ROW(1, ROW(1)) AS STRUCT(number BIGINT, row STRUCT(number BIGINT)))"
- )
-
self.validate_identity("CAST(x AS INT64)", "CAST(x AS BIGINT)")
self.validate_identity("CAST(x AS INT32)", "CAST(x AS INT)")
self.validate_identity("CAST(x AS INT16)", "CAST(x AS SMALLINT)")
@@ -948,6 +1008,7 @@ class TestDuckDB(Validator):
self.validate_identity("CAST(x AS CHAR)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS BPCHAR)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS STRING)", "CAST(x AS TEXT)")
+ self.validate_identity("CAST(x AS VARCHAR)", "CAST(x AS TEXT)")
self.validate_identity("CAST(x AS INT1)", "CAST(x AS TINYINT)")
self.validate_identity("CAST(x AS FLOAT4)", "CAST(x AS REAL)")
self.validate_identity("CAST(x AS FLOAT)", "CAST(x AS REAL)")
@@ -959,6 +1020,39 @@ class TestDuckDB(Validator):
self.validate_identity("CAST(x AS BINARY)", "CAST(x AS BLOB)")
self.validate_identity("CAST(x AS VARBINARY)", "CAST(x AS BLOB)")
self.validate_identity("CAST(x AS LOGICAL)", "CAST(x AS BOOLEAN)")
+ self.validate_identity(
+ "CAST(ROW(1, ROW(1)) AS STRUCT(number BIGINT, row STRUCT(number BIGINT)))"
+ )
+ self.validate_identity(
+ "123::CHARACTER VARYING",
+ "CAST(123 AS TEXT)",
+ )
+ self.validate_identity(
+ "CAST([[STRUCT_PACK(a := 1)]] AS STRUCT(a BIGINT)[][])",
+ "CAST([[{'a': 1}]] AS STRUCT(a BIGINT)[][])",
+ )
+ self.validate_identity(
+ "CAST([STRUCT_PACK(a := 1)] AS STRUCT(a BIGINT)[])",
+ "CAST([{'a': 1}] AS STRUCT(a BIGINT)[])",
+ )
+
+ self.validate_all(
+ "CAST(x AS VARCHAR(5))",
+ write={
+ "duckdb": "CAST(x AS TEXT)",
+ "postgres": "CAST(x AS TEXT)",
+ },
+ )
+ self.validate_all(
+ "CAST(x AS DECIMAL(38, 0))",
+ read={
+ "snowflake": "CAST(x AS NUMBER)",
+ "duckdb": "CAST(x AS DECIMAL(38, 0))",
+ },
+ write={
+ "snowflake": "CAST(x AS DECIMAL(38, 0))",
+ },
+ )
self.validate_all(
"CAST(x AS NUMERIC)",
write={
@@ -984,12 +1078,6 @@ class TestDuckDB(Validator):
},
)
self.validate_all(
- "123::CHARACTER VARYING",
- write={
- "duckdb": "CAST(123 AS TEXT)",
- },
- )
- self.validate_all(
"cast([[1]] as int[][])",
write={
"duckdb": "CAST([[1]] AS INT[][])",
@@ -997,7 +1085,10 @@ class TestDuckDB(Validator):
},
)
self.validate_all(
- "CAST(x AS DATE) + INTERVAL (7 * -1) DAY", read={"spark": "DATE_SUB(x, 7)"}
+ "CAST(x AS DATE) + INTERVAL (7 * -1) DAY",
+ read={
+ "spark": "DATE_SUB(x, 7)",
+ },
)
self.validate_all(
"TRY_CAST(1 AS DOUBLE)",
@@ -1024,24 +1115,6 @@ class TestDuckDB(Validator):
"snowflake": "CAST(COL AS ARRAY(BIGINT))",
},
)
- self.validate_all(
- "CAST([STRUCT_PACK(a := 1)] AS STRUCT(a BIGINT)[])",
- write={
- "duckdb": "CAST([{'a': 1}] AS STRUCT(a BIGINT)[])",
- },
- )
- self.validate_all(
- "CAST([[STRUCT_PACK(a := 1)]] AS STRUCT(a BIGINT)[][])",
- write={
- "duckdb": "CAST([[{'a': 1}]] AS STRUCT(a BIGINT)[][])",
- },
- )
-
- def test_bool_or(self):
- self.validate_all(
- "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a",
- write={"duckdb": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"},
- )
def test_encode_decode(self):
self.validate_all(
diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py
index 9215f05..0311336 100644
--- a/tests/dialects/test_hive.py
+++ b/tests/dialects/test_hive.py
@@ -334,7 +334,7 @@ class TestHive(Validator):
"hive": "DATE_ADD('2020-01-01', 1)",
"presto": "DATE_ADD('DAY', 1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))",
"redshift": "DATEADD(DAY, 1, '2020-01-01')",
- "snowflake": "DATEADD(DAY, 1, CAST(CAST('2020-01-01' AS TIMESTAMPNTZ) AS DATE))",
+ "snowflake": "DATEADD(DAY, 1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))",
"spark": "DATE_ADD('2020-01-01', 1)",
"tsql": "DATEADD(DAY, 1, CAST(CAST('2020-01-01' AS DATETIME2) AS DATE))",
},
@@ -348,7 +348,7 @@ class TestHive(Validator):
"hive": "DATE_ADD('2020-01-01', 1 * -1)",
"presto": "DATE_ADD('DAY', 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))",
"redshift": "DATEADD(DAY, 1 * -1, '2020-01-01')",
- "snowflake": "DATEADD(DAY, 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMPNTZ) AS DATE))",
+ "snowflake": "DATEADD(DAY, 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))",
"spark": "DATE_ADD('2020-01-01', 1 * -1)",
"tsql": "DATEADD(DAY, 1 * -1, CAST(CAST('2020-01-01' AS DATETIME2) AS DATE))",
},
@@ -406,7 +406,9 @@ class TestHive(Validator):
self.validate_identity("(VALUES (1 AS a, 2 AS b, 3))")
self.validate_identity("SELECT * FROM my_table TIMESTAMP AS OF DATE_ADD(CURRENT_DATE, -1)")
self.validate_identity("SELECT * FROM my_table VERSION AS OF DATE_ADD(CURRENT_DATE, -1)")
-
+ self.validate_identity(
+ "SELECT WEEKOFYEAR('2024-05-22'), DAYOFMONTH('2024-05-22'), DAYOFWEEK('2024-05-22')"
+ )
self.validate_identity(
"SELECT ROW() OVER (DISTRIBUTE BY x SORT BY y)",
"SELECT ROW() OVER (PARTITION BY x ORDER BY y)",
@@ -742,6 +744,16 @@ class TestHive(Validator):
"hive": "SELECT a, SUM(c) FROM t GROUP BY a, DATE_FORMAT(CAST(b AS TIMESTAMP), 'yyyy'), GROUPING SETS ((a, DATE_FORMAT(CAST(b AS TIMESTAMP), 'yyyy')), a)",
},
)
+ self.validate_all(
+ "SELECT TRUNC(CAST(ds AS TIMESTAMP), 'MONTH') AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'",
+ read={
+ "hive": "SELECT TRUNC(CAST(ds AS TIMESTAMP), 'MONTH') AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'",
+ "presto": "SELECT DATE_TRUNC('MONTH', CAST(ds AS TIMESTAMP)) AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'",
+ },
+ write={
+ "presto": "SELECT DATE_TRUNC('MONTH', TRY_CAST(ds AS TIMESTAMP)) AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'",
+ },
+ )
def test_escapes(self) -> None:
self.validate_identity("'\n'", "'\\n'")
diff --git a/tests/dialects/test_materialize.py b/tests/dialects/test_materialize.py
new file mode 100644
index 0000000..617a9b5
--- /dev/null
+++ b/tests/dialects/test_materialize.py
@@ -0,0 +1,77 @@
+from tests.dialects.test_dialect import Validator
+
+
+class TestMaterialize(Validator):
+ dialect = "materialize"
+
+ def test_materialize(self):
+ self.validate_all(
+ "CREATE TABLE example (id INT PRIMARY KEY, name TEXT)",
+ write={
+ "materialize": "CREATE TABLE example (id INT, name TEXT)",
+ "postgres": "CREATE TABLE example (id INT PRIMARY KEY, name TEXT)",
+ },
+ )
+ self.validate_all(
+ "INSERT INTO example (id, name) VALUES (1, 'Alice') ON CONFLICT(id) DO NOTHING",
+ write={
+ "materialize": "INSERT INTO example (id, name) VALUES (1, 'Alice')",
+ "postgres": "INSERT INTO example (id, name) VALUES (1, 'Alice') ON CONFLICT(id) DO NOTHING",
+ },
+ )
+ self.validate_all(
+ "CREATE TABLE example (id SERIAL, name TEXT)",
+ write={
+ "materialize": "CREATE TABLE example (id INT NOT NULL, name TEXT)",
+ "postgres": "CREATE TABLE example (id INT GENERATED BY DEFAULT AS IDENTITY NOT NULL, name TEXT)",
+ },
+ )
+ self.validate_all(
+ "CREATE TABLE example (id INT AUTO_INCREMENT, name TEXT)",
+ write={
+ "materialize": "CREATE TABLE example (id INT NOT NULL, name TEXT)",
+ "postgres": "CREATE TABLE example (id INT GENERATED BY DEFAULT AS IDENTITY NOT NULL, name TEXT)",
+ },
+ )
+ self.validate_all(
+ 'SELECT JSON_EXTRACT_PATH_TEXT(\'{ "farm": {"barn": { "color": "red", "feed stocked": true }}}\', \'farm\', \'barn\', \'color\')',
+ write={
+ "materialize": 'SELECT JSON_EXTRACT_PATH_TEXT(\'{ "farm": {"barn": { "color": "red", "feed stocked": true }}}\', \'farm\', \'barn\', \'color\')',
+ "postgres": 'SELECT JSON_EXTRACT_PATH_TEXT(\'{ "farm": {"barn": { "color": "red", "feed stocked": true }}}\', \'farm\', \'barn\', \'color\')',
+ },
+ )
+ self.validate_all(
+ "SELECT MAP['a' => 1]",
+ write={
+ "duckdb": "SELECT MAP {'a': 1}",
+ "materialize": "SELECT MAP['a' => 1]",
+ },
+ )
+
+ # Test now functions.
+ self.validate_identity("CURRENT_TIMESTAMP")
+ self.validate_identity("NOW()", write_sql="CURRENT_TIMESTAMP")
+ self.validate_identity("MZ_NOW()")
+
+ # Test custom timestamp type.
+ self.validate_identity("SELECT CAST(1 AS mz_timestamp)")
+
+ # Test DDL.
+ self.validate_identity("CREATE TABLE example (id INT, name LIST)")
+
+ # Test list types.
+ self.validate_identity("SELECT LIST[]")
+ self.validate_identity("SELECT LIST[1, 2, 3]")
+ self.validate_identity("SELECT LIST[LIST[1], LIST[2], NULL]")
+ self.validate_identity("SELECT CAST(LIST[1, 2, 3] AS INT LIST)")
+ self.validate_identity("SELECT CAST(NULL AS INT LIST)")
+ self.validate_identity("SELECT CAST(NULL AS INT LIST LIST LIST)")
+ self.validate_identity("SELECT LIST(SELECT 1)")
+
+ # Test map types.
+ self.validate_identity("SELECT MAP[]")
+ self.validate_identity("SELECT MAP['a' => MAP['b' => 'c']]")
+ self.validate_identity("SELECT CAST(MAP['a' => 1] AS MAP[TEXT => INT])")
+ self.validate_identity("SELECT CAST(NULL AS MAP[TEXT => INT])")
+ self.validate_identity("SELECT CAST(NULL AS MAP[TEXT => MAP[TEXT => INT]])")
+ self.validate_identity("SELECT MAP(SELECT 'a', 1)")
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index e8af5c6..280ebbf 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -1,4 +1,5 @@
from sqlglot import expressions as exp
+from sqlglot.dialects.mysql import MySQL
from tests.dialects.test_dialect import Validator
@@ -6,21 +7,11 @@ class TestMySQL(Validator):
dialect = "mysql"
def test_ddl(self):
- int_types = {"BIGINT", "INT", "MEDIUMINT", "SMALLINT", "TINYINT"}
-
- for t in int_types:
+ for t in ("BIGINT", "INT", "MEDIUMINT", "SMALLINT", "TINYINT"):
self.validate_identity(f"CREATE TABLE t (id {t} UNSIGNED)")
self.validate_identity(f"CREATE TABLE t (id {t}(10) UNSIGNED)")
self.validate_identity("CREATE TABLE t (id DECIMAL(20, 4) UNSIGNED)")
-
- self.validate_all(
- "CREATE TABLE t (id INT UNSIGNED)",
- write={
- "duckdb": "CREATE TABLE t (id UINTEGER)",
- },
- )
-
self.validate_identity("CREATE TABLE foo (a BIGINT, UNIQUE (b) USING BTREE)")
self.validate_identity("CREATE TABLE foo (id BIGINT)")
self.validate_identity("CREATE TABLE 00f (1d BIGINT)")
@@ -30,6 +21,9 @@ class TestMySQL(Validator):
self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))")
self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))")
self.validate_identity("ALTER TABLE t1 ADD COLUMN x INT, ALGORITHM=INPLACE, LOCK=EXCLUSIVE")
+ self.validate_identity("ALTER TABLE t ADD INDEX `i` (`c`)")
+ self.validate_identity("ALTER TABLE t ADD UNIQUE `i` (`c`)")
+ self.validate_identity("ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT")
self.validate_identity(
"CREATE TABLE `oauth_consumer` (`key` VARCHAR(32) NOT NULL, UNIQUE `OAUTH_CONSUMER_KEY` (`key`))"
)
@@ -70,6 +64,10 @@ class TestMySQL(Validator):
"CREATE OR REPLACE VIEW my_view AS SELECT column1 AS `boo`, column2 AS `foo` FROM my_table WHERE column3 = 'some_value' UNION SELECT q.* FROM fruits_table, JSON_TABLE(Fruits, '$[*]' COLUMNS(id VARCHAR(255) PATH '$.$id', value VARCHAR(255) PATH '$.value')) AS q",
)
self.validate_identity(
+ "ALTER TABLE t ADD KEY `i` (`c`)",
+ "ALTER TABLE t ADD INDEX `i` (`c`)",
+ )
+ self.validate_identity(
"CREATE TABLE `foo` (`id` char(36) NOT NULL DEFAULT (uuid()), PRIMARY KEY (`id`), UNIQUE KEY `id` (`id`))",
"CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))",
)
@@ -86,9 +84,6 @@ class TestMySQL(Validator):
"ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT",
)
self.validate_identity(
- "ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT",
- )
- self.validate_identity(
"CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP) DEFAULT CHARSET=utf8 ROW_FORMAT=DYNAMIC",
"CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC",
)
@@ -98,6 +93,13 @@ class TestMySQL(Validator):
)
self.validate_all(
+ "CREATE TABLE t (id INT UNSIGNED)",
+ write={
+ "duckdb": "CREATE TABLE t (id UINTEGER)",
+ "mysql": "CREATE TABLE t (id INT UNSIGNED)",
+ },
+ )
+ self.validate_all(
"CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'",
write={
"duckdb": "CREATE TABLE z (a INT)",
@@ -109,17 +111,13 @@ class TestMySQL(Validator):
self.validate_all(
"CREATE TABLE x (id int not null auto_increment, primary key (id))",
write={
+ "mysql": "CREATE TABLE x (id INT NOT NULL AUTO_INCREMENT, PRIMARY KEY (id))",
"sqlite": "CREATE TABLE x (id INTEGER NOT NULL AUTOINCREMENT PRIMARY KEY)",
},
)
- self.validate_all(
- "CREATE TABLE x (id int not null auto_increment)",
- write={
- "sqlite": "CREATE TABLE x (id INTEGER NOT NULL)",
- },
- )
def test_identity(self):
+ self.validate_identity("SELECT e.* FROM e STRAIGHT_JOIN p ON e.x = p.y")
self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1")
self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')")
self.validate_identity("SELECT @var1 := 1, @var2")
@@ -135,8 +133,6 @@ class TestMySQL(Validator):
self.validate_identity("SELECT CAST('[4,5]' AS JSON) MEMBER OF('[[3,4],[4,5]]')")
self.validate_identity("""SELECT 'ab' MEMBER OF('[23, "abc", 17, "ab", 10]')""")
self.validate_identity("""SELECT * FROM foo WHERE 'ab' MEMBER OF(content)""")
- self.validate_identity("CAST(x AS ENUM('a', 'b'))")
- self.validate_identity("CAST(x AS SET('a', 'b'))")
self.validate_identity("SELECT CURRENT_TIMESTAMP(6)")
self.validate_identity("x ->> '$.name'")
self.validate_identity("SELECT CAST(`a`.`b` AS CHAR) FROM foo")
@@ -158,9 +154,16 @@ class TestMySQL(Validator):
"SELECT * FROM t1, t2, t3 FOR SHARE OF t1 NOWAIT FOR UPDATE OF t2, t3 SKIP LOCKED"
)
self.validate_identity(
+ "REPLACE INTO table SELECT id FROM table2 WHERE cnt > 100", check_command_warning=True
+ )
+ self.validate_identity(
"""SELECT * FROM foo WHERE 3 MEMBER OF(info->'$.value')""",
"""SELECT * FROM foo WHERE 3 MEMBER OF(JSON_EXTRACT(info, '$.value'))""",
)
+ self.validate_identity(
+ "SELECT 1 AS row",
+ "SELECT 1 AS `row`",
+ )
# Index hints
self.validate_identity(
@@ -224,31 +227,52 @@ class TestMySQL(Validator):
self.validate_identity("CHAR(77, 121, 83, 81, '76')")
self.validate_identity("CHAR(77, 77.3, '77.3' USING utf8mb4)")
self.validate_identity("SELECT * FROM t1 PARTITION(p0)")
+ self.validate_identity("SELECT @var1 := 1, @var2")
+ self.validate_identity("SELECT @var1, @var2 := @var1")
+ self.validate_identity("SELECT @var1 := COUNT(*) FROM t1")
def test_types(self):
- self.validate_identity("CAST(x AS MEDIUMINT) + CAST(y AS YEAR(4))")
+ for char_type in MySQL.Generator.CHAR_CAST_MAPPING:
+ with self.subTest(f"MySQL cast into {char_type}"):
+ self.validate_identity(f"CAST(x AS {char_type.value})", "CAST(x AS CHAR)")
+
+ for signed_type in MySQL.Generator.SIGNED_CAST_MAPPING:
+ with self.subTest(f"MySQL cast into {signed_type}"):
+ self.validate_identity(f"CAST(x AS {signed_type.value})", "CAST(x AS SIGNED)")
+
+ self.validate_identity("CAST(x AS ENUM('a', 'b'))")
+ self.validate_identity("CAST(x AS SET('a', 'b'))")
+ self.validate_identity(
+ "CAST(x AS MEDIUMINT) + CAST(y AS YEAR(4))",
+ "CAST(x AS SIGNED) + CAST(y AS YEAR(4))",
+ )
+ self.validate_identity(
+ "CAST(x AS TIMESTAMP)",
+ "CAST(x AS DATETIME)",
+ )
+ self.validate_identity(
+ "CAST(x AS TIMESTAMPTZ)",
+ "TIMESTAMP(x)",
+ )
+ self.validate_identity(
+ "CAST(x AS TIMESTAMPLTZ)",
+ "TIMESTAMP(x)",
+ )
self.validate_all(
"CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT) + CAST(z AS TINYTEXT)",
- read={
- "mysql": "CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT) + CAST(z AS TINYTEXT)",
- },
write={
+ "mysql": "CAST(x AS CHAR) + CAST(y AS CHAR) + CAST(z AS CHAR)",
"spark": "CAST(x AS TEXT) + CAST(y AS TEXT) + CAST(z AS TEXT)",
},
)
self.validate_all(
"CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB) + CAST(z AS TINYBLOB)",
- read={
- "mysql": "CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB) + CAST(z AS TINYBLOB)",
- },
write={
+ "mysql": "CAST(x AS CHAR) + CAST(y AS CHAR) + CAST(z AS CHAR)",
"spark": "CAST(x AS BLOB) + CAST(y AS BLOB) + CAST(z AS BLOB)",
},
)
- self.validate_all("CAST(x AS TIMESTAMP)", write={"mysql": "CAST(x AS DATETIME)"})
- self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"mysql": "TIMESTAMP(x)"})
- self.validate_all("CAST(x AS TIMESTAMPLTZ)", write={"mysql": "TIMESTAMP(x)"})
def test_canonical_functions(self):
self.validate_identity("SELECT LEFT('str', 2)", "SELECT LEFT('str', 2)")
@@ -322,7 +346,7 @@ class TestMySQL(Validator):
write_CC = {
"bigquery": "SELECT 0xCC",
"clickhouse": "SELECT 0xCC",
- "databricks": "SELECT 204",
+ "databricks": "SELECT X'CC'",
"drill": "SELECT 204",
"duckdb": "SELECT 204",
"hive": "SELECT 204",
@@ -343,7 +367,7 @@ class TestMySQL(Validator):
write_CC_with_leading_zeros = {
"bigquery": "SELECT 0x0000CC",
"clickhouse": "SELECT 0x0000CC",
- "databricks": "SELECT 204",
+ "databricks": "SELECT X'0000CC'",
"drill": "SELECT 204",
"duckdb": "SELECT 204",
"hive": "SELECT 204",
@@ -457,63 +481,63 @@ class TestMySQL(Validator):
"SELECT DATE_FORMAT('2017-06-15', '%Y')",
write={
"mysql": "SELECT DATE_FORMAT('2017-06-15', '%Y')",
- "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'yyyy')",
+ "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'yyyy')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2017-06-15', '%m')",
write={
"mysql": "SELECT DATE_FORMAT('2017-06-15', '%m')",
- "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'mm')",
+ "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'mm')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2017-06-15', '%d')",
write={
"mysql": "SELECT DATE_FORMAT('2017-06-15', '%d')",
- "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'DD')",
+ "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'DD')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2017-06-15', '%Y-%m-%d')",
write={
"mysql": "SELECT DATE_FORMAT('2017-06-15', '%Y-%m-%d')",
- "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'yyyy-mm-DD')",
+ "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'yyyy-mm-DD')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2017-06-15 22:23:34', '%H')",
write={
"mysql": "SELECT DATE_FORMAT('2017-06-15 22:23:34', '%H')",
- "snowflake": "SELECT TO_CHAR(CAST('2017-06-15 22:23:34' AS TIMESTAMPNTZ), 'hh24')",
+ "snowflake": "SELECT TO_CHAR(CAST('2017-06-15 22:23:34' AS TIMESTAMP), 'hh24')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2017-06-15', '%w')",
write={
"mysql": "SELECT DATE_FORMAT('2017-06-15', '%w')",
- "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'dy')",
+ "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'dy')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')",
write={
"mysql": "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')",
- "snowflake": "SELECT TO_CHAR(CAST('2009-10-04 22:23:00' AS TIMESTAMPNTZ), 'DY mmmm yyyy')",
+ "snowflake": "SELECT TO_CHAR(CAST('2009-10-04 22:23:00' AS TIMESTAMP), 'DY mmmm yyyy')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2007-10-04 22:23:00', '%H:%i:%s')",
write={
"mysql": "SELECT DATE_FORMAT('2007-10-04 22:23:00', '%T')",
- "snowflake": "SELECT TO_CHAR(CAST('2007-10-04 22:23:00' AS TIMESTAMPNTZ), 'hh24:mi:ss')",
+ "snowflake": "SELECT TO_CHAR(CAST('2007-10-04 22:23:00' AS TIMESTAMP), 'hh24:mi:ss')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %a %d %m %b')",
write={
"mysql": "SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %W %d %m %b')",
- "snowflake": "SELECT TO_CHAR(CAST('1900-10-04 22:23:00' AS TIMESTAMPNTZ), 'DD yy DY DD mm mon')",
+ "snowflake": "SELECT TO_CHAR(CAST('1900-10-04 22:23:00' AS TIMESTAMP), 'DD yy DY DD mm mon')",
},
)
@@ -599,6 +623,29 @@ class TestMySQL(Validator):
def test_mysql(self):
self.validate_all(
+ "SELECT CONCAT('11', '22')",
+ read={
+ "postgres": "SELECT '11' || '22'",
+ },
+ write={
+ "mysql": "SELECT CONCAT('11', '22')",
+ "postgres": "SELECT CONCAT('11', '22')",
+ },
+ )
+ self.validate_all(
+ "SELECT department, GROUP_CONCAT(name) AS employee_names FROM data GROUP BY department",
+ read={
+ "postgres": "SELECT department, array_agg(name) AS employee_names FROM data GROUP BY department",
+ },
+ )
+ self.validate_all(
+ "SELECT UNIX_TIMESTAMP(CAST('2024-04-29 12:00:00' AS DATETIME))",
+ read={
+ "mysql": "SELECT UNIX_TIMESTAMP(CAST('2024-04-29 12:00:00' AS DATETIME))",
+ "postgres": "SELECT EXTRACT(epoch FROM TIMESTAMP '2024-04-29 12:00:00')",
+ },
+ )
+ self.validate_all(
"SELECT JSON_EXTRACT('[10, 20, [30, 40]]', '$[1]')",
read={
"sqlite": "SELECT JSON_EXTRACT('[10, 20, [30, 40]]', '$[1]')",
@@ -1109,3 +1156,23 @@ COMMENT='客户账户表'"""
"tsql": "CAST(a AS FLOAT) / NULLIF(b, 0)",
},
)
+
+ def test_timestamp_trunc(self):
+ for dialect in ("postgres", "snowflake", "duckdb", "spark", "databricks"):
+ for unit in (
+ "MILLISECOND",
+ "SECOND",
+ "DAY",
+ "MONTH",
+ "YEAR",
+ ):
+ with self.subTest(f"MySQL -> {dialect} Timestamp Trunc with unit {unit}: "):
+ self.validate_all(
+ f"DATE_ADD('0000-01-01 00:00:00', INTERVAL (TIMESTAMPDIFF({unit}, '0000-01-01 00:00:00', CAST('2001-02-16 20:38:40' AS DATETIME))) {unit})",
+ read={
+ dialect: f"DATE_TRUNC({unit}, TIMESTAMP '2001-02-16 20:38:40')",
+ },
+ write={
+ "mysql": f"DATE_ADD('0000-01-01 00:00:00', INTERVAL (TIMESTAMPDIFF({unit}, '0000-01-01 00:00:00', CAST('2001-02-16 20:38:40' AS DATETIME))) {unit})",
+ },
+ )
diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py
index 526b0b5..7cc4d72 100644
--- a/tests/dialects/test_oracle.py
+++ b/tests/dialects/test_oracle.py
@@ -1,5 +1,5 @@
-from sqlglot import exp
-from sqlglot.errors import UnsupportedError
+from sqlglot import exp, UnsupportedError
+from sqlglot.dialects.oracle import eliminate_join_marks
from tests.dialects.test_dialect import Validator
@@ -43,6 +43,7 @@ class TestOracle(Validator):
self.validate_identity("SELECT * FROM table_name SAMPLE (25) s")
self.validate_identity("SELECT COUNT(*) * 10 FROM orders SAMPLE (10) SEED (1)")
self.validate_identity("SELECT * FROM V$SESSION")
+ self.validate_identity("SELECT TO_DATE('January 15, 1989, 11:00 A.M.')")
self.validate_identity(
"SELECT last_name, employee_id, manager_id, LEVEL FROM employees START WITH employee_id = 100 CONNECT BY PRIOR employee_id = manager_id ORDER SIBLINGS BY last_name"
)
@@ -249,7 +250,8 @@ class TestOracle(Validator):
self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y")
self.validate_all(
- "SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)", write={"": UnsupportedError}
+ "SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
+ write={"": UnsupportedError},
)
self.validate_all(
"SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)",
@@ -413,3 +415,65 @@ WHERE
for query in (f"{body}{start}{connect}", f"{body}{connect}{start}"):
self.validate_identity(query, pretty, pretty=True)
+
+ def test_eliminate_join_marks(self):
+ test_sql = [
+ (
+ "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) > 5",
+ "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y > 5",
+ ),
+ (
+ "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) IS NULL",
+ "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y IS NULL",
+ ),
+ (
+ "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y IS NULL",
+ "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T2.y IS NULL",
+ ),
+ (
+ "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T1.Z > 4",
+ "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T1.Z > 4",
+ ),
+ (
+ "SELECT * FROM table1, table2 WHERE table1.column = table2.column(+)",
+ "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column",
+ ),
+ (
+ "SELECT * FROM table1, table2, table3, table4 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+) and table1.column = table4.column(+)",
+ "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column LEFT JOIN table4 ON table1.column = table4.column",
+ ),
+ (
+ "SELECT * FROM table1, table2, table3 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+)",
+ "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column",
+ ),
+ (
+ "SELECT table1.id, table2.cloumn1, table3.id FROM table1, table2, (SELECT tableInner1.id FROM tableInner1, tableInner2 WHERE tableInner1.id = tableInner2.id(+)) AS table3 WHERE table1.id = table2.id(+) and table1.id = table3.id(+)",
+ "SELECT table1.id, table2.cloumn1, table3.id FROM table1 LEFT JOIN table2 ON table1.id = table2.id LEFT JOIN (SELECT tableInner1.id FROM tableInner1 LEFT JOIN tableInner2 ON tableInner1.id = tableInner2.id) table3 ON table1.id = table3.id",
+ ),
+ # 2 join marks on one side of predicate
+ (
+ "SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + table2.column2(+)",
+ "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + table2.column2",
+ ),
+ # join mark and expression
+ (
+ "SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + 25",
+ "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + 25",
+ ),
+ ]
+
+ for original, expected in test_sql:
+ with self.subTest(original):
+ self.assertEqual(
+ eliminate_join_marks(self.parse_one(original)).sql(dialect=self.dialect),
+ expected,
+ )
+
+ def test_query_restrictions(self):
+ for restriction in ("READ ONLY", "CHECK OPTION"):
+ for constraint_name in (" CONSTRAINT name", ""):
+ with self.subTest(f"Restriction: {restriction}"):
+ self.validate_identity(f"SELECT * FROM tbl WITH {restriction}{constraint_name}")
+ self.validate_identity(
+ f"CREATE VIEW view AS SELECT * FROM tbl WITH {restriction}{constraint_name}"
+ )
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 5a55a7d..071677d 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -8,6 +8,10 @@ class TestPostgres(Validator):
dialect = "postgres"
def test_postgres(self):
+ self.validate_identity("SHA384(x)")
+ self.validate_identity(
+ 'CREATE TABLE x (a TEXT COLLATE "de_DE")', "CREATE TABLE x (a TEXT COLLATE de_DE)"
+ )
self.validate_identity("1.x", "1. AS x")
self.validate_identity("|/ x", "SQRT(x)")
self.validate_identity("||/ x", "CBRT(x)")
@@ -22,6 +26,7 @@ class TestPostgres(Validator):
self.assertIsInstance(expr, exp.AlterTable)
self.assertEqual(expr.sql(dialect="postgres"), alter_table_only)
+ self.validate_identity("STRING_TO_ARRAY('xx~^~yy~^~zz', '~^~', 'yy')")
self.validate_identity("SELECT x FROM t WHERE CAST($1 AS TEXT) = 'ok'")
self.validate_identity("SELECT * FROM t TABLESAMPLE SYSTEM (50) REPEATABLE (55)")
self.validate_identity("x @@ y")
@@ -38,8 +43,6 @@ class TestPostgres(Validator):
self.validate_identity("CAST(x AS TSTZMULTIRANGE)")
self.validate_identity("CAST(x AS DATERANGE)")
self.validate_identity("CAST(x AS DATEMULTIRANGE)")
- self.validate_identity("SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]")
- self.validate_identity("SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]")
self.validate_identity("x$")
self.validate_identity("SELECT ARRAY[1, 2, 3]")
self.validate_identity("SELECT ARRAY(SELECT 1)")
@@ -65,6 +68,10 @@ class TestPostgres(Validator):
self.validate_identity("SELECT CURRENT_USER")
self.validate_identity("SELECT * FROM ONLY t1")
self.validate_identity(
+ "SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]",
+ "SELECT ARRAY[1, 2] @> ARRAY[1, 2, 3]",
+ )
+ self.validate_identity(
"""UPDATE "x" SET "y" = CAST('0 days 60.000000 seconds' AS INTERVAL) WHERE "x"."id" IN (2, 3)"""
)
self.validate_identity(
@@ -312,9 +319,54 @@ class TestPostgres(Validator):
"MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)",
"MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)",
)
- self.validate_identity("SELECT * FROM t1*", "SELECT * FROM t1")
+ self.validate_identity(
+ "SELECT * FROM t1*",
+ "SELECT * FROM t1",
+ )
+ self.validate_identity(
+ "SELECT SUBSTRING('afafa' for 1)",
+ "SELECT SUBSTRING('afafa' FROM 1 FOR 1)",
+ )
+ self.validate_identity(
+ "CAST(x AS INT8)",
+ "CAST(x AS BIGINT)",
+ )
self.validate_all(
+ "STRING_TO_ARRAY('xx~^~yy~^~zz', '~^~', 'yy')",
+ read={
+ "doris": "SPLIT_BY_STRING('xx~^~yy~^~zz', '~^~', 'yy')",
+ },
+ write={
+ "doris": "SPLIT_BY_STRING('xx~^~yy~^~zz', '~^~', 'yy')",
+ "postgres": "STRING_TO_ARRAY('xx~^~yy~^~zz', '~^~', 'yy')",
+ },
+ )
+ self.validate_all(
+ "SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]",
+ read={
+ "duckdb": "SELECT ARRAY_HAS_ALL([1, 2, 3], [1, 2])",
+ },
+ write={
+ "duckdb": "SELECT ARRAY_HAS_ALL([1, 2, 3], [1, 2])",
+ "mysql": UnsupportedError,
+ "postgres": "SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]",
+ },
+ )
+ self.validate_all(
+ "SELECT REGEXP_REPLACE('mr .', '[^a-zA-Z]', '', 'g')",
+ write={
+ "duckdb": "SELECT REGEXP_REPLACE('mr .', '[^a-zA-Z]', '', 'g')",
+ "postgres": "SELECT REGEXP_REPLACE('mr .', '[^a-zA-Z]', '', 'g')",
+ },
+ )
+ self.validate_all(
+ "CREATE TABLE t (c INT)",
+ read={
+ "mysql": "CREATE TABLE t (c INT COMMENT 'comment 1') COMMENT = 'comment 2'",
+ },
+ )
+ self.validate_all(
'SELECT * FROM "test_table" ORDER BY RANDOM() LIMIT 5',
write={
"bigquery": "SELECT * FROM `test_table` ORDER BY RAND() NULLS LAST LIMIT 5",
@@ -449,7 +501,7 @@ class TestPostgres(Validator):
write={
"postgres": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))",
"redshift": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))",
- "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))",
+ "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))",
},
)
self.validate_all(
@@ -660,6 +712,41 @@ class TestPostgres(Validator):
)
self.assertIsInstance(self.parse_one("id::UUID"), exp.Cast)
+ self.validate_identity(
+ "COPY tbl (col1, col2) FROM 'file' WITH (FORMAT format, HEADER MATCH, FREEZE TRUE)"
+ )
+ self.validate_identity(
+ "COPY tbl (col1, col2) TO 'file' WITH (FORMAT format, HEADER MATCH, FREEZE TRUE)"
+ )
+ self.validate_identity(
+ "COPY (SELECT * FROM t) TO 'file' WITH (FORMAT format, HEADER MATCH, FREEZE TRUE)"
+ )
+ self.validate_identity("cast(a as FLOAT)", "CAST(a AS DOUBLE PRECISION)")
+ self.validate_identity("cast(a as FLOAT8)", "CAST(a AS DOUBLE PRECISION)")
+ self.validate_identity("cast(a as FLOAT4)", "CAST(a AS REAL)")
+
+ self.validate_all(
+ "1 / DIV(4, 2)",
+ read={
+ "postgres": "1 / DIV(4, 2)",
+ },
+ write={
+ "sqlite": "1 / CAST(CAST(CAST(4 AS REAL) / 2 AS INTEGER) AS REAL)",
+ "duckdb": "1 / CAST(4 // 2 AS DECIMAL)",
+ "bigquery": "1 / CAST(DIV(4, 2) AS NUMERIC)",
+ },
+ )
+ self.validate_all(
+ "CAST(DIV(4, 2) AS DECIMAL(5, 3))",
+ read={
+ "duckdb": "CAST(4 // 2 AS DECIMAL(5, 3))",
+ },
+ write={
+ "duckdb": "CAST(CAST(4 // 2 AS DECIMAL) AS DECIMAL(5, 3))",
+ "postgres": "CAST(DIV(4, 2) AS DECIMAL(5, 3))",
+ },
+ )
+
def test_ddl(self):
# Checks that user-defined types are parsed into DataType instead of Identifier
self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is(
@@ -676,6 +763,9 @@ class TestPostgres(Validator):
cdef.args["kind"].assert_is(exp.DataType)
self.assertEqual(expr.sql(dialect="postgres"), "CREATE TABLE t (x INTERVAL DAY)")
+ self.validate_identity("CREATE TABLE t (col INT[3][5])")
+ self.validate_identity("CREATE TABLE t (col INT[3])")
+ self.validate_identity("CREATE INDEX IF NOT EXISTS ON t(c)")
self.validate_identity("CREATE INDEX et_vid_idx ON et(vid) INCLUDE (fid)")
self.validate_identity("CREATE INDEX idx_x ON x USING BTREE(x, y) WHERE (NOT y IS NULL)")
self.validate_identity("CREATE TABLE test (elems JSONB[])")
@@ -698,6 +788,16 @@ class TestPostgres(Validator):
self.validate_identity("TRUNCATE TABLE t1 RESTRICT")
self.validate_identity("TRUNCATE TABLE t1 CONTINUE IDENTITY CASCADE")
self.validate_identity("TRUNCATE TABLE t1 RESTART IDENTITY RESTRICT")
+ self.validate_identity("ALTER TABLE t1 SET LOGGED")
+ self.validate_identity("ALTER TABLE t1 SET UNLOGGED")
+ self.validate_identity("ALTER TABLE t1 SET WITHOUT CLUSTER")
+ self.validate_identity("ALTER TABLE t1 SET WITHOUT OIDS")
+ self.validate_identity("ALTER TABLE t1 SET ACCESS METHOD method")
+ self.validate_identity("ALTER TABLE t1 SET TABLESPACE tablespace")
+ self.validate_identity("ALTER TABLE t1 SET (fillfactor = 5, autovacuum_enabled = TRUE)")
+ self.validate_identity(
+ "CREATE FUNCTION pymax(a INT, b INT) RETURNS INT LANGUAGE plpython3u AS $$\n if a > b:\n return a\n return b\n$$",
+ )
self.validate_identity(
"CREATE TABLE t (vid INT NOT NULL, CONSTRAINT ht_vid_nid_fid_idx EXCLUDE (INT4RANGE(vid, nid) WITH &&, INT4RANGE(fid, fid, '[]') WITH &&))"
)
@@ -763,23 +863,21 @@ class TestPostgres(Validator):
"CREATE TABLE test (x TIMESTAMP[][])",
)
self.validate_identity(
- "CREATE FUNCTION add(INT, INT) RETURNS INT SET search_path TO 'public' AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE",
- check_command_warning=True,
+ "CREATE FUNCTION add(integer, integer) RETURNS INT LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT AS 'select $1 + $2;'",
)
self.validate_identity(
- "CREATE FUNCTION x(INT) RETURNS INT SET foo FROM CURRENT",
- check_command_warning=True,
+ "CREATE FUNCTION add(integer, integer) RETURNS INT LANGUAGE SQL IMMUTABLE STRICT AS 'select $1 + $2;'"
)
self.validate_identity(
- "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT",
+ "CREATE FUNCTION add(INT, INT) RETURNS INT SET search_path TO 'public' AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE",
check_command_warning=True,
)
self.validate_identity(
- "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE CALLED ON NULL INPUT",
+ "CREATE FUNCTION x(INT) RETURNS INT SET foo FROM CURRENT",
check_command_warning=True,
)
self.validate_identity(
- "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE STRICT",
+ "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select $1 + $2;' LANGUAGE SQL IMMUTABLE CALLED ON NULL INPUT",
check_command_warning=True,
)
self.validate_identity(
@@ -790,6 +888,14 @@ class TestPostgres(Validator):
"CREATE UNLOGGED TABLE foo AS WITH t(c) AS (SELECT 1) SELECT * FROM (SELECT c AS c FROM t) AS temp"
)
self.validate_identity(
+ "CREATE TABLE t (col integer ARRAY[3])",
+ "CREATE TABLE t (col INT[3])",
+ )
+ self.validate_identity(
+ "CREATE TABLE t (col integer ARRAY)",
+ "CREATE TABLE t (col INT[])",
+ )
+ self.validate_identity(
"CREATE FUNCTION x(INT) RETURNS INT SET search_path TO 'public'",
"CREATE FUNCTION x(INT) RETURNS INT SET search_path = 'public'",
)
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index 4bafc08..ebb270a 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -10,6 +10,8 @@ class TestPresto(Validator):
self.validate_identity("SELECT * FROM x qualify", "SELECT * FROM x AS qualify")
self.validate_identity("CAST(x AS IPADDRESS)")
self.validate_identity("CAST(x AS IPPREFIX)")
+ self.validate_identity("CAST(TDIGEST_AGG(1) AS TDIGEST)")
+ self.validate_identity("CAST(x AS HYPERLOGLOG)")
self.validate_all(
"CAST(x AS INTERVAL YEAR TO MONTH)",
@@ -208,6 +210,8 @@ class TestPresto(Validator):
"bigquery": f"SELECT INTERVAL '1' {expected}",
"presto": f"SELECT INTERVAL '1' {expected}",
"trino": f"SELECT INTERVAL '1' {expected}",
+ "mysql": f"SELECT INTERVAL '1' {expected}",
+ "doris": f"SELECT INTERVAL '1' {expected}",
},
)
@@ -560,6 +564,7 @@ class TestPresto(Validator):
self.validate_all(
f"{prefix}'Hello winter \\2603 !'",
write={
+ "oracle": "U'Hello winter \\2603 !'",
"presto": "U&'Hello winter \\2603 !'",
"snowflake": "'Hello winter \\u2603 !'",
"spark": "'Hello winter \\u2603 !'",
@@ -568,6 +573,7 @@ class TestPresto(Validator):
self.validate_all(
f"{prefix}'Hello winter #2603 !' UESCAPE '#'",
write={
+ "oracle": "U'Hello winter \\2603 !'",
"presto": "U&'Hello winter #2603 !' UESCAPE '#'",
"snowflake": "'Hello winter \\u2603 !'",
"spark": "'Hello winter \\u2603 !'",
@@ -1059,6 +1065,15 @@ class TestPresto(Validator):
)
def test_json(self):
+ with self.assertLogs(helper_logger):
+ self.validate_all(
+ """SELECT JSON_EXTRACT_SCALAR(TRY(FILTER(CAST(JSON_EXTRACT('{"k1": [{"k2": "{\\"k3\\": 1}", "k4": "v"}]}', '$.k1') AS ARRAY(MAP(VARCHAR, VARCHAR))), x -> x['k4'] = 'v')[1]['k2']), '$.k3')""",
+ write={
+ "presto": """SELECT JSON_EXTRACT_SCALAR(TRY(FILTER(CAST(JSON_EXTRACT('{"k1": [{"k2": "{\\"k3\\": 1}", "k4": "v"}]}', '$.k1') AS ARRAY(MAP(VARCHAR, VARCHAR))), x -> x['k4'] = 'v')[1]['k2']), '$.k3')""",
+ "spark": """SELECT GET_JSON_OBJECT(FILTER(FROM_JSON(GET_JSON_OBJECT('{"k1": [{"k2": "{\\\\"k3\\\\": 1}", "k4": "v"}]}', '$.k1'), 'ARRAY<MAP<STRING, STRING>>'), x -> x['k4'] = 'v')[0]['k2'], '$.k3')""",
+ },
+ )
+
self.validate_all(
"SELECT CAST(JSON '[1,23,456]' AS ARRAY(INTEGER))",
write={
@@ -1073,7 +1088,6 @@ class TestPresto(Validator):
"presto": 'SELECT CAST(JSON_PARSE(\'{"k1":1,"k2":23,"k3":456}\') AS MAP(VARCHAR, INTEGER))',
},
)
-
self.validate_all(
"SELECT CAST(ARRAY [1, 23, 456] AS JSON)",
write={
diff --git a/tests/dialects/test_prql.py b/tests/dialects/test_prql.py
index 1a0eec2..5b438f1 100644
--- a/tests/dialects/test_prql.py
+++ b/tests/dialects/test_prql.py
@@ -66,3 +66,16 @@ class TestPRQL(Validator):
"from x filter (a > 1 || null != b || c != null)",
"SELECT * FROM x WHERE (a > 1 OR NOT b IS NULL OR NOT c IS NULL)",
)
+ self.validate_identity("from a aggregate { average x }", "SELECT AVG(x) FROM a")
+ self.validate_identity(
+ "from a aggregate { average x, min y, ct = sum z }",
+ "SELECT AVG(x), MIN(y), COALESCE(SUM(z), 0) AS ct FROM a",
+ )
+ self.validate_identity(
+ "from a aggregate { average x, min y, sum z }",
+ "SELECT AVG(x), MIN(y), COALESCE(SUM(z), 0) FROM a",
+ )
+ self.validate_identity(
+ "from a aggregate { min y, b = stddev x, max z }",
+ "SELECT MIN(y), STDDEV(x) AS b, MAX(z) FROM a",
+ )
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index a91f4f9..69793c7 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -6,6 +6,14 @@ class TestRedshift(Validator):
dialect = "redshift"
def test_redshift(self):
+ self.validate_identity("1 div", "1 AS div")
+ self.validate_all(
+ "SELECT SPLIT_TO_ARRAY('12,345,6789')",
+ write={
+ "postgres": "SELECT STRING_TO_ARRAY('12,345,6789', ',')",
+ "redshift": "SELECT SPLIT_TO_ARRAY('12,345,6789', ',')",
+ },
+ )
self.validate_all(
"GETDATE()",
read={
@@ -162,7 +170,7 @@ class TestRedshift(Validator):
write={
"postgres": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))",
"redshift": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))",
- "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))",
+ "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))",
},
)
self.validate_all(
@@ -271,8 +279,11 @@ class TestRedshift(Validator):
"postgres": "SELECT CAST('2008-02-28' AS TIMESTAMP) + INTERVAL '18 MONTH'",
"presto": "SELECT DATE_ADD('MONTH', 18, CAST('2008-02-28' AS TIMESTAMP))",
"redshift": "SELECT DATEADD(MONTH, 18, '2008-02-28')",
- "snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMPNTZ))",
+ "snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMP))",
"tsql": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS DATETIME2))",
+ "spark": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
+ "spark2": "SELECT ADD_MONTHS('2008-02-28', 18)",
+ "databricks": "SELECT DATE_ADD(MONTH, 18, '2008-02-28')",
},
)
self.validate_all(
@@ -362,8 +373,10 @@ class TestRedshift(Validator):
"CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO"
)
self.validate_identity(
- "COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'",
- check_command_warning=True,
+ "COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole' REGION 'us-east-1' FORMAT orc",
+ )
+ self.validate_identity(
+ "COPY customer FROM 's3://mybucket/mydata' CREDENTIALS 'aws_iam_role=arn:aws:iam::<aws-account-id>:role/<role-name>;master_symmetric_key=<root-key>' emptyasnull blanksasnull timeformat 'YYYY-MM-DD HH:MI:SS'"
)
self.validate_identity(
"UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'",
@@ -471,6 +484,10 @@ FROM (
self.validate_identity("CREATE TABLE table_backup BACKUP YES AS SELECT * FROM event")
self.validate_identity("CREATE TABLE table_backup (i INTEGER, b VARCHAR) BACKUP NO")
self.validate_identity("CREATE TABLE table_backup (i INTEGER, b VARCHAR) BACKUP YES")
+ self.validate_identity(
+ "select foo, bar from table_1 minus select foo, bar from table_2",
+ "SELECT foo, bar FROM table_1 EXCEPT SELECT foo, bar FROM table_2",
+ )
def test_create_table_like(self):
self.validate_identity(
@@ -496,7 +513,25 @@ FROM (
},
)
- def test_rename_table(self):
+ def test_alter_table(self):
+ self.validate_identity("ALTER TABLE s.t ALTER SORTKEY (c)")
+ self.validate_identity("ALTER TABLE t ALTER SORTKEY AUTO")
+ self.validate_identity("ALTER TABLE t ALTER SORTKEY NONE")
+ self.validate_identity("ALTER TABLE t ALTER SORTKEY (c1, c2)")
+ self.validate_identity("ALTER TABLE t ALTER SORTKEY (c1, c2)")
+ self.validate_identity("ALTER TABLE t ALTER COMPOUND SORTKEY (c1, c2)")
+ self.validate_identity("ALTER TABLE t ALTER DISTSTYLE ALL")
+ self.validate_identity("ALTER TABLE t ALTER DISTSTYLE EVEN")
+ self.validate_identity("ALTER TABLE t ALTER DISTSTYLE AUTO")
+ self.validate_identity("ALTER TABLE t ALTER DISTSTYLE KEY DISTKEY c")
+ self.validate_identity("ALTER TABLE t SET TABLE PROPERTIES ('a' = '5', 'b' = 'c')")
+ self.validate_identity("ALTER TABLE t SET LOCATION 's3://bucket/folder/'")
+ self.validate_identity("ALTER TABLE t SET FILE FORMAT AVRO")
+ self.validate_identity(
+ "ALTER TABLE t ALTER DISTKEY c",
+ "ALTER TABLE t ALTER DISTSTYLE KEY DISTKEY c",
+ )
+
self.validate_all(
"ALTER TABLE db.t1 RENAME TO db.t2",
write={
@@ -553,3 +588,9 @@ FROM (
self.assertEqual(
ast.sql("redshift"), "SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l"
)
+
+ def test_join_markers(self):
+ self.validate_identity(
+ "select a.foo, b.bar, a.baz from a, b where a.baz = b.baz (+)",
+ "SELECT a.foo, b.bar, a.baz FROM a, b WHERE a.baz = b.baz (+)",
+ )
diff --git a/tests/dialects/test_risingwave.py b/tests/dialects/test_risingwave.py
new file mode 100644
index 0000000..7d6d50c
--- /dev/null
+++ b/tests/dialects/test_risingwave.py
@@ -0,0 +1,14 @@
+from tests.dialects.test_dialect import Validator
+
+
+class TestRisingWave(Validator):
+ dialect = "risingwave"
+ maxDiff = None
+
+ def test_risingwave(self):
+ self.validate_all(
+ "SELECT a FROM tbl",
+ read={
+ "": "SELECT a FROM tbl FOR UPDATE",
+ },
+ )
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 1cbf68c..1286436 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -2,6 +2,7 @@ from unittest import mock
from sqlglot import UnsupportedError, exp, parse_one
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
+from sqlglot.optimizer.qualify_columns import quote_identifiers
from tests.dialects.test_dialect import Validator
@@ -10,11 +11,19 @@ class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
- self.validate_identity("ALTER TABLE table1 CLUSTER BY (name DESC)")
self.validate_identity(
- "INSERT OVERWRITE TABLE t SELECT 1", "INSERT OVERWRITE INTO t SELECT 1"
+ "transform(x, a int -> a + a + 1)",
+ "TRANSFORM(x, a -> CAST(a AS INT) + CAST(a AS INT) + 1)",
)
- self.validate_identity("SELECT rename, replace")
+
+ self.validate_all(
+ "ARRAY_CONSTRUCT_COMPACT(1, null, 2)",
+ write={
+ "spark": "ARRAY_COMPACT(ARRAY(1, NULL, 2))",
+ "snowflake": "ARRAY_CONSTRUCT_COMPACT(1, NULL, 2)",
+ },
+ )
+
expr = parse_one("SELECT APPROX_TOP_K(C4, 3, 5) FROM t")
expr.selects[0].assert_is(exp.AggFunc)
self.assertEqual(expr.sql(dialect="snowflake"), "SELECT APPROX_TOP_K(C4, 3, 5) FROM t")
@@ -40,6 +49,10 @@ WHERE
)""",
)
+ self.validate_identity("SELECT number").selects[0].assert_is(exp.Column)
+ self.validate_identity("INTERVAL '4 years, 5 months, 3 hours'")
+ self.validate_identity("ALTER TABLE table1 CLUSTER BY (name DESC)")
+ self.validate_identity("SELECT rename, replace")
self.validate_identity("SELECT TIMEADD(HOUR, 2, CAST('09:05:03' AS TIME))")
self.validate_identity("SELECT CAST(OBJECT_CONSTRUCT('a', 1) AS MAP(VARCHAR, INT))")
self.validate_identity("SELECT CAST(OBJECT_CONSTRUCT('a', 1) AS OBJECT(a CHAR NOT NULL))")
@@ -83,25 +96,28 @@ WHERE
self.validate_identity("a$b") # valid snowflake identifier
self.validate_identity("SELECT REGEXP_LIKE(a, b, c)")
self.validate_identity("CREATE TABLE foo (bar FLOAT AUTOINCREMENT START 0 INCREMENT 1)")
- self.validate_identity("ALTER TABLE IF EXISTS foo SET TAG a = 'a', b = 'b', c = 'c'")
- self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c")
- self.validate_identity("ALTER TABLE foo SET COMMENT = 'bar'")
- self.validate_identity("ALTER TABLE foo SET CHANGE_TRACKING = FALSE")
- self.validate_identity("ALTER TABLE foo UNSET DATA_RETENTION_TIME_IN_DAYS, CHANGE_TRACKING")
self.validate_identity("COMMENT IF EXISTS ON TABLE foo IS 'bar'")
self.validate_identity("SELECT CONVERT_TIMEZONE('UTC', 'America/Los_Angeles', col)")
self.validate_identity("ALTER TABLE a SWAP WITH b")
self.validate_identity("SELECT MATCH_CONDITION")
+ self.validate_identity("SELECT * REPLACE (CAST(col AS TEXT) AS scol) FROM t")
self.validate_identity(
- 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage'
+ "MERGE INTO my_db AS ids USING (SELECT new_id FROM my_model WHERE NOT col IS NULL) AS new_ids ON ids.type = new_ids.type AND ids.source = new_ids.source WHEN NOT MATCHED THEN INSERT VALUES (new_ids.new_id)"
+ )
+ self.validate_identity(
+ "INSERT OVERWRITE TABLE t SELECT 1", "INSERT OVERWRITE INTO t SELECT 1"
)
self.validate_identity(
- "SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE (0.1)"
+ 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage'
)
self.validate_identity(
"SELECT * FROM DATA AS DATA_L ASOF JOIN DATA AS DATA_R MATCH_CONDITION (DATA_L.VAL > DATA_R.VAL) ON DATA_L.ID = DATA_R.ID"
)
self.validate_identity(
+ "CURRENT_TIMESTAMP - INTERVAL '1 w' AND (1 = 1)",
+ "CURRENT_TIMESTAMP() - INTERVAL '1 WEEK' AND (1 = 1)",
+ )
+ self.validate_identity(
"REGEXP_REPLACE('target', 'pattern', '\n')",
"REGEXP_REPLACE('target', 'pattern', '\\n')",
)
@@ -109,6 +125,10 @@ WHERE
"SELECT a:from::STRING, a:from || ' test' ",
"SELECT CAST(GET_PATH(a, 'from') AS TEXT), GET_PATH(a, 'from') || ' test'",
)
+ self.validate_identity(
+ "SELECT a:select",
+ "SELECT GET_PATH(a, 'select')",
+ )
self.validate_identity("x:from", "GET_PATH(x, 'from')")
self.validate_identity(
"value:values::string::int",
@@ -143,10 +163,6 @@ WHERE
"SELECT TIMESTAMP_FROM_PARTS(d, t)",
)
self.validate_identity(
- "SELECT user_id, value FROM table_name SAMPLE ($s) SEED (0)",
- "SELECT user_id, value FROM table_name TABLESAMPLE ($s) SEED (0)",
- )
- self.validate_identity(
"SELECT v:attr[0].name FROM vartab",
"SELECT GET_PATH(v, 'attr[0].name') FROM vartab",
)
@@ -233,6 +249,38 @@ WHERE
"CAST(x AS NCHAR VARYING)",
"CAST(x AS VARCHAR)",
)
+ self.validate_identity(
+ "CREATE OR REPLACE TEMPORARY TABLE x (y NUMBER IDENTITY(0, 1))",
+ "CREATE OR REPLACE TEMPORARY TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)",
+ )
+ self.validate_identity(
+ "CREATE TEMPORARY TABLE x (y NUMBER AUTOINCREMENT(0, 1))",
+ "CREATE TEMPORARY TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)",
+ )
+ self.validate_identity(
+ "CREATE TABLE x (y NUMBER IDENTITY START 0 INCREMENT 1)",
+ "CREATE TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)",
+ )
+ self.validate_identity(
+ "ALTER TABLE foo ADD COLUMN id INT identity(1, 1)",
+ "ALTER TABLE foo ADD COLUMN id INT AUTOINCREMENT START 1 INCREMENT 1",
+ )
+ self.validate_identity(
+ "SELECT DAYOFWEEK('2016-01-02T23:39:20.123-07:00'::TIMESTAMP)",
+ "SELECT DAYOFWEEK(CAST('2016-01-02T23:39:20.123-07:00' AS TIMESTAMP))",
+ )
+ self.validate_identity(
+ "SELECT * FROM xxx WHERE col ilike '%Don''t%'",
+ "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'",
+ )
+ self.validate_identity(
+ "SELECT * EXCLUDE a, b FROM xxx",
+ "SELECT * EXCLUDE (a), b FROM xxx",
+ )
+ self.validate_identity(
+ "SELECT * RENAME a AS b, c AS d FROM xxx",
+ "SELECT * RENAME (a AS b), c AS d FROM xxx",
+ )
self.validate_all(
"OBJECT_CONSTRUCT_KEEP_NULL('key_1', 'one', 'key_2', NULL)",
@@ -247,18 +295,6 @@ WHERE
},
)
self.validate_all(
- "SELECT * FROM example TABLESAMPLE (3) SEED (82)",
- read={
- "databricks": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)",
- "duckdb": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)",
- },
- write={
- "databricks": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)",
- "duckdb": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)",
- "snowflake": "SELECT * FROM example TABLESAMPLE (3) SEED (82)",
- },
- )
- self.validate_all(
"SELECT TIME_FROM_PARTS(12, 34, 56, 987654321)",
write={
"duckdb": "SELECT MAKE_TIME(12, 34, 56 + (987654321 / 1000000000.0))",
@@ -301,10 +337,12 @@ WHERE
"""SELECT PARSE_JSON('{"fruit":"banana"}'):fruit""",
write={
"bigquery": """SELECT JSON_EXTRACT(PARSE_JSON('{"fruit":"banana"}'), '$.fruit')""",
+ "databricks": """SELECT GET_JSON_OBJECT('{"fruit":"banana"}', '$.fruit')""",
"duckdb": """SELECT JSON('{"fruit":"banana"}') -> '$.fruit'""",
"mysql": """SELECT JSON_EXTRACT('{"fruit":"banana"}', '$.fruit')""",
"presto": """SELECT JSON_EXTRACT(JSON_PARSE('{"fruit":"banana"}'), '$.fruit')""",
"snowflake": """SELECT GET_PATH(PARSE_JSON('{"fruit":"banana"}'), 'fruit')""",
+ "spark": """SELECT GET_JSON_OBJECT('{"fruit":"banana"}', '$.fruit')""",
"tsql": """SELECT ISNULL(JSON_QUERY('{"fruit":"banana"}', '$.fruit'), JSON_VALUE('{"fruit":"banana"}', '$.fruit'))""",
},
)
@@ -388,7 +426,7 @@ WHERE
"SELECT DATE_PART('year', TIMESTAMP '2020-01-01')",
write={
"hive": "SELECT EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMP))",
- "snowflake": "SELECT DATE_PART('year', CAST('2020-01-01' AS TIMESTAMPNTZ))",
+ "snowflake": "SELECT DATE_PART('year', CAST('2020-01-01' AS TIMESTAMP))",
"spark": "SELECT EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMP))",
},
)
@@ -565,60 +603,12 @@ WHERE
},
)
self.validate_all(
- "CREATE OR REPLACE TEMPORARY TABLE x (y NUMBER IDENTITY(0, 1))",
- write={
- "snowflake": "CREATE OR REPLACE TEMPORARY TABLE x (y DECIMAL AUTOINCREMENT START 0 INCREMENT 1)",
- },
- )
- self.validate_all(
- "CREATE TEMPORARY TABLE x (y NUMBER AUTOINCREMENT(0, 1))",
- write={
- "snowflake": "CREATE TEMPORARY TABLE x (y DECIMAL AUTOINCREMENT START 0 INCREMENT 1)",
- },
- )
- self.validate_all(
- "CREATE TABLE x (y NUMBER IDENTITY START 0 INCREMENT 1)",
- write={
- "snowflake": "CREATE TABLE x (y DECIMAL AUTOINCREMENT START 0 INCREMENT 1)",
- },
- )
- self.validate_all(
- "ALTER TABLE foo ADD COLUMN id INT identity(1, 1)",
- write={
- "snowflake": "ALTER TABLE foo ADD COLUMN id INT AUTOINCREMENT START 1 INCREMENT 1",
- },
- )
- self.validate_all(
- "SELECT DAYOFWEEK('2016-01-02T23:39:20.123-07:00'::TIMESTAMP)",
- write={
- "snowflake": "SELECT DAYOFWEEK(CAST('2016-01-02T23:39:20.123-07:00' AS TIMESTAMPNTZ))",
- },
- )
- self.validate_all(
- "SELECT * FROM xxx WHERE col ilike '%Don''t%'",
- write={
- "snowflake": "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'",
- },
- )
- self.validate_all(
- "SELECT * EXCLUDE a, b FROM xxx",
- write={
- "snowflake": "SELECT * EXCLUDE (a), b FROM xxx",
- },
- )
- self.validate_all(
- "SELECT * RENAME a AS b, c AS d FROM xxx",
- write={
- "snowflake": "SELECT * RENAME (a AS b), c AS d FROM xxx",
- },
- )
- self.validate_all(
- "SELECT * EXCLUDE (a, b) RENAME (c AS d, E AS F) FROM xxx",
+ "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx",
read={
"duckdb": "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx",
},
write={
- "snowflake": "SELECT * EXCLUDE (a, b) RENAME (c AS d, E AS F) FROM xxx",
+ "snowflake": "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx",
"duckdb": "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx",
},
)
@@ -689,7 +679,7 @@ WHERE
"SELECT TO_TIMESTAMP('2013-04-05 01:02:03')",
write={
"bigquery": "SELECT CAST('2013-04-05 01:02:03' AS DATETIME)",
- "snowflake": "SELECT CAST('2013-04-05 01:02:03' AS TIMESTAMPNTZ)",
+ "snowflake": "SELECT CAST('2013-04-05 01:02:03' AS TIMESTAMP)",
"spark": "SELECT CAST('2013-04-05 01:02:03' AS TIMESTAMP)",
},
)
@@ -828,6 +818,18 @@ WHERE
"snowflake": "SELECT LISTAGG(col1, ', ') WITHIN GROUP (ORDER BY col2) FROM t",
},
)
+ self.validate_all(
+ "SELECT APPROX_PERCENTILE(a, 0.5) FROM t",
+ read={
+ "trino": "SELECT APPROX_PERCENTILE(a, 1, 0.5, 0.001) FROM t",
+ "presto": "SELECT APPROX_PERCENTILE(a, 1, 0.5, 0.001) FROM t",
+ },
+ write={
+ "trino": "SELECT APPROX_PERCENTILE(a, 0.5) FROM t",
+ "presto": "SELECT APPROX_PERCENTILE(a, 0.5) FROM t",
+ "snowflake": "SELECT APPROX_PERCENTILE(a, 0.5) FROM t",
+ },
+ )
def test_null_treatment(self):
self.validate_all(
@@ -878,9 +880,9 @@ WHERE
self.validate_identity("SELECT * FROM @namespace.%table_name/path/to/file.json.gz")
self.validate_identity("SELECT * FROM '@external/location' (FILE_FORMAT => 'path.to.csv')")
self.validate_identity("PUT file:///dir/tmp.csv @%table", check_command_warning=True)
+ self.validate_identity("SELECT * FROM (SELECT a FROM @foo)")
self.validate_identity(
- 'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)',
- check_command_warning=True,
+ "SELECT * FROM (SELECT * FROM '@external/location' (FILE_FORMAT => 'path.to.csv'))"
)
self.validate_identity(
"SELECT * FROM @foo/bar (FILE_FORMAT => ds_sandbox.test.my_csv_format, PATTERN => 'test') AS bla"
@@ -903,18 +905,27 @@ WHERE
def test_sample(self):
self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)")
- self.validate_identity("SELECT * FROM testtable TABLESAMPLE (100)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)")
- self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)")
self.validate_identity(
- "SELECT i, j FROM table1 AS t1 INNER JOIN table2 AS t2 TABLESAMPLE (50) WHERE t2.j = t1.i"
+ "SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE BERNOULLI (0.1)"
+ )
+ self.validate_identity(
+ "SELECT i, j FROM table1 AS t1 INNER JOIN table2 AS t2 TABLESAMPLE BERNOULLI (50) WHERE t2.j = t1.i"
+ )
+ self.validate_identity(
+ "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE BERNOULLI (1)"
+ )
+ self.validate_identity(
+ "SELECT * FROM testtable TABLESAMPLE (10 ROWS)",
+ "SELECT * FROM testtable TABLESAMPLE BERNOULLI (10 ROWS)",
)
self.validate_identity(
- "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1)"
+ "SELECT * FROM testtable TABLESAMPLE (100)",
+ "SELECT * FROM testtable TABLESAMPLE BERNOULLI (100)",
)
self.validate_identity(
"SELECT * FROM testtable SAMPLE (10)",
- "SELECT * FROM testtable TABLESAMPLE (10)",
+ "SELECT * FROM testtable TABLESAMPLE BERNOULLI (10)",
)
self.validate_identity(
"SELECT * FROM testtable SAMPLE ROW (0)",
@@ -924,8 +935,30 @@ WHERE
"SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)",
"SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)",
)
+ self.validate_identity(
+ "SELECT user_id, value FROM table_name SAMPLE BERNOULLI ($s) SEED (0)",
+ "SELECT user_id, value FROM table_name TABLESAMPLE BERNOULLI ($s) SEED (0)",
+ )
self.validate_all(
+ "SELECT * FROM example TABLESAMPLE BERNOULLI (3) SEED (82)",
+ read={
+ "duckdb": "SELECT * FROM example TABLESAMPLE BERNOULLI (3 PERCENT) REPEATABLE (82)",
+ },
+ write={
+ "databricks": "SELECT * FROM example TABLESAMPLE (3 PERCENT) REPEATABLE (82)",
+ "duckdb": "SELECT * FROM example TABLESAMPLE BERNOULLI (3 PERCENT) REPEATABLE (82)",
+ "snowflake": "SELECT * FROM example TABLESAMPLE BERNOULLI (3) SEED (82)",
+ },
+ )
+ self.validate_all(
+ "SELECT * FROM test AS _tmp TABLESAMPLE (5)",
+ write={
+ "postgres": "SELECT * FROM test AS _tmp TABLESAMPLE BERNOULLI (5)",
+ "snowflake": "SELECT * FROM test AS _tmp TABLESAMPLE BERNOULLI (5)",
+ },
+ )
+ self.validate_all(
"""
SELECT i, j
FROM
@@ -934,7 +967,7 @@ WHERE
table2 AS t2 SAMPLE (50) -- 50% of rows in table2
WHERE t2.j = t1.i""",
write={
- "snowflake": "SELECT i, j FROM table1 AS t1 TABLESAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 TABLESAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i",
+ "snowflake": "SELECT i, j FROM table1 AS t1 TABLESAMPLE BERNOULLI (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 TABLESAMPLE BERNOULLI (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i",
},
)
self.validate_all(
@@ -946,7 +979,7 @@ WHERE
self.validate_all(
"SELECT * FROM (SELECT * FROM t1 join t2 on t1.a = t2.c) SAMPLE (1)",
write={
- "snowflake": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1)",
+ "snowflake": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE BERNOULLI (1)",
"spark": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1 PERCENT)",
},
)
@@ -955,12 +988,16 @@ WHERE
self.validate_identity("SELECT CAST('12:00:00' AS TIME)")
self.validate_identity("SELECT DATE_PART(month, a)")
- self.validate_all(
- "SELECT CAST(a AS TIMESTAMP)",
- write={
- "snowflake": "SELECT CAST(a AS TIMESTAMPNTZ)",
- },
- )
+ for data_type in (
+ "TIMESTAMP",
+ "TIMESTAMPLTZ",
+ "TIMESTAMPNTZ",
+ ):
+ self.validate_identity(f"CAST(a AS {data_type})")
+
+ self.validate_identity("CAST(a AS TIMESTAMP_NTZ)", "CAST(a AS TIMESTAMPNTZ)")
+ self.validate_identity("CAST(a AS TIMESTAMP_LTZ)", "CAST(a AS TIMESTAMPLTZ)")
+
self.validate_all(
"SELECT a::TIMESTAMP_LTZ(9)",
write={
@@ -1000,14 +1037,14 @@ WHERE
self.validate_all(
"SELECT DATE_PART(epoch_second, foo) as ddate from table_name",
write={
- "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) AS ddate FROM table_name",
+ "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMP)) AS ddate FROM table_name",
"presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) AS ddate FROM table_name",
},
)
self.validate_all(
"SELECT DATE_PART(epoch_milliseconds, foo) as ddate from table_name",
write={
- "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) * 1000 AS ddate FROM table_name",
+ "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMP)) * 1000 AS ddate FROM table_name",
"presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) * 1000 AS ddate FROM table_name",
},
)
@@ -1138,7 +1175,7 @@ WHERE
)
self.validate_identity(
"SELECT * FROM my_table AT (TIMESTAMP => 'Fri, 01 May 2015 16:20:00 -0700'::timestamp)",
- "SELECT * FROM my_table AT (TIMESTAMP => CAST('Fri, 01 May 2015 16:20:00 -0700' AS TIMESTAMPNTZ))",
+ "SELECT * FROM my_table AT (TIMESTAMP => CAST('Fri, 01 May 2015 16:20:00 -0700' AS TIMESTAMP))",
)
self.validate_identity(
"SELECT * FROM my_table AT(TIMESTAMP => 'Fri, 01 May 2015 16:20:00 -0700'::timestamp_tz)",
@@ -1160,6 +1197,25 @@ WHERE
)
def test_ddl(self):
+ for constraint_prefix in ("WITH ", ""):
+ with self.subTest(f"Constraint prefix: {constraint_prefix}"):
+ self.validate_identity(
+ f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p.q.r)",
+ "CREATE TABLE t (id INT MASKING POLICY p.q.r)",
+ )
+ self.validate_identity(
+ f"CREATE TABLE t (id INT {constraint_prefix}MASKING POLICY p USING (c1, c2, c3))",
+ "CREATE TABLE t (id INT MASKING POLICY p USING (c1, c2, c3))",
+ )
+ self.validate_identity(
+ f"CREATE TABLE t (id INT {constraint_prefix}PROJECTION POLICY p.q.r)",
+ "CREATE TABLE t (id INT PROJECTION POLICY p.q.r)",
+ )
+ self.validate_identity(
+ f"CREATE TABLE t (id INT {constraint_prefix}TAG (key1='value_1', key2='value_2'))",
+ "CREATE TABLE t (id INT TAG (key1='value_1', key2='value_2'))",
+ )
+
self.validate_identity(
"""create external table et2(
col1 date as (parse_json(metadata$external_table_partition):COL1::date),
@@ -1169,7 +1225,7 @@ WHERE
location=@s2/logs/
partition_type = user_specified
file_format = (type = parquet)""",
- "CREATE EXTERNAL TABLE et2 (col1 DATE AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL1') AS DATE)), col2 VARCHAR AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL2') AS VARCHAR)), col3 DECIMAL AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL3') AS DECIMAL))) LOCATION @s2/logs/ PARTITION BY (col1, col2, col3) partition_type=user_specified file_format=(type = parquet)",
+ "CREATE EXTERNAL TABLE et2 (col1 DATE AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL1') AS DATE)), col2 VARCHAR AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL2') AS VARCHAR)), col3 DECIMAL(38, 0) AS (CAST(GET_PATH(PARSE_JSON(metadata$external_table_partition), 'COL3') AS DECIMAL(38, 0)))) LOCATION @s2/logs/ PARTITION BY (col1, col2, col3) partition_type=user_specified file_format=(type = parquet)",
)
self.validate_identity("CREATE OR REPLACE VIEW foo (uid) COPY GRANTS AS (SELECT 1)")
self.validate_identity("CREATE TABLE geospatial_table (id INT, g GEOGRAPHY)")
@@ -1178,8 +1234,17 @@ WHERE
self.validate_identity("CREATE SCHEMA mytestschema_clone CLONE testschema")
self.validate_identity("CREATE TABLE IDENTIFIER('foo') (COLUMN1 VARCHAR, COLUMN2 VARCHAR)")
self.validate_identity("CREATE TABLE IDENTIFIER($foo) (col1 VARCHAR, col2 VARCHAR)")
+ self.validate_identity("CREATE TAG cost_center ALLOWED_VALUES 'a', 'b'")
+ self.validate_identity("CREATE WAREHOUSE x").this.assert_is(exp.Identifier)
+ self.validate_identity("CREATE STREAMLIT x").this.assert_is(exp.Identifier)
+ self.validate_identity(
+ "CREATE OR REPLACE TAG IF NOT EXISTS cost_center COMMENT='cost_center tag'"
+ ).this.assert_is(exp.Identifier)
self.validate_identity(
- "DROP function my_udf (OBJECT(city VARCHAR, zipcode DECIMAL, val ARRAY(BOOLEAN)))"
+ "ALTER TABLE db_name.schmaName.tblName ADD COLUMN COLUMN_1 VARCHAR NOT NULL TAG (key1='value_1')"
+ )
+ self.validate_identity(
+ "DROP FUNCTION my_udf (OBJECT(city VARCHAR, zipcode DECIMAL(38, 0), val ARRAY(BOOLEAN)))"
)
self.validate_identity(
"CREATE TABLE orders_clone_restore CLONE orders AT (TIMESTAMP => TO_TIMESTAMP_TZ('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss'))"
@@ -1200,13 +1265,27 @@ WHERE
"CREATE ICEBERG TABLE my_iceberg_table (amount ARRAY(INT)) CATALOG='SNOWFLAKE' EXTERNAL_VOLUME='my_external_volume' BASE_LOCATION='my/relative/path/from/extvol'"
)
self.validate_identity(
- "CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL, val ARRAY(BOOLEAN))) RETURNS VARCHAR AS $$ SELECT 'foo' $$",
- "CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL, val ARRAY(BOOLEAN))) RETURNS VARCHAR AS ' SELECT \\'foo\\' '",
+ """CREATE OR REPLACE FUNCTION ibis_udfs.public.object_values("obj" OBJECT) RETURNS ARRAY LANGUAGE JAVASCRIPT RETURNS NULL ON NULL INPUT AS ' return Object.values(obj) '"""
+ )
+ self.validate_identity(
+ """CREATE OR REPLACE FUNCTION ibis_udfs.public.object_values("obj" OBJECT) RETURNS ARRAY LANGUAGE JAVASCRIPT STRICT AS ' return Object.values(obj) '"""
+ )
+ self.validate_identity(
+ "CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL(38, 0), val ARRAY(BOOLEAN))) RETURNS VARCHAR AS $$ SELECT 'foo' $$",
+ "CREATE OR REPLACE FUNCTION my_udf(location OBJECT(city VARCHAR, zipcode DECIMAL(38, 0), val ARRAY(BOOLEAN))) RETURNS VARCHAR AS ' SELECT \\'foo\\' '",
)
self.validate_identity(
"CREATE OR REPLACE FUNCTION my_udtf(foo BOOLEAN) RETURNS TABLE(col1 ARRAY(INT)) AS $$ WITH t AS (SELECT CAST([1, 2, 3] AS ARRAY(INT)) AS c) SELECT c FROM t $$",
"CREATE OR REPLACE FUNCTION my_udtf(foo BOOLEAN) RETURNS TABLE (col1 ARRAY(INT)) AS ' WITH t AS (SELECT CAST([1, 2, 3] AS ARRAY(INT)) AS c) SELECT c FROM t '",
)
+ self.validate_identity(
+ "CREATE SEQUENCE seq1 WITH START=1, INCREMENT=1 ORDER",
+ "CREATE SEQUENCE seq1 START=1 INCREMENT BY 1 ORDER",
+ )
+ self.validate_identity(
+ "CREATE SEQUENCE seq1 WITH START=1 INCREMENT=1 ORDER",
+ "CREATE SEQUENCE seq1 START=1 INCREMENT=1 ORDER",
+ )
self.validate_all(
"CREATE TABLE orders_clone CLONE orders",
@@ -1237,6 +1316,20 @@ WHERE
write={"snowflake": "CREATE TABLE a (b INT)"},
)
+ for action in ("SET", "DROP"):
+ with self.subTest(f"ALTER COLUMN {action} NOT NULL"):
+ self.validate_all(
+ f"""
+ ALTER TABLE a
+ ALTER COLUMN my_column {action} NOT NULL;
+ """,
+ write={
+ "snowflake": f"ALTER TABLE a ALTER COLUMN my_column {action} NOT NULL",
+ "duckdb": f"ALTER TABLE a ALTER COLUMN my_column {action} NOT NULL",
+ "postgres": f"ALTER TABLE a ALTER COLUMN my_column {action} NOT NULL",
+ },
+ )
+
def test_user_defined_functions(self):
self.validate_all(
"CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS $$ SELECT 1 $$",
@@ -1432,6 +1525,9 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene
)
def test_values(self):
+ select = exp.select("*").from_("values (map(['a'], [1]))")
+ self.assertEqual(select.sql("snowflake"), "SELECT * FROM (SELECT OBJECT_CONSTRUCT('a', 1))")
+
self.validate_all(
'SELECT "c0", "c1" FROM (VALUES (1, 2), (3, 4)) AS "t0"("c0", "c1")',
read={
@@ -1581,7 +1677,7 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene
"REGEXP_REPLACE(subject, pattern, replacement, position, occurrence, parameters)",
write={
"bigquery": "REGEXP_REPLACE(subject, pattern, replacement)",
- "duckdb": "REGEXP_REPLACE(subject, pattern, replacement)",
+ "duckdb": "REGEXP_REPLACE(subject, pattern, replacement, parameters)",
"hive": "REGEXP_REPLACE(subject, pattern, replacement)",
"snowflake": "REGEXP_REPLACE(subject, pattern, replacement, position, occurrence, parameters)",
"spark": "REGEXP_REPLACE(subject, pattern, replacement, position)",
@@ -1785,7 +1881,7 @@ STORAGE_AWS_ROLE_ARN='arn:aws:iam::001234567890:role/myrole'
ENABLED=TRUE
STORAGE_ALLOWED_LOCATIONS=('s3://mybucket1/path1/', 's3://mybucket2/path2/')""",
pretty=True,
- )
+ ).this.assert_is(exp.Identifier)
def test_swap(self):
ast = parse_one("ALTER TABLE a SWAP WITH b", read="snowflake")
@@ -1827,3 +1923,101 @@ STORAGE_ALLOWED_LOCATIONS=('s3://mybucket1/path1/', 's3://mybucket2/path2/')""",
expression = annotate_types(expression)
self.assertEqual(expression.sql(dialect="snowflake"), "SELECT TRY_CAST(FOO() AS TEXT)")
+
+ def test_copy(self):
+ self.validate_identity("COPY INTO test (c1) FROM (SELECT $1.c1 FROM @mystage)")
+ self.validate_identity(
+ """COPY INTO temp FROM @random_stage/path/ FILE_FORMAT = (TYPE=CSV FIELD_DELIMITER='|' NULL_IF=('str1', 'str2') FIELD_OPTIONALLY_ENCLOSED_BY='"' TIMESTAMP_FORMAT='TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' DATE_FORMAT='TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' BINARY_FORMAT=BASE64) VALIDATION_MODE = 'RETURN_3_ROWS'"""
+ )
+ self.validate_identity(
+ """COPY INTO load1 FROM @%load1/data1/ CREDENTIALS = (AWS_KEY_ID='id' AWS_SECRET_KEY='key' AWS_TOKEN='token') FILES = ('test1.csv', 'test2.csv') FORCE = TRUE"""
+ )
+ self.validate_identity(
+ """COPY INTO mytable FROM 'azure://myaccount.blob.core.windows.net/mycontainer/data/files' CREDENTIALS = (AZURE_SAS_TOKEN='token') ENCRYPTION = (TYPE='AZURE_CSE' MASTER_KEY='kPx...') FILE_FORMAT = (FORMAT_NAME=my_csv_format)"""
+ )
+ self.validate_identity(
+ """COPY INTO mytable (col1, col2) FROM 's3://mybucket/data/files' STORAGE_INTEGRATION = "storage" ENCRYPTION = (TYPE='NONE' MASTER_KEY='key') FILES = ('file1', 'file2') PATTERN = 'pattern' FILE_FORMAT = (FORMAT_NAME=my_csv_format NULL_IF=('')) PARSE_HEADER = TRUE"""
+ )
+ self.validate_identity(
+ """COPY INTO @my_stage/result/data FROM (SELECT * FROM orderstiny) FILE_FORMAT = (TYPE='csv')"""
+ )
+ self.validate_all(
+ """COPY INTO 's3://example/data.csv'
+ FROM EXTRA.EXAMPLE.TABLE
+ CREDENTIALS = ()
+ FILE_FORMAT = (TYPE = CSV COMPRESSION = NONE NULL_IF = ('') FIELD_OPTIONALLY_ENCLOSED_BY = '"')
+ HEADER = TRUE
+ OVERWRITE = TRUE
+ SINGLE = TRUE
+ """,
+ write={
+ "": """COPY INTO 's3://example/data.csv'
+FROM EXTRA.EXAMPLE.TABLE
+CREDENTIALS = () WITH (
+ FILE_FORMAT = (TYPE=CSV COMPRESSION=NONE NULL_IF=(
+ ''
+ ) FIELD_OPTIONALLY_ENCLOSED_BY='"'),
+ HEADER TRUE,
+ OVERWRITE TRUE,
+ SINGLE TRUE
+)""",
+ "snowflake": """COPY INTO 's3://example/data.csv'
+FROM EXTRA.EXAMPLE.TABLE
+CREDENTIALS = ()
+FILE_FORMAT = (TYPE=CSV COMPRESSION=NONE NULL_IF=(
+ ''
+) FIELD_OPTIONALLY_ENCLOSED_BY='"')
+HEADER = TRUE
+OVERWRITE = TRUE
+SINGLE = TRUE""",
+ },
+ pretty=True,
+ )
+ self.validate_all(
+ """COPY INTO 's3://example/data.csv'
+ FROM EXTRA.EXAMPLE.TABLE
+ STORAGE_INTEGRATION = S3_INTEGRATION
+ FILE_FORMAT = (TYPE=CSV COMPRESSION=NONE NULL_IF=('') FIELD_OPTIONALLY_ENCLOSED_BY='"')
+ HEADER = TRUE
+ OVERWRITE = TRUE
+ SINGLE = TRUE
+ """,
+ write={
+ "": """COPY INTO 's3://example/data.csv' FROM EXTRA.EXAMPLE.TABLE STORAGE_INTEGRATION = S3_INTEGRATION WITH (FILE_FORMAT = (TYPE=CSV COMPRESSION=NONE NULL_IF=('') FIELD_OPTIONALLY_ENCLOSED_BY='"'), HEADER TRUE, OVERWRITE TRUE, SINGLE TRUE)""",
+ "snowflake": """COPY INTO 's3://example/data.csv' FROM EXTRA.EXAMPLE.TABLE STORAGE_INTEGRATION = S3_INTEGRATION FILE_FORMAT = (TYPE=CSV COMPRESSION=NONE NULL_IF=('') FIELD_OPTIONALLY_ENCLOSED_BY='"') HEADER = TRUE OVERWRITE = TRUE SINGLE = TRUE""",
+ },
+ )
+
+ copy_ast = parse_one(
+ """COPY INTO 's3://example/contacts.csv' FROM db.tbl STORAGE_INTEGRATION = PROD_S3_SIDETRADE_INTEGRATION FILE_FORMAT = (FORMAT_NAME=my_csv_format TYPE=CSV COMPRESSION=NONE NULL_IF=('') FIELD_OPTIONALLY_ENCLOSED_BY='"') MATCH_BY_COLUMN_NAME = CASE_SENSITIVE OVERWRITE = TRUE SINGLE = TRUE INCLUDE_METADATA = (col1 = METADATA$START_SCAN_TIME)""",
+ read="snowflake",
+ )
+ self.assertEqual(
+ quote_identifiers(copy_ast, dialect="snowflake").sql(dialect="snowflake"),
+ """COPY INTO 's3://example/contacts.csv' FROM "db"."tbl" STORAGE_INTEGRATION = "PROD_S3_SIDETRADE_INTEGRATION" FILE_FORMAT = (FORMAT_NAME="my_csv_format" TYPE=CSV COMPRESSION=NONE NULL_IF=('') FIELD_OPTIONALLY_ENCLOSED_BY='"') MATCH_BY_COLUMN_NAME = CASE_SENSITIVE OVERWRITE = TRUE SINGLE = TRUE INCLUDE_METADATA = ("col1" = "METADATA$START_SCAN_TIME")""",
+ )
+
+ def test_querying_semi_structured_data(self):
+ self.validate_identity("SELECT $1")
+ self.validate_identity("SELECT $1.elem")
+
+ self.validate_identity("SELECT $1:a.b", "SELECT GET_PATH($1, 'a.b')")
+ self.validate_identity("SELECT t.$23:a.b", "SELECT GET_PATH(t.$23, 'a.b')")
+ self.validate_identity("SELECT t.$17:a[0].b[0].c", "SELECT GET_PATH(t.$17, 'a[0].b[0].c')")
+
+ def test_alter_set_unset(self):
+ self.validate_identity("ALTER TABLE tbl SET DATA_RETENTION_TIME_IN_DAYS=1")
+ self.validate_identity("ALTER TABLE tbl SET DEFAULT_DDL_COLLATION='test'")
+ self.validate_identity("ALTER TABLE foo SET COMMENT='bar'")
+ self.validate_identity("ALTER TABLE foo SET CHANGE_TRACKING=FALSE")
+ self.validate_identity("ALTER TABLE table1 SET TAG foo.bar = 'baz'")
+ self.validate_identity("ALTER TABLE IF EXISTS foo SET TAG a = 'a', b = 'b', c = 'c'")
+ self.validate_identity(
+ """ALTER TABLE tbl SET STAGE_FILE_FORMAT = (TYPE=CSV FIELD_DELIMITER='|' NULL_IF=('') FIELD_OPTIONALLY_ENCLOSED_BY='"' TIMESTAMP_FORMAT='TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' DATE_FORMAT='TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' BINARY_FORMAT=BASE64)""",
+ )
+ self.validate_identity(
+ """ALTER TABLE tbl SET STAGE_COPY_OPTIONS = (ON_ERROR=SKIP_FILE SIZE_LIMIT=5 PURGE=TRUE MATCH_BY_COLUMN_NAME=CASE_SENSITIVE)"""
+ )
+
+ self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c")
+ self.validate_identity("ALTER TABLE foo UNSET DATA_RETENTION_TIME_IN_DAYS, CHANGE_TRACKING")
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index 7534573..bff91bf 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -325,7 +325,7 @@ TBLPROPERTIES (
write={
"clickhouse": "WITH tbl AS (SELECT 1 AS id, 'eggy' AS name UNION ALL SELECT NULL AS id, 'jake' AS name) SELECT COUNT(DISTINCT id, name) AS cnt FROM tbl",
"databricks": "WITH tbl AS (SELECT 1 AS id, 'eggy' AS name UNION ALL SELECT NULL AS id, 'jake' AS name) SELECT COUNT(DISTINCT id, name) AS cnt FROM tbl",
- "doris": "WITH tbl AS (SELECT 1 AS id, 'eggy' AS name UNION ALL SELECT NULL AS id, 'jake' AS name) SELECT COUNT(DISTINCT id, name) AS cnt FROM tbl",
+ "doris": "WITH tbl AS (SELECT 1 AS id, 'eggy' AS `name` UNION ALL SELECT NULL AS id, 'jake' AS `name`) SELECT COUNT(DISTINCT id, `name`) AS cnt FROM tbl",
"duckdb": "WITH tbl AS (SELECT 1 AS id, 'eggy' AS name UNION ALL SELECT NULL AS id, 'jake' AS name) SELECT COUNT(DISTINCT CASE WHEN id IS NULL THEN NULL WHEN name IS NULL THEN NULL ELSE (id, name) END) AS cnt FROM tbl",
"hive": "WITH tbl AS (SELECT 1 AS id, 'eggy' AS name UNION ALL SELECT NULL AS id, 'jake' AS name) SELECT COUNT(DISTINCT id, name) AS cnt FROM tbl",
"mysql": "WITH tbl AS (SELECT 1 AS id, 'eggy' AS name UNION ALL SELECT NULL AS id, 'jake' AS name) SELECT COUNT(DISTINCT id, name) AS cnt FROM tbl",
@@ -343,7 +343,7 @@ TBLPROPERTIES (
"postgres": "SELECT CAST('2016-08-31' AS TIMESTAMP) AT TIME ZONE 'Asia/Seoul' AT TIME ZONE 'UTC'",
"presto": "SELECT WITH_TIMEZONE(CAST('2016-08-31' AS TIMESTAMP), 'Asia/Seoul') AT TIME ZONE 'UTC'",
"redshift": "SELECT CAST('2016-08-31' AS TIMESTAMP) AT TIME ZONE 'Asia/Seoul' AT TIME ZONE 'UTC'",
- "snowflake": "SELECT CONVERT_TIMEZONE('Asia/Seoul', 'UTC', CAST('2016-08-31' AS TIMESTAMPNTZ))",
+ "snowflake": "SELECT CONVERT_TIMEZONE('Asia/Seoul', 'UTC', CAST('2016-08-31' AS TIMESTAMP))",
"spark": "SELECT TO_UTC_TIMESTAMP(CAST('2016-08-31' AS TIMESTAMP), 'Asia/Seoul')",
},
)
@@ -523,7 +523,14 @@ TBLPROPERTIES (
},
)
- for data_type in ("BOOLEAN", "DATE", "DOUBLE", "FLOAT", "INT", "TIMESTAMP"):
+ for data_type in (
+ "BOOLEAN",
+ "DATE",
+ "DOUBLE",
+ "FLOAT",
+ "INT",
+ "TIMESTAMP",
+ ):
self.validate_all(
f"{data_type}(x)",
write={
@@ -531,6 +538,16 @@ TBLPROPERTIES (
"spark": f"CAST(x AS {data_type})",
},
)
+
+ for ts_suffix in ("NTZ", "LTZ"):
+ self.validate_all(
+ f"TIMESTAMP_{ts_suffix}(x)",
+ write={
+ "": f"CAST(x AS TIMESTAMP{ts_suffix})",
+ "spark": f"CAST(x AS TIMESTAMP_{ts_suffix})",
+ },
+ )
+
self.validate_all(
"STRING(x)",
write={
@@ -546,6 +563,7 @@ TBLPROPERTIES (
"SELECT DATE_ADD(my_date_column, 1)",
write={
"spark": "SELECT DATE_ADD(my_date_column, 1)",
+ "spark2": "SELECT DATE_ADD(my_date_column, 1)",
"bigquery": "SELECT DATE_ADD(CAST(CAST(my_date_column AS DATETIME) AS DATE), INTERVAL 1 DAY)",
},
)
@@ -658,6 +676,16 @@ TBLPROPERTIES (
"spark": "SELECT ARRAY_SORT(x)",
},
)
+ self.validate_all(
+ "SELECT DATE_ADD(MONTH, 20, col)",
+ read={
+ "spark": "SELECT TIMESTAMPADD(MONTH, 20, col)",
+ },
+ write={
+ "spark": "SELECT DATE_ADD(MONTH, 20, col)",
+ "databricks": "SELECT DATE_ADD(MONTH, 20, col)",
+ },
+ )
def test_bool_or(self):
self.validate_all(
diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py
index f3cde0b..46bbadc 100644
--- a/tests/dialects/test_sqlite.py
+++ b/tests/dialects/test_sqlite.py
@@ -202,6 +202,7 @@ class TestSQLite(Validator):
"CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
read={
"mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)",
+ "postgres": "CREATE TABLE z (a INT GENERATED BY DEFAULT AS IDENTITY NOT NULL UNIQUE PRIMARY KEY)",
},
write={
"sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)",
diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py
index 010b683..74d5f88 100644
--- a/tests/dialects/test_teradata.py
+++ b/tests/dialects/test_teradata.py
@@ -38,7 +38,7 @@ class TestTeradata(Validator):
"UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1",
write={
"teradata": "UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1",
- "mysql": "UPDATE A SET col2 = '' FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1",
+ "mysql": "UPDATE A SET col2 = '' FROM `schema`.tableA AS A, (SELECT col1 FROM `schema`.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1",
},
)
diff --git a/tests/dialects/test_trino.py b/tests/dialects/test_trino.py
new file mode 100644
index 0000000..ccc1407
--- /dev/null
+++ b/tests/dialects/test_trino.py
@@ -0,0 +1,18 @@
+from tests.dialects.test_dialect import Validator
+
+
+class TestTrino(Validator):
+ dialect = "trino"
+
+ def test_trim(self):
+ self.validate_identity("SELECT TRIM('!' FROM '!foo!')")
+ self.validate_identity("SELECT TRIM(BOTH '$' FROM '$var$')")
+ self.validate_identity("SELECT TRIM(TRAILING 'ER' FROM UPPER('worker'))")
+ self.validate_identity(
+ "SELECT TRIM(LEADING FROM ' abcd')",
+ "SELECT LTRIM(' abcd')",
+ )
+ self.validate_identity(
+ "SELECT TRIM('!foo!', '!')",
+ "SELECT TRIM('!' FROM '!foo!')",
+ )
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index 4a475f6..7455650 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -1,13 +1,20 @@
-from sqlglot import exp, parse, parse_one
-from sqlglot.parser import logger as parser_logger
+from sqlglot import exp, parse
from tests.dialects.test_dialect import Validator
from sqlglot.errors import ParseError
+from sqlglot.optimizer.annotate_types import annotate_types
class TestTSQL(Validator):
dialect = "tsql"
def test_tsql(self):
+ self.assertEqual(
+ annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"),
+ "SELECT 1 WHERE EXISTS(SELECT 1)",
+ )
+
+ self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c")
+ self.validate_identity("DROP view a.b.c", "DROP VIEW b.c")
self.validate_identity("ROUND(x, 1, 0)")
self.validate_identity("EXEC MyProc @id=7, @name='Lochristi'", check_command_warning=True)
# https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN
@@ -29,6 +36,9 @@ class TestTSQL(Validator):
self.validate_identity("1 AND true", "1 <> 0 AND (1 = 1)")
self.validate_identity("CAST(x AS int) OR y", "CAST(x AS INTEGER) <> 0 OR y <> 0")
self.validate_identity("TRUNCATE TABLE t1 WITH (PARTITIONS(1, 2 TO 5, 10 TO 20, 84))")
+ self.validate_identity(
+ "COPY INTO test_1 FROM 'path' WITH (FORMAT_NAME = test, FILE_TYPE = 'CSV', CREDENTIAL = (IDENTITY='Shared Access Signature', SECRET='token'), FIELDTERMINATOR = ';', ROWTERMINATOR = '0X0A', ENCODING = 'UTF8', DATEFORMAT = 'ymd', MAXERRORS = 10, ERRORFILE = 'errorsfolder', IDENTITY_INSERT = 'ON')"
+ )
self.validate_all(
"SELECT IIF(cond <> 0, 'True', 'False')",
@@ -188,16 +198,9 @@ class TestTSQL(Validator):
)
self.validate_all(
- """
- CREATE TABLE x(
- [zip_cd] [varchar](5) NULL NOT FOR REPLICATION,
- [zip_cd_mkey] [varchar](5) NOT NULL,
- CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC)
- WITH (PAD_INDEX = ON, STATISTICS_NORECOMPUTE = OFF) ON [INDEX]
- ) ON [SECONDARY]
- """,
+ """CREATE TABLE x ([zip_cd] VARCHAR(5) NULL NOT FOR REPLICATION, [zip_cd_mkey] VARCHAR(5) NOT NULL, CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC) WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON [INDEX]) ON [SECONDARY]""",
write={
- "tsql": "CREATE TABLE x ([zip_cd] VARCHAR(5) NULL NOT FOR REPLICATION, [zip_cd_mkey] VARCHAR(5) NOT NULL, CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC) WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON [INDEX]) ON [SECONDARY]",
+ "tsql": "CREATE TABLE x ([zip_cd] VARCHAR(5) NULL NOT FOR REPLICATION, [zip_cd_mkey] VARCHAR(5) NOT NULL, CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC) WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON [INDEX]) ON [SECONDARY]",
"spark2": "CREATE TABLE x (`zip_cd` VARCHAR(5), `zip_cd_mkey` VARCHAR(5) NOT NULL, CONSTRAINT `pk_mytable` PRIMARY KEY (`zip_cd_mkey`))",
},
)
@@ -220,9 +223,9 @@ class TestTSQL(Validator):
"CREATE TABLE [db].[tbl] ([a] INTEGER)",
)
- projection = parse_one("SELECT a = 1", read="tsql").selects[0]
- projection.assert_is(exp.Alias)
- projection.args["alias"].assert_is(exp.Identifier)
+ self.validate_identity("SELECT a = 1", "SELECT 1 AS a").selects[0].assert_is(
+ exp.Alias
+ ).args["alias"].assert_is(exp.Identifier)
self.validate_all(
"IF OBJECT_ID('tempdb.dbo.#TempTableName', 'U') IS NOT NULL DROP TABLE #TempTableName",
@@ -256,7 +259,7 @@ class TestTSQL(Validator):
self.validate_identity("SELECT * FROM ##foo")
self.validate_identity("SELECT a = 1", "SELECT 1 AS a")
self.validate_identity(
- "DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'", check_command_warning=True
+ "DECLARE @TestVariable AS VARCHAR(100) = 'Save Our Planet'",
)
self.validate_identity(
"SELECT a = 1 UNION ALL SELECT a = b", "SELECT 1 AS a UNION ALL SELECT b AS a"
@@ -458,6 +461,7 @@ class TestTSQL(Validator):
self.validate_identity("CAST(x AS IMAGE)")
self.validate_identity("CAST(x AS SQL_VARIANT)")
self.validate_identity("CAST(x AS BIT)")
+
self.validate_all(
"CAST(x AS DATETIME2)",
read={
@@ -485,7 +489,7 @@ class TestTSQL(Validator):
},
)
- def test__types_ints(self):
+ def test_types_ints(self):
self.validate_all(
"CAST(X AS INT)",
write={
@@ -518,10 +522,14 @@ class TestTSQL(Validator):
self.validate_all(
"CAST(X AS TINYINT)",
+ read={
+ "duckdb": "CAST(X AS UTINYINT)",
+ },
write={
- "hive": "CAST(X AS TINYINT)",
- "spark2": "CAST(X AS TINYINT)",
- "spark": "CAST(X AS TINYINT)",
+ "duckdb": "CAST(X AS UTINYINT)",
+ "hive": "CAST(X AS SMALLINT)",
+ "spark2": "CAST(X AS SMALLINT)",
+ "spark": "CAST(X AS SMALLINT)",
"tsql": "CAST(X AS TINYINT)",
},
)
@@ -754,29 +762,48 @@ class TestTSQL(Validator):
for view_attr in ("ENCRYPTION", "SCHEMABINDING", "VIEW_METADATA"):
self.validate_identity(f"CREATE VIEW a.b WITH {view_attr} AS SELECT * FROM x")
- expression = parse_one("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B", dialect="tsql")
- self.assertIsInstance(expression, exp.AlterTable)
- self.assertIsInstance(expression.args["actions"][0], exp.Drop)
- self.assertEqual(
- expression.sql(dialect="tsql"), "ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B"
- )
+ self.validate_identity("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B").assert_is(
+ exp.AlterTable
+ ).args["actions"][0].assert_is(exp.Drop)
- for clusterd_keyword in ("CLUSTERED", "NONCLUSTERED"):
+ for clustered_keyword in ("CLUSTERED", "NONCLUSTERED"):
self.validate_identity(
'CREATE TABLE "dbo"."benchmark" ('
'"name" CHAR(7) NOT NULL, '
'"internal_id" VARCHAR(10) NOT NULL, '
- f'UNIQUE {clusterd_keyword} ("internal_id" ASC))',
+ f'UNIQUE {clustered_keyword} ("internal_id" ASC))',
"CREATE TABLE [dbo].[benchmark] ("
"[name] CHAR(7) NOT NULL, "
"[internal_id] VARCHAR(10) NOT NULL, "
- f"UNIQUE {clusterd_keyword} ([internal_id] ASC))",
+ f"UNIQUE {clustered_keyword} ([internal_id] ASC))",
)
self.validate_identity(
+ "ALTER TABLE tbl SET SYSTEM_VERSIONING=ON(HISTORY_TABLE=db.tbl, DATA_CONSISTENCY_CHECK=OFF, HISTORY_RETENTION_PERIOD=5 DAYS)"
+ )
+ self.validate_identity(
+ "ALTER TABLE tbl SET SYSTEM_VERSIONING=ON(HISTORY_TABLE=db.tbl, HISTORY_RETENTION_PERIOD=INFINITE)"
+ )
+ self.validate_identity("ALTER TABLE tbl SET SYSTEM_VERSIONING=OFF")
+ self.validate_identity("ALTER TABLE tbl SET FILESTREAM_ON = 'test'")
+ self.validate_identity(
+ "ALTER TABLE tbl SET DATA_DELETION=ON(FILTER_COLUMN=col, RETENTION_PERIOD=5 MONTHS)"
+ )
+ self.validate_identity("ALTER TABLE tbl SET DATA_DELETION=ON")
+ self.validate_identity("ALTER TABLE tbl SET DATA_DELETION=OFF")
+
+ self.validate_identity(
"CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < CURRENT_TIMESTAMP - 7 END",
"CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < GETDATE() - 7 END",
)
+
+ self.validate_all(
+ "CREATE TABLE [#temptest] (name INTEGER)",
+ read={
+ "duckdb": "CREATE TEMPORARY TABLE 'temptest' (name INTEGER)",
+ "tsql": "CREATE TABLE [#temptest] (name INTEGER)",
+ },
+ )
self.validate_all(
"CREATE TABLE tbl (id INTEGER IDENTITY PRIMARY KEY)",
read={
@@ -889,8 +916,7 @@ class TestTSQL(Validator):
def test_udf(self):
self.validate_identity(
- "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)",
- check_command_warning=True,
+ "DECLARE @DWH_DateCreated AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)",
)
self.validate_identity(
"CREATE PROCEDURE foo @a INTEGER, @b INTEGER AS SELECT @a = SUM(bla) FROM baz AS bar"
@@ -962,9 +988,9 @@ WHERE
BEGIN
SET XACT_ABORT ON;
- DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104);
- DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104);
- DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER);
+ DECLARE @DWH_DateCreated AS DATETIME = CONVERT(DATETIME, getdate(), 104);
+ DECLARE @DWH_DateModified DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104);
+ DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (CURRENT_USER());
DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER);
DECLARE @SalesAmountBefore float;
@@ -974,18 +1000,17 @@ WHERE
expected_sqls = [
"CREATE PROCEDURE [TRANSF].[SP_Merge_Sales_Real] @Loadid INTEGER, @NumberOfRows INTEGER AS BEGIN SET XACT_ABORT ON",
- "DECLARE @DWH_DateCreated DATETIME = CONVERT(DATETIME, getdate(), 104)",
- "DECLARE @DWH_DateModified DATETIME = CONVERT(DATETIME, getdate(), 104)",
- "DECLARE @DWH_IdUserCreated INTEGER = SUSER_ID (SYSTEM_USER)",
- "DECLARE @DWH_IdUserModified INTEGER = SUSER_ID (SYSTEM_USER)",
- "DECLARE @SalesAmountBefore float",
+ "DECLARE @DWH_DateCreated AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)",
+ "DECLARE @DWH_DateModified AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)",
+ "DECLARE @DWH_IdUserCreated AS INTEGER = SUSER_ID(CURRENT_USER())",
+ "DECLARE @DWH_IdUserModified AS INTEGER = SUSER_ID(CURRENT_USER())",
+ "DECLARE @SalesAmountBefore AS FLOAT",
"SELECT @SalesAmountBefore = SUM(SalesAmount) FROM TRANSF.[Pre_Merge_Sales_Real] AS S",
"END",
]
- with self.assertLogs(parser_logger):
- for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
- self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
+ for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
+ self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
sql = """
CREATE PROC [dbo].[transform_proc] AS
@@ -999,14 +1024,13 @@ WHERE
"""
expected_sqls = [
- "CREATE PROC [dbo].[transform_proc] AS DECLARE @CurrentDate VARCHAR(20)",
+ "CREATE PROC [dbo].[transform_proc] AS DECLARE @CurrentDate AS VARCHAR(20)",
"SET @CurrentDate = CONVERT(VARCHAR(20), GETDATE(), 120)",
"CREATE TABLE [target_schema].[target_table] (a INTEGER) WITH (DISTRIBUTION=REPLICATE, HEAP)",
]
- with self.assertLogs(parser_logger):
- for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
- self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
+ for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls):
+ self.assertEqual(expr.sql(dialect="tsql"), expected_sql)
def test_charindex(self):
self.validate_identity(
@@ -1072,7 +1096,13 @@ WHERE
self.validate_all("LEN('x')", write={"tsql": "LEN('x')", "spark": "LENGTH('x')"})
def test_replicate(self):
- self.validate_all("REPLICATE('x', 2)", write={"spark": "REPEAT('x', 2)"})
+ self.validate_all(
+ "REPLICATE('x', 2)",
+ write={
+ "spark": "REPEAT('x', 2)",
+ "tsql": "REPLICATE('x', 2)",
+ },
+ )
def test_isnull(self):
self.validate_all("ISNULL(x, y)", write={"spark": "COALESCE(x, y)"})
@@ -1605,27 +1635,23 @@ WHERE
)
def test_identifier_prefixes(self):
- expr = parse_one("#x", read="tsql")
- self.assertIsInstance(expr, exp.Column)
- self.assertIsInstance(expr.this, exp.Identifier)
- self.assertTrue(expr.this.args.get("temporary"))
- self.assertEqual(expr.sql("tsql"), "#x")
-
- expr = parse_one("##x", read="tsql")
- self.assertIsInstance(expr, exp.Column)
- self.assertIsInstance(expr.this, exp.Identifier)
- self.assertTrue(expr.this.args.get("global"))
- self.assertEqual(expr.sql("tsql"), "##x")
-
- expr = parse_one("@x", read="tsql")
- self.assertIsInstance(expr, exp.Parameter)
- self.assertIsInstance(expr.this, exp.Var)
- self.assertEqual(expr.sql("tsql"), "@x")
+ self.assertTrue(
+ self.validate_identity("#x")
+ .assert_is(exp.Column)
+ .this.assert_is(exp.Identifier)
+ .args.get("temporary")
+ )
+ self.assertTrue(
+ self.validate_identity("##x")
+ .assert_is(exp.Column)
+ .this.assert_is(exp.Identifier)
+ .args.get("global")
+ )
- table = parse_one("select * from @x", read="tsql").args["from"].this
- self.assertIsInstance(table, exp.Table)
- self.assertIsInstance(table.this, exp.Parameter)
- self.assertIsInstance(table.this.this, exp.Var)
+ self.validate_identity("@x").assert_is(exp.Parameter).this.assert_is(exp.Var)
+ self.validate_identity("SELECT * FROM @x").args["from"].this.assert_is(
+ exp.Table
+ ).this.assert_is(exp.Parameter).this.assert_is(exp.Var)
self.validate_all(
"SELECT @x",
@@ -1636,8 +1662,6 @@ WHERE
"tsql": "SELECT @x",
},
)
-
- def test_temp_table(self):
self.validate_all(
"SELECT * FROM #mytemptable",
write={
@@ -1812,3 +1836,35 @@ FROM OPENJSON(@json) WITH (
"duckdb": "WITH t1(c) AS (SELECT 1), t2 AS (SELECT CAST(c AS INTEGER) FROM t1) SELECT * FROM t2",
},
)
+
+ def test_declare(self):
+ # supported cases
+ self.validate_identity("DECLARE @X INT", "DECLARE @X AS INTEGER")
+ self.validate_identity("DECLARE @X INT = 1", "DECLARE @X AS INTEGER = 1")
+ self.validate_identity(
+ "DECLARE @X INT, @Y VARCHAR(10)", "DECLARE @X AS INTEGER, @Y AS VARCHAR(10)"
+ )
+ self.validate_identity(
+ "declare @X int = (select col from table where id = 1)",
+ "DECLARE @X AS INTEGER = (SELECT col FROM table WHERE id = 1)",
+ )
+ self.validate_identity(
+ "declare @X TABLE (Id INT NOT NULL, Name VARCHAR(100) NOT NULL)",
+ "DECLARE @X AS TABLE (Id INTEGER NOT NULL, Name VARCHAR(100) NOT NULL)",
+ )
+ self.validate_identity(
+ "declare @X TABLE (Id INT NOT NULL, constraint PK_Id primary key (Id))",
+ "DECLARE @X AS TABLE (Id INTEGER NOT NULL, CONSTRAINT PK_Id PRIMARY KEY (Id))",
+ )
+ self.validate_identity(
+ "declare @X UserDefinedTableType",
+ "DECLARE @X AS UserDefinedTableType",
+ )
+ self.validate_identity(
+ "DECLARE @MyTableVar TABLE (EmpID INT NOT NULL, PRIMARY KEY CLUSTERED (EmpID), UNIQUE NONCLUSTERED (EmpID), INDEX CustomNonClusteredIndex NONCLUSTERED (EmpID))",
+ check_command_warning=True,
+ )
+ self.validate_identity(
+ "DECLARE vendor_cursor CURSOR FOR SELECT VendorID, Name FROM Purchasing.Vendor WHERE PreferredVendorStatus = 1 ORDER BY VendorID",
+ check_command_warning=True,
+ )