diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dialects/test_bigquery.py | 71 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 315 | ||||
-rw-r--r-- | tests/dialects/test_databricks.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 18 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 72 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 28 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 22 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 11 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 20 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 54 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 24 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 44 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 26 | ||||
-rw-r--r-- | tests/dialects/test_teradata.py | 19 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 52 |
15 files changed, 707 insertions, 77 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 87bba6f..99d8a3c 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -6,6 +6,14 @@ class TestBigQuery(Validator): dialect = "bigquery" def test_bigquery(self): + self.validate_identity("SAFE_CAST(x AS STRING)") + self.validate_identity("SELECT * FROM a-b-c.mydataset.mytable") + self.validate_identity("SELECT * FROM abc-def-ghi") + self.validate_identity("SELECT * FROM a-b-c") + self.validate_identity("SELECT * FROM my-table") + self.validate_identity("SELECT * FROM my-project.mydataset.mytable") + self.validate_identity("SELECT * FROM pro-ject_id.c.d CROSS JOIN foo-bar") + self.validate_identity("x <> ''") self.validate_identity("DATE_TRUNC(col, WEEK(MONDAY))") self.validate_identity("SELECT b'abc'") self.validate_identity("""SELECT * FROM UNNEST(ARRAY<STRUCT<x INT64>>[1, 2])""") @@ -13,6 +21,7 @@ class TestBigQuery(Validator): self.validate_identity("SELECT DISTINCT AS STRUCT 1 AS a, 2 AS b") self.validate_identity("SELECT AS VALUE STRUCT(1 AS a, 2 AS b)") self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])") + self.validate_identity("SELECT STRUCT<STRING>((SELECT a FROM b.c LIMIT 1)).*") self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))") self.validate_identity("""CREATE TABLE x (a STRUCT<values ARRAY<INT64>>)""") self.validate_identity("""CREATE TABLE x (a STRUCT<b STRING OPTIONS (description='b')>)""") @@ -22,17 +31,34 @@ class TestBigQuery(Validator): self.validate_identity( "SELECT * FROM (SELECT * FROM `t`) AS a UNPIVOT((c) FOR c_name IN (v1, v2))" ) + self.validate_identity( + "CREATE TABLE IF NOT EXISTS foo AS SELECT * FROM bla EXCEPT DISTINCT (SELECT * FROM bar) LIMIT 0" + ) + self.validate_all('x <> ""', write={"bigquery": "x <> ''"}) + self.validate_all('x <> """"""', write={"bigquery": "x <> ''"}) + self.validate_all("x <> ''''''", write={"bigquery": "x <> ''"}) self.validate_all( "CREATE TEMP TABLE foo AS SELECT 1", write={"bigquery": "CREATE TEMPORARY TABLE foo AS SELECT 1"}, ) + self.validate_all( + "SELECT * FROM `SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW`", + write={ + "bigquery": "SELECT * FROM SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW", + }, + ) + self.validate_all( + "SELECT * FROM `my-project.my-dataset.my-table`", + write={"bigquery": "SELECT * FROM `my-project`.`my-dataset`.`my-table`"}, + ) self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"}) self.validate_all("CAST(x AS CHAR)", write={"bigquery": "CAST(x AS STRING)"}) self.validate_all("CAST(x AS NCHAR)", write={"bigquery": "CAST(x AS STRING)"}) self.validate_all("CAST(x AS NVARCHAR)", write={"bigquery": "CAST(x AS STRING)"}) self.validate_all("CAST(x AS TIMESTAMP)", write={"bigquery": "CAST(x AS DATETIME)"}) self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"bigquery": "CAST(x AS TIMESTAMP)"}) + self.validate_all("CAST(x AS RECORD)", write={"bigquery": "CAST(x AS STRUCT)"}) self.validate_all( "SELECT ARRAY(SELECT AS STRUCT 1 a, 2 b)", write={ @@ -64,16 +90,6 @@ class TestBigQuery(Validator): "spark": "'x\\''", }, ) - self.validate_all( - r'r"""/\*.*\*/"""', - write={ - "bigquery": r"'/\*.*\*/'", - "duckdb": r"'/\*.*\*/'", - "presto": r"'/\*.*\*/'", - "hive": r"'/\*.*\*/'", - "spark": r"'/\*.*\*/'", - }, - ) with self.assertRaises(ValueError): transpile("'\\'", read="bigquery") @@ -87,13 +103,23 @@ class TestBigQuery(Validator): }, ) self.validate_all( + r'r"""/\*.*\*/"""', + write={ + "bigquery": r"r'/\*.*\*/'", + "duckdb": r"'/\\*.*\\*/'", + "presto": r"'/\\*.*\\*/'", + "hive": r"'/\\*.*\\*/'", + "spark": r"'/\\*.*\\*/'", + }, + ) + self.validate_all( r'R"""/\*.*\*/"""', write={ - "bigquery": r"'/\*.*\*/'", - "duckdb": r"'/\*.*\*/'", - "presto": r"'/\*.*\*/'", - "hive": r"'/\*.*\*/'", - "spark": r"'/\*.*\*/'", + "bigquery": r"r'/\*.*\*/'", + "duckdb": r"'/\\*.*\\*/'", + "presto": r"'/\\*.*\\*/'", + "hive": r"'/\\*.*\\*/'", + "spark": r"'/\\*.*\\*/'", }, ) self.validate_all( @@ -234,7 +260,7 @@ class TestBigQuery(Validator): "DIV(x, y)", write={ "bigquery": "DIV(x, y)", - "duckdb": "CAST(x / y AS INT)", + "duckdb": "x // y", }, ) @@ -308,7 +334,7 @@ class TestBigQuery(Validator): self.validate_all( "DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY)", write={ - "postgres": "CURRENT_DATE - INTERVAL '1' DAY", + "postgres": "CURRENT_DATE - INTERVAL '1 DAY'", "bigquery": "DATE_SUB(CURRENT_DATE, INTERVAL 1 DAY)", }, ) @@ -318,7 +344,7 @@ class TestBigQuery(Validator): "bigquery": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)", "duckdb": "CURRENT_DATE + INTERVAL 1 DAY", "mysql": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)", - "postgres": "CURRENT_DATE + INTERVAL '1' DAY", + "postgres": "CURRENT_DATE + INTERVAL '1 DAY'", "presto": "DATE_ADD('DAY', 1, CURRENT_DATE)", "hive": "DATE_ADD(CURRENT_DATE, 1)", "spark": "DATE_ADD(CURRENT_DATE, 1)", @@ -470,3 +496,12 @@ class TestBigQuery(Validator): "snowflake": "MERGE INTO dataset.Inventory AS T USING dataset.NewArrivals AS S ON FALSE WHEN NOT MATCHED AND product LIKE '%a%' THEN DELETE WHEN NOT MATCHED AND product LIKE '%b%' THEN DELETE", }, ) + + def test_rename_table(self): + self.validate_all( + "ALTER TABLE db.t1 RENAME TO db.t2", + write={ + "snowflake": "ALTER TABLE db.t1 RENAME TO db.t2", + "bigquery": "ALTER TABLE db.t1 RENAME TO t2", + }, + ) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 1060881..b6a7765 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -1,3 +1,4 @@ +from sqlglot import exp, parse_one from tests.dialects.test_dialect import Validator @@ -5,6 +6,9 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): + self.validate_identity("ATTACH DATABASE DEFAULT ENGINE = ORDINARY") + self.validate_identity("CAST(['hello'], 'Array(Enum8(''hello'' = 1))')") + self.validate_identity("SELECT x, COUNT() FROM y GROUP BY x WITH TOTALS") self.validate_identity("SELECT INTERVAL t.days day") self.validate_identity("SELECT match('abc', '([a-z]+)')") self.validate_identity("dictGet(x, 'y')") @@ -16,17 +20,36 @@ class TestClickhouse(Validator): self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla") self.validate_identity("SELECT * FROM foo ASOF JOIN bla") self.validate_identity("SELECT * FROM foo ANY JOIN bla") + self.validate_identity("SELECT * FROM foo GLOBAL ANY JOIN bla") + self.validate_identity("SELECT * FROM foo GLOBAL LEFT ANY JOIN bla") self.validate_identity("SELECT quantile(0.5)(a)") self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t") self.validate_identity("SELECT quantiles(0.1, 0.2, 0.3)(a)") + self.validate_identity("SELECT quantileTiming(0.5)(RANGE(100))") self.validate_identity("SELECT histogram(5)(a)") self.validate_identity("SELECT groupUniqArray(2)(a)") self.validate_identity("SELECT exponentialTimeDecayedAvg(60)(a, b)") self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)") self.validate_identity("position(haystack, needle)") self.validate_identity("position(haystack, needle, position)") + self.validate_identity("CAST(x AS DATETIME)") + self.validate_identity( + 'SELECT CAST(tuple(1 AS "a", 2 AS "b", 3.0 AS "c").2 AS Nullable(TEXT))' + ) + self.validate_identity( + "CREATE TABLE test (id UInt8) ENGINE=AggregatingMergeTree() ORDER BY tuple()" + ) self.validate_all( + "SELECT uniq(x) FROM (SELECT any(y) AS x FROM (SELECT 1 AS y))", + read={ + "bigquery": "SELECT APPROX_COUNT_DISTINCT(x) FROM (SELECT ANY_VALUE(y) x FROM (SELECT 1 y))", + }, + write={ + "bigquery": "SELECT APPROX_COUNT_DISTINCT(x) FROM (SELECT ANY_VALUE(y) AS x FROM (SELECT 1 AS y))", + }, + ) + self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", @@ -41,9 +64,7 @@ class TestClickhouse(Validator): ) self.validate_all( "CAST(1 AS Nullable(DateTime64(6, 'UTC')))", - write={ - "clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))", - }, + write={"clickhouse": "CAST(1 AS Nullable(DateTime64(6, 'UTC')))"}, ) self.validate_all( "SELECT x #! comment", @@ -59,6 +80,40 @@ class TestClickhouse(Validator): "SELECT position(needle IN haystack)", write={"clickhouse": "SELECT position(haystack, needle)"}, ) + self.validate_identity( + "SELECT * FROM x LIMIT 10 SETTINGS max_results = 100, result = 'break'" + ) + self.validate_identity("SELECT * FROM x LIMIT 10 SETTINGS max_results = 100, result_") + self.validate_identity("SELECT * FROM x FORMAT PrettyCompact") + self.validate_identity( + "SELECT * FROM x LIMIT 10 SETTINGS max_results = 100, result_ FORMAT PrettyCompact" + ) + self.validate_all( + "SELECT * FROM foo JOIN bar USING id, name", + write={"clickhouse": "SELECT * FROM foo JOIN bar USING (id, name)"}, + ) + self.validate_all( + "SELECT * FROM foo ANY LEFT JOIN bla ON foo.c1 = bla.c2", + write={"clickhouse": "SELECT * FROM foo LEFT ANY JOIN bla ON foo.c1 = bla.c2"}, + ) + self.validate_all( + "SELECT * FROM foo GLOBAL ANY LEFT JOIN bla ON foo.c1 = bla.c2", + write={"clickhouse": "SELECT * FROM foo GLOBAL LEFT ANY JOIN bla ON foo.c1 = bla.c2"}, + ) + self.validate_all( + """ + SELECT + loyalty, + count() + FROM hits SEMI LEFT JOIN users USING (UserID) + GROUP BY loyalty + ORDER BY loyalty ASC + """, + write={ + "clickhouse": "SELECT loyalty, COUNT() FROM hits LEFT SEMI JOIN users USING (UserID)" + + " GROUP BY loyalty ORDER BY loyalty" + }, + ) def test_cte(self): self.validate_identity("WITH 'x' AS foo SELECT foo") @@ -66,6 +121,57 @@ class TestClickhouse(Validator): self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5") self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1") + def test_ternary(self): + self.validate_all("x ? 1 : 2", write={"clickhouse": "CASE WHEN x THEN 1 ELSE 2 END"}) + self.validate_all( + "IF(BAR(col), sign > 0 ? FOO() : 0, 1)", + write={ + "clickhouse": "CASE WHEN BAR(col) THEN CASE WHEN sign > 0 THEN FOO() ELSE 0 END ELSE 1 END" + }, + ) + self.validate_all( + "x AND FOO() > 3 + 2 ? 1 : 2", + write={"clickhouse": "CASE WHEN x AND FOO() > 3 + 2 THEN 1 ELSE 2 END"}, + ) + self.validate_all( + "x ? (y ? 1 : 2) : 3", + write={"clickhouse": "CASE WHEN x THEN (CASE WHEN y THEN 1 ELSE 2 END) ELSE 3 END"}, + ) + self.validate_all( + "x AND (foo() ? FALSE : TRUE) ? (y ? 1 : 2) : 3", + write={ + "clickhouse": "CASE WHEN x AND (CASE WHEN foo() THEN FALSE ELSE TRUE END) THEN (CASE WHEN y THEN 1 ELSE 2 END) ELSE 3 END" + }, + ) + + ternary = parse_one("x ? (y ? 1 : 2) : 3", read="clickhouse") + + self.assertIsInstance(ternary, exp.If) + self.assertIsInstance(ternary.this, exp.Column) + self.assertIsInstance(ternary.args["true"], exp.Paren) + self.assertIsInstance(ternary.args["false"], exp.Literal) + + nested_ternary = ternary.args["true"].this + + self.assertIsInstance(nested_ternary.this, exp.Column) + self.assertIsInstance(nested_ternary.args["true"], exp.Literal) + self.assertIsInstance(nested_ternary.args["false"], exp.Literal) + + parse_one("a and b ? 1 : 2", read="clickhouse").assert_is(exp.If).this.assert_is(exp.And) + + def test_parameterization(self): + self.validate_all( + "SELECT {abc: UInt32}, {b: String}, {c: DateTime},{d: Map(String, Array(UInt8))}, {e: Tuple(UInt8, String)}", + write={ + "clickhouse": "SELECT {abc: UInt32}, {b: TEXT}, {c: DATETIME}, {d: Map(TEXT, Array(UInt8))}, {e: Tuple(UInt8, String)}", + "": "SELECT :abc, :b, :c, :d, :e", + }, + ) + self.validate_all( + "SELECT * FROM {table: Identifier}", + write={"clickhouse": "SELECT * FROM {table: Identifier}"}, + ) + def test_signed_and_unsigned_types(self): data_types = [ "UInt8", @@ -86,3 +192,206 @@ class TestClickhouse(Validator): f"pow(2, 32)::{data_type}", write={"clickhouse": f"CAST(POWER(2, 32) AS {data_type})"}, ) + + def test_ddl(self): + self.validate_identity( + "CREATE TABLE foo (x UInt32) TTL time_column + INTERVAL '1' MONTH DELETE WHERE column = 'value'" + ) + + self.validate_all( + """ + CREATE TABLE example1 ( + timestamp DateTime, + x UInt32 TTL now() + INTERVAL 1 MONTH, + y String TTL timestamp + INTERVAL 1 DAY, + z String + ) + ENGINE = MergeTree + ORDER BY tuple() + """, + write={ + "clickhouse": """CREATE TABLE example1 ( + timestamp DATETIME, + x UInt32 TTL now() + INTERVAL '1' MONTH, + y TEXT TTL timestamp + INTERVAL '1' DAY, + z TEXT +) +ENGINE=MergeTree +ORDER BY tuple()""", + }, + pretty=True, + ) + self.validate_all( + """ + CREATE TABLE test (id UInt64, timestamp DateTime64, data String, max_hits UInt64, sum_hits UInt64) ENGINE = MergeTree + PRIMARY KEY (id, toStartOfDay(timestamp), timestamp) + TTL timestamp + INTERVAL 1 DAY + GROUP BY id, toStartOfDay(timestamp) + SET + max_hits = max(max_hits), + sum_hits = sum(sum_hits) + """, + write={ + "clickhouse": """CREATE TABLE test ( + id UInt64, + timestamp DateTime64, + data TEXT, + max_hits UInt64, + sum_hits UInt64 +) +ENGINE=MergeTree +PRIMARY KEY (id, toStartOfDay(timestamp), timestamp) +TTL + timestamp + INTERVAL '1' DAY +GROUP BY + id, + toStartOfDay(timestamp) +SET + max_hits = MAX(max_hits), + sum_hits = SUM(sum_hits)""", + }, + pretty=True, + ) + self.validate_all( + """ + CREATE TABLE test (id String, data String) ENGINE = AggregatingMergeTree() + ORDER BY tuple() + SETTINGS + max_suspicious_broken_parts=500, + parts_to_throw_insert=100 + """, + write={ + "clickhouse": """CREATE TABLE test ( + id TEXT, + data TEXT +) +ENGINE=AggregatingMergeTree() +ORDER BY tuple() +SETTINGS + max_suspicious_broken_parts = 500, + parts_to_throw_insert = 100""", + }, + pretty=True, + ) + self.validate_all( + """ + CREATE TABLE example_table + ( + d DateTime, + a Int + ) + ENGINE = MergeTree + PARTITION BY toYYYYMM(d) + ORDER BY d + TTL d + INTERVAL 1 MONTH DELETE, + d + INTERVAL 1 WEEK TO VOLUME 'aaa', + d + INTERVAL 2 WEEK TO DISK 'bbb'; + """, + write={ + "clickhouse": """CREATE TABLE example_table ( + d DATETIME, + a Int32 +) +ENGINE=MergeTree +PARTITION BY toYYYYMM(d) +ORDER BY d +TTL + d + INTERVAL '1' MONTH DELETE, + d + INTERVAL '1' WEEK TO VOLUME 'aaa', + d + INTERVAL '2' WEEK TO DISK 'bbb'""", + }, + pretty=True, + ) + self.validate_all( + """ + CREATE TABLE table_with_where + ( + d DateTime, + a Int + ) + ENGINE = MergeTree + PARTITION BY toYYYYMM(d) + ORDER BY d + TTL d + INTERVAL 1 MONTH DELETE WHERE toDayOfWeek(d) = 1; + """, + write={ + "clickhouse": """CREATE TABLE table_with_where ( + d DATETIME, + a Int32 +) +ENGINE=MergeTree +PARTITION BY toYYYYMM(d) +ORDER BY d +TTL + d + INTERVAL '1' MONTH DELETE +WHERE + toDayOfWeek(d) = 1""", + }, + pretty=True, + ) + self.validate_all( + """ + CREATE TABLE table_for_recompression + ( + d DateTime, + key UInt64, + value String + ) ENGINE MergeTree() + ORDER BY tuple() + PARTITION BY key + TTL d + INTERVAL 1 MONTH RECOMPRESS CODEC(ZSTD(17)), d + INTERVAL 1 YEAR RECOMPRESS CODEC(LZ4HC(10)) + SETTINGS min_rows_for_wide_part = 0, min_bytes_for_wide_part = 0; + """, + write={ + "clickhouse": """CREATE TABLE table_for_recompression ( + d DATETIME, + key UInt64, + value TEXT +) +ENGINE=MergeTree() +ORDER BY tuple() +PARTITION BY key +TTL + d + INTERVAL '1' MONTH RECOMPRESS CODEC(ZSTD(17)), + d + INTERVAL '1' YEAR RECOMPRESS CODEC(LZ4HC(10)) +SETTINGS + min_rows_for_wide_part = 0, + min_bytes_for_wide_part = 0""", + }, + pretty=True, + ) + self.validate_all( + """ + CREATE TABLE table_for_aggregation + ( + d DateTime, + k1 Int, + k2 Int, + x Int, + y Int + ) + ENGINE = MergeTree + ORDER BY (k1, k2) + TTL d + INTERVAL 1 MONTH GROUP BY k1, k2 SET x = max(x), y = min(y); + """, + write={ + "clickhouse": """CREATE TABLE table_for_aggregation ( + d DATETIME, + k1 Int32, + k2 Int32, + x Int32, + y Int32 +) +ENGINE=MergeTree +ORDER BY (k1, k2) +TTL + d + INTERVAL '1' MONTH +GROUP BY + k1, + k2 +SET + x = MAX(x), + y = MIN(y)""", + }, + pretty=True, + ) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 4619108..8239dec 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -9,6 +9,14 @@ class TestDatabricks(Validator): self.validate_identity("CREATE FUNCTION a.b(x INT) RETURNS INT RETURN x + 1") self.validate_identity("CREATE FUNCTION a AS b") self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1") + self.validate_identity("CREATE TABLE foo (x DATE GENERATED ALWAYS AS (CAST(y AS DATE)))") + + self.validate_all( + "CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))", + write={ + "databricks": "CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(TO_DATE(y))))", + }, + ) # https://docs.databricks.com/sql/language-manual/functions/colonsign.html def test_json(self): diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index f12273b..e144e81 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -574,7 +574,7 @@ class TestDialect(Validator): "duckdb": "x + INTERVAL 1 DAY", "hive": "DATE_ADD(x, 1)", "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", - "postgres": "x + INTERVAL '1' DAY", + "postgres": "x + INTERVAL '1 DAY'", "presto": "DATE_ADD('DAY', 1, x)", "snowflake": "DATEADD(DAY, 1, x)", "spark": "DATE_ADD(x, 1)", @@ -869,7 +869,7 @@ class TestDialect(Validator): "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ "bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", "oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", @@ -1354,7 +1354,7 @@ class TestDialect(Validator): }, write={ "hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)", - "postgres": "CREATE INDEX my_idx ON tbl (a, b)", + "postgres": "CREATE INDEX my_idx ON tbl (a NULLS FIRST, b NULLS FIRST)", "sqlite": "CREATE INDEX my_idx ON tbl (a, b)", }, ) @@ -1366,7 +1366,7 @@ class TestDialect(Validator): }, write={ "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)", - "postgres": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + "postgres": "CREATE UNIQUE INDEX my_idx ON tbl (a NULLS FIRST, b NULLS FIRST)", "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", }, ) @@ -1415,12 +1415,12 @@ class TestDialect(Validator): }, ) self.validate_all( - "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c", + "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t CROSS JOIN cte2 WHERE cte1.a = cte2.c", write={ - "hive": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c", - "oracle": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 t JOIN cte2 WHERE cte1.a = cte2.c", - "postgres": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c", - "sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c", + "hive": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t CROSS JOIN cte2 WHERE cte1.a = cte2.c", + "oracle": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 t CROSS JOIN cte2 WHERE cte1.a = cte2.c", + "postgres": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t CROSS JOIN cte2 WHERE cte1.a = cte2.c", + "sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t CROSS JOIN cte2 WHERE cte1.a = cte2.c", }, ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 8c1b748..ce6b122 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -1,4 +1,4 @@ -from sqlglot import ErrorLevel, UnsupportedError, transpile +from sqlglot import ErrorLevel, UnsupportedError, exp, parse_one, transpile from tests.dialects.test_dialect import Validator @@ -124,6 +124,20 @@ class TestDuckDB(Validator): ) def test_duckdb(self): + # https://github.com/duckdb/duckdb/releases/tag/v0.8.0 + self.assertEqual( + parse_one("a / b", read="duckdb").assert_is(exp.Div).sql(dialect="duckdb"), "a / b" + ) + self.assertEqual( + parse_one("a // b", read="duckdb").assert_is(exp.IntDiv).sql(dialect="duckdb"), "a // b" + ) + + self.validate_identity("PIVOT Cities ON Year USING SUM(Population)") + self.validate_identity("PIVOT Cities ON Year USING FIRST(Population)") + self.validate_identity("PIVOT Cities ON Year USING SUM(Population) GROUP BY Country") + self.validate_identity("PIVOT Cities ON Country, Name USING SUM(Population)") + self.validate_identity("PIVOT Cities ON Country || '_' || Name USING SUM(Population)") + self.validate_identity("PIVOT Cities ON Year USING SUM(Population) GROUP BY Country, Name") self.validate_identity("SELECT {'a': 1} AS x") self.validate_identity("SELECT {'a': {'b': {'c': 1}}, 'd': {'e': 2}} AS x") self.validate_identity("SELECT {'x': 1, 'y': 2, 'z': 3}") @@ -138,9 +152,36 @@ class TestDuckDB(Validator): self.validate_identity( "SELECT a['x space'] FROM (SELECT {'x space': 1, 'y': 2, 'z': 3} AS a)" ) + self.validate_identity( + "PIVOT Cities ON Year IN (2000, 2010) USING SUM(Population) GROUP BY Country" + ) + self.validate_identity( + "PIVOT Cities ON Year USING SUM(Population) AS total, MAX(Population) AS max GROUP BY Country" + ) + self.validate_identity( + "WITH pivot_alias AS (PIVOT Cities ON Year USING SUM(Population) GROUP BY Country) SELECT * FROM pivot_alias" + ) + self.validate_identity( + "SELECT * FROM (PIVOT Cities ON Year USING SUM(Population) GROUP BY Country) AS pivot_alias" + ) + self.validate_all("FROM (FROM tbl)", write={"duckdb": "SELECT * FROM (SELECT * FROM tbl)"}) + self.validate_all("FROM tbl", write={"duckdb": "SELECT * FROM tbl"}) self.validate_all("0b1010", write={"": "0 AS b1010"}) self.validate_all("0x1010", write={"": "0 AS x1010"}) + self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"}) + self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'}) + self.validate_all( + "PIVOT_WIDER Cities ON Year USING SUM(Population)", + write={"duckdb": "PIVOT Cities ON Year USING SUM(Population)"}, + ) + self.validate_all( + "WITH t AS (SELECT 1) FROM t", write={"duckdb": "WITH t AS (SELECT 1) SELECT * FROM t"} + ) + self.validate_all( + "WITH t AS (SELECT 1) SELECT * FROM (FROM t)", + write={"duckdb": "WITH t AS (SELECT 1) SELECT * FROM (SELECT * FROM t)"}, + ) self.validate_all( """SELECT DATEDIFF('day', t1."A", t1."B") FROM "table" AS t1""", write={ @@ -155,8 +196,6 @@ class TestDuckDB(Validator): "trino": "SELECT DATE_DIFF('day', CAST('2020-01-01' AS DATE), CAST('2020-01-05' AS DATE))", }, ) - self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"}) - self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'}) self.validate_all( "WITH 'x' AS (SELECT 1) SELECT * FROM x", write={"duckdb": 'WITH "x" AS (SELECT 1) SELECT * FROM x'}, @@ -341,7 +380,8 @@ class TestDuckDB(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", }, ) self.validate_all( @@ -370,13 +410,14 @@ class TestDuckDB(Validator): "hive": "SELECT DATE_ADD(TO_DATE(x), 1)", }, ) - - with self.assertRaises(UnsupportedError): - transpile( - "SELECT a FROM b PIVOT(SUM(x) FOR y IN ('z', 'q'))", - read="duckdb", - unsupported_level=ErrorLevel.IMMEDIATE, - ) + self.validate_all( + "SELECT CAST('2020-05-06' AS DATE) - INTERVAL 5 DAY", + read={"bigquery": "SELECT DATE_SUB(CAST('2020-05-06' AS DATE), INTERVAL 5 DAY)"}, + ) + self.validate_all( + "SELECT CAST('2020-05-06' AS DATE) + INTERVAL 5 DAY", + read={"bigquery": "SELECT DATE_ADD(CAST('2020-05-06' AS DATE), INTERVAL 5 DAY)"}, + ) with self.assertRaises(UnsupportedError): transpile( @@ -481,3 +522,12 @@ class TestDuckDB(Validator): "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a", write={"duckdb": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"}, ) + + def test_rename_table(self): + self.validate_all( + "ALTER TABLE db.t1 RENAME TO db.t2", + write={ + "snowflake": "ALTER TABLE db.t1 RENAME TO db.t2", + "duckdb": "ALTER TABLE db.t1 RENAME TO t2", + }, + ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index c69368c..99b5602 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -4,6 +4,17 @@ from tests.dialects.test_dialect import Validator class TestHive(Validator): dialect = "hive" + def test_hive(self): + self.validate_identity("SELECT * FROM test DISTRIBUTE BY y SORT BY x DESC ORDER BY l") + self.validate_identity( + "SELECT * FROM test WHERE RAND() <= 0.1 DISTRIBUTE BY RAND() SORT BY RAND()" + ) + self.validate_identity("(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z") + self.validate_identity("(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z SORT BY x") + self.validate_identity("(SELECT 1 UNION SELECT 2) CLUSTER BY y DESC") + self.validate_identity("SELECT * FROM test CLUSTER BY y") + self.validate_identity("(SELECT 1 UNION SELECT 2) SORT BY z") + def test_bits(self): self.validate_all( "x & 1", @@ -362,7 +373,7 @@ class TestHive(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", @@ -512,10 +523,10 @@ class TestHive(Validator): }, ) self.validate_all( - "SELECT * FROM x TABLESAMPLE(10 PERCENT) y", + "SELECT * FROM x.z TABLESAMPLE(10 PERCENT) y", write={ - "hive": "SELECT * FROM x TABLESAMPLE (10 PERCENT) AS y", - "spark": "SELECT * FROM x TABLESAMPLE (10 PERCENT) AS y", + "hive": "SELECT * FROM x.z TABLESAMPLE (10 PERCENT) AS y", + "spark": "SELECT * FROM x.z TABLESAMPLE (10 PERCENT) AS y", }, ) self.validate_all( @@ -549,6 +560,12 @@ class TestHive(Validator): }, ) self.validate_all( + "STRUCT(a = b, c = d)", + read={ + "snowflake": "OBJECT_CONSTRUCT(a, b, c, d)", + }, + ) + self.validate_all( "MAP(a, b, c, d)", read={ "": "VAR_MAP(a, b, c, d)", @@ -557,7 +574,6 @@ class TestHive(Validator): "hive": "MAP(a, b, c, d)", "presto": "MAP(ARRAY[a, c], ARRAY[b, d])", "spark": "MAP(a, b, c, d)", - "snowflake": "OBJECT_CONSTRUCT(a, b, c, d)", }, write={ "": "MAP(ARRAY(a, c), ARRAY(b, d))", @@ -627,7 +643,7 @@ class TestHive(Validator): self.validate_all( "x div y", write={ - "duckdb": "CAST(x / y AS INT)", + "duckdb": "x // y", "presto": "CAST(x / y AS INTEGER)", "hive": "CAST(x / y AS INT)", "spark": "CAST(x / y AS INT)", diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index f31b1b9..a80153b 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -6,6 +6,10 @@ class TestMySQL(Validator): dialect = "mysql" def test_ddl(self): + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1" + ) + self.validate_all( "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'", write={ @@ -21,10 +25,6 @@ class TestMySQL(Validator): "mysql": "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC", }, ) - self.validate_identity( - "INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1" - ) - self.validate_all( "CREATE TABLE x (id int not null auto_increment, primary key (id))", write={ @@ -37,6 +37,12 @@ class TestMySQL(Validator): "sqlite": "CREATE TABLE x (id INTEGER NOT NULL)", }, ) + self.validate_all( + "CREATE TABLE `foo` (`id` char(36) NOT NULL DEFAULT (uuid()), PRIMARY KEY (`id`), UNIQUE KEY `id` (`id`))", + write={ + "mysql": "CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))", + }, + ) def test_identity(self): self.validate_identity("SELECT CURRENT_TIMESTAMP(6)") @@ -47,8 +53,11 @@ class TestMySQL(Validator): self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')") self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')") self.validate_identity("@@GLOBAL.max_connections") - self.validate_identity("CREATE TABLE A LIKE B") + self.validate_identity("SELECT * FROM t1, t2 FOR SHARE OF t1, t2 SKIP LOCKED") + self.validate_identity( + "SELECT * FROM t1, t2, t3 FOR SHARE OF t1 NOWAIT FOR UPDATE OF t2, t3 SKIP LOCKED" + ) # SET Commands self.validate_identity("SET @var_name = expr") @@ -369,6 +378,9 @@ class TestMySQL(Validator): def test_mysql(self): self.validate_all( + "SELECT * FROM t LOCK IN SHARE MODE", write={"mysql": "SELECT * FROM t FOR SHARE"} + ) + self.validate_all( "SELECT DATE(DATE_SUB(`dt`, INTERVAL DAYOFMONTH(`dt`) - 1 DAY)) AS __timestamp FROM tableT", write={ "mysql": "SELECT DATE(DATE_SUB(`dt`, INTERVAL (DAYOFMONTH(`dt`) - 1) DAY)) AS __timestamp FROM tableT", diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 88c79fd..12ff699 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -5,6 +5,17 @@ class TestOracle(Validator): dialect = "oracle" def test_oracle(self): + self.validate_identity("SELECT * FROM t FOR UPDATE") + self.validate_identity("SELECT * FROM t FOR UPDATE WAIT 5") + self.validate_identity("SELECT * FROM t FOR UPDATE NOWAIT") + self.validate_identity("SELECT * FROM t FOR UPDATE SKIP LOCKED") + self.validate_identity("SELECT * FROM t FOR UPDATE OF s.t.c, s.t.v") + self.validate_identity("SELECT * FROM t FOR UPDATE OF s.t.c, s.t.v NOWAIT") + self.validate_identity("SELECT * FROM t FOR UPDATE OF s.t.c, s.t.v SKIP LOCKED") + self.validate_identity("SELECT STANDARD_HASH('hello')") + self.validate_identity("SELECT STANDARD_HASH('hello', 'MD5')") + self.validate_identity("SELECT CAST(NULL AS VARCHAR2(2328 CHAR)) AS COL1") + self.validate_identity("SELECT CAST(NULL AS VARCHAR2(2328 BYTE)) AS COL1") self.validate_identity("SELECT * FROM table_name@dblink_name.database_link_domain") self.validate_identity("SELECT * FROM table_name SAMPLE (25) s") self.validate_identity("SELECT * FROM V$SESSION") diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index b535a84..1f288c6 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -85,6 +85,9 @@ class TestPostgres(Validator): ) def test_postgres(self): + self.validate_identity( + """LAST_VALUE("col1") OVER (ORDER BY "col2" RANGE BETWEEN INTERVAL '1 day' PRECEDING AND '1 month' FOLLOWING)""" + ) self.validate_identity("SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]") self.validate_identity("SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]") self.validate_identity("SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]") @@ -158,7 +161,7 @@ class TestPostgres(Validator): write={ "postgres": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", "redshift": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", - "snowflake": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))", + "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))", }, ) self.validate_all( @@ -166,7 +169,7 @@ class TestPostgres(Validator): write={ "postgres": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", "redshift": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", - "snowflake": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + "snowflake": "SELECT DATE_PART(month, CAST('20220502' AS DATE))", }, ) self.validate_all( @@ -193,7 +196,7 @@ class TestPostgres(Validator): self.validate_all( "GENERATE_SERIES(a, b, ' 2 days ')", write={ - "postgres": "GENERATE_SERIES(a, b, INTERVAL '2' day)", + "postgres": "GENERATE_SERIES(a, b, INTERVAL '2 days')", "presto": "SEQUENCE(a, b, INTERVAL '2' day)", "trino": "SEQUENCE(a, b, INTERVAL '2' day)", }, @@ -201,7 +204,7 @@ class TestPostgres(Validator): self.validate_all( "GENERATE_SERIES('2019-01-01'::TIMESTAMP, NOW(), '1day')", write={ - "postgres": "GENERATE_SERIES(CAST('2019-01-01' AS TIMESTAMP), CURRENT_TIMESTAMP, INTERVAL '1' day)", + "postgres": "GENERATE_SERIES(CAST('2019-01-01' AS TIMESTAMP), CURRENT_TIMESTAMP, INTERVAL '1 day')", "presto": "SEQUENCE(TRY_CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)", "trino": "SEQUENCE(TRY_CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)", }, @@ -234,6 +237,15 @@ class TestPostgres(Validator): }, ) self.validate_all( + "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4) AS s", + write={ + "postgres": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4) AS s", + "presto": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4)) AS _u(s)", + "trino": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4)) AS _u(s)", + "tsql": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4) AS s", + }, + ) + self.validate_all( "END WORK AND NO CHAIN", write={"postgres": "COMMIT AND NO CHAIN"}, ) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 15962cc..1f5953c 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -438,6 +438,36 @@ class TestPresto(Validator): self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' week"}) self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' WEEKS"}) self.validate_all( + "WITH RECURSIVE t(n) AS (SELECT 1 AS n UNION ALL SELECT n + 1 AS n FROM t WHERE n < 4) SELECT SUM(n) FROM t", + read={ + "postgres": "WITH RECURSIVE t AS (SELECT 1 AS n UNION ALL SELECT n + 1 AS n FROM t WHERE n < 4) SELECT SUM(n) FROM t", + }, + ) + self.validate_all( + "WITH RECURSIVE t(n, k) AS (SELECT 1 AS n, 2 AS k) SELECT SUM(n) FROM t", + read={ + "postgres": "WITH RECURSIVE t AS (SELECT 1 AS n, 2 as k) SELECT SUM(n) FROM t", + }, + ) + self.validate_all( + "WITH RECURSIVE t1(n) AS (SELECT 1 AS n), t2(n) AS (SELECT 2 AS n) SELECT SUM(t1.n), SUM(t2.n) FROM t1, t2", + read={ + "postgres": "WITH RECURSIVE t1 AS (SELECT 1 AS n), t2 AS (SELECT 2 AS n) SELECT SUM(t1.n), SUM(t2.n) FROM t1, t2", + }, + ) + self.validate_all( + "WITH RECURSIVE t(n, _c_0) AS (SELECT 1 AS n, (1 + 2)) SELECT * FROM t", + read={ + "postgres": "WITH RECURSIVE t AS (SELECT 1 AS n, (1 + 2)) SELECT * FROM t", + }, + ) + self.validate_all( + 'WITH RECURSIVE t(n, "1") AS (SELECT n, 1 FROM tbl) SELECT * FROM t', + read={ + "postgres": "WITH RECURSIVE t AS (SELECT n, 1 FROM tbl) SELECT * FROM t", + }, + ) + self.validate_all( "SELECT JSON_OBJECT(KEY 'key1' VALUE 1, KEY 'key2' VALUE TRUE)", write={ "presto": "SELECT JSON_OBJECT('key1': 1, 'key2': TRUE)", @@ -757,3 +787,27 @@ class TestPresto(Validator): "SELECT col, pos, pos_2, col_2 FROM _u CROSS JOIN UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u_2(col_2, pos_2)", read={"spark": "SELECT col, pos, POSEXPLODE(SEQUENCE(2, 3)) FROM _u"}, ) + + def test_match_recognize(self): + self.validate_identity( + """SELECT + * +FROM orders +MATCH_RECOGNIZE ( + PARTITION BY custkey + ORDER BY + orderdate + MEASURES + A.totalprice AS starting_price, + LAST(B.totalprice) AS bottom_price, + LAST(C.totalprice) AS top_price + ONE ROW PER MATCH + AFTER MATCH SKIP PAST LAST ROW + PATTERN (A B+ C+ D+) + DEFINE + B AS totalprice < PREV(totalprice), + C AS totalprice > PREV(totalprice) AND totalprice <= A.totalprice, + D AS totalprice > PREV(totalprice) +)""", + pretty=True, + ) diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index f75480e..6707b7a 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -11,10 +11,10 @@ class TestRedshift(Validator): self.validate_identity("$foo") self.validate_all( - "SELECT SNAPSHOT", + "SELECT SNAPSHOT, type", write={ - "": "SELECT SNAPSHOT", - "redshift": 'SELECT "SNAPSHOT"', + "": "SELECT SNAPSHOT, type", + "redshift": 'SELECT "SNAPSHOT", "type"', }, ) @@ -31,7 +31,7 @@ class TestRedshift(Validator): write={ "postgres": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", "redshift": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", - "snowflake": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))", + "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))", }, ) self.validate_all( @@ -39,10 +39,10 @@ class TestRedshift(Validator): write={ "postgres": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", "redshift": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", - "snowflake": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + "snowflake": "SELECT DATE_PART(month, CAST('20220502' AS DATE))", }, ) - self.validate_all("SELECT INTERVAL '5 day'", read={"": "SELECT INTERVAL '5' days"}) + self.validate_all("SELECT INTERVAL '5 days'", read={"": "SELECT INTERVAL '5' days"}) self.validate_all("CONVERT(INTEGER, x)", write={"redshift": "CAST(x AS INTEGER)"}) self.validate_all( "DATEADD('day', ndays, caldate)", write={"redshift": "DATEADD(day, ndays, caldate)"} @@ -65,7 +65,7 @@ class TestRedshift(Validator): "SELECT ST_AsEWKT(ST_GeomFromEWKT('SRID=4326;POINT(10 20)')::geography)", write={ "redshift": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", - "bigquery": "SELECT ST_ASEWKT(TRY_CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", + "bigquery": "SELECT ST_ASEWKT(SAFE_CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", }, ) self.validate_all( @@ -182,6 +182,16 @@ class TestRedshift(Validator): def test_values(self): self.validate_all( + "SELECT * FROM (VALUES (1, 2)) AS t", + write={"redshift": "SELECT * FROM (SELECT 1, 2) AS t"}, + ) + self.validate_all( + "SELECT * FROM (VALUES (1)) AS t1(id) CROSS JOIN (VALUES (1)) AS t2(id)", + write={ + "redshift": "SELECT * FROM (SELECT 1 AS id) AS t1 CROSS JOIN (SELECT 1 AS id) AS t2", + }, + ) + self.validate_all( "SELECT a, b FROM (VALUES (1, 2)) AS t (a, b)", write={ "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b) AS t", diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 57ee235..941f2aa 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -6,6 +6,8 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + self.validate_identity("INITCAP('iqamqinterestedqinqthisqtopic', 'q')") + self.validate_identity("CAST(x AS GEOMETRY)") self.validate_identity("OBJECT_CONSTRUCT(*)") self.validate_identity("SELECT TO_DATE('2019-02-28') + INTERVAL '1 day, 1 year'") self.validate_identity("SELECT CAST('2021-01-01' AS DATE) + INTERVAL '1 DAY'") @@ -25,7 +27,21 @@ class TestSnowflake(Validator): 'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)' ) self.validate_identity("COMMENT IF EXISTS ON TABLE foo IS 'bar'") + self.validate_identity("SELECT CONVERT_TIMEZONE('UTC', 'America/Los_Angeles', col)") + self.validate_all("CAST(x AS CHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) + self.validate_all("CAST(x AS CHARACTER VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) + self.validate_all("CAST(x AS NCHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) + self.validate_all( + "OBJECT_CONSTRUCT(a, b, c, d)", + read={ + "": "STRUCT(a as b, c as d)", + }, + write={ + "duckdb": "{'a': b, 'c': d}", + "snowflake": "OBJECT_CONSTRUCT(a, b, c, d)", + }, + ) self.validate_all( "SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1", write={ @@ -284,7 +300,7 @@ class TestSnowflake(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname", "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", @@ -458,7 +474,8 @@ class TestSnowflake(Validator): ) def test_timestamps(self): - self.validate_identity("SELECT EXTRACT(month FROM a)") + self.validate_identity("SELECT CAST('12:00:00' AS TIME)") + self.validate_identity("SELECT DATE_PART(month, a)") self.validate_all( "SELECT CAST(a AS TIMESTAMP)", @@ -487,19 +504,19 @@ class TestSnowflake(Validator): self.validate_all( "SELECT EXTRACT('month', a)", write={ - "snowflake": "SELECT EXTRACT('month' FROM a)", + "snowflake": "SELECT DATE_PART('month', a)", }, ) self.validate_all( "SELECT DATE_PART('month', a)", write={ - "snowflake": "SELECT EXTRACT('month' FROM a)", + "snowflake": "SELECT DATE_PART('month', a)", }, ) self.validate_all( "SELECT DATE_PART(month, a::DATETIME)", write={ - "snowflake": "SELECT EXTRACT(month FROM CAST(a AS DATETIME))", + "snowflake": "SELECT DATE_PART(month, CAST(a AS DATETIME))", }, ) self.validate_all( @@ -554,10 +571,23 @@ class TestSnowflake(Validator): ) def test_ddl(self): + self.validate_identity("CREATE TABLE geospatial_table (id INT, g GEOGRAPHY)") + self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x") + self.validate_identity("CREATE DATABASE mytestdb_clone CLONE mytestdb") + self.validate_identity("CREATE SCHEMA mytestschema_clone CLONE testschema") + self.validate_identity("CREATE TABLE orders_clone CLONE orders") + self.validate_identity( + "CREATE TABLE orders_clone_restore CLONE orders AT (TIMESTAMP => TO_TIMESTAMP_TZ('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss'))" + ) + self.validate_identity( + "CREATE TABLE orders_clone_restore CLONE orders BEFORE (STATEMENT => '8e5d0ca9-005e-44e6-b858-a8f5b37c5726')" + ) self.validate_identity( "CREATE TABLE a (x DATE, y BIGINT) WITH (PARTITION BY (x), integration='q', auto_refresh=TRUE, file_format=(type = parquet))" ) - self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x") + self.validate_identity( + "CREATE SCHEMA mytestschema_clone_restore CLONE testschema BEFORE (TIMESTAMP => TO_TIMESTAMP(40 * 365 * 86400))" + ) self.validate_all( "CREATE OR REPLACE TRANSIENT TABLE a (id INT)", @@ -758,7 +788,7 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA def test_values(self): self.validate_all( - 'SELECT c0, c1 FROM (VALUES (1, 2), (3, 4)) AS "t0"(c0, c1)', + 'SELECT "c0", "c1" FROM (VALUES (1, 2), (3, 4)) AS "t0"("c0", "c1")', read={ "spark": "SELECT `c0`, `c1` FROM (VALUES (1, 2), (3, 4)) AS `t0`(`c0`, `c1`)", }, diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index be03b4e..bcfd984 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -39,8 +39,8 @@ class TestSpark(Validator): "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", write={ "duckdb": "CREATE TABLE x", - "presto": "CREATE TABLE x WITH (TABLE_FORMAT='ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])", - "hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", + "presto": "CREATE TABLE x WITH (FORMAT='ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])", + "hive": "CREATE TABLE x STORED AS ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", "spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", }, ) @@ -113,6 +113,13 @@ TBLPROPERTIES ( "spark": "ALTER TABLE StudentInfo DROP COLUMNS (LastName, DOB)", }, ) + self.validate_all( + "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", + identify=True, + write={ + "spark": "CREATE TABLE `x` USING ICEBERG PARTITIONED BY (MONTHS(`y`)) LOCATION 's3://z'", + }, + ) def test_to_date(self): self.validate_all( @@ -207,6 +214,7 @@ TBLPROPERTIES ( ) def test_spark(self): + self.validate_identity("INTERVAL -86 days") self.validate_identity("SELECT UNIX_TIMESTAMP()") self.validate_identity("TRIM(' SparkSQL ')") self.validate_identity("TRIM(BOTH 'SL' FROM 'SSparkSQLS')") @@ -215,6 +223,18 @@ TBLPROPERTIES ( self.validate_identity("SPLIT(str, pattern, lim)") self.validate_all( + "SELECT piv.Q1 FROM (SELECT * FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2'))) AS piv", + read={ + "snowflake": "SELECT piv.Q1 FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2')) piv", + }, + ) + self.validate_all( + "SELECT piv.Q1 FROM (SELECT * FROM (SELECT * FROM produce) PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2'))) AS piv", + read={ + "snowflake": "SELECT piv.Q1 FROM (SELECT * FROM produce) PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2')) piv", + }, + ) + self.validate_all( "SELECT * FROM produce PIVOT(SUM(produce.sales) FOR quarter IN ('Q1', 'Q2'))", read={ "snowflake": "SELECT * FROM produce PIVOT (SUM(produce.sales) FOR produce.quarter IN ('Q1', 'Q2'))", @@ -301,7 +321,7 @@ TBLPROPERTIES ( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname NULLS FIRST", "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index dcb513d..03eb02f 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -24,6 +24,13 @@ class TestTeradata(Validator): def test_create(self): self.validate_identity("CREATE TABLE x (y INT) PRIMARY INDEX (y) PARTITION BY y INDEX (y)") + self.validate_identity("CREATE TABLE x (y INT) PARTITION BY y INDEX (y)") + self.validate_identity( + "CREATE MULTISET VOLATILE TABLE my_table (id INT) PRIMARY INDEX (id) ON COMMIT PRESERVE ROWS" + ) + self.validate_identity( + "CREATE SET VOLATILE TABLE my_table (id INT) PRIMARY INDEX (id) ON COMMIT DELETE ROWS" + ) self.validate_identity( "CREATE TABLE a (b INT) PRIMARY INDEX (y) PARTITION BY RANGE_N(b BETWEEN 'a', 'b' AND 'c' EACH '1')" ) @@ -35,10 +42,20 @@ class TestTeradata(Validator): ) self.validate_all( + """ + CREATE SET TABLE test, NO FALLBACK, NO BEFORE JOURNAL, NO AFTER JOURNAL, + CHECKSUM = DEFAULT (x INT, y INT, z CHAR(30), a INT, b DATE, e INT) + PRIMARY INDEX (a), + INDEX(x, y) + """, + write={ + "teradata": "CREATE SET TABLE test, NO FALLBACK, NO BEFORE JOURNAL, NO AFTER JOURNAL, CHECKSUM=DEFAULT (x INT, y INT, z CHAR(30), a INT, b DATE, e INT) PRIMARY INDEX (a) INDEX (x, y)", + }, + ) + self.validate_all( "REPLACE VIEW a AS (SELECT b FROM c)", write={"teradata": "CREATE OR REPLACE VIEW a AS (SELECT b FROM c)"}, ) - self.validate_all( "CREATE VOLATILE TABLE a", write={ diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 3a3ac73..8789aed 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -467,6 +467,11 @@ WHERE def test_add_date(self): self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')") + + self.validate_all( + "DATEADD(year, 50, '2006-07-31')", + write={"bigquery": "DATE_ADD('2006-07-31', INTERVAL 50 YEAR)"}, + ) self.validate_all( "SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"}, @@ -525,7 +530,7 @@ WHERE self.validate_all( "SELECT x.a, x.b, t.v, t.y FROM x CROSS APPLY (SELECT v, y FROM t) t(v, y)", write={ - "spark": "SELECT x.a, x.b, t.v, t.y FROM x JOIN LATERAL (SELECT v, y FROM t) AS t(v, y)", + "spark": "SELECT x.a, x.b, t.v, t.y FROM x, LATERAL (SELECT v, y FROM t) AS t(v, y)", }, ) self.validate_all( @@ -545,7 +550,7 @@ WHERE self.validate_all( "SELECT t.x, y.z FROM x CROSS APPLY tvfTest(t.x)y(z)", write={ - "spark": "SELECT t.x, y.z FROM x JOIN LATERAL TVFTEST(t.x) AS y(z)", + "spark": "SELECT t.x, y.z FROM x, LATERAL TVFTEST(t.x) AS y(z)", }, ) self.validate_all( @@ -637,7 +642,7 @@ WHERE self.assertIsInstance(expr.this, exp.Var) self.assertEqual(expr.sql("tsql"), "@x") - table = parse_one("select * from @x", read="tsql").args["from"].expressions[0] + table = parse_one("select * from @x", read="tsql").args["from"].this self.assertIsInstance(table, exp.Table) self.assertIsInstance(table.this, exp.Parameter) self.assertIsInstance(table.this.this, exp.Var) @@ -724,3 +729,44 @@ WHERE }, ) self.validate_identity("SELECT x FROM a INNER LOOP JOIN b ON b.id = a.id") + + def test_openjson(self): + self.validate_identity("SELECT * FROM OPENJSON(@json)") + + self.validate_all( + """SELECT [key], value FROM OPENJSON(@json,'$.path.to."sub-object"')""", + write={ + "tsql": """SELECT "key", value FROM OPENJSON(@json, '$.path.to."sub-object"')""", + }, + ) + self.validate_all( + "SELECT * FROM OPENJSON(@array) WITH (month VARCHAR(3), temp int, month_id tinyint '$.sql:identity()') as months", + write={ + "tsql": "SELECT * FROM OPENJSON(@array) WITH (month VARCHAR(3), temp INTEGER, month_id TINYINT '$.sql:identity()') AS months", + }, + ) + self.validate_all( + """ + SELECT * + FROM OPENJSON ( @json ) + WITH ( + Number VARCHAR(200) '$.Order.Number', + Date DATETIME '$.Order.Date', + Customer VARCHAR(200) '$.AccountNumber', + Quantity INT '$.Item.Quantity', + [Order] NVARCHAR(MAX) AS JSON + ) + """, + write={ + "tsql": """SELECT + * +FROM OPENJSON(@json) WITH ( + Number VARCHAR(200) '$.Order.Number', + Date DATETIME2 '$.Order.Date', + Customer VARCHAR(200) '$.AccountNumber', + Quantity INTEGER '$.Item.Quantity', + "Order" TEXT AS JSON +)""" + }, + pretty=True, + ) |