diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dialects/test_athena.py | 46 | ||||
-rw-r--r-- | tests/dialects/test_bigquery.py | 64 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 28 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 141 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 7 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 23 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 15 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 113 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 7 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 37 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 67 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 7 | ||||
-rw-r--r-- | tests/dialects/test_starrocks.py | 65 | ||||
-rw-r--r-- | tests/dialects/test_teradata.py | 9 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 35 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 3 | ||||
-rw-r--r-- | tests/fixtures/optimizer/annotate_types.sql | 14 | ||||
-rw-r--r-- | tests/fixtures/optimizer/optimizer.sql | 2 | ||||
-rw-r--r-- | tests/fixtures/pretty.sql | 13 | ||||
-rw-r--r-- | tests/test_build.py | 36 | ||||
-rw-r--r-- | tests/test_optimizer.py | 23 | ||||
-rw-r--r-- | tests/test_parser.py | 11 |
22 files changed, 683 insertions, 83 deletions
diff --git a/tests/dialects/test_athena.py b/tests/dialects/test_athena.py index 5522976..6ec870b 100644 --- a/tests/dialects/test_athena.py +++ b/tests/dialects/test_athena.py @@ -23,3 +23,49 @@ class TestAthena(Validator): some_function(1)""", check_command_warning=True, ) + + def test_ddl_quoting(self): + self.validate_identity("CREATE SCHEMA `foo`") + self.validate_identity("CREATE SCHEMA foo") + self.validate_identity("CREATE SCHEMA foo", write_sql="CREATE SCHEMA `foo`", identify=True) + + self.validate_identity("CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'") + self.validate_identity("CREATE EXTERNAL TABLE foo (id INTEGER) LOCATION 's3://foo/'") + self.validate_identity( + "CREATE EXTERNAL TABLE foo (id INTEGER) LOCATION 's3://foo/'", + write_sql="CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'", + identify=True, + ) + + self.validate_identity("DROP TABLE `foo`") + self.validate_identity("DROP TABLE foo") + self.validate_identity("DROP TABLE foo", write_sql="DROP TABLE `foo`", identify=True) + + self.validate_identity('CREATE VIEW "foo" AS SELECT "id" FROM "tbl"') + self.validate_identity("CREATE VIEW foo AS SELECT id FROM tbl") + self.validate_identity( + "CREATE VIEW foo AS SELECT id FROM tbl", + write_sql='CREATE VIEW "foo" AS SELECT "id" FROM "tbl"', + identify=True, + ) + + # As a side effect of being able to parse both quote types, we can also fix the quoting on incorrectly quoted source queries + self.validate_identity('CREATE SCHEMA "foo"', write_sql="CREATE SCHEMA `foo`") + self.validate_identity( + 'CREATE EXTERNAL TABLE "foo" ("id" INTEGER) LOCATION \'s3://foo/\'', + write_sql="CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'", + ) + self.validate_identity('DROP TABLE "foo"', write_sql="DROP TABLE `foo`") + self.validate_identity( + 'CREATE VIEW `foo` AS SELECT "id" FROM `tbl`', + write_sql='CREATE VIEW "foo" AS SELECT "id" FROM "tbl"', + ) + + def test_dml_quoting(self): + self.validate_identity("SELECT a AS foo FROM tbl") + self.validate_identity('SELECT "a" AS "foo" FROM "tbl"') + self.validate_identity( + 'SELECT `a` AS `foo` FROM "tbl"', + write_sql='SELECT "a" AS "foo" FROM "tbl"', + identify=True, + ) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index f6e8fe8..c8c2176 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -108,7 +108,6 @@ LANGUAGE js AS self.validate_identity("SELECT * FROM READ_CSV('bla.csv')") self.validate_identity("CAST(x AS STRUCT<list ARRAY<INT64>>)") self.validate_identity("assert.true(1 = 1)") - self.validate_identity("SELECT ARRAY_TO_STRING(list, '--') AS text") self.validate_identity("SELECT jsondoc['some_key']") self.validate_identity("SELECT `p.d.UdF`(data).* FROM `p.d.t`") self.validate_identity("SELECT * FROM `my-project.my-dataset.my-table`") @@ -631,9 +630,9 @@ LANGUAGE js AS self.validate_all( "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", write={ - "bigquery": "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", - "databricks": "SELECT TIMESTAMPADD(MILLISECOND, 1, '2023-01-01T00:00:00')", - "duckdb": "SELECT CAST('2023-01-01T00:00:00' AS DATETIME) + INTERVAL 1 MILLISECOND", + "bigquery": "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL '1' MILLISECOND)", + "databricks": "SELECT TIMESTAMPADD(MILLISECOND, '1', '2023-01-01T00:00:00')", + "duckdb": "SELECT CAST('2023-01-01T00:00:00' AS DATETIME) + INTERVAL '1' MILLISECOND", }, ), ) @@ -641,9 +640,9 @@ LANGUAGE js AS self.validate_all( "SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", write={ - "bigquery": "SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", - "databricks": "SELECT TIMESTAMPADD(MILLISECOND, 1 * -1, '2023-01-01T00:00:00')", - "duckdb": "SELECT CAST('2023-01-01T00:00:00' AS DATETIME) - INTERVAL 1 MILLISECOND", + "bigquery": "SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL '1' MILLISECOND)", + "databricks": "SELECT TIMESTAMPADD(MILLISECOND, '1' * -1, '2023-01-01T00:00:00')", + "duckdb": "SELECT CAST('2023-01-01T00:00:00' AS DATETIME) - INTERVAL '1' MILLISECOND", }, ), ) @@ -661,17 +660,24 @@ LANGUAGE js AS self.validate_all( 'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)', write={ - "bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)", - "databricks": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", - "mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)", - "spark": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", + "bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL '10' MINUTE)", + "databricks": "SELECT DATE_ADD(MINUTE, '10', CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", + "mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL '10' MINUTE)", + "spark": "SELECT DATE_ADD(MINUTE, '10', CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", }, ) self.validate_all( 'SELECT TIMESTAMP_SUB(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)', write={ - "bigquery": "SELECT TIMESTAMP_SUB(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)", - "mysql": "SELECT DATE_SUB(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)", + "bigquery": "SELECT TIMESTAMP_SUB(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL '10' MINUTE)", + "mysql": "SELECT DATE_SUB(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL '10' MINUTE)", + }, + ) + self.validate_all( + "SELECT TIME_ADD(CAST('09:05:03' AS TIME), INTERVAL 2 HOUR)", + write={ + "bigquery": "SELECT TIME_ADD(CAST('09:05:03' AS TIME), INTERVAL '2' HOUR)", + "duckdb": "SELECT CAST('09:05:03' AS TIME) + INTERVAL '2' HOUR", }, ) self.validate_all( @@ -1237,19 +1243,19 @@ LANGUAGE js AS "DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY)", write={ "postgres": "CURRENT_DATE - INTERVAL '1 DAY'", - "bigquery": "DATE_SUB(CURRENT_DATE, INTERVAL 1 DAY)", + "bigquery": "DATE_SUB(CURRENT_DATE, INTERVAL '1' DAY)", }, ) self.validate_all( - "DATE_ADD(CURRENT_DATE(), INTERVAL 1 DAY)", + "DATE_ADD(CURRENT_DATE(), INTERVAL -1 DAY)", write={ - "bigquery": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)", - "duckdb": "CURRENT_DATE + INTERVAL 1 DAY", - "mysql": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)", - "postgres": "CURRENT_DATE + INTERVAL '1 DAY'", - "presto": "DATE_ADD('DAY', 1, CURRENT_DATE)", - "hive": "DATE_ADD(CURRENT_DATE, 1)", - "spark": "DATE_ADD(CURRENT_DATE, 1)", + "bigquery": "DATE_ADD(CURRENT_DATE, INTERVAL '-1' DAY)", + "duckdb": "CURRENT_DATE + INTERVAL '-1' DAY", + "mysql": "DATE_ADD(CURRENT_DATE, INTERVAL '-1' DAY)", + "postgres": "CURRENT_DATE + INTERVAL '-1 DAY'", + "presto": "DATE_ADD('DAY', CAST('-1' AS BIGINT), CURRENT_DATE)", + "hive": "DATE_ADD(CURRENT_DATE, '-1')", + "spark": "DATE_ADD(CURRENT_DATE, '-1')", }, ) self.validate_all( @@ -1478,6 +1484,20 @@ WHERE "duckdb": "SELECT CAST(STRPTIME('Thursday Dec 25 2008', '%A %b %-d %Y') AS DATE)", }, ) + self.validate_all( + "SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--') AS text", + write={ + "bigquery": "SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--') AS text", + "duckdb": "SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--') AS text", + }, + ) + self.validate_all( + "SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--', 'MISSING') AS text", + write={ + "bigquery": "SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--', 'MISSING') AS text", + "duckdb": "SELECT ARRAY_TO_STRING(LIST_TRANSFORM(['cake', 'pie', NULL], x -> COALESCE(x, 'MISSING')), '--') AS text", + }, + ) def test_errors(self): with self.assertRaises(TokenError): diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index b4ba09e..ea6064a 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -160,6 +160,27 @@ class TestClickhouse(Validator): ) self.validate_all( + "char(67) || char(65) || char(84)", + read={ + "clickhouse": "char(67) || char(65) || char(84)", + "oracle": "chr(67) || chr(65) || chr(84)", + }, + ) + self.validate_all( + "SELECT lagInFrame(salary, 1, 0) OVER (ORDER BY hire_date) AS prev_sal FROM employees", + read={ + "clickhouse": "SELECT lagInFrame(salary, 1, 0) OVER (ORDER BY hire_date) AS prev_sal FROM employees", + "oracle": "SELECT LAG(salary, 1, 0) OVER (ORDER BY hire_date) AS prev_sal FROM employees", + }, + ) + self.validate_all( + "SELECT leadInFrame(salary, 1, 0) OVER (ORDER BY hire_date) AS prev_sal FROM employees", + read={ + "clickhouse": "SELECT leadInFrame(salary, 1, 0) OVER (ORDER BY hire_date) AS prev_sal FROM employees", + "oracle": "SELECT LEAD(salary, 1, 0) OVER (ORDER BY hire_date) AS prev_sal FROM employees", + }, + ) + self.validate_all( "SELECT CAST(STR_TO_DATE('05 12 2000', '%d %m %Y') AS DATE)", read={ "clickhouse": "SELECT CAST(STR_TO_DATE('05 12 2000', '%d %m %Y') AS DATE)", @@ -494,6 +515,7 @@ class TestClickhouse(Validator): ) self.validate_identity("SELECT TRIM(TRAILING ')' FROM '( Hello, world! )')") self.validate_identity("SELECT TRIM(LEADING '(' FROM '( Hello, world! )')") + self.validate_identity("current_timestamp").assert_is(exp.Column) def test_clickhouse_values(self): values = exp.select("*").from_( @@ -995,6 +1017,12 @@ LIFETIME(MIN 0 MAX 0)""", pretty=True, ) + self.assertIsNotNone( + self.validate_identity("CREATE TABLE t1 (a String MATERIALIZED func())").find( + exp.ColumnConstraint + ) + ) + def test_agg_functions(self): def extract_agg_func(query): return parse_one(query, read="clickhouse").selects[0].this diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 190d044..f0faccb 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -20,7 +20,9 @@ class Validator(unittest.TestCase): def parse_one(self, sql, **kwargs): return parse_one(sql, read=self.dialect, **kwargs) - def validate_identity(self, sql, write_sql=None, pretty=False, check_command_warning=False): + def validate_identity( + self, sql, write_sql=None, pretty=False, check_command_warning=False, identify=False + ): if check_command_warning: with self.assertLogs(parser_logger) as cm: expression = self.parse_one(sql) @@ -28,7 +30,9 @@ class Validator(unittest.TestCase): else: expression = self.parse_one(sql) - self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect, pretty=pretty)) + self.assertEqual( + write_sql or sql, expression.sql(dialect=self.dialect, pretty=pretty, identify=identify) + ) return expression def validate_all(self, sql, read=None, write=None, pretty=False, identify=False): @@ -1408,6 +1412,13 @@ class TestDialect(Validator): }, ) + for dialect in ("duckdb", "starrocks"): + with self.subTest(f"Generating json extraction with digit-prefixed key ({dialect})"): + self.assertEqual( + parse_one("""select '{"0": "v"}' -> '0'""", read=dialect).sql(dialect=dialect), + """SELECT '{"0": "v"}' -> '0'""", + ) + def test_cross_join(self): self.validate_all( "SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)", @@ -1422,7 +1433,7 @@ class TestDialect(Validator): write={ "drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)", "presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)", - "spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b", + "spark": "SELECT a, b FROM x LATERAL VIEW INLINE(ARRAYS_ZIP(y, z)) t AS a, b", }, ) self.validate_all( @@ -1488,12 +1499,14 @@ class TestDialect(Validator): "SELECT * FROM a INTERSECT SELECT * FROM b", read={ "bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", + "clickhouse": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", "duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b", "presto": "SELECT * FROM a INTERSECT SELECT * FROM b", "spark": "SELECT * FROM a INTERSECT SELECT * FROM b", }, write={ "bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", + "clickhouse": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", "duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b", "presto": "SELECT * FROM a INTERSECT SELECT * FROM b", "spark": "SELECT * FROM a INTERSECT SELECT * FROM b", @@ -1503,12 +1516,14 @@ class TestDialect(Validator): "SELECT * FROM a EXCEPT SELECT * FROM b", read={ "bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", + "clickhouse": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", "duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b", "presto": "SELECT * FROM a EXCEPT SELECT * FROM b", "spark": "SELECT * FROM a EXCEPT SELECT * FROM b", }, write={ "bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", + "clickhouse": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", "duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b", "presto": "SELECT * FROM a EXCEPT SELECT * FROM b", "spark": "SELECT * FROM a EXCEPT SELECT * FROM b", @@ -1527,6 +1542,7 @@ class TestDialect(Validator): "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", write={ "bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", + "clickhouse": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", "duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b", "presto": "SELECT * FROM a INTERSECT SELECT * FROM b", "spark": "SELECT * FROM a INTERSECT SELECT * FROM b", @@ -1536,6 +1552,7 @@ class TestDialect(Validator): "SELECT * FROM a INTERSECT ALL SELECT * FROM b", write={ "bigquery": "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + "clickhouse": "SELECT * FROM a INTERSECT SELECT * FROM b", "duckdb": "SELECT * FROM a INTERSECT ALL SELECT * FROM b", "presto": "SELECT * FROM a INTERSECT ALL SELECT * FROM b", "spark": "SELECT * FROM a INTERSECT ALL SELECT * FROM b", @@ -1545,6 +1562,7 @@ class TestDialect(Validator): "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", write={ "bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", + "clickhouse": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", "duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b", "presto": "SELECT * FROM a EXCEPT SELECT * FROM b", "spark": "SELECT * FROM a EXCEPT SELECT * FROM b", @@ -1554,6 +1572,7 @@ class TestDialect(Validator): "SELECT * FROM a EXCEPT ALL SELECT * FROM b", read={ "bigquery": "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + "clickhouse": "SELECT * FROM a EXCEPT ALL SELECT * FROM b", "duckdb": "SELECT * FROM a EXCEPT ALL SELECT * FROM b", "presto": "SELECT * FROM a EXCEPT ALL SELECT * FROM b", "spark": "SELECT * FROM a EXCEPT ALL SELECT * FROM b", @@ -2354,7 +2373,7 @@ SELECT "mysql": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", "oracle": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) _t WHERE _w > 1", "postgres": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", - "tsql": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", + "tsql": "SELECT * FROM (SELECT *, COUNT_BIG(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", }, ) self.validate_all( @@ -2366,7 +2385,7 @@ SELECT "mysql": "SELECT `user id`, some_id, other_id, `2 nd id` FROM (SELECT `user id`, some_id, 1 AS other_id, 2 AS `2 nd id`, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", "oracle": 'SELECT "user id", some_id, other_id, "2 nd id" FROM (SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id", COUNT(*) OVER () AS _w FROM t) _t WHERE _w > 1', "postgres": 'SELECT "user id", some_id, other_id, "2 nd id" FROM (SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id", COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1', - "tsql": "SELECT [user id], some_id, other_id, [2 nd id] FROM (SELECT [user id] AS [user id], some_id AS some_id, 1 AS other_id, 2 AS [2 nd id], COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", + "tsql": "SELECT [user id], some_id, other_id, [2 nd id] FROM (SELECT [user id] AS [user id], some_id AS some_id, 1 AS other_id, 2 AS [2 nd id], COUNT_BIG(*) OVER () AS _w FROM t) AS _t WHERE _w > 1", }, ) @@ -2722,3 +2741,115 @@ FROM subquery2""", "tsql": "WITH _generated_dates(date_week) AS (SELECT CAST('2020-01-01' AS DATE) AS date_week UNION ALL SELECT CAST(DATEADD(WEEK, 1, date_week) AS DATE) FROM _generated_dates WHERE CAST(DATEADD(WEEK, 1, date_week) AS DATE) <= CAST('2020-02-01' AS DATE)) SELECT * FROM (SELECT date_week AS date_week FROM _generated_dates) AS _generated_dates", }, ) + + def test_set_operation_specifiers(self): + self.validate_all( + "SELECT 1 EXCEPT ALL SELECT 1", + write={ + "": "SELECT 1 EXCEPT ALL SELECT 1", + "bigquery": UnsupportedError, + "clickhouse": "SELECT 1 EXCEPT SELECT 1", + "databricks": "SELECT 1 EXCEPT ALL SELECT 1", + "duckdb": "SELECT 1 EXCEPT ALL SELECT 1", + "mysql": "SELECT 1 EXCEPT ALL SELECT 1", + "oracle": "SELECT 1 EXCEPT ALL SELECT 1", + "postgres": "SELECT 1 EXCEPT ALL SELECT 1", + "presto": UnsupportedError, + "redshift": UnsupportedError, + "snowflake": UnsupportedError, + "spark": "SELECT 1 EXCEPT ALL SELECT 1", + "sqlite": UnsupportedError, + "starrocks": UnsupportedError, + "trino": UnsupportedError, + "tsql": UnsupportedError, + }, + ) + + def test_normalize(self): + for form in ("", ", nfkc"): + with self.subTest(f"Testing NORMALIZE('str'{form}) roundtrip"): + self.validate_all( + f"SELECT NORMALIZE('str'{form})", + read={ + "presto": f"SELECT NORMALIZE('str'{form})", + "trino": f"SELECT NORMALIZE('str'{form})", + "bigquery": f"SELECT NORMALIZE('str'{form})", + }, + write={ + "presto": f"SELECT NORMALIZE('str'{form})", + "trino": f"SELECT NORMALIZE('str'{form})", + "bigquery": f"SELECT NORMALIZE('str'{form})", + }, + ) + + self.assertIsInstance(parse_one("NORMALIZE('str', NFD)").args.get("form"), exp.Var) + + def test_coalesce(self): + """ + Validate that "expressions" is a list for all the exp.Coalesce instances; This is important + as some optimizer rules are coalesce specific and will iterate on "expressions" + """ + + # Check the 2-arg aliases + for func in ("COALESCE", "IFNULL", "NVL"): + self.assertIsInstance(self.parse_one(f"{func}(1, 2)").expressions, list) + + # Check the varlen case + coalesce = self.parse_one("COALESCE(x, y, z)") + self.assertIsInstance(coalesce.expressions, list) + self.assertIsNone(coalesce.args.get("is_nvl")) + + # Check Oracle's NVL which is decoupled from COALESCE + oracle_nvl = parse_one("NVL(x, y)", read="oracle") + self.assertIsInstance(oracle_nvl.expressions, list) + self.assertTrue(oracle_nvl.args.get("is_nvl")) + + # Check T-SQL's ISNULL which is parsed into exp.Coalesce + self.assertIsInstance(parse_one("ISNULL(x, y)", read="tsql").expressions, list) + + def test_trim(self): + self.validate_all( + "TRIM('abc', 'a')", + read={ + "bigquery": "TRIM('abc', 'a')", + "snowflake": "TRIM('abc', 'a')", + }, + write={ + "bigquery": "TRIM('abc', 'a')", + "snowflake": "TRIM('abc', 'a')", + }, + ) + + self.validate_all( + "LTRIM('Hello World', 'H')", + read={ + "oracle": "LTRIM('Hello World', 'H')", + "clickhouse": "TRIM(LEADING 'H' FROM 'Hello World')", + "snowflake": "LTRIM('Hello World', 'H')", + "bigquery": "LTRIM('Hello World', 'H')", + "": "LTRIM('Hello World', 'H')", + }, + write={ + "clickhouse": "TRIM(LEADING 'H' FROM 'Hello World')", + "oracle": "LTRIM('Hello World', 'H')", + "snowflake": "LTRIM('Hello World', 'H')", + "bigquery": "LTRIM('Hello World', 'H')", + }, + ) + + self.validate_all( + "RTRIM('Hello World', 'd')", + read={ + "clickhouse": "TRIM(TRAILING 'd' FROM 'Hello World')", + "oracle": "RTRIM('Hello World', 'd')", + "snowflake": "RTRIM('Hello World', 'd')", + "bigquery": "RTRIM('Hello World', 'd')", + "": "RTRIM('Hello World', 'd')", + }, + write={ + "clickhouse": "TRIM(TRAILING 'd' FROM 'Hello World')", + "oracle": "RTRIM('Hello World', 'd')", + "snowflake": "RTRIM('Hello World', 'd')", + "bigquery": "RTRIM('Hello World', 'd')", + }, + ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 5d2d044..18a030c 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -696,11 +696,11 @@ class TestDuckDB(Validator): }, ) self.validate_all( - "SELECT CAST('2020-05-06' AS DATE) - INTERVAL 5 DAY", + "SELECT CAST('2020-05-06' AS DATE) - INTERVAL '5' DAY", read={"bigquery": "SELECT DATE_SUB(CAST('2020-05-06' AS DATE), INTERVAL 5 DAY)"}, ) self.validate_all( - "SELECT CAST('2020-05-06' AS DATE) + INTERVAL 5 DAY", + "SELECT CAST('2020-05-06' AS DATE) + INTERVAL '5' DAY", read={"bigquery": "SELECT DATE_ADD(CAST('2020-05-06' AS DATE), INTERVAL 5 DAY)"}, ) self.validate_identity( @@ -879,7 +879,7 @@ class TestDuckDB(Validator): write={"duckdb": "SELECT (90 * INTERVAL '1' DAY)"}, ) self.validate_all( - "SELECT ((DATE_TRUNC('DAY', CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP) + INTERVAL (0 - ((DAYOFWEEK(CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7) % 7) DAY) + (7 * INTERVAL (-5) DAY))) AS t1", + "SELECT ((DATE_TRUNC('DAY', CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP) + INTERVAL (0 - ((ISODOW(CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7) % 7) DAY) + (7 * INTERVAL (-5) DAY))) AS t1", read={ "presto": "SELECT ((DATE_ADD('week', -5, DATE_TRUNC('DAY', DATE_ADD('day', (0 - MOD((DAY_OF_WEEK(CAST(CAST(DATE_TRUNC('DAY', NOW()) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7, 7)), CAST(CAST(DATE_TRUNC('DAY', NOW()) AS DATE) AS TIMESTAMP)))))) AS t1", }, @@ -1100,7 +1100,6 @@ class TestDuckDB(Validator): self.validate_all( "SELECT CAST('09:05:03' AS TIME) + INTERVAL 2 HOUR", read={ - "bigquery": "SELECT TIME_ADD(CAST('09:05:03' AS TIME), INTERVAL 2 HOUR)", "snowflake": "SELECT TIMEADD(HOUR, 2, TO_TIME('09:05:03'))", }, write={ diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index c2768bf..136ea60 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -412,6 +412,7 @@ class TestHive(Validator): ) def test_hive(self): + self.validate_identity("SELECT * FROM t WHERE col IN ('stream')") self.validate_identity("SET hiveconf:some_var = 5", check_command_warning=True) self.validate_identity("(VALUES (1 AS a, 2 AS b, 3))") self.validate_identity("SELECT * FROM my_table TIMESTAMP AS OF DATE_ADD(CURRENT_DATE, -1)") @@ -715,8 +716,8 @@ class TestHive(Validator): "presto": "ARRAY_AGG(x)", }, write={ - "duckdb": "ARRAY_AGG(x)", - "presto": "ARRAY_AGG(x)", + "duckdb": "ARRAY_AGG(x) FILTER(WHERE x IS NOT NULL)", + "presto": "ARRAY_AGG(x) FILTER(WHERE x IS NOT NULL)", "hive": "COLLECT_LIST(x)", "spark": "COLLECT_LIST(x)", }, @@ -764,6 +765,24 @@ class TestHive(Validator): "presto": "SELECT DATE_TRUNC('MONTH', TRY_CAST(ds AS TIMESTAMP)) AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'", }, ) + self.validate_all( + "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + read={ + "hive": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + "spark2": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + "spark": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + "databricks": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + }, + write={ + "hive": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + "spark2": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + "spark": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + "databricks": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + "presto": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 1)", + "trino": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 1)", + "duckdb": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 1)", + }, + ) def test_escapes(self) -> None: self.validate_identity("'\n'", "'\\n'") diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 45b79bf..2fd9ef0 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -241,7 +241,7 @@ class TestMySQL(Validator): "SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000" ) self.validate_identity("INTERVAL '1' YEAR") - self.validate_identity("DATE_ADD(x, INTERVAL 1 YEAR)") + self.validate_identity("DATE_ADD(x, INTERVAL '1' YEAR)") self.validate_identity("CHAR(0)") self.validate_identity("CHAR(77, 121, 83, 81, '76')") self.validate_identity("CHAR(77, 77.3, '77.3' USING utf8mb4)") @@ -539,9 +539,16 @@ class TestMySQL(Validator): }, ) self.validate_all( - "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')", + "SELECT DATE_FORMAT('2024-08-22 14:53:12', '%a')", write={ - "mysql": "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')", + "mysql": "SELECT DATE_FORMAT('2024-08-22 14:53:12', '%a')", + "snowflake": "SELECT TO_CHAR(CAST('2024-08-22 14:53:12' AS TIMESTAMP), 'DY')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%a %M %Y')", + write={ + "mysql": "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%a %M %Y')", "snowflake": "SELECT TO_CHAR(CAST('2009-10-04 22:23:00' AS TIMESTAMP), 'DY mmmm yyyy')", }, ) @@ -555,7 +562,7 @@ class TestMySQL(Validator): self.validate_all( "SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %a %d %m %b')", write={ - "mysql": "SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %W %d %m %b')", + "mysql": "SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %a %d %m %b')", "snowflake": "SELECT TO_CHAR(CAST('1900-10-04 22:23:00' AS TIMESTAMP), 'DD yy DY DD mm mon')", }, ) diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 77f46e4..8bdc4af 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -99,6 +99,13 @@ class TestOracle(Validator): ) self.validate_all( + "TRUNC(SYSDATE, 'YEAR')", + write={ + "clickhouse": "DATE_TRUNC('YEAR', CURRENT_TIMESTAMP())", + "oracle": "TRUNC(SYSDATE, 'YEAR')", + }, + ) + self.validate_all( "SELECT * FROM test WHERE MOD(col1, 4) = 3", read={ "duckdb": "SELECT * FROM test WHERE col1 % 4 = 3", @@ -260,22 +267,6 @@ class TestOracle(Validator): }, ) self.validate_all( - "LTRIM('Hello World', 'H')", - write={ - "": "LTRIM('Hello World', 'H')", - "oracle": "LTRIM('Hello World', 'H')", - "clickhouse": "TRIM(LEADING 'H' FROM 'Hello World')", - }, - ) - self.validate_all( - "RTRIM('Hello World', 'd')", - write={ - "": "RTRIM('Hello World', 'd')", - "oracle": "RTRIM('Hello World', 'd')", - "clickhouse": "TRIM(TRAILING 'd' FROM 'Hello World')", - }, - ) - self.validate_all( "TRIM(BOTH 'h' FROM 'Hello World')", write={ "oracle": "TRIM(BOTH 'h' FROM 'Hello World')", @@ -461,3 +452,93 @@ WHERE self.validate_identity( f"CREATE VIEW view AS SELECT * FROM tbl WITH {restriction}{constraint_name}" ) + + def test_multitable_inserts(self): + self.maxDiff = None + self.validate_identity( + "INSERT ALL " + "INTO dest_tab1 (id, description) VALUES (id, description) " + "INTO dest_tab2 (id, description) VALUES (id, description) " + "INTO dest_tab3 (id, description) VALUES (id, description) " + "SELECT id, description FROM source_tab" + ) + + self.validate_identity( + "INSERT ALL " + "INTO pivot_dest (id, day, val) VALUES (id, 'mon', mon_val) " + "INTO pivot_dest (id, day, val) VALUES (id, 'tue', tue_val) " + "INTO pivot_dest (id, day, val) VALUES (id, 'wed', wed_val) " + "INTO pivot_dest (id, day, val) VALUES (id, 'thu', thu_val) " + "INTO pivot_dest (id, day, val) VALUES (id, 'fri', fri_val) " + "SELECT * " + "FROM pivot_source" + ) + + self.validate_identity( + "INSERT ALL " + "WHEN id <= 3 THEN " + "INTO dest_tab1 (id, description) VALUES (id, description) " + "WHEN id BETWEEN 4 AND 7 THEN " + "INTO dest_tab2 (id, description) VALUES (id, description) " + "WHEN id >= 8 THEN " + "INTO dest_tab3 (id, description) VALUES (id, description) " + "SELECT id, description " + "FROM source_tab" + ) + + self.validate_identity( + "INSERT ALL " + "WHEN id <= 3 THEN " + "INTO dest_tab1 (id, description) VALUES (id, description) " + "WHEN id BETWEEN 4 AND 7 THEN " + "INTO dest_tab2 (id, description) VALUES (id, description) " + "WHEN 1 = 1 THEN " + "INTO dest_tab3 (id, description) VALUES (id, description) " + "SELECT id, description " + "FROM source_tab" + ) + + self.validate_identity( + "INSERT FIRST " + "WHEN id <= 3 THEN " + "INTO dest_tab1 (id, description) VALUES (id, description) " + "WHEN id <= 5 THEN " + "INTO dest_tab2 (id, description) VALUES (id, description) " + "ELSE " + "INTO dest_tab3 (id, description) VALUES (id, description) " + "SELECT id, description " + "FROM source_tab" + ) + + self.validate_identity( + "INSERT FIRST " + "WHEN id <= 3 THEN " + "INTO dest_tab1 (id, description) VALUES (id, description) " + "ELSE " + "INTO dest_tab2 (id, description) VALUES (id, description) " + "INTO dest_tab3 (id, description) VALUES (id, description) " + "SELECT id, description " + "FROM source_tab" + ) + + self.validate_identity( + "/* COMMENT */ INSERT FIRST " + "WHEN salary > 4000 THEN INTO emp2 " + "WHEN salary > 5000 THEN INTO emp3 " + "WHEN salary > 6000 THEN INTO emp4 " + "SELECT salary FROM employees" + ) + + def test_json_functions(self): + for format_json in ("", " FORMAT JSON"): + for on_cond in ( + "", + " TRUE ON ERROR", + " NULL ON EMPTY", + " DEFAULT 1 ON ERROR TRUE ON EMPTY", + ): + for passing in ("", " PASSING 'name1' AS \"var1\", 'name2' AS \"var2\""): + with self.subTest("Testing JSON_EXISTS()"): + self.validate_identity( + f"SELECT * FROM t WHERE JSON_EXISTS(name{format_json}, '$[1].middle'{passing}{on_cond})" + ) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index c628db4..f3f21a9 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -846,6 +846,9 @@ class TestPostgres(Validator): self.validate_identity("ALTER TABLE t1 SET TABLESPACE tablespace") self.validate_identity("ALTER TABLE t1 SET (fillfactor = 5, autovacuum_enabled = TRUE)") self.validate_identity( + "ALTER TABLE tested_table ADD CONSTRAINT unique_example UNIQUE (column_name) NOT VALID" + ) + self.validate_identity( "CREATE FUNCTION pymax(a INT, b INT) RETURNS INT LANGUAGE plpython3u AS $$\n if a > b:\n return a\n return b\n$$", ) self.validate_identity( @@ -1023,6 +1026,10 @@ class TestPostgres(Validator): self.validate_identity( "CREATE INDEX CONCURRENTLY IF NOT EXISTS ix_table_id ON tbl USING btree(id)" ) + self.validate_identity("DROP INDEX ix_table_id") + self.validate_identity("DROP INDEX IF EXISTS ix_table_id") + self.validate_identity("DROP INDEX CONCURRENTLY ix_table_id") + self.validate_identity("DROP INDEX CONCURRENTLY IF EXISTS ix_table_id") self.validate_identity( """ diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index c8e616e..9c61f62 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -329,11 +329,20 @@ class TestPresto(Validator): }, ) self.validate_all( - "DAY_OF_WEEK(timestamp '2012-08-08 01:00:00')", - write={ + "((DAY_OF_WEEK(CAST(TRY_CAST('2012-08-08 01:00:00' AS TIMESTAMP) AS DATE)) % 7) + 1)", + read={ "spark": "DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", + }, + ) + self.validate_all( + "DAY_OF_WEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", + read={ + "duckdb": "ISODOW(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", + }, + write={ + "spark": "((DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP)) % 7) + 1)", "presto": "DAY_OF_WEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", - "duckdb": "DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", + "duckdb": "ISODOW(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", }, ) @@ -522,6 +531,9 @@ class TestPresto(Validator): }, ) + self.validate_identity("""CREATE OR REPLACE VIEW v SECURITY DEFINER AS SELECT id FROM t""") + self.validate_identity("""CREATE OR REPLACE VIEW v SECURITY INVOKER AS SELECT id FROM t""") + def test_quotes(self): self.validate_all( "''''", @@ -1022,6 +1034,25 @@ class TestPresto(Validator): "spark": "SELECT REGEXP_EXTRACT(TO_JSON(FROM_JSON('[[1, 2, 3]]', SCHEMA_OF_JSON('[[1, 2, 3]]'))), '^.(.*).$', 1)", }, ) + self.validate_all( + "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + read={ + "presto": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + "trino": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + "duckdb": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + "snowflake": "REGEXP_SUBSTR('abc', '(a)(b)(c)')", + }, + write={ + "presto": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + "trino": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + "duckdb": "REGEXP_EXTRACT('abc', '(a)(b)(c)')", + "snowflake": "REGEXP_SUBSTR('abc', '(a)(b)(c)')", + "hive": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 0)", + "spark2": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 0)", + "spark": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 0)", + "databricks": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 0)", + }, + ) def test_encode_decode(self): self.validate_identity("FROM_UTF8(x, y)") diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 7837cc9..3e0d600 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -11,30 +11,12 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): - self.validate_identity("1 /* /* */") - self.validate_identity( - "SELECT * FROM table AT (TIMESTAMP => '2024-07-24') UNPIVOT(a FOR b IN (c)) AS pivot_table" - ) - self.assertEqual( # Ensures we don't fail when generating ParseJSON with the `safe` arg set to `True` self.validate_identity("""SELECT TRY_PARSE_JSON('{"x: 1}')""").sql(), """SELECT PARSE_JSON('{"x: 1}')""", ) - self.validate_identity( - "transform(x, a int -> a + a + 1)", - "TRANSFORM(x, a -> CAST(a AS INT) + CAST(a AS INT) + 1)", - ) - - self.validate_all( - "ARRAY_CONSTRUCT_COMPACT(1, null, 2)", - write={ - "spark": "ARRAY_COMPACT(ARRAY(1, NULL, 2))", - "snowflake": "ARRAY_CONSTRUCT_COMPACT(1, NULL, 2)", - }, - ) - expr = parse_one("SELECT APPROX_TOP_K(C4, 3, 5) FROM t") expr.selects[0].assert_is(exp.AggFunc) self.assertEqual(expr.sql(dialect="snowflake"), "SELECT APPROX_TOP_K(C4, 3, 5) FROM t") @@ -98,7 +80,6 @@ WHERE self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT foo FROM IDENTIFIER('x')") self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT IDENTIFIER('foo') FROM x") self.validate_identity("INITCAP('iqamqinterestedqinqthisqtopic', 'q')") - self.validate_identity("CAST(x AS GEOMETRY)") self.validate_identity("OBJECT_CONSTRUCT(*)") self.validate_identity("SELECT CAST('2021-01-01' AS DATE) + INTERVAL '1 DAY'") self.validate_identity("SELECT HLL(*)") @@ -115,6 +96,10 @@ WHERE self.validate_identity("ALTER TABLE a SWAP WITH b") self.validate_identity("SELECT MATCH_CONDITION") self.validate_identity("SELECT * REPLACE (CAST(col AS TEXT) AS scol) FROM t") + self.validate_identity("1 /* /* */") + self.validate_identity( + "SELECT * FROM table AT (TIMESTAMP => '2024-07-24') UNPIVOT(a FOR b IN (c)) AS pivot_table" + ) self.validate_identity( "SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q3', '2023_Q4', '2024_Q1') DEFAULT ON NULL (0)) ORDER BY empid" ) @@ -140,6 +125,18 @@ WHERE "SELECT * FROM DATA AS DATA_L ASOF JOIN DATA AS DATA_R MATCH_CONDITION (DATA_L.VAL > DATA_R.VAL) ON DATA_L.ID = DATA_R.ID" ) self.validate_identity( + "CAST(x AS GEOGRAPHY)", + "TO_GEOGRAPHY(x)", + ) + self.validate_identity( + "CAST(x AS GEOMETRY)", + "TO_GEOMETRY(x)", + ) + self.validate_identity( + "transform(x, a int -> a + a + 1)", + "TRANSFORM(x, a -> CAST(a AS INT) + CAST(a AS INT) + 1)", + ) + self.validate_identity( "SELECT * FROM s WHERE c NOT IN (1, 2, 3)", "SELECT * FROM s WHERE NOT c IN (1, 2, 3)", ) @@ -309,6 +306,13 @@ WHERE ) self.validate_all( + "ARRAY_CONSTRUCT_COMPACT(1, null, 2)", + write={ + "spark": "ARRAY_COMPACT(ARRAY(1, NULL, 2))", + "snowflake": "ARRAY_CONSTRUCT_COMPACT(1, NULL, 2)", + }, + ) + self.validate_all( "OBJECT_CONSTRUCT_KEEP_NULL('key_1', 'one', 'key_2', NULL)", read={ "bigquery": "JSON_OBJECT(['key_1', 'key_2'], ['one', NULL])", @@ -1419,6 +1423,12 @@ WHERE }, ) + self.assertIsNotNone( + self.validate_identity("CREATE TABLE foo (bar INT AS (foo))").find( + exp.TransformColumnConstraint + ) + ) + def test_user_defined_functions(self): self.validate_all( "CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS $$ SELECT 1 $$", @@ -1697,16 +1707,27 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene "REGEXP_SUBSTR(subject, pattern)", read={ "bigquery": "REGEXP_EXTRACT(subject, pattern)", + "snowflake": "REGEXP_EXTRACT(subject, pattern)", + }, + write={ + "bigquery": "REGEXP_EXTRACT(subject, pattern)", + "snowflake": "REGEXP_SUBSTR(subject, pattern)", + }, + ) + self.validate_all( + "REGEXP_SUBSTR(subject, pattern, 1, 1, 'c', 1)", + read={ "hive": "REGEXP_EXTRACT(subject, pattern)", - "presto": "REGEXP_EXTRACT(subject, pattern)", + "spark2": "REGEXP_EXTRACT(subject, pattern)", "spark": "REGEXP_EXTRACT(subject, pattern)", + "databricks": "REGEXP_EXTRACT(subject, pattern)", }, write={ - "bigquery": "REGEXP_EXTRACT(subject, pattern)", "hive": "REGEXP_EXTRACT(subject, pattern)", - "presto": "REGEXP_EXTRACT(subject, pattern)", - "snowflake": "REGEXP_SUBSTR(subject, pattern)", + "spark2": "REGEXP_EXTRACT(subject, pattern)", "spark": "REGEXP_EXTRACT(subject, pattern)", + "databricks": "REGEXP_EXTRACT(subject, pattern)", + "snowflake": "REGEXP_SUBSTR(subject, pattern, 1, 1, 'c', 1)", }, ) self.validate_all( diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index cbaa169..4fed68c 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -309,6 +309,13 @@ TBLPROPERTIES ( ) self.validate_all( + "SELECT ARRAY_AGG(x) FILTER (WHERE x = 5) FROM (SELECT 1 UNION ALL SELECT NULL) AS t(x)", + write={ + "duckdb": "SELECT ARRAY_AGG(x) FILTER(WHERE x = 5 AND NOT x IS NULL) FROM (SELECT 1 UNION ALL SELECT NULL) AS t(x)", + "spark": "SELECT COLLECT_LIST(x) FILTER(WHERE x = 5) FROM (SELECT 1 UNION ALL SELECT NULL) AS t(x)", + }, + ) + self.validate_all( "SELECT DATE_FORMAT(DATE '2020-01-01', 'EEEE') AS weekday", write={ "presto": "SELECT DATE_FORMAT(CAST(CAST('2020-01-01' AS DATE) AS TIMESTAMP), '%W') AS weekday", diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py index fa9a2cc..ee4dc90 100644 --- a/tests/dialects/test_starrocks.py +++ b/tests/dialects/test_starrocks.py @@ -1,13 +1,41 @@ +from sqlglot.errors import UnsupportedError from tests.dialects.test_dialect import Validator class TestStarrocks(Validator): dialect = "starrocks" + def test_ddl(self): + ddl_sqls = [ + "DISTRIBUTED BY HASH (col1) BUCKETS 1", + "DISTRIBUTED BY HASH (col1)", + "DISTRIBUTED BY RANDOM BUCKETS 1", + "DISTRIBUTED BY RANDOM", + "DISTRIBUTED BY HASH (col1) ORDER BY (col1)", + "DISTRIBUTED BY HASH (col1) PROPERTIES ('replication_num'='1')", + "PRIMARY KEY (col1) DISTRIBUTED BY HASH (col1)", + "DUPLICATE KEY (col1, col2) DISTRIBUTED BY HASH (col1)", + ] + + for properties in ddl_sqls: + with self.subTest(f"Testing create scheme: {properties}"): + self.validate_identity(f"CREATE TABLE foo (col1 BIGINT, col2 BIGINT) {properties}") + self.validate_identity( + f"CREATE TABLE foo (col1 BIGINT, col2 BIGINT) ENGINE=OLAP {properties}" + ) + + # Test the different wider DECIMAL types + self.validate_identity( + "CREATE TABLE foo (col0 DECIMAL(9, 1), col1 DECIMAL32(9, 1), col2 DECIMAL64(18, 10), col3 DECIMAL128(38, 10)) DISTRIBUTED BY HASH (col1) BUCKETS 1" + ) + def test_identity(self): self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") self.validate_identity("SELECT APPROX_COUNT_DISTINCT(a) FROM x") self.validate_identity("SELECT [1, 2, 3]") + self.validate_identity( + """SELECT CAST(PARSE_JSON(fieldvalue) -> '00000000-0000-0000-0000-00000000' AS VARCHAR) AS `code` FROM (SELECT '{"00000000-0000-0000-0000-00000000":"code01"}') AS t(fieldvalue)""" + ) def test_time(self): self.validate_identity("TIMESTAMP('2022-01-01')") @@ -35,6 +63,43 @@ class TestStarrocks(Validator): "SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t", "SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t(unnest)", ) + self.validate_all( + "SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores)", + write={ + "spark": "SELECT student, score, unnest FROM tests LATERAL VIEW EXPLODE(scores) unnest AS unnest", + "starrocks": "SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS unnest(unnest)", + }, + ) + self.validate_all( + r"""SELECT * FROM UNNEST(array['John','Jane','Jim','Jamie'], array[24,25,26,27]) AS t(name, age)""", + write={ + "postgres": "SELECT * FROM UNNEST(ARRAY['John', 'Jane', 'Jim', 'Jamie'], ARRAY[24, 25, 26, 27]) AS t(name, age)", + "spark": "SELECT * FROM INLINE(ARRAYS_ZIP(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27))) AS t(name, age)", + "starrocks": "SELECT * FROM UNNEST(['John', 'Jane', 'Jim', 'Jamie'], [24, 25, 26, 27]) AS t(name, age)", + }, + ) + + # Use UNNEST to convert into multiple columns + # see: https://docs.starrocks.io/docs/sql-reference/sql-functions/array-functions/unnest/ + self.validate_all( + r"""SELECT id, t.type, t.scores FROM example_table, unnest(split(type, ";"), scores) AS t(type,scores)""", + write={ + "postgres": "SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)", + "spark": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""", + "databricks": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""", + "starrocks": r"""SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)""", + "hive": UnsupportedError, + }, + ) + + self.validate_all( + r"""SELECT id, t.type, t.scores FROM example_table_2 CROSS JOIN LATERAL unnest(split(type, ";"), scores) AS t(type,scores)""", + write={ + "spark": r"""SELECT id, t.type, t.scores FROM example_table_2 LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""", + "starrocks": r"""SELECT id, t.type, t.scores FROM example_table_2 CROSS JOIN LATERAL UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)""", + "hive": UnsupportedError, + }, + ) lateral_explode_sqls = [ "SELECT id, t.col FROM tbl, UNNEST(scores) AS t(col)", diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 598cb53..466f5d5 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -155,6 +155,15 @@ class TestTeradata(Validator): "tsql": "CREATE TABLE a", }, ) + self.validate_identity( + "CREATE TABLE db.foo (id INT NOT NULL, valid_date DATE FORMAT 'YYYY-MM-DD', measurement INT COMPRESS)" + ) + self.validate_identity( + "CREATE TABLE db.foo (id INT NOT NULL, valid_date DATE FORMAT 'YYYY-MM-DD', measurement INT COMPRESS (1, 2, 3))" + ) + self.validate_identity( + "CREATE TABLE db.foo (id INT NOT NULL, valid_date DATE FORMAT 'YYYY-MM-DD' COMPRESS (CAST('9999-09-09' AS DATE)), measurement INT)" + ) def test_insert(self): self.validate_all( diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index ecb83da..7114750 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -420,6 +420,11 @@ class TestTSQL(Validator): "SELECT val FROM (VALUES ((TRUE), (FALSE), (NULL))) AS t(val)", write_sql="SELECT val FROM (VALUES ((1), (0), (NULL))) AS t(val)", ) + self.validate_identity("'a' + 'b'") + self.validate_identity( + "'a' || 'b'", + "'a' + 'b'", + ) def test_option(self): possible_options = [ @@ -1701,7 +1706,7 @@ WHERE "duckdb": "LAST_DAY(CAST(CURRENT_TIMESTAMP AS DATE) + INTERVAL (-1) MONTH)", "mysql": "LAST_DAY(DATE_ADD(CURRENT_TIMESTAMP(), INTERVAL -1 MONTH))", "postgres": "CAST(DATE_TRUNC('MONTH', CAST(CURRENT_TIMESTAMP AS DATE) + INTERVAL '-1 MONTH') + INTERVAL '1 MONTH' - INTERVAL '1 DAY' AS DATE)", - "presto": "LAST_DAY_OF_MONTH(DATE_ADD('MONTH', CAST(-1 AS BIGINT), CAST(CAST(CURRENT_TIMESTAMP AS TIMESTAMP) AS DATE)))", + "presto": "LAST_DAY_OF_MONTH(DATE_ADD('MONTH', -1, CAST(CAST(CURRENT_TIMESTAMP AS TIMESTAMP) AS DATE)))", "redshift": "LAST_DAY(DATEADD(MONTH, -1, CAST(GETDATE() AS DATE)))", "snowflake": "LAST_DAY(DATEADD(MONTH, -1, TO_DATE(CURRENT_TIMESTAMP())))", "spark": "LAST_DAY(ADD_MONTHS(TO_DATE(CURRENT_TIMESTAMP()), -1))", @@ -1965,3 +1970,31 @@ FROM OPENJSON(@json) WITH ( base_sql = expr.sql() self.assertEqual(base_sql, f"SCOPE_RESOLUTION({lhs + ', ' if lhs else ''}{rhs})") self.assertEqual(parse_one(base_sql).sql("tsql"), f"{lhs}::{rhs}") + + def test_count(self): + count = annotate_types(self.validate_identity("SELECT COUNT(1) FROM x")) + self.assertEqual(count.expressions[0].type.this, exp.DataType.Type.INT) + + count_big = annotate_types(self.validate_identity("SELECT COUNT_BIG(1) FROM x")) + self.assertEqual(count_big.expressions[0].type.this, exp.DataType.Type.BIGINT) + + self.validate_all( + "SELECT COUNT_BIG(1) FROM x", + read={ + "duckdb": "SELECT COUNT(1) FROM x", + "spark": "SELECT COUNT(1) FROM x", + }, + write={ + "duckdb": "SELECT COUNT(1) FROM x", + "spark": "SELECT COUNT(1) FROM x", + "tsql": "SELECT COUNT_BIG(1) FROM x", + }, + ) + self.validate_all( + "SELECT COUNT(1) FROM x", + write={ + "duckdb": "SELECT COUNT(1) FROM x", + "spark": "SELECT COUNT(1) FROM x", + "tsql": "SELECT COUNT(1) FROM x", + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 20cbe7f..013eed8 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -596,6 +596,7 @@ CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (CYCLE)) CREATE TABLE customer (period INT NOT NULL) CREATE TABLE foo (baz_id INT REFERENCES baz (id) DEFERRABLE) CREATE TABLE foo (baz CHAR(4) CHARACTER SET LATIN UPPERCASE NOT CASESPECIFIC COMPRESS 'a') +CREATE TABLE db.foo (id INT NOT NULL, valid_date DATE FORMAT 'YYYY-MM-DD', measurement INT COMPRESS) CREATE TABLE foo (baz DATE FORMAT 'YYYY/MM/DD' TITLE 'title' INLINE LENGTH 1 COMPRESS ('a', 'b')) CREATE TABLE t (title TEXT) CREATE TABLE foo (baz INT, inline TEXT) @@ -877,4 +878,4 @@ SELECT * FROM a STRAIGHT_JOIN b SELECT COUNT(DISTINCT "foo bar") FROM (SELECT 1 AS "foo bar") AS t SELECT vector WITH all AS (SELECT 1 AS count) SELECT all.count FROM all -SELECT rename
\ No newline at end of file +SELECT rename diff --git a/tests/fixtures/optimizer/annotate_types.sql b/tests/fixtures/optimizer/annotate_types.sql index 0a5fc22..f608851 100644 --- a/tests/fixtures/optimizer/annotate_types.sql +++ b/tests/fixtures/optimizer/annotate_types.sql @@ -1,13 +1,25 @@ 5; INT; +-5; +INT; + +~5; +INT; + +(5); +INT; + 5.3; DOUBLE; 'bla'; VARCHAR; -True; +true; +bool; + +not true; bool; false; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index c5f8a4f..76fc16d 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -838,7 +838,7 @@ SELECT FROM `bigquery-public-data.GooGle_tReNDs.TOp_TeRmS` AS `TOp_TeRmS` WHERE `TOp_TeRmS`.`rank` = 1 - AND `TOp_TeRmS`.`refresh_date` >= DATE_SUB(CURRENT_DATE, INTERVAL 2 WEEK) + AND `TOp_TeRmS`.`refresh_date` >= DATE_SUB(CURRENT_DATE, INTERVAL '2' WEEK) GROUP BY `day`, `top_term`, diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index 8e10517..3e5619a 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -405,3 +405,16 @@ JOIN b 'eeeeeeeeeeeeeeeeeeeee' ); +/* COMMENT */ +INSERT FIRST WHEN salary > 4000 THEN INTO emp2 + WHEN salary > 5000 THEN INTO emp3 + WHEN salary > 6000 THEN INTO emp4 +SELECT salary FROM employees; +/* COMMENT */ +INSERT FIRST + WHEN salary > 4000 THEN INTO emp2 + WHEN salary > 5000 THEN INTO emp3 + WHEN salary > 6000 THEN INTO emp4 +SELECT + salary +FROM employees; diff --git a/tests/test_build.py b/tests/test_build.py index 150bb42..e074fea 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -361,10 +361,34 @@ class TestBuild(unittest.TestCase): ( lambda: select("x") .from_("tbl") + .with_("tbl", as_="SELECT x FROM tbl2", materialized=True), + "WITH tbl AS MATERIALIZED (SELECT x FROM tbl2) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .from_("tbl") + .with_("tbl", as_="SELECT x FROM tbl2", materialized=False), + "WITH tbl AS NOT MATERIALIZED (SELECT x FROM tbl2) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .from_("tbl") .with_("tbl", as_="SELECT x FROM tbl2", recursive=True), "WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), ( + lambda: select("x") + .from_("tbl") + .with_("tbl", as_=select("x").from_("tbl2"), recursive=True, materialized=True), + "WITH RECURSIVE tbl AS MATERIALIZED (SELECT x FROM tbl2) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .from_("tbl") + .with_("tbl", as_=select("x").from_("tbl2"), recursive=True, materialized=False), + "WITH RECURSIVE tbl AS NOT MATERIALIZED (SELECT x FROM tbl2) SELECT x FROM tbl", + ), + ( lambda: select("x").from_("tbl").with_("tbl", as_=select("x").from_("tbl2")), "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", ), @@ -677,6 +701,18 @@ class TestBuild(unittest.TestCase): "WITH cte AS (SELECT x FROM tbl) INSERT INTO t SELECT * FROM cte", ), ( + lambda: exp.insert("SELECT * FROM cte", "t").with_( + "cte", as_="SELECT x FROM tbl", materialized=True + ), + "WITH cte AS MATERIALIZED (SELECT x FROM tbl) INSERT INTO t SELECT * FROM cte", + ), + ( + lambda: exp.insert("SELECT * FROM cte", "t").with_( + "cte", as_="SELECT x FROM tbl", materialized=False + ), + "WITH cte AS NOT MATERIALIZED (SELECT x FROM tbl) INSERT INTO t SELECT * FROM cte", + ), + ( lambda: exp.convert((exp.column("x"), exp.column("y"))).isin((1, 2), (3, 4)), "(x, y) IN ((1, 2), (3, 4))", "postgres", diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index d6e11a9..fe5a4d7 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1345,3 +1345,26 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(4, normalization_distance(gen_expr(2), max_=100)) self.assertEqual(18, normalization_distance(gen_expr(3), max_=100)) self.assertEqual(110, normalization_distance(gen_expr(10), max_=100)) + + def test_custom_annotators(self): + # In Spark hierarchy, SUBSTRING result type is dependent on input expr type + for dialect in ("spark2", "spark", "databricks"): + for expr_type_pair in ( + ("col", "STRING"), + ("col", "BINARY"), + ("'str_literal'", "STRING"), + ("CAST('str_literal' AS BINARY)", "BINARY"), + ): + with self.subTest( + f"Testing {dialect}'s SUBSTRING() result type for {expr_type_pair}" + ): + expr, type = expr_type_pair + ast = parse_one(f"SELECT substring({expr}, 2, 3) AS x FROM tbl", read=dialect) + + subst_type = ( + optimizer.optimize(ast, schema={"tbl": {"col": type}}, dialect=dialect) + .expressions[0] + .type + ) + + self.assertEqual(subst_type.sql(dialect), exp.DataType.build(type).sql(dialect)) diff --git a/tests/test_parser.py b/tests/test_parser.py index ff82e08..9ff8373 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -854,3 +854,14 @@ class TestParser(unittest.TestCase): ).find(exp.Collate) self.assertIsInstance(collate_node, exp.Collate) self.assertIsInstance(collate_node.expression, collate_pair[1]) + + def test_odbc_date_literals(self): + for value, cls in [ + ("{d'2024-01-01'}", exp.Date), + ("{t'12:00:00'}", exp.Time), + ("{ts'2024-01-01 12:00:00'}", exp.Timestamp), + ]: + sql = f"INSERT INTO tab(ds) VALUES ({value})" + expr = parse_one(sql) + self.assertIsInstance(expr, exp.Insert) + self.assertIsInstance(expr.expression.expressions[0].expressions[0], cls) |