diff options
Diffstat (limited to 'tests/dialects')
-rw-r--r-- | tests/dialects/test_bigquery.py | 107 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_databricks.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 14 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 97 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 15 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 45 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 42 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 35 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 39 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 72 | ||||
-rw-r--r-- | tests/dialects/test_sqlite.py | 21 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 100 |
14 files changed, 452 insertions, 152 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 448a077..1f5f902 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -9,36 +9,6 @@ class TestBigQuery(Validator): maxDiff = None def test_bigquery(self): - self.validate_identity("SELECT * FROM tbl FOR SYSTEM_TIME AS OF z") - self.validate_identity( - """SELECT JSON '"foo"' AS json_data""", - """SELECT PARSE_JSON('"foo"') AS json_data""", - ) - - self.validate_all( - """SELECT - `u`.`harness_user_email` AS `harness_user_email`, - `d`.`harness_user_id` AS `harness_user_id`, - `harness_account_id` AS `harness_account_id` -FROM `analytics_staging`.`stg_mongodb__users` AS `u`, UNNEST(`u`.`harness_cluster_details`) AS `d`, UNNEST(`d`.`harness_account_ids`) AS `harness_account_id` -WHERE - NOT `harness_account_id` IS NULL""", - read={ - "": """ - SELECT - "u"."harness_user_email" AS "harness_user_email", - "_q_0"."d"."harness_user_id" AS "harness_user_id", - "_q_1"."harness_account_id" AS "harness_account_id" - FROM - "analytics_staging"."stg_mongodb__users" AS "u", - UNNEST("u"."harness_cluster_details") AS "_q_0"("d"), - UNNEST("_q_0"."d"."harness_account_ids") AS "_q_1"("harness_account_id") - WHERE - NOT "_q_1"."harness_account_id" IS NULL - """ - }, - pretty=True, - ) with self.assertRaises(TokenError): transpile("'\\'", read="bigquery") @@ -63,6 +33,9 @@ WHERE with self.assertRaises(ParseError): transpile("DATE_ADD(x, day)", read="bigquery") + self.validate_identity("SELECT test.Unknown FROM test") + self.validate_identity(r"SELECT '\n\r\a\v\f\t'") + self.validate_identity("SELECT * FROM tbl FOR SYSTEM_TIME AS OF z") self.validate_identity("STRING_AGG(DISTINCT a ORDER BY b DESC, c DESC LIMIT 10)") self.validate_identity("SELECT PARSE_TIMESTAMP('%c', 'Thu Dec 25 07:30:00 2008', 'UTC')") self.validate_identity("SELECT ANY_VALUE(fruit HAVING MAX sold) FROM fruits") @@ -111,6 +84,7 @@ WHERE self.validate_identity("COMMIT TRANSACTION") self.validate_identity("ROLLBACK TRANSACTION") self.validate_identity("CAST(x AS BIGNUMERIC)") + self.validate_identity("SELECT y + 1 FROM x GROUP BY y + 1 ORDER BY 1") self.validate_identity( "DATE(CAST('2016-12-25 05:30:00+07' AS DATETIME), 'America/Los_Angeles')" ) @@ -132,6 +106,22 @@ WHERE self.validate_identity( "SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)", ) + self.validate_identity( + "SELECT a overlaps", + "SELECT a AS overlaps", + ) + self.validate_identity( + "SELECT y + 1 z FROM x GROUP BY y + 1 ORDER BY z", + "SELECT y + 1 AS z FROM x GROUP BY z ORDER BY z", + ) + self.validate_identity( + "SELECT y + 1 z FROM x GROUP BY y + 1", + "SELECT y + 1 AS z FROM x GROUP BY y + 1", + ) + self.validate_identity( + """SELECT JSON '"foo"' AS json_data""", + """SELECT PARSE_JSON('"foo"') AS json_data""", + ) self.validate_all("SELECT SPLIT(foo)", write={"bigquery": "SELECT SPLIT(foo, ',')"}) self.validate_all("SELECT 1 AS hash", write={"bigquery": "SELECT 1 AS `hash`"}) @@ -246,7 +236,7 @@ WHERE }, ) self.validate_all( - "WITH cte AS (SELECT [1, 2, 3] AS arr) SELECT col FROM cte CROSS JOIN UNNEST(arr) AS col", + "WITH cte AS (SELECT [1, 2, 3] AS arr) SELECT IF(pos = pos_2, col, NULL) AS col FROM cte, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(arr)) - 1)) AS pos CROSS JOIN UNNEST(arr) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(arr) - 1) AND pos_2 = (ARRAY_LENGTH(arr) - 1))", read={ "spark": "WITH cte AS (SELECT ARRAY(1, 2, 3) AS arr) SELECT EXPLODE(arr) FROM cte" }, @@ -291,6 +281,10 @@ WHERE "bigquery": "SELECT ARRAY(SELECT AS STRUCT 1 AS a, 2 AS b)", }, ) + self.validate_identity( + r"REGEXP_EXTRACT(svc_plugin_output, r'\\\((.*)')", + r"REGEXP_EXTRACT(svc_plugin_output, '\\\\\\((.*)')", + ) self.validate_all( "REGEXP_CONTAINS('foo', '.*')", read={ @@ -302,7 +296,7 @@ WHERE "mysql": "REGEXP_LIKE('foo', '.*')", "starrocks": "REGEXP('foo', '.*')", }, - ), + ) self.validate_all( '"""x"""', write={ @@ -453,7 +447,6 @@ WHERE "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)", write={"bigquery": "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)"}, ) - self.validate_all( "x IS unknown", write={ @@ -465,6 +458,16 @@ WHERE }, ) self.validate_all( + "x IS NOT unknown", + write={ + "bigquery": "NOT x IS NULL", + "duckdb": "NOT x IS NULL", + "presto": "NOT x IS NULL", + "hive": "NOT x IS NULL", + "spark": "NOT x IS NULL", + }, + ) + self.validate_all( "CURRENT_TIMESTAMP()", read={ "tsql": "GETDATE()", @@ -682,16 +685,32 @@ WHERE "spark": "TO_JSON(x)", }, ) - - self.validate_identity( - "SELECT y + 1 z FROM x GROUP BY y + 1 ORDER BY z", - "SELECT y + 1 AS z FROM x GROUP BY z ORDER BY z", - ) - self.validate_identity( - "SELECT y + 1 z FROM x GROUP BY y + 1", - "SELECT y + 1 AS z FROM x GROUP BY y + 1", + self.validate_all( + """SELECT + `u`.`harness_user_email` AS `harness_user_email`, + `d`.`harness_user_id` AS `harness_user_id`, + `harness_account_id` AS `harness_account_id` +FROM `analytics_staging`.`stg_mongodb__users` AS `u`, UNNEST(`u`.`harness_cluster_details`) AS `d`, UNNEST(`d`.`harness_account_ids`) AS `harness_account_id` +WHERE + NOT `harness_account_id` IS NULL""", + read={ + "": """ + SELECT + "u"."harness_user_email" AS "harness_user_email", + "_q_0"."d"."harness_user_id" AS "harness_user_id", + "_q_1"."harness_account_id" AS "harness_account_id" + FROM + "analytics_staging"."stg_mongodb__users" AS "u", + UNNEST("u"."harness_cluster_details") AS "_q_0"("d"), + UNNEST("_q_0"."d"."harness_account_ids") AS "_q_1"("harness_account_id") + WHERE + NOT "_q_1"."harness_account_id" IS NULL + """ + }, + pretty=True, ) - self.validate_identity("SELECT y + 1 FROM x GROUP BY y + 1 ORDER BY 1") + + self.validate_identity("LOG(n, b)") def test_user_defined_functions(self): self.validate_identity( @@ -702,6 +721,10 @@ WHERE self.validate_identity( "CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t" ) + self.validate_identity( + '''CREATE TEMPORARY FUNCTION string_length_0(strings ARRAY<STRING>) RETURNS FLOAT64 LANGUAGE js AS """'use strict'; function string_length(strings) { return _.sum(_.map(strings, ((x) => x.length))); } return string_length(strings);""" OPTIONS (library=['gs://ibis-testing-libraries/lodash.min.js'])''', + "CREATE TEMPORARY FUNCTION string_length_0(strings ARRAY<STRING>) RETURNS FLOAT64 LANGUAGE js OPTIONS (library=['gs://ibis-testing-libraries/lodash.min.js']) AS '\\'use strict\\'; function string_length(strings) { return _.sum(_.map(strings, ((x) => x.length))); } return string_length(strings);'", + ) def test_group_concat(self): self.validate_all( diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 2cda0dc..40a270e 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -157,8 +157,8 @@ class TestClickhouse(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", }, ) self.validate_all( @@ -216,7 +216,7 @@ class TestClickhouse(Validator): """, write={ "clickhouse": "SELECT loyalty, count() FROM hits LEFT SEMI JOIN users USING (UserID)" - + " GROUP BY loyalty ORDER BY loyalty" + " GROUP BY loyalty ORDER BY loyalty ASC" }, ) self.validate_identity("SELECT s, arr FROM arrays_test ARRAY JOIN arr") diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index d06e0f1..3df968b 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -5,6 +5,9 @@ class TestDatabricks(Validator): dialect = "databricks" def test_databricks(self): + self.validate_identity("CREATE TABLE t (c STRUCT<interval: DOUBLE COMMENT 'aaa'>)") + self.validate_identity("CREATE TABLE my_table () TBLPROPERTIES (a.b=15)") + self.validate_identity("CREATE TABLE my_table () TBLPROPERTIES ('a.b'=15)") self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO HOUR)") self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO MINUTE)") self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO SECOND)") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 47e1ec7..3e0ffd5 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -964,12 +964,12 @@ class TestDialect(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", }, ) @@ -1354,7 +1354,7 @@ class TestDialect(Validator): self.validate_all( "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) as foo FROM baz", write={ - "bigquery": "SELECT CASE WHEN COALESCE(bar, 0) = 1 THEN TRUE ELSE FALSE END AS foo FROM baz", + "bigquery": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz", "duckdb": "SELECT CASE WHEN COALESCE(bar, 0) = 1 THEN TRUE ELSE FALSE END AS foo FROM baz", "presto": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz", "hive": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz", diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 36fca7c..dbf0a87 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -6,6 +6,92 @@ class TestDuckDB(Validator): dialect = "duckdb" def test_duckdb(self): + self.assertEqual( + parse_one("select * from t limit (select 5)").sql(dialect="duckdb"), + exp.select("*").from_("t").limit(exp.select("5").subquery()).sql(dialect="duckdb"), + ) + + for struct_value in ("{'a': 1}", "struct_pack(a := 1)"): + self.validate_all(struct_value, write={"presto": UnsupportedError}) + + for join_type in ("SEMI", "ANTI"): + exists = "EXISTS" if join_type == "SEMI" else "NOT EXISTS" + + self.validate_all( + f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + write={ + "bigquery": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "clickhouse": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "databricks": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "doris": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "drill": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "duckdb": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "hive": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "mysql": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "oracle": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "postgres": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "presto": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "redshift": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "snowflake": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "spark": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "sqlite": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "starrocks": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "teradata": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "trino": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "tsql": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + }, + ) + self.validate_all( + f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + read={ + "duckdb": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "spark": f"SELECT * FROM t1 LEFT {join_type} JOIN t2 ON t1.x = t2.x", + }, + ) + + self.validate_all( + "WITH cte(x) AS (SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) SELECT AVG(x) FILTER (WHERE x > 1) FROM cte", + write={ + "duckdb": "WITH cte(x) AS (SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) SELECT AVG(x) FILTER(WHERE x > 1) FROM cte", + "snowflake": "WITH cte(x) AS (SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) SELECT AVG(IFF(x > 1, x, NULL)) FROM cte", + }, + ) + self.validate_all( + "SELECT AVG(x) FILTER (WHERE TRUE) FROM t", + write={ + "duckdb": "SELECT AVG(x) FILTER(WHERE TRUE) FROM t", + "snowflake": "SELECT AVG(IFF(TRUE, x, NULL)) FROM t", + }, + ) + self.validate_all( + "SELECT UNNEST(ARRAY[1, 2, 3]), UNNEST(ARRAY[4, 5]), UNNEST(ARRAY[6])", + write={ + "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_3, col_2, NULL) AS col_2, IF(pos = pos_4, col_3, NULL) AS col_3 FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2, 3]), ARRAY_LENGTH([4, 5]), ARRAY_LENGTH([6])) - 1)) AS pos CROSS JOIN UNNEST([1, 2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5]) AS col_2 WITH OFFSET AS pos_3 CROSS JOIN UNNEST([6]) AS col_3 WITH OFFSET AS pos_4 WHERE ((pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5]) - 1)))) AND (pos = pos_4 OR (pos > (ARRAY_LENGTH([6]) - 1) AND pos_4 = (ARRAY_LENGTH([6]) - 1)))", + "presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_3, col_2) AS col_2, IF(pos = pos_4, col_3) AS col_3 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2, 3]) AND pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5]) AND pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (pos = pos_4 OR (pos > CARDINALITY(ARRAY[6]) AND pos_4 = CARDINALITY(ARRAY[6])))", + }, + ) + + self.validate_all( + "SELECT UNNEST(ARRAY[1, 2, 3]), UNNEST(ARRAY[4, 5]), UNNEST(ARRAY[6]) FROM x", + write={ + "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_3, col_2, NULL) AS col_2, IF(pos = pos_4, col_3, NULL) AS col_3 FROM x, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2, 3]), ARRAY_LENGTH([4, 5]), ARRAY_LENGTH([6])) - 1)) AS pos CROSS JOIN UNNEST([1, 2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5]) AS col_2 WITH OFFSET AS pos_3 CROSS JOIN UNNEST([6]) AS col_3 WITH OFFSET AS pos_4 WHERE ((pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5]) - 1)))) AND (pos = pos_4 OR (pos > (ARRAY_LENGTH([6]) - 1) AND pos_4 = (ARRAY_LENGTH([6]) - 1)))", + "presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_3, col_2) AS col_2, IF(pos = pos_4, col_3) AS col_3 FROM x, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2, 3]) AND pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5]) AND pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (pos = pos_4 OR (pos > CARDINALITY(ARRAY[6]) AND pos_4 = CARDINALITY(ARRAY[6])))", + }, + ) + self.validate_all( + "SELECT UNNEST(x) + 1", + write={ + "bigquery": "SELECT IF(pos = pos_2, col, NULL) + 1 AS col FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(x)) - 1)) AS pos CROSS JOIN UNNEST(x) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(x) - 1) AND pos_2 = (ARRAY_LENGTH(x) - 1))", + }, + ) + self.validate_all( + "SELECT UNNEST(x) + 1 AS y", + write={ + "bigquery": "SELECT IF(pos = pos_2, y, NULL) + 1 AS y FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(x)) - 1)) AS pos CROSS JOIN UNNEST(x) AS y WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(x) - 1) AND pos_2 = (ARRAY_LENGTH(x) - 1))", + }, + ) + + self.validate_identity("SELECT i FROM RANGE(5) AS _(i) ORDER BY i ASC") self.validate_identity("[x.STRING_SPLIT(' ')[1] FOR x IN ['1', '2', 3] IF x.CONTAINS('1')]") self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y") self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x") @@ -62,6 +148,13 @@ class TestDuckDB(Validator): self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"}) self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'}) self.validate_all( + "SELECT UNNEST([1, 2, 3])", + write={ + "duckdb": "SELECT UNNEST([1, 2, 3])", + "snowflake": "SELECT IFF(pos = pos_2, col, NULL) AS col FROM (SELECT value FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (GREATEST(ARRAY_SIZE([1, 2, 3])) - 1) + 1)))) AS _u(pos) CROSS JOIN (SELECT value, index FROM TABLE(FLATTEN(INPUT => [1, 2, 3]))) AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > (ARRAY_SIZE([1, 2, 3]) - 1) AND pos_2 = (ARRAY_SIZE([1, 2, 3]) - 1))", + }, + ) + self.validate_all( "VAR_POP(x)", read={ "": "VARIANCE_POP(x)", @@ -304,8 +397,8 @@ class TestDuckDB(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", + "": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname", }, ) self.validate_all( diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 70a05fd..26f0189 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -382,10 +382,10 @@ class TestHive(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", }, ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index e362e9e..20f872c 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -499,6 +499,21 @@ class TestMySQL(Validator): def test_mysql(self): self.validate_all( + "a XOR b", + read={ + "mysql": "a XOR b", + "snowflake": "BOOLXOR(a, b)", + }, + write={ + "duckdb": "(a AND (NOT b)) OR ((NOT a) AND b)", + "mysql": "a XOR b", + "postgres": "(a AND (NOT b)) OR ((NOT a) AND b)", + "snowflake": "BOOLXOR(a, b)", + "trino": "(a AND (NOT b)) OR ((NOT a) AND b)", + }, + ) + + self.validate_all( "SELECT * FROM test LIMIT 0 + 1, 0 + 1", write={ "mysql": "SELECT * FROM test LIMIT 1 OFFSET 1", diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 285496a..6a3df47 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -13,6 +13,7 @@ class TestPostgres(Validator): "CREATE TABLE test (x TIMESTAMP WITHOUT TIME ZONE[][])", "CREATE TABLE test (x TIMESTAMP[][])", ) + self.validate_identity("CREATE INDEX idx_x ON x USING BTREE(x, y) WHERE (NOT y IS NULL)") self.validate_identity("CREATE TABLE test (elems JSONB[])") self.validate_identity("CREATE TABLE public.y (x TSTZRANGE NOT NULL)") self.validate_identity("CREATE TABLE test (foo HSTORE)") @@ -83,6 +84,28 @@ class TestPostgres(Validator): " CONSTRAINT valid_discount CHECK (price > discounted_price))" }, ) + self.validate_identity( + """ + CREATE INDEX index_ci_builds_on_commit_id_and_artifacts_expireatandidpartial + ON public.ci_builds + USING btree (commit_id, artifacts_expire_at, id) + WHERE ( + ((type)::text = 'Ci::Build'::text) + AND ((retried = false) OR (retried IS NULL)) + AND ((name)::text = ANY (ARRAY[ + ('sast'::character varying)::text, + ('dependency_scanning'::character varying)::text, + ('sast:container'::character varying)::text, + ('container_scanning'::character varying)::text, + ('dast'::character varying)::text + ])) + ) + """, + "CREATE INDEX index_ci_builds_on_commit_id_and_artifacts_expireatandidpartial ON public.ci_builds USING btree(commit_id, artifacts_expire_at, id) WHERE ((CAST((type) AS TEXT) = CAST('Ci::Build' AS TEXT)) AND ((retried = FALSE) OR (retried IS NULL)) AND (CAST((name) AS TEXT) = ANY (ARRAY[CAST((CAST('sast' AS VARCHAR)) AS TEXT), CAST((CAST('dependency_scanning' AS VARCHAR)) AS TEXT), CAST((CAST('sast:container' AS VARCHAR)) AS TEXT), CAST((CAST('container_scanning' AS VARCHAR)) AS TEXT), CAST((CAST('dast' AS VARCHAR)) AS TEXT)])))", + ) + self.validate_identity( + "CREATE INDEX index_ci_pipelines_on_project_idandrefandiddesc ON public.ci_pipelines USING btree(project_id, ref, id DESC)" + ) with self.assertRaises(ParseError): transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres") @@ -102,7 +125,7 @@ class TestPostgres(Validator): write={ "hive": "SELECT EXPLODE(c) FROM t", "postgres": "SELECT UNNEST(c) FROM t", - "presto": "SELECT col FROM t CROSS JOIN UNNEST(c) AS _u(col)", + "presto": "SELECT IF(pos = pos_2, col) AS col FROM t, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(c)))) AS _u(pos) CROSS JOIN UNNEST(c) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(c) AND pos_2 = CARDINALITY(c))", }, ) self.validate_all( @@ -110,7 +133,7 @@ class TestPostgres(Validator): write={ "hive": "SELECT EXPLODE(ARRAY(1))", "postgres": "SELECT UNNEST(ARRAY[1])", - "presto": "SELECT col FROM UNNEST(ARRAY[1]) AS _u(col)", + "presto": "SELECT IF(pos = pos_2, col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[1]) AND pos_2 = CARDINALITY(ARRAY[1]))", }, ) @@ -139,6 +162,16 @@ class TestPostgres(Validator): self.assertIsInstance(expr, exp.AlterTable) self.assertEqual(expr.sql(dialect="postgres"), alter_table_only) + self.validate_identity( + "SELECT ARRAY[]::INT[] AS foo", + "SELECT CAST(ARRAY[] AS INT[]) AS foo", + ) + self.validate_identity( + """ALTER TABLE ONLY "Album" ADD CONSTRAINT "FK_AlbumArtistId" FOREIGN KEY ("ArtistId") REFERENCES "Artist" ("ArtistId") ON DELETE CASCADE""" + ) + self.validate_identity( + """ALTER TABLE ONLY "Album" ADD CONSTRAINT "FK_AlbumArtistId" FOREIGN KEY ("ArtistId") REFERENCES "Artist" ("ArtistId") ON DELETE RESTRICT""" + ) self.validate_identity("x @@ y") self.validate_identity("CAST(x AS MONEY)") self.validate_identity("CAST(x AS INT4RANGE)") @@ -362,10 +395,10 @@ class TestPostgres(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname", - "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", - "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", }, ) self.validate_all( diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index a92f04f..a80013e 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -431,8 +431,8 @@ class TestPresto(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", }, ) @@ -947,44 +947,6 @@ class TestPresto(Validator): }, ) - def test_explode_to_unnest(self): - self.validate_all( - "SELECT col FROM tbl CROSS JOIN UNNEST(x) AS _u(col)", - read={"spark": "SELECT EXPLODE(x) FROM tbl"}, - ) - self.validate_all( - "SELECT col_2 FROM _u CROSS JOIN UNNEST(col) AS _u_2(col_2)", - read={"spark": "SELECT EXPLODE(col) FROM _u"}, - ) - self.validate_all( - "SELECT exploded FROM schema.tbl CROSS JOIN UNNEST(col) AS _u(exploded)", - read={"spark": "SELECT EXPLODE(col) AS exploded FROM schema.tbl"}, - ) - self.validate_all( - "SELECT col FROM UNNEST(SEQUENCE(1, 2)) AS _u(col)", - read={"spark": "SELECT EXPLODE(SEQUENCE(1, 2))"}, - ) - self.validate_all( - "SELECT col FROM tbl AS t CROSS JOIN UNNEST(t.c) AS _u(col)", - read={"spark": "SELECT EXPLODE(t.c) FROM tbl t"}, - ) - self.validate_all( - "SELECT pos, col FROM UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u(col, pos)", - read={"spark": "SELECT POSEXPLODE(SEQUENCE(2, 3))"}, - ) - self.validate_all( - "SELECT pos, col FROM tbl CROSS JOIN UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u(col, pos)", - read={"spark": "SELECT POSEXPLODE(SEQUENCE(2, 3)) FROM tbl"}, - ) - self.validate_all( - "SELECT pos, col FROM tbl AS t CROSS JOIN UNNEST(t.c) WITH ORDINALITY AS _u(col, pos)", - read={"spark": "SELECT POSEXPLODE(t.c) FROM tbl t"}, - ) - self.validate_all( - "SELECT col, pos, pos_2, col_2 FROM _u CROSS JOIN UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u_2(col_2, pos_2)", - read={"spark": "SELECT col, pos, POSEXPLODE(SEQUENCE(2, 3)) FROM _u"}, - ) - def test_match_recognize(self): self.validate_identity( """SELECT diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index e261c01..c75654c 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -1,3 +1,4 @@ +from sqlglot import transpile from tests.dialects.test_dialect import Validator @@ -270,6 +271,26 @@ class TestRedshift(Validator): ) def test_values(self): + # Test crazy-sized VALUES clause to UNION ALL conversion to ensure we don't get RecursionError + values = [str(v) for v in range(0, 10000)] + values_query = f"SELECT * FROM (VALUES {', '.join('(' + v + ')' for v in values)})" + union_query = f"SELECT * FROM ({' UNION ALL '.join('SELECT ' + v for v in values)})" + self.assertEqual(transpile(values_query, write="redshift")[0], union_query) + + self.validate_identity( + "SELECT * FROM (VALUES (1), (2))", + """SELECT + * +FROM ( + SELECT + 1 + UNION ALL + SELECT + 2 +)""", + pretty=True, + ) + self.validate_all( "SELECT * FROM (VALUES (1, 2)) AS t", write={ @@ -291,9 +312,9 @@ class TestRedshift(Validator): }, ) self.validate_all( - "SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)", + 'SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS "t" (a, b)', write={ - "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t", + "redshift": 'SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS "t"', }, ) self.validate_all( @@ -320,6 +341,16 @@ class TestRedshift(Validator): "redshift": "INSERT INTO t (a, b) VALUES (1, 2), (3, 4)", }, ) + self.validate_identity( + 'SELECT * FROM (VALUES (1)) AS "t"(a)', + '''SELECT + * +FROM ( + SELECT + 1 AS a +) AS "t"''', + pretty=True, + ) def test_create_table_like(self): self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 30a1f03..a217394 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -78,12 +78,33 @@ class TestSnowflake(Validator): r"SELECT $$a ' \ \t \x21 z $ $$", r"SELECT 'a \' \\ \\t \\x21 z $ '", ) + self.validate_identity( + "SELECT {'test': 'best'}::VARIANT", + "SELECT CAST(OBJECT_CONSTRUCT('test', 'best') AS VARIANT)", + ) self.validate_all("CAST(x AS BYTEINT)", write={"snowflake": "CAST(x AS INT)"}) self.validate_all("CAST(x AS CHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all("CAST(x AS CHARACTER VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all("CAST(x AS NCHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all( + "ARRAY_GENERATE_RANGE(0, 3)", + write={ + "bigquery": "GENERATE_ARRAY(0, 3 - 1)", + "postgres": "GENERATE_SERIES(0, 3 - 1)", + "presto": "SEQUENCE(0, 3 - 1)", + "snowflake": "ARRAY_GENERATE_RANGE(0, (3 - 1) + 1)", + }, + ) + self.validate_all( + "ARRAY_GENERATE_RANGE(0, 3 + 1)", + read={ + "bigquery": "GENERATE_ARRAY(0, 3)", + "postgres": "GENERATE_SERIES(0, 3)", + "presto": "SEQUENCE(0, 3)", + }, + ) + self.validate_all( "SELECT DATE_PART('year', TIMESTAMP '2020-01-01')", write={ "hive": "SELECT EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMP))", @@ -258,13 +279,13 @@ class TestSnowflake(Validator): self.validate_all( "SELECT * EXCLUDE a, b FROM xxx", write={ - "snowflake": "SELECT * EXCLUDE (a, b) FROM xxx", + "snowflake": "SELECT * EXCLUDE (a), b FROM xxx", }, ) self.validate_all( "SELECT * RENAME a AS b, c AS d FROM xxx", write={ - "snowflake": "SELECT * RENAME (a AS b, c AS d) FROM xxx", + "snowflake": "SELECT * RENAME (a AS b), c AS d FROM xxx", }, ) self.validate_all( @@ -364,12 +385,12 @@ class TestSnowflake(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", - "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname", - "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", - "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", - "snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname", + "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", + "snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname", }, ) self.validate_all( @@ -867,7 +888,7 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA """SELECT $1 AS "_1" FROM VALUES ('a'), ('b')""", write={ "snowflake": """SELECT $1 AS "_1" FROM (VALUES ('a'), ('b'))""", - "spark": """SELECT @1 AS `_1` FROM VALUES ('a'), ('b')""", + "spark": """SELECT ${1} AS `_1` FROM VALUES ('a'), ('b')""", }, ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index becb66a..2e43ba5 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -229,6 +229,7 @@ TBLPROPERTIES ( self.assertIsInstance(expr.args.get("ignore_nulls"), exp.Boolean) self.assertEqual(expr.sql(dialect="spark"), "ANY_VALUE(col, TRUE)") + self.validate_identity("SELECT * FROM t1 SEMI JOIN t2 ON t1.x = t2.x") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)") self.validate_identity("REFRESH table a.b.c") @@ -460,13 +461,13 @@ TBLPROPERTIES ( self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname NULLS FIRST", - "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname NULLS FIRST", + "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname NULLS FIRST", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname NULLS FIRST", }, ) self.validate_all( @@ -583,3 +584,60 @@ TBLPROPERTIES ( "databricks": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", }, ) + + def test_explode_to_unnest(self): + self.validate_all( + "SELECT EXPLODE(x) FROM tbl", + write={ + "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col FROM tbl, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(x)) - 1)) AS pos CROSS JOIN UNNEST(x) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(x) - 1) AND pos_2 = (ARRAY_LENGTH(x) - 1))", + "presto": "SELECT IF(pos = pos_2, col) AS col FROM tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(x) AND pos_2 = CARDINALITY(x))", + "spark": "SELECT EXPLODE(x) FROM tbl", + }, + ) + self.validate_all( + "SELECT EXPLODE(col) FROM _u", + write={ + "bigquery": "SELECT IF(pos = pos_2, col_2, NULL) AS col_2 FROM _u, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(col)) - 1)) AS pos CROSS JOIN UNNEST(col) AS col_2 WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(col) - 1) AND pos_2 = (ARRAY_LENGTH(col) - 1))", + "presto": "SELECT IF(pos = pos_2, col_2) AS col_2 FROM _u, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(col)))) AS _u_2(pos) CROSS JOIN UNNEST(col) WITH ORDINALITY AS _u_3(col_2, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(col) AND pos_2 = CARDINALITY(col))", + "spark": "SELECT EXPLODE(col) FROM _u", + }, + ) + self.validate_all( + "SELECT EXPLODE(col) AS exploded FROM schema.tbl", + write={ + "presto": "SELECT IF(pos = pos_2, exploded) AS exploded FROM schema.tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(col)))) AS _u(pos) CROSS JOIN UNNEST(col) WITH ORDINALITY AS _u_2(exploded, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(col) AND pos_2 = CARDINALITY(col))", + }, + ) + self.validate_all( + "SELECT EXPLODE(ARRAY(1, 2))", + write={ + "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2])) - 1)) AS pos CROSS JOIN UNNEST([1, 2]) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2]) - 1))", + "presto": "SELECT IF(pos = pos_2, col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2]) AND pos_2 = CARDINALITY(ARRAY[1, 2]))", + }, + ) + self.validate_all( + "SELECT POSEXPLODE(ARRAY(2, 3)) AS x", + write={ + "bigquery": "SELECT IF(pos = pos_2, x, NULL) AS x, IF(pos = pos_2, pos_2, NULL) AS pos_2 FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([2, 3])) - 1)) AS pos CROSS JOIN UNNEST([2, 3]) AS x WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH([2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([2, 3]) - 1))", + "presto": "SELECT IF(pos = pos_2, x) AS x, IF(pos = pos_2, pos_2) AS pos_2 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(x, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[2, 3]) AND pos_2 = CARDINALITY(ARRAY[2, 3]))", + }, + ) + self.validate_all( + "SELECT POSEXPLODE(x) AS (a, b)", + write={ + "presto": "SELECT IF(pos = a, b) AS b, IF(pos = a, a) AS a FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(b, a) WHERE pos = a OR (pos > CARDINALITY(x) AND a = CARDINALITY(x))", + }, + ) + self.validate_all( + "SELECT POSEXPLODE(ARRAY(2, 3)), EXPLODE(ARRAY(4, 5, 6)) FROM tbl", + write={ + "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_2, pos_2, NULL) AS pos_2, IF(pos = pos_3, col_2, NULL) AS col_2 FROM tbl, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([2, 3]), ARRAY_LENGTH([4, 5, 6])) - 1)) AS pos CROSS JOIN UNNEST([2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5, 6]) AS col_2 WITH OFFSET AS pos_3 WHERE (pos = pos_2 OR (pos > (ARRAY_LENGTH([2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5, 6]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5, 6]) - 1)))", + "presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_2, pos_2) AS pos_2, IF(pos = pos_3, col_2) AS col_2 FROM tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3]), CARDINALITY(ARRAY[4, 5, 6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5, 6]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE (pos = pos_2 OR (pos > CARDINALITY(ARRAY[2, 3]) AND pos_2 = CARDINALITY(ARRAY[2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5, 6]) AND pos_3 = CARDINALITY(ARRAY[4, 5, 6])))", + }, + ) + self.validate_all( + "SELECT col, pos, POSEXPLODE(ARRAY(2, 3)) FROM _u", + write={ + "presto": "SELECT col, pos, IF(pos_2 = pos_3, col_2) AS col_2, IF(pos_2 = pos_3, pos_3) AS pos_3 FROM _u, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u_2(pos_2) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE pos_2 = pos_3 OR (pos_2 > CARDINALITY(ARRAY[2, 3]) AND pos_3 = CARDINALITY(ARRAY[2, 3]))", + }, + ) diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index 4cf0832..3df74c8 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -59,6 +59,23 @@ class TestSQLite(Validator): ) def test_sqlite(self): + self.validate_identity("SELECT DATE()") + self.validate_identity("SELECT DATE('now', 'start of month', '+1 month', '-1 day')") + self.validate_identity("SELECT DATETIME(1092941466, 'unixepoch')") + self.validate_identity("SELECT DATETIME(1092941466, 'auto')") + self.validate_identity("SELECT DATETIME(1092941466, 'unixepoch', 'localtime')") + self.validate_identity("SELECT UNIXEPOCH()") + self.validate_identity("SELECT STRFTIME('%s')") + self.validate_identity("SELECT JULIANDAY('now') - JULIANDAY('1776-07-04')") + self.validate_identity("SELECT UNIXEPOCH() - UNIXEPOCH('2004-01-01 02:34:56')") + self.validate_identity("SELECT DATE('now', 'start of year', '+9 months', 'weekday 2')") + self.validate_identity("SELECT (JULIANDAY('now') - 2440587.5) * 86400.0") + self.validate_identity("SELECT UNIXEPOCH('now', 'subsec')") + self.validate_identity("SELECT TIMEDIFF('now', '1809-02-12')") + self.validate_identity( + """SELECT item AS "item", some AS "some" FROM data WHERE (item = 'value_1' COLLATE NOCASE) AND (some = 't' COLLATE NOCASE) ORDER BY item ASC LIMIT 1 OFFSET 0""" + ) + self.validate_all("SELECT LIKE(y, x)", write={"sqlite": "SELECT x LIKE y"}) self.validate_all("SELECT GLOB('*y*', 'xyz')", write={"sqlite": "SELECT 'xyz' GLOB '*y*'"}) self.validate_all( @@ -112,8 +129,8 @@ class TestSQLite(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", }, ) self.validate_all("x", read={"snowflake": "LEAST(x)"}) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index acf8b79..f76894d 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -18,16 +18,28 @@ class TestTSQL(Validator): 'CREATE TABLE x (CONSTRAINT "pk_mytable" UNIQUE NONCLUSTERED (a DESC)) ON b (c)' ) - self.validate_identity( + self.validate_all( """ CREATE TABLE x( - [zip_cd] [varchar](5) NULL NOT FOR REPLICATION - CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED - ([zip_cd_mkey] ASC) - WITH (PAD_INDEX = ON, STATISTICS_NORECOMPUTE = OFF) ON [PRIMARY] - ) ON [PRIMARY] + [zip_cd] [varchar](5) NULL NOT FOR REPLICATION, + [zip_cd_mkey] [varchar](5) NOT NULL, + CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC) + WITH (PAD_INDEX = ON, STATISTICS_NORECOMPUTE = OFF) ON [INDEX] + ) ON [SECONDARY] """, - 'CREATE TABLE x ("zip_cd" VARCHAR(5) NULL NOT FOR REPLICATION CONSTRAINT "pk_mytable" PRIMARY KEY CLUSTERED ("zip_cd_mkey") WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON "PRIMARY") ON "PRIMARY"', + write={ + "tsql": 'CREATE TABLE x ("zip_cd" VARCHAR(5) NULL NOT FOR REPLICATION, "zip_cd_mkey" VARCHAR(5) NOT NULL, CONSTRAINT "pk_mytable" PRIMARY KEY CLUSTERED ("zip_cd_mkey" ASC) WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON "INDEX") ON "SECONDARY"', + "spark2": "CREATE TABLE x (`zip_cd` VARCHAR(5), `zip_cd_mkey` VARCHAR(5) NOT NULL, CONSTRAINT `pk_mytable` PRIMARY KEY (`zip_cd_mkey`))", + }, + ) + + self.validate_identity("CREATE TABLE x (A INTEGER NOT NULL, B INTEGER NULL)") + + self.validate_all( + "CREATE TABLE x ( A INTEGER NOT NULL, B INTEGER NULL )", + write={ + "hive": "CREATE TABLE x (A INT NOT NULL, B INT)", + }, ) self.validate_identity( @@ -123,10 +135,10 @@ class TestTSQL(Validator): self.validate_all( "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z ASC)", write={ - "tsql": "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z)", - "mysql": "GROUP_CONCAT(x ORDER BY z SEPARATOR '|')", + "tsql": "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z ASC)", + "mysql": "GROUP_CONCAT(x ORDER BY z ASC SEPARATOR '|')", "sqlite": "GROUP_CONCAT(x, '|')", - "postgres": "STRING_AGG(x, '|' ORDER BY z NULLS FIRST)", + "postgres": "STRING_AGG(x, '|' ORDER BY z ASC NULLS FIRST)", }, ) self.validate_all( @@ -186,6 +198,7 @@ class TestTSQL(Validator): }, ) self.validate_identity("HASHBYTES('MD2', 'x')") + self.validate_identity("LOG(n, b)") def test_types(self): self.validate_identity("CAST(x AS XML)") @@ -494,6 +507,12 @@ class TestTSQL(Validator): }, ) self.validate_all( + "SELECT * INTO foo.bar.baz FROM (SELECT * FROM a.b.c) AS temp", + read={ + "": "CREATE TABLE foo.bar.baz AS SELECT * FROM a.b.c", + }, + ) + self.validate_all( "IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id('db.tbl') AND name = 'idx') EXEC('CREATE INDEX idx ON db.tbl')", read={ "": "CREATE INDEX IF NOT EXISTS idx ON db.tbl", @@ -507,12 +526,17 @@ class TestTSQL(Validator): }, ) self.validate_all( - "IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'foo') EXEC('CREATE TABLE foo (a INTEGER)')", + "IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'baz' AND table_schema = 'bar' AND table_catalog = 'foo') EXEC('CREATE TABLE foo.bar.baz (a INTEGER)')", read={ - "": "CREATE TABLE IF NOT EXISTS foo (a INTEGER)", + "": "CREATE TABLE IF NOT EXISTS foo.bar.baz (a INTEGER)", + }, + ) + self.validate_all( + "IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'baz' AND table_schema = 'bar' AND table_catalog = 'foo') EXEC('SELECT * INTO foo.bar.baz FROM (SELECT ''2020'' AS z FROM a.b.c) AS temp')", + read={ + "": "CREATE TABLE IF NOT EXISTS foo.bar.baz AS SELECT '2020' AS z FROM a.b.c", }, ) - self.validate_all( "CREATE OR ALTER VIEW a.b AS SELECT 1", read={ @@ -553,15 +577,11 @@ class TestTSQL(Validator): "oracle": "CREATE TEMPORARY TABLE mytemptable (a NUMBER)", }, ) + + def test_insert_cte(self): self.validate_all( - "CREATE TABLE #mytemptable AS SELECT a FROM Source_Table", - write={ - "duckdb": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table", - "oracle": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table", - "snowflake": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table", - "spark": "CREATE TEMPORARY VIEW mytemptable AS SELECT a FROM Source_Table", - "tsql": "CREATE TABLE #mytemptable AS SELECT a FROM Source_Table", - }, + "INSERT INTO foo.bar WITH cte AS (SELECT 1 AS one) SELECT * FROM cte", + write={"tsql": "WITH cte AS (SELECT 1 AS one) INSERT INTO foo.bar SELECT * FROM cte"}, ) def test_transaction(self): @@ -709,18 +729,14 @@ WHERE SET @CurrentDate = CONVERT(VARCHAR(20), GETDATE(), 120); CREATE TABLE [target_schema].[target_table] - WITH (DISTRIBUTION = REPLICATE, HEAP) - AS - - SELECT - @CurrentDate AS DWCreatedDate - FROM source_schema.sourcetable; + (a INTEGER) + WITH (DISTRIBUTION = REPLICATE, HEAP); """ expected_sqls = [ 'CREATE PROC "dbo"."transform_proc" AS DECLARE @CurrentDate VARCHAR(20)', "SET @CurrentDate = CAST(FORMAT(GETDATE(), 'yyyy-MM-dd HH:mm:ss') AS VARCHAR(20))", - 'CREATE TABLE "target_schema"."target_table" WITH (DISTRIBUTION=REPLICATE, HEAP) AS SELECT @CurrentDate AS DWCreatedDate FROM source_schema.sourcetable', + 'CREATE TABLE "target_schema"."target_table" (a INTEGER) WITH (DISTRIBUTION=REPLICATE, HEAP)', ] for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): @@ -1178,6 +1194,16 @@ WHERE self.assertIsInstance(table.this, exp.Parameter) self.assertIsInstance(table.this.this, exp.Var) + self.validate_all( + "SELECT @x", + write={ + "databricks": "SELECT ${x}", + "hive": "SELECT ${x}", + "spark": "SELECT ${x}", + "tsql": "SELECT @x", + }, + ) + def test_temp_table(self): self.validate_all( "SELECT * FROM #mytemptable", @@ -1319,3 +1345,21 @@ FROM OPENJSON(@json) WITH ( }, pretty=True, ) + + def test_set(self): + self.validate_all( + "SET KEY VALUE", + write={ + "tsql": "SET KEY VALUE", + "duckdb": "SET KEY = VALUE", + "spark": "SET KEY = VALUE", + }, + ) + self.validate_all( + "SET @count = (SELECT COUNT(1) FROM x)", + write={ + "databricks": "SET count = (SELECT COUNT(1) FROM x)", + "tsql": "SET @count = (SELECT COUNT(1) FROM x)", + "spark": "SET count = (SELECT COUNT(1) FROM x)", + }, + ) |