diff options
Diffstat (limited to 'tests')
26 files changed, 660 insertions, 168 deletions
diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py index 665cc91..117789e 100644 --- a/tests/dataframe/unit/test_column.py +++ b/tests/dataframe/unit/test_column.py @@ -95,16 +95,16 @@ class TestDataframeColumn(unittest.TestCase): self.assertEqual("cola IN (1, 2, 3)", F.col("cola").isin(1, 2, 3).sql()) def test_asc(self): - self.assertEqual("cola", F.col("cola").asc().sql()) + self.assertEqual("cola ASC", F.col("cola").asc().sql()) def test_desc(self): self.assertEqual("cola DESC", F.col("cola").desc().sql()) def test_asc_nulls_first(self): - self.assertEqual("cola", F.col("cola").asc_nulls_first().sql()) + self.assertEqual("cola ASC", F.col("cola").asc_nulls_first().sql()) def test_asc_nulls_last(self): - self.assertEqual("cola NULLS LAST", F.col("cola").asc_nulls_last().sql()) + self.assertEqual("cola ASC NULLS LAST", F.col("cola").asc_nulls_last().sql()) def test_desc_nulls_first(self): self.assertEqual("cola DESC NULLS FIRST", F.col("cola").desc_nulls_first().sql()) diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 2fb5650..586b8fc 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -335,18 +335,18 @@ class TestFunctions(unittest.TestCase): def test_asc_nulls_first(self): col_str = SF.asc_nulls_first("cola") self.assertIsInstance(col_str.expression, exp.Ordered) - self.assertEqual("cola", col_str.sql()) + self.assertEqual("cola ASC", col_str.sql()) col = SF.asc_nulls_first(SF.col("cola")) self.assertIsInstance(col.expression, exp.Ordered) - self.assertEqual("cola", col.sql()) + self.assertEqual("cola ASC", col.sql()) def test_asc_nulls_last(self): col_str = SF.asc_nulls_last("cola") self.assertIsInstance(col_str.expression, exp.Ordered) - self.assertEqual("cola NULLS LAST", col_str.sql()) + self.assertEqual("cola ASC NULLS LAST", col_str.sql()) col = SF.asc_nulls_last(SF.col("cola")) self.assertIsInstance(col.expression, exp.Ordered) - self.assertEqual("cola NULLS LAST", col.sql()) + self.assertEqual("cola ASC NULLS LAST", col.sql()) def test_desc_nulls_first(self): col_str = SF.desc_nulls_first("cola") diff --git a/tests/dataframe/unit/test_session_case_sensitivity.py b/tests/dataframe/unit/test_session_case_sensitivity.py index f9119b0..462edb6 100644 --- a/tests/dataframe/unit/test_session_case_sensitivity.py +++ b/tests/dataframe/unit/test_session_case_sensitivity.py @@ -79,3 +79,9 @@ class TestSessionCaseSensitivity(DataFrameTestBase): df.sql() else: self.compare_sql(df, expected) + + def test_alias(self): + col = F.col('"Name"') + self.assertEqual(col.sql(dialect=self.spark.dialect), '"Name"') + self.assertEqual(col.alias("nAME").sql(dialect=self.spark.dialect), '"Name" AS NAME') + self.assertEqual(col.alias('"nAME"').sql(dialect=self.spark.dialect), '"Name" AS "nAME"') diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 448a077..1f5f902 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -9,36 +9,6 @@ class TestBigQuery(Validator): maxDiff = None def test_bigquery(self): - self.validate_identity("SELECT * FROM tbl FOR SYSTEM_TIME AS OF z") - self.validate_identity( - """SELECT JSON '"foo"' AS json_data""", - """SELECT PARSE_JSON('"foo"') AS json_data""", - ) - - self.validate_all( - """SELECT - `u`.`harness_user_email` AS `harness_user_email`, - `d`.`harness_user_id` AS `harness_user_id`, - `harness_account_id` AS `harness_account_id` -FROM `analytics_staging`.`stg_mongodb__users` AS `u`, UNNEST(`u`.`harness_cluster_details`) AS `d`, UNNEST(`d`.`harness_account_ids`) AS `harness_account_id` -WHERE - NOT `harness_account_id` IS NULL""", - read={ - "": """ - SELECT - "u"."harness_user_email" AS "harness_user_email", - "_q_0"."d"."harness_user_id" AS "harness_user_id", - "_q_1"."harness_account_id" AS "harness_account_id" - FROM - "analytics_staging"."stg_mongodb__users" AS "u", - UNNEST("u"."harness_cluster_details") AS "_q_0"("d"), - UNNEST("_q_0"."d"."harness_account_ids") AS "_q_1"("harness_account_id") - WHERE - NOT "_q_1"."harness_account_id" IS NULL - """ - }, - pretty=True, - ) with self.assertRaises(TokenError): transpile("'\\'", read="bigquery") @@ -63,6 +33,9 @@ WHERE with self.assertRaises(ParseError): transpile("DATE_ADD(x, day)", read="bigquery") + self.validate_identity("SELECT test.Unknown FROM test") + self.validate_identity(r"SELECT '\n\r\a\v\f\t'") + self.validate_identity("SELECT * FROM tbl FOR SYSTEM_TIME AS OF z") self.validate_identity("STRING_AGG(DISTINCT a ORDER BY b DESC, c DESC LIMIT 10)") self.validate_identity("SELECT PARSE_TIMESTAMP('%c', 'Thu Dec 25 07:30:00 2008', 'UTC')") self.validate_identity("SELECT ANY_VALUE(fruit HAVING MAX sold) FROM fruits") @@ -111,6 +84,7 @@ WHERE self.validate_identity("COMMIT TRANSACTION") self.validate_identity("ROLLBACK TRANSACTION") self.validate_identity("CAST(x AS BIGNUMERIC)") + self.validate_identity("SELECT y + 1 FROM x GROUP BY y + 1 ORDER BY 1") self.validate_identity( "DATE(CAST('2016-12-25 05:30:00+07' AS DATETIME), 'America/Los_Angeles')" ) @@ -132,6 +106,22 @@ WHERE self.validate_identity( "SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)", ) + self.validate_identity( + "SELECT a overlaps", + "SELECT a AS overlaps", + ) + self.validate_identity( + "SELECT y + 1 z FROM x GROUP BY y + 1 ORDER BY z", + "SELECT y + 1 AS z FROM x GROUP BY z ORDER BY z", + ) + self.validate_identity( + "SELECT y + 1 z FROM x GROUP BY y + 1", + "SELECT y + 1 AS z FROM x GROUP BY y + 1", + ) + self.validate_identity( + """SELECT JSON '"foo"' AS json_data""", + """SELECT PARSE_JSON('"foo"') AS json_data""", + ) self.validate_all("SELECT SPLIT(foo)", write={"bigquery": "SELECT SPLIT(foo, ',')"}) self.validate_all("SELECT 1 AS hash", write={"bigquery": "SELECT 1 AS `hash`"}) @@ -246,7 +236,7 @@ WHERE }, ) self.validate_all( - "WITH cte AS (SELECT [1, 2, 3] AS arr) SELECT col FROM cte CROSS JOIN UNNEST(arr) AS col", + "WITH cte AS (SELECT [1, 2, 3] AS arr) SELECT IF(pos = pos_2, col, NULL) AS col FROM cte, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(arr)) - 1)) AS pos CROSS JOIN UNNEST(arr) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(arr) - 1) AND pos_2 = (ARRAY_LENGTH(arr) - 1))", read={ "spark": "WITH cte AS (SELECT ARRAY(1, 2, 3) AS arr) SELECT EXPLODE(arr) FROM cte" }, @@ -291,6 +281,10 @@ WHERE "bigquery": "SELECT ARRAY(SELECT AS STRUCT 1 AS a, 2 AS b)", }, ) + self.validate_identity( + r"REGEXP_EXTRACT(svc_plugin_output, r'\\\((.*)')", + r"REGEXP_EXTRACT(svc_plugin_output, '\\\\\\((.*)')", + ) self.validate_all( "REGEXP_CONTAINS('foo', '.*')", read={ @@ -302,7 +296,7 @@ WHERE "mysql": "REGEXP_LIKE('foo', '.*')", "starrocks": "REGEXP('foo', '.*')", }, - ), + ) self.validate_all( '"""x"""', write={ @@ -453,7 +447,6 @@ WHERE "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)", write={"bigquery": "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)"}, ) - self.validate_all( "x IS unknown", write={ @@ -465,6 +458,16 @@ WHERE }, ) self.validate_all( + "x IS NOT unknown", + write={ + "bigquery": "NOT x IS NULL", + "duckdb": "NOT x IS NULL", + "presto": "NOT x IS NULL", + "hive": "NOT x IS NULL", + "spark": "NOT x IS NULL", + }, + ) + self.validate_all( "CURRENT_TIMESTAMP()", read={ "tsql": "GETDATE()", @@ -682,16 +685,32 @@ WHERE "spark": "TO_JSON(x)", }, ) - - self.validate_identity( - "SELECT y + 1 z FROM x GROUP BY y + 1 ORDER BY z", - "SELECT y + 1 AS z FROM x GROUP BY z ORDER BY z", - ) - self.validate_identity( - "SELECT y + 1 z FROM x GROUP BY y + 1", - "SELECT y + 1 AS z FROM x GROUP BY y + 1", + self.validate_all( + """SELECT + `u`.`harness_user_email` AS `harness_user_email`, + `d`.`harness_user_id` AS `harness_user_id`, + `harness_account_id` AS `harness_account_id` +FROM `analytics_staging`.`stg_mongodb__users` AS `u`, UNNEST(`u`.`harness_cluster_details`) AS `d`, UNNEST(`d`.`harness_account_ids`) AS `harness_account_id` +WHERE + NOT `harness_account_id` IS NULL""", + read={ + "": """ + SELECT + "u"."harness_user_email" AS "harness_user_email", + "_q_0"."d"."harness_user_id" AS "harness_user_id", + "_q_1"."harness_account_id" AS "harness_account_id" + FROM + "analytics_staging"."stg_mongodb__users" AS "u", + UNNEST("u"."harness_cluster_details") AS "_q_0"("d"), + UNNEST("_q_0"."d"."harness_account_ids") AS "_q_1"("harness_account_id") + WHERE + NOT "_q_1"."harness_account_id" IS NULL + """ + }, + pretty=True, ) - self.validate_identity("SELECT y + 1 FROM x GROUP BY y + 1 ORDER BY 1") + + self.validate_identity("LOG(n, b)") def test_user_defined_functions(self): self.validate_identity( @@ -702,6 +721,10 @@ WHERE self.validate_identity( "CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t" ) + self.validate_identity( + '''CREATE TEMPORARY FUNCTION string_length_0(strings ARRAY<STRING>) RETURNS FLOAT64 LANGUAGE js AS """'use strict'; function string_length(strings) { return _.sum(_.map(strings, ((x) => x.length))); } return string_length(strings);""" OPTIONS (library=['gs://ibis-testing-libraries/lodash.min.js'])''', + "CREATE TEMPORARY FUNCTION string_length_0(strings ARRAY<STRING>) RETURNS FLOAT64 LANGUAGE js OPTIONS (library=['gs://ibis-testing-libraries/lodash.min.js']) AS '\\'use strict\\'; function string_length(strings) { return _.sum(_.map(strings, ((x) => x.length))); } return string_length(strings);'", + ) def test_group_concat(self): self.validate_all( diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 2cda0dc..40a270e 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -157,8 +157,8 @@ class TestClickhouse(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", }, ) self.validate_all( @@ -216,7 +216,7 @@ class TestClickhouse(Validator): """, write={ "clickhouse": "SELECT loyalty, count() FROM hits LEFT SEMI JOIN users USING (UserID)" - + " GROUP BY loyalty ORDER BY loyalty" + " GROUP BY loyalty ORDER BY loyalty ASC" }, ) self.validate_identity("SELECT s, arr FROM arrays_test ARRAY JOIN arr") diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index d06e0f1..3df968b 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -5,6 +5,9 @@ class TestDatabricks(Validator): dialect = "databricks" def test_databricks(self): + self.validate_identity("CREATE TABLE t (c STRUCT<interval: DOUBLE COMMENT 'aaa'>)") + self.validate_identity("CREATE TABLE my_table () TBLPROPERTIES (a.b=15)") + self.validate_identity("CREATE TABLE my_table () TBLPROPERTIES ('a.b'=15)") self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO HOUR)") self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO MINUTE)") self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO SECOND)") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 47e1ec7..3e0ffd5 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -964,12 +964,12 @@ class TestDialect(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", }, ) @@ -1354,7 +1354,7 @@ class TestDialect(Validator): self.validate_all( "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) as foo FROM baz", write={ - "bigquery": "SELECT CASE WHEN COALESCE(bar, 0) = 1 THEN TRUE ELSE FALSE END AS foo FROM baz", + "bigquery": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz", "duckdb": "SELECT CASE WHEN COALESCE(bar, 0) = 1 THEN TRUE ELSE FALSE END AS foo FROM baz", "presto": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz", "hive": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz", diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 36fca7c..dbf0a87 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -6,6 +6,92 @@ class TestDuckDB(Validator): dialect = "duckdb" def test_duckdb(self): + self.assertEqual( + parse_one("select * from t limit (select 5)").sql(dialect="duckdb"), + exp.select("*").from_("t").limit(exp.select("5").subquery()).sql(dialect="duckdb"), + ) + + for struct_value in ("{'a': 1}", "struct_pack(a := 1)"): + self.validate_all(struct_value, write={"presto": UnsupportedError}) + + for join_type in ("SEMI", "ANTI"): + exists = "EXISTS" if join_type == "SEMI" else "NOT EXISTS" + + self.validate_all( + f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + write={ + "bigquery": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "clickhouse": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "databricks": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "doris": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "drill": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "duckdb": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "hive": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "mysql": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "oracle": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "postgres": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "presto": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "redshift": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "snowflake": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "spark": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "sqlite": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "starrocks": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "teradata": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "trino": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + "tsql": f"SELECT * FROM t1 WHERE {exists}(SELECT 1 FROM t2 WHERE t1.x = t2.x)", + }, + ) + self.validate_all( + f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + read={ + "duckdb": f"SELECT * FROM t1 {join_type} JOIN t2 ON t1.x = t2.x", + "spark": f"SELECT * FROM t1 LEFT {join_type} JOIN t2 ON t1.x = t2.x", + }, + ) + + self.validate_all( + "WITH cte(x) AS (SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) SELECT AVG(x) FILTER (WHERE x > 1) FROM cte", + write={ + "duckdb": "WITH cte(x) AS (SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) SELECT AVG(x) FILTER(WHERE x > 1) FROM cte", + "snowflake": "WITH cte(x) AS (SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3) SELECT AVG(IFF(x > 1, x, NULL)) FROM cte", + }, + ) + self.validate_all( + "SELECT AVG(x) FILTER (WHERE TRUE) FROM t", + write={ + "duckdb": "SELECT AVG(x) FILTER(WHERE TRUE) FROM t", + "snowflake": "SELECT AVG(IFF(TRUE, x, NULL)) FROM t", + }, + ) + self.validate_all( + "SELECT UNNEST(ARRAY[1, 2, 3]), UNNEST(ARRAY[4, 5]), UNNEST(ARRAY[6])", + write={ + "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_3, col_2, NULL) AS col_2, IF(pos = pos_4, col_3, NULL) AS col_3 FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2, 3]), ARRAY_LENGTH([4, 5]), ARRAY_LENGTH([6])) - 1)) AS pos CROSS JOIN UNNEST([1, 2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5]) AS col_2 WITH OFFSET AS pos_3 CROSS JOIN UNNEST([6]) AS col_3 WITH OFFSET AS pos_4 WHERE ((pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5]) - 1)))) AND (pos = pos_4 OR (pos > (ARRAY_LENGTH([6]) - 1) AND pos_4 = (ARRAY_LENGTH([6]) - 1)))", + "presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_3, col_2) AS col_2, IF(pos = pos_4, col_3) AS col_3 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2, 3]) AND pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5]) AND pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (pos = pos_4 OR (pos > CARDINALITY(ARRAY[6]) AND pos_4 = CARDINALITY(ARRAY[6])))", + }, + ) + + self.validate_all( + "SELECT UNNEST(ARRAY[1, 2, 3]), UNNEST(ARRAY[4, 5]), UNNEST(ARRAY[6]) FROM x", + write={ + "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_3, col_2, NULL) AS col_2, IF(pos = pos_4, col_3, NULL) AS col_3 FROM x, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2, 3]), ARRAY_LENGTH([4, 5]), ARRAY_LENGTH([6])) - 1)) AS pos CROSS JOIN UNNEST([1, 2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5]) AS col_2 WITH OFFSET AS pos_3 CROSS JOIN UNNEST([6]) AS col_3 WITH OFFSET AS pos_4 WHERE ((pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5]) - 1)))) AND (pos = pos_4 OR (pos > (ARRAY_LENGTH([6]) - 1) AND pos_4 = (ARRAY_LENGTH([6]) - 1)))", + "presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_3, col_2) AS col_2, IF(pos = pos_4, col_3) AS col_3 FROM x, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2, 3]) AND pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5]) AND pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (pos = pos_4 OR (pos > CARDINALITY(ARRAY[6]) AND pos_4 = CARDINALITY(ARRAY[6])))", + }, + ) + self.validate_all( + "SELECT UNNEST(x) + 1", + write={ + "bigquery": "SELECT IF(pos = pos_2, col, NULL) + 1 AS col FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(x)) - 1)) AS pos CROSS JOIN UNNEST(x) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(x) - 1) AND pos_2 = (ARRAY_LENGTH(x) - 1))", + }, + ) + self.validate_all( + "SELECT UNNEST(x) + 1 AS y", + write={ + "bigquery": "SELECT IF(pos = pos_2, y, NULL) + 1 AS y FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(x)) - 1)) AS pos CROSS JOIN UNNEST(x) AS y WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(x) - 1) AND pos_2 = (ARRAY_LENGTH(x) - 1))", + }, + ) + + self.validate_identity("SELECT i FROM RANGE(5) AS _(i) ORDER BY i ASC") self.validate_identity("[x.STRING_SPLIT(' ')[1] FOR x IN ['1', '2', 3] IF x.CONTAINS('1')]") self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y") self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x") @@ -62,6 +148,13 @@ class TestDuckDB(Validator): self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"}) self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'}) self.validate_all( + "SELECT UNNEST([1, 2, 3])", + write={ + "duckdb": "SELECT UNNEST([1, 2, 3])", + "snowflake": "SELECT IFF(pos = pos_2, col, NULL) AS col FROM (SELECT value FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (GREATEST(ARRAY_SIZE([1, 2, 3])) - 1) + 1)))) AS _u(pos) CROSS JOIN (SELECT value, index FROM TABLE(FLATTEN(INPUT => [1, 2, 3]))) AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > (ARRAY_SIZE([1, 2, 3]) - 1) AND pos_2 = (ARRAY_SIZE([1, 2, 3]) - 1))", + }, + ) + self.validate_all( "VAR_POP(x)", read={ "": "VARIANCE_POP(x)", @@ -304,8 +397,8 @@ class TestDuckDB(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", + "": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname", }, ) self.validate_all( diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 70a05fd..26f0189 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -382,10 +382,10 @@ class TestHive(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", }, ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index e362e9e..20f872c 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -499,6 +499,21 @@ class TestMySQL(Validator): def test_mysql(self): self.validate_all( + "a XOR b", + read={ + "mysql": "a XOR b", + "snowflake": "BOOLXOR(a, b)", + }, + write={ + "duckdb": "(a AND (NOT b)) OR ((NOT a) AND b)", + "mysql": "a XOR b", + "postgres": "(a AND (NOT b)) OR ((NOT a) AND b)", + "snowflake": "BOOLXOR(a, b)", + "trino": "(a AND (NOT b)) OR ((NOT a) AND b)", + }, + ) + + self.validate_all( "SELECT * FROM test LIMIT 0 + 1, 0 + 1", write={ "mysql": "SELECT * FROM test LIMIT 1 OFFSET 1", diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 285496a..6a3df47 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -13,6 +13,7 @@ class TestPostgres(Validator): "CREATE TABLE test (x TIMESTAMP WITHOUT TIME ZONE[][])", "CREATE TABLE test (x TIMESTAMP[][])", ) + self.validate_identity("CREATE INDEX idx_x ON x USING BTREE(x, y) WHERE (NOT y IS NULL)") self.validate_identity("CREATE TABLE test (elems JSONB[])") self.validate_identity("CREATE TABLE public.y (x TSTZRANGE NOT NULL)") self.validate_identity("CREATE TABLE test (foo HSTORE)") @@ -83,6 +84,28 @@ class TestPostgres(Validator): " CONSTRAINT valid_discount CHECK (price > discounted_price))" }, ) + self.validate_identity( + """ + CREATE INDEX index_ci_builds_on_commit_id_and_artifacts_expireatandidpartial + ON public.ci_builds + USING btree (commit_id, artifacts_expire_at, id) + WHERE ( + ((type)::text = 'Ci::Build'::text) + AND ((retried = false) OR (retried IS NULL)) + AND ((name)::text = ANY (ARRAY[ + ('sast'::character varying)::text, + ('dependency_scanning'::character varying)::text, + ('sast:container'::character varying)::text, + ('container_scanning'::character varying)::text, + ('dast'::character varying)::text + ])) + ) + """, + "CREATE INDEX index_ci_builds_on_commit_id_and_artifacts_expireatandidpartial ON public.ci_builds USING btree(commit_id, artifacts_expire_at, id) WHERE ((CAST((type) AS TEXT) = CAST('Ci::Build' AS TEXT)) AND ((retried = FALSE) OR (retried IS NULL)) AND (CAST((name) AS TEXT) = ANY (ARRAY[CAST((CAST('sast' AS VARCHAR)) AS TEXT), CAST((CAST('dependency_scanning' AS VARCHAR)) AS TEXT), CAST((CAST('sast:container' AS VARCHAR)) AS TEXT), CAST((CAST('container_scanning' AS VARCHAR)) AS TEXT), CAST((CAST('dast' AS VARCHAR)) AS TEXT)])))", + ) + self.validate_identity( + "CREATE INDEX index_ci_pipelines_on_project_idandrefandiddesc ON public.ci_pipelines USING btree(project_id, ref, id DESC)" + ) with self.assertRaises(ParseError): transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres") @@ -102,7 +125,7 @@ class TestPostgres(Validator): write={ "hive": "SELECT EXPLODE(c) FROM t", "postgres": "SELECT UNNEST(c) FROM t", - "presto": "SELECT col FROM t CROSS JOIN UNNEST(c) AS _u(col)", + "presto": "SELECT IF(pos = pos_2, col) AS col FROM t, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(c)))) AS _u(pos) CROSS JOIN UNNEST(c) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(c) AND pos_2 = CARDINALITY(c))", }, ) self.validate_all( @@ -110,7 +133,7 @@ class TestPostgres(Validator): write={ "hive": "SELECT EXPLODE(ARRAY(1))", "postgres": "SELECT UNNEST(ARRAY[1])", - "presto": "SELECT col FROM UNNEST(ARRAY[1]) AS _u(col)", + "presto": "SELECT IF(pos = pos_2, col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[1]) AND pos_2 = CARDINALITY(ARRAY[1]))", }, ) @@ -139,6 +162,16 @@ class TestPostgres(Validator): self.assertIsInstance(expr, exp.AlterTable) self.assertEqual(expr.sql(dialect="postgres"), alter_table_only) + self.validate_identity( + "SELECT ARRAY[]::INT[] AS foo", + "SELECT CAST(ARRAY[] AS INT[]) AS foo", + ) + self.validate_identity( + """ALTER TABLE ONLY "Album" ADD CONSTRAINT "FK_AlbumArtistId" FOREIGN KEY ("ArtistId") REFERENCES "Artist" ("ArtistId") ON DELETE CASCADE""" + ) + self.validate_identity( + """ALTER TABLE ONLY "Album" ADD CONSTRAINT "FK_AlbumArtistId" FOREIGN KEY ("ArtistId") REFERENCES "Artist" ("ArtistId") ON DELETE RESTRICT""" + ) self.validate_identity("x @@ y") self.validate_identity("CAST(x AS MONEY)") self.validate_identity("CAST(x AS INT4RANGE)") @@ -362,10 +395,10 @@ class TestPostgres(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname", - "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", - "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", }, ) self.validate_all( diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index a92f04f..a80013e 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -431,8 +431,8 @@ class TestPresto(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", }, ) @@ -947,44 +947,6 @@ class TestPresto(Validator): }, ) - def test_explode_to_unnest(self): - self.validate_all( - "SELECT col FROM tbl CROSS JOIN UNNEST(x) AS _u(col)", - read={"spark": "SELECT EXPLODE(x) FROM tbl"}, - ) - self.validate_all( - "SELECT col_2 FROM _u CROSS JOIN UNNEST(col) AS _u_2(col_2)", - read={"spark": "SELECT EXPLODE(col) FROM _u"}, - ) - self.validate_all( - "SELECT exploded FROM schema.tbl CROSS JOIN UNNEST(col) AS _u(exploded)", - read={"spark": "SELECT EXPLODE(col) AS exploded FROM schema.tbl"}, - ) - self.validate_all( - "SELECT col FROM UNNEST(SEQUENCE(1, 2)) AS _u(col)", - read={"spark": "SELECT EXPLODE(SEQUENCE(1, 2))"}, - ) - self.validate_all( - "SELECT col FROM tbl AS t CROSS JOIN UNNEST(t.c) AS _u(col)", - read={"spark": "SELECT EXPLODE(t.c) FROM tbl t"}, - ) - self.validate_all( - "SELECT pos, col FROM UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u(col, pos)", - read={"spark": "SELECT POSEXPLODE(SEQUENCE(2, 3))"}, - ) - self.validate_all( - "SELECT pos, col FROM tbl CROSS JOIN UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u(col, pos)", - read={"spark": "SELECT POSEXPLODE(SEQUENCE(2, 3)) FROM tbl"}, - ) - self.validate_all( - "SELECT pos, col FROM tbl AS t CROSS JOIN UNNEST(t.c) WITH ORDINALITY AS _u(col, pos)", - read={"spark": "SELECT POSEXPLODE(t.c) FROM tbl t"}, - ) - self.validate_all( - "SELECT col, pos, pos_2, col_2 FROM _u CROSS JOIN UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u_2(col_2, pos_2)", - read={"spark": "SELECT col, pos, POSEXPLODE(SEQUENCE(2, 3)) FROM _u"}, - ) - def test_match_recognize(self): self.validate_identity( """SELECT diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index e261c01..c75654c 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -1,3 +1,4 @@ +from sqlglot import transpile from tests.dialects.test_dialect import Validator @@ -270,6 +271,26 @@ class TestRedshift(Validator): ) def test_values(self): + # Test crazy-sized VALUES clause to UNION ALL conversion to ensure we don't get RecursionError + values = [str(v) for v in range(0, 10000)] + values_query = f"SELECT * FROM (VALUES {', '.join('(' + v + ')' for v in values)})" + union_query = f"SELECT * FROM ({' UNION ALL '.join('SELECT ' + v for v in values)})" + self.assertEqual(transpile(values_query, write="redshift")[0], union_query) + + self.validate_identity( + "SELECT * FROM (VALUES (1), (2))", + """SELECT + * +FROM ( + SELECT + 1 + UNION ALL + SELECT + 2 +)""", + pretty=True, + ) + self.validate_all( "SELECT * FROM (VALUES (1, 2)) AS t", write={ @@ -291,9 +312,9 @@ class TestRedshift(Validator): }, ) self.validate_all( - "SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)", + 'SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS "t" (a, b)', write={ - "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t", + "redshift": 'SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS "t"', }, ) self.validate_all( @@ -320,6 +341,16 @@ class TestRedshift(Validator): "redshift": "INSERT INTO t (a, b) VALUES (1, 2), (3, 4)", }, ) + self.validate_identity( + 'SELECT * FROM (VALUES (1)) AS "t"(a)', + '''SELECT + * +FROM ( + SELECT + 1 AS a +) AS "t"''', + pretty=True, + ) def test_create_table_like(self): self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 30a1f03..a217394 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -78,12 +78,33 @@ class TestSnowflake(Validator): r"SELECT $$a ' \ \t \x21 z $ $$", r"SELECT 'a \' \\ \\t \\x21 z $ '", ) + self.validate_identity( + "SELECT {'test': 'best'}::VARIANT", + "SELECT CAST(OBJECT_CONSTRUCT('test', 'best') AS VARIANT)", + ) self.validate_all("CAST(x AS BYTEINT)", write={"snowflake": "CAST(x AS INT)"}) self.validate_all("CAST(x AS CHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all("CAST(x AS CHARACTER VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all("CAST(x AS NCHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all( + "ARRAY_GENERATE_RANGE(0, 3)", + write={ + "bigquery": "GENERATE_ARRAY(0, 3 - 1)", + "postgres": "GENERATE_SERIES(0, 3 - 1)", + "presto": "SEQUENCE(0, 3 - 1)", + "snowflake": "ARRAY_GENERATE_RANGE(0, (3 - 1) + 1)", + }, + ) + self.validate_all( + "ARRAY_GENERATE_RANGE(0, 3 + 1)", + read={ + "bigquery": "GENERATE_ARRAY(0, 3)", + "postgres": "GENERATE_SERIES(0, 3)", + "presto": "SEQUENCE(0, 3)", + }, + ) + self.validate_all( "SELECT DATE_PART('year', TIMESTAMP '2020-01-01')", write={ "hive": "SELECT EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMP))", @@ -258,13 +279,13 @@ class TestSnowflake(Validator): self.validate_all( "SELECT * EXCLUDE a, b FROM xxx", write={ - "snowflake": "SELECT * EXCLUDE (a, b) FROM xxx", + "snowflake": "SELECT * EXCLUDE (a), b FROM xxx", }, ) self.validate_all( "SELECT * RENAME a AS b, c AS d FROM xxx", write={ - "snowflake": "SELECT * RENAME (a AS b, c AS d) FROM xxx", + "snowflake": "SELECT * RENAME (a AS b), c AS d FROM xxx", }, ) self.validate_all( @@ -364,12 +385,12 @@ class TestSnowflake(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", - "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname", - "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", - "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", - "snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname", + "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname NULLS LAST", + "snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname", }, ) self.validate_all( @@ -867,7 +888,7 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA """SELECT $1 AS "_1" FROM VALUES ('a'), ('b')""", write={ "snowflake": """SELECT $1 AS "_1" FROM (VALUES ('a'), ('b'))""", - "spark": """SELECT @1 AS `_1` FROM VALUES ('a'), ('b')""", + "spark": """SELECT ${1} AS `_1` FROM VALUES ('a'), ('b')""", }, ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index becb66a..2e43ba5 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -229,6 +229,7 @@ TBLPROPERTIES ( self.assertIsInstance(expr.args.get("ignore_nulls"), exp.Boolean) self.assertEqual(expr.sql(dialect="spark"), "ANY_VALUE(col, TRUE)") + self.validate_identity("SELECT * FROM t1 SEMI JOIN t2 ON t1.x = t2.x") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)") self.validate_identity("REFRESH table a.b.c") @@ -460,13 +461,13 @@ TBLPROPERTIES ( self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname NULLS FIRST", - "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", - "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname NULLS FIRST", + "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname NULLS FIRST", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname ASC, lname NULLS FIRST", }, ) self.validate_all( @@ -583,3 +584,60 @@ TBLPROPERTIES ( "databricks": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", }, ) + + def test_explode_to_unnest(self): + self.validate_all( + "SELECT EXPLODE(x) FROM tbl", + write={ + "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col FROM tbl, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(x)) - 1)) AS pos CROSS JOIN UNNEST(x) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(x) - 1) AND pos_2 = (ARRAY_LENGTH(x) - 1))", + "presto": "SELECT IF(pos = pos_2, col) AS col FROM tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(x) AND pos_2 = CARDINALITY(x))", + "spark": "SELECT EXPLODE(x) FROM tbl", + }, + ) + self.validate_all( + "SELECT EXPLODE(col) FROM _u", + write={ + "bigquery": "SELECT IF(pos = pos_2, col_2, NULL) AS col_2 FROM _u, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(col)) - 1)) AS pos CROSS JOIN UNNEST(col) AS col_2 WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(col) - 1) AND pos_2 = (ARRAY_LENGTH(col) - 1))", + "presto": "SELECT IF(pos = pos_2, col_2) AS col_2 FROM _u, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(col)))) AS _u_2(pos) CROSS JOIN UNNEST(col) WITH ORDINALITY AS _u_3(col_2, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(col) AND pos_2 = CARDINALITY(col))", + "spark": "SELECT EXPLODE(col) FROM _u", + }, + ) + self.validate_all( + "SELECT EXPLODE(col) AS exploded FROM schema.tbl", + write={ + "presto": "SELECT IF(pos = pos_2, exploded) AS exploded FROM schema.tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(col)))) AS _u(pos) CROSS JOIN UNNEST(col) WITH ORDINALITY AS _u_2(exploded, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(col) AND pos_2 = CARDINALITY(col))", + }, + ) + self.validate_all( + "SELECT EXPLODE(ARRAY(1, 2))", + write={ + "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2])) - 1)) AS pos CROSS JOIN UNNEST([1, 2]) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2]) - 1))", + "presto": "SELECT IF(pos = pos_2, col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2]) AND pos_2 = CARDINALITY(ARRAY[1, 2]))", + }, + ) + self.validate_all( + "SELECT POSEXPLODE(ARRAY(2, 3)) AS x", + write={ + "bigquery": "SELECT IF(pos = pos_2, x, NULL) AS x, IF(pos = pos_2, pos_2, NULL) AS pos_2 FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([2, 3])) - 1)) AS pos CROSS JOIN UNNEST([2, 3]) AS x WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH([2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([2, 3]) - 1))", + "presto": "SELECT IF(pos = pos_2, x) AS x, IF(pos = pos_2, pos_2) AS pos_2 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(x, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[2, 3]) AND pos_2 = CARDINALITY(ARRAY[2, 3]))", + }, + ) + self.validate_all( + "SELECT POSEXPLODE(x) AS (a, b)", + write={ + "presto": "SELECT IF(pos = a, b) AS b, IF(pos = a, a) AS a FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(b, a) WHERE pos = a OR (pos > CARDINALITY(x) AND a = CARDINALITY(x))", + }, + ) + self.validate_all( + "SELECT POSEXPLODE(ARRAY(2, 3)), EXPLODE(ARRAY(4, 5, 6)) FROM tbl", + write={ + "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_2, pos_2, NULL) AS pos_2, IF(pos = pos_3, col_2, NULL) AS col_2 FROM tbl, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([2, 3]), ARRAY_LENGTH([4, 5, 6])) - 1)) AS pos CROSS JOIN UNNEST([2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5, 6]) AS col_2 WITH OFFSET AS pos_3 WHERE (pos = pos_2 OR (pos > (ARRAY_LENGTH([2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5, 6]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5, 6]) - 1)))", + "presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_2, pos_2) AS pos_2, IF(pos = pos_3, col_2) AS col_2 FROM tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3]), CARDINALITY(ARRAY[4, 5, 6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5, 6]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE (pos = pos_2 OR (pos > CARDINALITY(ARRAY[2, 3]) AND pos_2 = CARDINALITY(ARRAY[2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5, 6]) AND pos_3 = CARDINALITY(ARRAY[4, 5, 6])))", + }, + ) + self.validate_all( + "SELECT col, pos, POSEXPLODE(ARRAY(2, 3)) FROM _u", + write={ + "presto": "SELECT col, pos, IF(pos_2 = pos_3, col_2) AS col_2, IF(pos_2 = pos_3, pos_3) AS pos_3 FROM _u, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u_2(pos_2) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE pos_2 = pos_3 OR (pos_2 > CARDINALITY(ARRAY[2, 3]) AND pos_3 = CARDINALITY(ARRAY[2, 3]))", + }, + ) diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index 4cf0832..3df74c8 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -59,6 +59,23 @@ class TestSQLite(Validator): ) def test_sqlite(self): + self.validate_identity("SELECT DATE()") + self.validate_identity("SELECT DATE('now', 'start of month', '+1 month', '-1 day')") + self.validate_identity("SELECT DATETIME(1092941466, 'unixepoch')") + self.validate_identity("SELECT DATETIME(1092941466, 'auto')") + self.validate_identity("SELECT DATETIME(1092941466, 'unixepoch', 'localtime')") + self.validate_identity("SELECT UNIXEPOCH()") + self.validate_identity("SELECT STRFTIME('%s')") + self.validate_identity("SELECT JULIANDAY('now') - JULIANDAY('1776-07-04')") + self.validate_identity("SELECT UNIXEPOCH() - UNIXEPOCH('2004-01-01 02:34:56')") + self.validate_identity("SELECT DATE('now', 'start of year', '+9 months', 'weekday 2')") + self.validate_identity("SELECT (JULIANDAY('now') - 2440587.5) * 86400.0") + self.validate_identity("SELECT UNIXEPOCH('now', 'subsec')") + self.validate_identity("SELECT TIMEDIFF('now', '1809-02-12')") + self.validate_identity( + """SELECT item AS "item", some AS "some" FROM data WHERE (item = 'value_1' COLLATE NOCASE) AND (some = 't' COLLATE NOCASE) ORDER BY item ASC LIMIT 1 OFFSET 0""" + ) + self.validate_all("SELECT LIKE(y, x)", write={"sqlite": "SELECT x LIKE y"}) self.validate_all("SELECT GLOB('*y*', 'xyz')", write={"sqlite": "SELECT 'xyz' GLOB '*y*'"}) self.validate_all( @@ -112,8 +129,8 @@ class TestSQLite(Validator): self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ - "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", - "sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", }, ) self.validate_all("x", read={"snowflake": "LEAST(x)"}) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index acf8b79..f76894d 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -18,16 +18,28 @@ class TestTSQL(Validator): 'CREATE TABLE x (CONSTRAINT "pk_mytable" UNIQUE NONCLUSTERED (a DESC)) ON b (c)' ) - self.validate_identity( + self.validate_all( """ CREATE TABLE x( - [zip_cd] [varchar](5) NULL NOT FOR REPLICATION - CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED - ([zip_cd_mkey] ASC) - WITH (PAD_INDEX = ON, STATISTICS_NORECOMPUTE = OFF) ON [PRIMARY] - ) ON [PRIMARY] + [zip_cd] [varchar](5) NULL NOT FOR REPLICATION, + [zip_cd_mkey] [varchar](5) NOT NULL, + CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED ([zip_cd_mkey] ASC) + WITH (PAD_INDEX = ON, STATISTICS_NORECOMPUTE = OFF) ON [INDEX] + ) ON [SECONDARY] """, - 'CREATE TABLE x ("zip_cd" VARCHAR(5) NULL NOT FOR REPLICATION CONSTRAINT "pk_mytable" PRIMARY KEY CLUSTERED ("zip_cd_mkey") WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON "PRIMARY") ON "PRIMARY"', + write={ + "tsql": 'CREATE TABLE x ("zip_cd" VARCHAR(5) NULL NOT FOR REPLICATION, "zip_cd_mkey" VARCHAR(5) NOT NULL, CONSTRAINT "pk_mytable" PRIMARY KEY CLUSTERED ("zip_cd_mkey" ASC) WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON "INDEX") ON "SECONDARY"', + "spark2": "CREATE TABLE x (`zip_cd` VARCHAR(5), `zip_cd_mkey` VARCHAR(5) NOT NULL, CONSTRAINT `pk_mytable` PRIMARY KEY (`zip_cd_mkey`))", + }, + ) + + self.validate_identity("CREATE TABLE x (A INTEGER NOT NULL, B INTEGER NULL)") + + self.validate_all( + "CREATE TABLE x ( A INTEGER NOT NULL, B INTEGER NULL )", + write={ + "hive": "CREATE TABLE x (A INT NOT NULL, B INT)", + }, ) self.validate_identity( @@ -123,10 +135,10 @@ class TestTSQL(Validator): self.validate_all( "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z ASC)", write={ - "tsql": "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z)", - "mysql": "GROUP_CONCAT(x ORDER BY z SEPARATOR '|')", + "tsql": "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z ASC)", + "mysql": "GROUP_CONCAT(x ORDER BY z ASC SEPARATOR '|')", "sqlite": "GROUP_CONCAT(x, '|')", - "postgres": "STRING_AGG(x, '|' ORDER BY z NULLS FIRST)", + "postgres": "STRING_AGG(x, '|' ORDER BY z ASC NULLS FIRST)", }, ) self.validate_all( @@ -186,6 +198,7 @@ class TestTSQL(Validator): }, ) self.validate_identity("HASHBYTES('MD2', 'x')") + self.validate_identity("LOG(n, b)") def test_types(self): self.validate_identity("CAST(x AS XML)") @@ -494,6 +507,12 @@ class TestTSQL(Validator): }, ) self.validate_all( + "SELECT * INTO foo.bar.baz FROM (SELECT * FROM a.b.c) AS temp", + read={ + "": "CREATE TABLE foo.bar.baz AS SELECT * FROM a.b.c", + }, + ) + self.validate_all( "IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id('db.tbl') AND name = 'idx') EXEC('CREATE INDEX idx ON db.tbl')", read={ "": "CREATE INDEX IF NOT EXISTS idx ON db.tbl", @@ -507,12 +526,17 @@ class TestTSQL(Validator): }, ) self.validate_all( - "IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'foo') EXEC('CREATE TABLE foo (a INTEGER)')", + "IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'baz' AND table_schema = 'bar' AND table_catalog = 'foo') EXEC('CREATE TABLE foo.bar.baz (a INTEGER)')", read={ - "": "CREATE TABLE IF NOT EXISTS foo (a INTEGER)", + "": "CREATE TABLE IF NOT EXISTS foo.bar.baz (a INTEGER)", + }, + ) + self.validate_all( + "IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'baz' AND table_schema = 'bar' AND table_catalog = 'foo') EXEC('SELECT * INTO foo.bar.baz FROM (SELECT ''2020'' AS z FROM a.b.c) AS temp')", + read={ + "": "CREATE TABLE IF NOT EXISTS foo.bar.baz AS SELECT '2020' AS z FROM a.b.c", }, ) - self.validate_all( "CREATE OR ALTER VIEW a.b AS SELECT 1", read={ @@ -553,15 +577,11 @@ class TestTSQL(Validator): "oracle": "CREATE TEMPORARY TABLE mytemptable (a NUMBER)", }, ) + + def test_insert_cte(self): self.validate_all( - "CREATE TABLE #mytemptable AS SELECT a FROM Source_Table", - write={ - "duckdb": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table", - "oracle": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table", - "snowflake": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table", - "spark": "CREATE TEMPORARY VIEW mytemptable AS SELECT a FROM Source_Table", - "tsql": "CREATE TABLE #mytemptable AS SELECT a FROM Source_Table", - }, + "INSERT INTO foo.bar WITH cte AS (SELECT 1 AS one) SELECT * FROM cte", + write={"tsql": "WITH cte AS (SELECT 1 AS one) INSERT INTO foo.bar SELECT * FROM cte"}, ) def test_transaction(self): @@ -709,18 +729,14 @@ WHERE SET @CurrentDate = CONVERT(VARCHAR(20), GETDATE(), 120); CREATE TABLE [target_schema].[target_table] - WITH (DISTRIBUTION = REPLICATE, HEAP) - AS - - SELECT - @CurrentDate AS DWCreatedDate - FROM source_schema.sourcetable; + (a INTEGER) + WITH (DISTRIBUTION = REPLICATE, HEAP); """ expected_sqls = [ 'CREATE PROC "dbo"."transform_proc" AS DECLARE @CurrentDate VARCHAR(20)', "SET @CurrentDate = CAST(FORMAT(GETDATE(), 'yyyy-MM-dd HH:mm:ss') AS VARCHAR(20))", - 'CREATE TABLE "target_schema"."target_table" WITH (DISTRIBUTION=REPLICATE, HEAP) AS SELECT @CurrentDate AS DWCreatedDate FROM source_schema.sourcetable', + 'CREATE TABLE "target_schema"."target_table" (a INTEGER) WITH (DISTRIBUTION=REPLICATE, HEAP)', ] for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): @@ -1178,6 +1194,16 @@ WHERE self.assertIsInstance(table.this, exp.Parameter) self.assertIsInstance(table.this.this, exp.Var) + self.validate_all( + "SELECT @x", + write={ + "databricks": "SELECT ${x}", + "hive": "SELECT ${x}", + "spark": "SELECT ${x}", + "tsql": "SELECT @x", + }, + ) + def test_temp_table(self): self.validate_all( "SELECT * FROM #mytemptable", @@ -1319,3 +1345,21 @@ FROM OPENJSON(@json) WITH ( }, pretty=True, ) + + def test_set(self): + self.validate_all( + "SET KEY VALUE", + write={ + "tsql": "SET KEY VALUE", + "duckdb": "SET KEY = VALUE", + "spark": "SET KEY = VALUE", + }, + ) + self.validate_all( + "SET @count = (SELECT COUNT(1) FROM x)", + write={ + "databricks": "SET count = (SELECT COUNT(1) FROM x)", + "tsql": "SET @count = (SELECT COUNT(1) FROM x)", + "spark": "SET count = (SELECT COUNT(1) FROM x)", + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index f999620..17506e4 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -401,6 +401,7 @@ SELECT 1 FROM a INNER JOIN b ON a.x = b.x SELECT 1 FROM a LEFT JOIN b ON a.x = b.x SELECT 1 FROM a RIGHT JOIN b ON a.x = b.x SELECT 1 FROM a CROSS JOIN b ON a.x = b.x +SELECT 1 FROM a SEMI JOIN b ON a.x = b.x SELECT 1 FROM a LEFT SEMI JOIN b ON a.x = b.x SELECT 1 FROM a LEFT ANTI JOIN b ON a.x = b.x SELECT 1 FROM a RIGHT SEMI JOIN b ON a.x = b.x @@ -859,3 +860,7 @@ SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1) /* comment */ CREATE TABLE foo AS SELECT 1 SELECT next, transform, if SELECT "any", "case", "if", "next" +SELECT x FROM y ORDER BY x ASC +KILL '123' +KILL CONNECTION 123 +KILL QUERY '123' diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql index 1fc44ef..e27b2d3 100644 --- a/tests/fixtures/optimizer/canonicalize.sql +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -28,3 +28,15 @@ SELECT "x"."a" AS "a" FROM "x" AS "x" GROUP BY "x"."a" HAVING SUM("x"."b") <> 0 SELECT a FROM x WHERE 1; SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE 1 <> 0; + +-------------------------------------- +-- Replace date functions +-------------------------------------- +DATE('2023-01-01'); +CAST('2023-01-01' AS DATE); + +TIMESTAMP('2023-01-01'); +CAST('2023-01-01' AS TIMESTAMP); + +TIMESTAMP('2023-01-01', '12:00:00'); +TIMESTAMP('2023-01-01', '12:00:00'); diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 66fb19c..584e9d6 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -680,3 +680,137 @@ CONCAT('a', x, y, 'bc'); 'a' || 'b' || x; CONCAT('ab', x); + +-------------------------------------- +-- DATE_TRUNC +-------------------------------------- +DATE_TRUNC('year', x) = CAST('2021-01-01' AS DATE); +x < CAST('2022-01-01' AS DATE) AND x >= CAST('2021-01-01' AS DATE); + +DATE_TRUNC('quarter', x) = CAST('2021-01-01' AS DATE); +x < CAST('2021-04-01' AS DATE) AND x >= CAST('2021-01-01' AS DATE); + +DATE_TRUNC('month', x) = CAST('2021-01-01' AS DATE); +x < CAST('2021-02-01' AS DATE) AND x >= CAST('2021-01-01' AS DATE); + +DATE_TRUNC('week', x) = CAST('2021-01-04' AS DATE); +x < CAST('2021-01-11' AS DATE) AND x >= CAST('2021-01-04' AS DATE); + +DATE_TRUNC('day', x) = CAST('2021-01-01' AS DATE); +x < CAST('2021-01-02' AS DATE) AND x >= CAST('2021-01-01' AS DATE); + +CAST('2021-01-01' AS DATE) = DATE_TRUNC('year', x); +x < CAST('2022-01-01' AS DATE) AND x >= CAST('2021-01-01' AS DATE); + +-- Always false, except for nulls +DATE_TRUNC('quarter', x) = CAST('2021-01-02' AS DATE); +DATE_TRUNC('quarter', x) = CAST('2021-01-02' AS DATE); + +DATE_TRUNC('year', x) <> CAST('2021-01-01' AS DATE); +x < CAST('2021-01-01' AS DATE) AND x >= CAST('2022-01-01' AS DATE); + +-- Always true, except for nulls +DATE_TRUNC('year', x) <> CAST('2021-01-02' AS DATE); +DATE_TRUNC('year', x) <> CAST('2021-01-02' AS DATE); + +DATE_TRUNC('year', x) <= CAST('2021-01-01' AS DATE); +x < CAST('2022-01-01' AS DATE); + +DATE_TRUNC('year', x) <= CAST('2021-01-02' AS DATE); +x < CAST('2022-01-01' AS DATE); + +CAST('2021-01-01' AS DATE) >= DATE_TRUNC('year', x); +x < CAST('2022-01-01' AS DATE); + +DATE_TRUNC('year', x) < CAST('2021-01-01' AS DATE); +x < CAST('2021-01-01' AS DATE); + +DATE_TRUNC('year', x) < CAST('2021-01-02' AS DATE); +x < CAST('2021-01-01' AS DATE); + +DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE); +x >= CAST('2021-01-01' AS DATE); + +DATE_TRUNC('year', x) >= CAST('2021-01-02' AS DATE); +x >= CAST('2022-01-01' AS DATE); + +DATE_TRUNC('year', x) > CAST('2021-01-01' AS DATE); +x >= CAST('2022-01-01' AS DATE); + +DATE_TRUNC('year', x) > CAST('2021-01-02' AS DATE); +x >= CAST('2022-01-01' AS DATE); + +-- right is not a date +DATE_TRUNC('year', x) <> '2021-01-02'; +DATE_TRUNC('year', x) <> '2021-01-02'; + +DATE_TRUNC('year', x) IN (CAST('2021-01-01' AS DATE), CAST('2023-01-01' AS DATE)); +(x < CAST('2022-01-01' AS DATE) AND x >= CAST('2021-01-01' AS DATE)) OR (x < CAST('2024-01-01' AS DATE) AND x >= CAST('2023-01-01' AS DATE)); + +-- merge ranges +DATE_TRUNC('year', x) IN (CAST('2021-01-01' AS DATE), CAST('2022-01-01' AS DATE)); +x < CAST('2023-01-01' AS DATE) AND x >= CAST('2021-01-01' AS DATE); + +-- one of the values will always be false +DATE_TRUNC('year', x) IN (CAST('2021-01-01' AS DATE), CAST('2022-01-02' AS DATE)); +x < CAST('2022-01-01' AS DATE) AND x >= CAST('2021-01-01' AS DATE); + +TIMESTAMP_TRUNC(x, YEAR) = CAST('2021-01-01' AS DATETIME); +x < CAST('2022-01-01 00:00:00' AS DATETIME) AND x >= CAST('2021-01-01 00:00:00' AS DATETIME); + +-------------------------------------- +-- EQUALITY +-------------------------------------- +x + 1 = 3; +x = 2; + +1 + x = 3; +x = 2; + +3 = x + 1; +x = 2; + +x - 1 = 3; +x = 4; + +x + 1 > 3; +x > 2; + +x + 1 >= 3; +x >= 2; + +x + 1 <= 3; +x <= 2; + +x + 1 <= 3; +x <= 2; + +x + 1 <> 3; +x <> 2; + +1 + x + 1 = 3 + 1; +x = 2; + +x - INTERVAL 1 DAY = CAST('2021-01-01' AS DATE); +x = CAST('2021-01-02' AS DATE); + +x - INTERVAL 1 HOUR > CAST('2021-01-01' AS DATETIME); +x > CAST('2021-01-01 01:00:00' AS DATETIME); + +DATETIME_ADD(x, 1, HOUR) < CAST('2021-01-01' AS DATETIME); +x < CAST('2020-12-31 23:00:00' AS DATETIME); + +DATETIME_SUB(x, 1, DAY) >= CAST('2021-01-01' AS DATETIME); +x >= CAST('2021-01-02 00:00:00' AS DATETIME); + +DATE_ADD(x, 1, DAY) <= CAST('2021-01-01' AS DATE); +x <= CAST('2020-12-31' AS DATE); + +DATE_SUB(x, 1, DAY) <> CAST('2021-01-01' AS DATE); +x <> CAST('2021-01-02' AS DATE); + +DATE_ADD(DATE_ADD(DATE_TRUNC('week', DATE_SUB(x, 1, DAY)), 1, DAY), 1, YEAR) < CAST('2021-01-08' AS DATE); +x < CAST('2020-01-07' AS DATE); + +x - INTERVAL '1' day = CAST(y AS DATE); +x - INTERVAL '1' day = CAST(y AS DATE); diff --git a/tests/fixtures/optimizer/unnest_subqueries.sql b/tests/fixtures/optimizer/unnest_subqueries.sql index 9d760e0..e78bed0 100644 --- a/tests/fixtures/optimizer/unnest_subqueries.sql +++ b/tests/fixtures/optimizer/unnest_subqueries.sql @@ -209,3 +209,15 @@ WHERE ) AND ARRAY_ALL(_u_19."", _x -> _x = x.a) AND x.a > COALESCE(_u_21.d, 0); +SELECT + CAST(( + SELECT + x.a AS a + FROM x + ) AS TEXT) AS a; +SELECT + CAST(( + SELECT + x.a AS a + FROM x + ) AS TEXT) AS a; diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index 1a61334..23d9511 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -346,7 +346,7 @@ SELECT fruit, basket_index FROM table_data -CROSS JOIN UNNEST(fruit_basket) AS fruit WITH OFFSET basket_index; +CROSS JOIN UNNEST(fruit_basket) WITH ORDINALITY AS fruit(basket_index); WITH table_data AS ( SELECT 'bob' AS name, @@ -357,11 +357,12 @@ SELECT fruit, basket_index FROM table_data -CROSS JOIN UNNEST(fruit_basket) AS fruit WITH OFFSET AS basket_index; +CROSS JOIN UNNEST(fruit_basket) WITH ORDINALITY AS fruit(basket_index); SELECT A.* EXCEPT A.COL_1, A.COL_2 FROM TABLE_1 A; SELECT A.* - EXCEPT (A.COL_1, A.COL_2) + EXCEPT (A.COL_1), + A.COL_2 FROM TABLE_1 AS A; SELECT * diff --git a/tests/test_expressions.py b/tests/test_expressions.py index b3ce926..b1b5360 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -14,6 +14,13 @@ class TestExpressions(unittest.TestCase): def test_depth(self): self.assertEqual(parse_one("x(1)").find(exp.Literal).depth, 1) + def test_iter(self): + self.assertEqual([exp.Literal.number(1), exp.Literal.number(2)], list(parse_one("[1, 2]"))) + + with self.assertRaises(TypeError): + for x in parse_one("1"): + pass + def test_eq(self): self.assertNotEqual(exp.to_identifier("a"), exp.to_identifier("A")) diff --git a/tests/test_helper.py b/tests/test_helper.py index 7d63c34..a8872e9 100644 --- a/tests/test_helper.py +++ b/tests/test_helper.py @@ -1,7 +1,7 @@ import unittest from sqlglot.dialects import BigQuery, Dialect, Snowflake -from sqlglot.helper import name_sequence, tsort +from sqlglot.helper import merge_ranges, name_sequence, tsort class TestHelper(unittest.TestCase): @@ -66,3 +66,10 @@ class TestHelper(unittest.TestCase): self.assertEqual(s1(), "a2") self.assertEqual(s2(), "b1") self.assertEqual(s2(), "b2") + + def test_merge_ranges(self): + self.assertEqual([], merge_ranges([])) + self.assertEqual([(0, 1)], merge_ranges([(0, 1)])) + self.assertEqual([(0, 1), (2, 3)], merge_ranges([(0, 1), (2, 3)])) + self.assertEqual([(0, 3)], merge_ranges([(0, 1), (1, 3)])) + self.assertEqual([(0, 1), (2, 4)], merge_ranges([(2, 3), (0, 1), (3, 4)])) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index a40f089..8775852 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -679,7 +679,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') def test_unknown_annotation(self): schema = {"x": {"cola": "VARCHAR"}} - sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x" + sql = "SELECT x.cola + SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x" concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.UNKNOWN) @@ -702,7 +702,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.right.type.this, exp.DataType.Type.INT) # NULL <op> UNKNOWN should yield NULL - sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result" + sql = "SELECT NULL + SOME_ANONYMOUS_FUNC() AS result" concat_expr_alias = annotate_types(parse_one(sql)).expressions[0] self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.NULL) @@ -776,6 +776,17 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(exp.DataType.Type.ARRAY, expression.selects[0].type.this) self.assertEqual(expression.selects[0].type.sql(), "ARRAY<INT>") + def test_type_annotation_cache(self): + sql = "SELECT 1 + 1" + expression = annotate_types(parse_one(sql)) + + self.assertEqual(exp.DataType.Type.INT, expression.selects[0].type.this) + + expression.selects[0].this.replace(parse_one("1.2")) + expression = annotate_types(expression) + + self.assertEqual(exp.DataType.Type.DOUBLE, expression.selects[0].type.this) + def test_user_defined_type_annotation(self): schema = MappingSchema({"t": {"x": "int"}}, dialect="postgres") expression = annotate_types(parse_one("SELECT CAST(x AS IPADDRESS) FROM t"), schema=schema) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 2b51be2..a5b1977 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -44,9 +44,6 @@ class TestTranspile(unittest.TestCase): with self.assertRaises(ParseError): self.validate(f"SELECT x {key}", "") - def test_asc(self): - self.validate("SELECT x FROM y ORDER BY x ASC", "SELECT x FROM y ORDER BY x") - def test_unary(self): self.validate("+++1", "1") self.validate("+-1", "-1") |