summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/dialects/test_athena.py46
-rw-r--r--tests/dialects/test_bigquery.py64
-rw-r--r--tests/dialects/test_clickhouse.py28
-rw-r--r--tests/dialects/test_dialect.py141
-rw-r--r--tests/dialects/test_duckdb.py7
-rw-r--r--tests/dialects/test_hive.py23
-rw-r--r--tests/dialects/test_mysql.py15
-rw-r--r--tests/dialects/test_oracle.py113
-rw-r--r--tests/dialects/test_postgres.py7
-rw-r--r--tests/dialects/test_presto.py37
-rw-r--r--tests/dialects/test_snowflake.py67
-rw-r--r--tests/dialects/test_spark.py7
-rw-r--r--tests/dialects/test_starrocks.py65
-rw-r--r--tests/dialects/test_teradata.py9
-rw-r--r--tests/dialects/test_tsql.py35
-rw-r--r--tests/fixtures/identity.sql3
-rw-r--r--tests/fixtures/optimizer/annotate_types.sql14
-rw-r--r--tests/fixtures/optimizer/optimizer.sql2
-rw-r--r--tests/fixtures/pretty.sql13
-rw-r--r--tests/test_build.py36
-rw-r--r--tests/test_optimizer.py23
-rw-r--r--tests/test_parser.py11
22 files changed, 683 insertions, 83 deletions
diff --git a/tests/dialects/test_athena.py b/tests/dialects/test_athena.py
index 5522976..6ec870b 100644
--- a/tests/dialects/test_athena.py
+++ b/tests/dialects/test_athena.py
@@ -23,3 +23,49 @@ class TestAthena(Validator):
some_function(1)""",
check_command_warning=True,
)
+
+ def test_ddl_quoting(self):
+ self.validate_identity("CREATE SCHEMA `foo`")
+ self.validate_identity("CREATE SCHEMA foo")
+ self.validate_identity("CREATE SCHEMA foo", write_sql="CREATE SCHEMA `foo`", identify=True)
+
+ self.validate_identity("CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'")
+ self.validate_identity("CREATE EXTERNAL TABLE foo (id INTEGER) LOCATION 's3://foo/'")
+ self.validate_identity(
+ "CREATE EXTERNAL TABLE foo (id INTEGER) LOCATION 's3://foo/'",
+ write_sql="CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'",
+ identify=True,
+ )
+
+ self.validate_identity("DROP TABLE `foo`")
+ self.validate_identity("DROP TABLE foo")
+ self.validate_identity("DROP TABLE foo", write_sql="DROP TABLE `foo`", identify=True)
+
+ self.validate_identity('CREATE VIEW "foo" AS SELECT "id" FROM "tbl"')
+ self.validate_identity("CREATE VIEW foo AS SELECT id FROM tbl")
+ self.validate_identity(
+ "CREATE VIEW foo AS SELECT id FROM tbl",
+ write_sql='CREATE VIEW "foo" AS SELECT "id" FROM "tbl"',
+ identify=True,
+ )
+
+ # As a side effect of being able to parse both quote types, we can also fix the quoting on incorrectly quoted source queries
+ self.validate_identity('CREATE SCHEMA "foo"', write_sql="CREATE SCHEMA `foo`")
+ self.validate_identity(
+ 'CREATE EXTERNAL TABLE "foo" ("id" INTEGER) LOCATION \'s3://foo/\'',
+ write_sql="CREATE EXTERNAL TABLE `foo` (`id` INTEGER) LOCATION 's3://foo/'",
+ )
+ self.validate_identity('DROP TABLE "foo"', write_sql="DROP TABLE `foo`")
+ self.validate_identity(
+ 'CREATE VIEW `foo` AS SELECT "id" FROM `tbl`',
+ write_sql='CREATE VIEW "foo" AS SELECT "id" FROM "tbl"',
+ )
+
+ def test_dml_quoting(self):
+ self.validate_identity("SELECT a AS foo FROM tbl")
+ self.validate_identity('SELECT "a" AS "foo" FROM "tbl"')
+ self.validate_identity(
+ 'SELECT `a` AS `foo` FROM "tbl"',
+ write_sql='SELECT "a" AS "foo" FROM "tbl"',
+ identify=True,
+ )
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index f6e8fe8..c8c2176 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -108,7 +108,6 @@ LANGUAGE js AS
self.validate_identity("SELECT * FROM READ_CSV('bla.csv')")
self.validate_identity("CAST(x AS STRUCT<list ARRAY<INT64>>)")
self.validate_identity("assert.true(1 = 1)")
- self.validate_identity("SELECT ARRAY_TO_STRING(list, '--') AS text")
self.validate_identity("SELECT jsondoc['some_key']")
self.validate_identity("SELECT `p.d.UdF`(data).* FROM `p.d.t`")
self.validate_identity("SELECT * FROM `my-project.my-dataset.my-table`")
@@ -631,9 +630,9 @@ LANGUAGE js AS
self.validate_all(
"SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)",
write={
- "bigquery": "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)",
- "databricks": "SELECT TIMESTAMPADD(MILLISECOND, 1, '2023-01-01T00:00:00')",
- "duckdb": "SELECT CAST('2023-01-01T00:00:00' AS DATETIME) + INTERVAL 1 MILLISECOND",
+ "bigquery": "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL '1' MILLISECOND)",
+ "databricks": "SELECT TIMESTAMPADD(MILLISECOND, '1', '2023-01-01T00:00:00')",
+ "duckdb": "SELECT CAST('2023-01-01T00:00:00' AS DATETIME) + INTERVAL '1' MILLISECOND",
},
),
)
@@ -641,9 +640,9 @@ LANGUAGE js AS
self.validate_all(
"SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)",
write={
- "bigquery": "SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)",
- "databricks": "SELECT TIMESTAMPADD(MILLISECOND, 1 * -1, '2023-01-01T00:00:00')",
- "duckdb": "SELECT CAST('2023-01-01T00:00:00' AS DATETIME) - INTERVAL 1 MILLISECOND",
+ "bigquery": "SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL '1' MILLISECOND)",
+ "databricks": "SELECT TIMESTAMPADD(MILLISECOND, '1' * -1, '2023-01-01T00:00:00')",
+ "duckdb": "SELECT CAST('2023-01-01T00:00:00' AS DATETIME) - INTERVAL '1' MILLISECOND",
},
),
)
@@ -661,17 +660,24 @@ LANGUAGE js AS
self.validate_all(
'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)',
write={
- "bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)",
- "databricks": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
- "mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)",
- "spark": "SELECT DATE_ADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
+ "bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL '10' MINUTE)",
+ "databricks": "SELECT DATE_ADD(MINUTE, '10', CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
+ "mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL '10' MINUTE)",
+ "spark": "SELECT DATE_ADD(MINUTE, '10', CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
},
)
self.validate_all(
'SELECT TIMESTAMP_SUB(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)',
write={
- "bigquery": "SELECT TIMESTAMP_SUB(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)",
- "mysql": "SELECT DATE_SUB(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)",
+ "bigquery": "SELECT TIMESTAMP_SUB(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL '10' MINUTE)",
+ "mysql": "SELECT DATE_SUB(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL '10' MINUTE)",
+ },
+ )
+ self.validate_all(
+ "SELECT TIME_ADD(CAST('09:05:03' AS TIME), INTERVAL 2 HOUR)",
+ write={
+ "bigquery": "SELECT TIME_ADD(CAST('09:05:03' AS TIME), INTERVAL '2' HOUR)",
+ "duckdb": "SELECT CAST('09:05:03' AS TIME) + INTERVAL '2' HOUR",
},
)
self.validate_all(
@@ -1237,19 +1243,19 @@ LANGUAGE js AS
"DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY)",
write={
"postgres": "CURRENT_DATE - INTERVAL '1 DAY'",
- "bigquery": "DATE_SUB(CURRENT_DATE, INTERVAL 1 DAY)",
+ "bigquery": "DATE_SUB(CURRENT_DATE, INTERVAL '1' DAY)",
},
)
self.validate_all(
- "DATE_ADD(CURRENT_DATE(), INTERVAL 1 DAY)",
+ "DATE_ADD(CURRENT_DATE(), INTERVAL -1 DAY)",
write={
- "bigquery": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)",
- "duckdb": "CURRENT_DATE + INTERVAL 1 DAY",
- "mysql": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)",
- "postgres": "CURRENT_DATE + INTERVAL '1 DAY'",
- "presto": "DATE_ADD('DAY', 1, CURRENT_DATE)",
- "hive": "DATE_ADD(CURRENT_DATE, 1)",
- "spark": "DATE_ADD(CURRENT_DATE, 1)",
+ "bigquery": "DATE_ADD(CURRENT_DATE, INTERVAL '-1' DAY)",
+ "duckdb": "CURRENT_DATE + INTERVAL '-1' DAY",
+ "mysql": "DATE_ADD(CURRENT_DATE, INTERVAL '-1' DAY)",
+ "postgres": "CURRENT_DATE + INTERVAL '-1 DAY'",
+ "presto": "DATE_ADD('DAY', CAST('-1' AS BIGINT), CURRENT_DATE)",
+ "hive": "DATE_ADD(CURRENT_DATE, '-1')",
+ "spark": "DATE_ADD(CURRENT_DATE, '-1')",
},
)
self.validate_all(
@@ -1478,6 +1484,20 @@ WHERE
"duckdb": "SELECT CAST(STRPTIME('Thursday Dec 25 2008', '%A %b %-d %Y') AS DATE)",
},
)
+ self.validate_all(
+ "SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--') AS text",
+ write={
+ "bigquery": "SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--') AS text",
+ "duckdb": "SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--') AS text",
+ },
+ )
+ self.validate_all(
+ "SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--', 'MISSING') AS text",
+ write={
+ "bigquery": "SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--', 'MISSING') AS text",
+ "duckdb": "SELECT ARRAY_TO_STRING(LIST_TRANSFORM(['cake', 'pie', NULL], x -> COALESCE(x, 'MISSING')), '--') AS text",
+ },
+ )
def test_errors(self):
with self.assertRaises(TokenError):
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index b4ba09e..ea6064a 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -160,6 +160,27 @@ class TestClickhouse(Validator):
)
self.validate_all(
+ "char(67) || char(65) || char(84)",
+ read={
+ "clickhouse": "char(67) || char(65) || char(84)",
+ "oracle": "chr(67) || chr(65) || chr(84)",
+ },
+ )
+ self.validate_all(
+ "SELECT lagInFrame(salary, 1, 0) OVER (ORDER BY hire_date) AS prev_sal FROM employees",
+ read={
+ "clickhouse": "SELECT lagInFrame(salary, 1, 0) OVER (ORDER BY hire_date) AS prev_sal FROM employees",
+ "oracle": "SELECT LAG(salary, 1, 0) OVER (ORDER BY hire_date) AS prev_sal FROM employees",
+ },
+ )
+ self.validate_all(
+ "SELECT leadInFrame(salary, 1, 0) OVER (ORDER BY hire_date) AS prev_sal FROM employees",
+ read={
+ "clickhouse": "SELECT leadInFrame(salary, 1, 0) OVER (ORDER BY hire_date) AS prev_sal FROM employees",
+ "oracle": "SELECT LEAD(salary, 1, 0) OVER (ORDER BY hire_date) AS prev_sal FROM employees",
+ },
+ )
+ self.validate_all(
"SELECT CAST(STR_TO_DATE('05 12 2000', '%d %m %Y') AS DATE)",
read={
"clickhouse": "SELECT CAST(STR_TO_DATE('05 12 2000', '%d %m %Y') AS DATE)",
@@ -494,6 +515,7 @@ class TestClickhouse(Validator):
)
self.validate_identity("SELECT TRIM(TRAILING ')' FROM '( Hello, world! )')")
self.validate_identity("SELECT TRIM(LEADING '(' FROM '( Hello, world! )')")
+ self.validate_identity("current_timestamp").assert_is(exp.Column)
def test_clickhouse_values(self):
values = exp.select("*").from_(
@@ -995,6 +1017,12 @@ LIFETIME(MIN 0 MAX 0)""",
pretty=True,
)
+ self.assertIsNotNone(
+ self.validate_identity("CREATE TABLE t1 (a String MATERIALIZED func())").find(
+ exp.ColumnConstraint
+ )
+ )
+
def test_agg_functions(self):
def extract_agg_func(query):
return parse_one(query, read="clickhouse").selects[0].this
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 190d044..f0faccb 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -20,7 +20,9 @@ class Validator(unittest.TestCase):
def parse_one(self, sql, **kwargs):
return parse_one(sql, read=self.dialect, **kwargs)
- def validate_identity(self, sql, write_sql=None, pretty=False, check_command_warning=False):
+ def validate_identity(
+ self, sql, write_sql=None, pretty=False, check_command_warning=False, identify=False
+ ):
if check_command_warning:
with self.assertLogs(parser_logger) as cm:
expression = self.parse_one(sql)
@@ -28,7 +30,9 @@ class Validator(unittest.TestCase):
else:
expression = self.parse_one(sql)
- self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect, pretty=pretty))
+ self.assertEqual(
+ write_sql or sql, expression.sql(dialect=self.dialect, pretty=pretty, identify=identify)
+ )
return expression
def validate_all(self, sql, read=None, write=None, pretty=False, identify=False):
@@ -1408,6 +1412,13 @@ class TestDialect(Validator):
},
)
+ for dialect in ("duckdb", "starrocks"):
+ with self.subTest(f"Generating json extraction with digit-prefixed key ({dialect})"):
+ self.assertEqual(
+ parse_one("""select '{"0": "v"}' -> '0'""", read=dialect).sql(dialect=dialect),
+ """SELECT '{"0": "v"}' -> '0'""",
+ )
+
def test_cross_join(self):
self.validate_all(
"SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)",
@@ -1422,7 +1433,7 @@ class TestDialect(Validator):
write={
"drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
- "spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
+ "spark": "SELECT a, b FROM x LATERAL VIEW INLINE(ARRAYS_ZIP(y, z)) t AS a, b",
},
)
self.validate_all(
@@ -1488,12 +1499,14 @@ class TestDialect(Validator):
"SELECT * FROM a INTERSECT SELECT * FROM b",
read={
"bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b",
+ "clickhouse": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b",
"presto": "SELECT * FROM a INTERSECT SELECT * FROM b",
"spark": "SELECT * FROM a INTERSECT SELECT * FROM b",
},
write={
"bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b",
+ "clickhouse": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b",
"presto": "SELECT * FROM a INTERSECT SELECT * FROM b",
"spark": "SELECT * FROM a INTERSECT SELECT * FROM b",
@@ -1503,12 +1516,14 @@ class TestDialect(Validator):
"SELECT * FROM a EXCEPT SELECT * FROM b",
read={
"bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b",
+ "clickhouse": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b",
"presto": "SELECT * FROM a EXCEPT SELECT * FROM b",
"spark": "SELECT * FROM a EXCEPT SELECT * FROM b",
},
write={
"bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b",
+ "clickhouse": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b",
"presto": "SELECT * FROM a EXCEPT SELECT * FROM b",
"spark": "SELECT * FROM a EXCEPT SELECT * FROM b",
@@ -1527,6 +1542,7 @@ class TestDialect(Validator):
"SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b",
write={
"bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b",
+ "clickhouse": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b",
"presto": "SELECT * FROM a INTERSECT SELECT * FROM b",
"spark": "SELECT * FROM a INTERSECT SELECT * FROM b",
@@ -1536,6 +1552,7 @@ class TestDialect(Validator):
"SELECT * FROM a INTERSECT ALL SELECT * FROM b",
write={
"bigquery": "SELECT * FROM a INTERSECT ALL SELECT * FROM b",
+ "clickhouse": "SELECT * FROM a INTERSECT SELECT * FROM b",
"duckdb": "SELECT * FROM a INTERSECT ALL SELECT * FROM b",
"presto": "SELECT * FROM a INTERSECT ALL SELECT * FROM b",
"spark": "SELECT * FROM a INTERSECT ALL SELECT * FROM b",
@@ -1545,6 +1562,7 @@ class TestDialect(Validator):
"SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b",
write={
"bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b",
+ "clickhouse": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b",
"duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b",
"presto": "SELECT * FROM a EXCEPT SELECT * FROM b",
"spark": "SELECT * FROM a EXCEPT SELECT * FROM b",
@@ -1554,6 +1572,7 @@ class TestDialect(Validator):
"SELECT * FROM a EXCEPT ALL SELECT * FROM b",
read={
"bigquery": "SELECT * FROM a EXCEPT ALL SELECT * FROM b",
+ "clickhouse": "SELECT * FROM a EXCEPT ALL SELECT * FROM b",
"duckdb": "SELECT * FROM a EXCEPT ALL SELECT * FROM b",
"presto": "SELECT * FROM a EXCEPT ALL SELECT * FROM b",
"spark": "SELECT * FROM a EXCEPT ALL SELECT * FROM b",
@@ -2354,7 +2373,7 @@ SELECT
"mysql": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
"oracle": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) _t WHERE _w > 1",
"postgres": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
- "tsql": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ "tsql": "SELECT * FROM (SELECT *, COUNT_BIG(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
},
)
self.validate_all(
@@ -2366,7 +2385,7 @@ SELECT
"mysql": "SELECT `user id`, some_id, other_id, `2 nd id` FROM (SELECT `user id`, some_id, 1 AS other_id, 2 AS `2 nd id`, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
"oracle": 'SELECT "user id", some_id, other_id, "2 nd id" FROM (SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id", COUNT(*) OVER () AS _w FROM t) _t WHERE _w > 1',
"postgres": 'SELECT "user id", some_id, other_id, "2 nd id" FROM (SELECT "user id", some_id, 1 AS other_id, 2 AS "2 nd id", COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1',
- "tsql": "SELECT [user id], some_id, other_id, [2 nd id] FROM (SELECT [user id] AS [user id], some_id AS some_id, 1 AS other_id, 2 AS [2 nd id], COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ "tsql": "SELECT [user id], some_id, other_id, [2 nd id] FROM (SELECT [user id] AS [user id], some_id AS some_id, 1 AS other_id, 2 AS [2 nd id], COUNT_BIG(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
},
)
@@ -2722,3 +2741,115 @@ FROM subquery2""",
"tsql": "WITH _generated_dates(date_week) AS (SELECT CAST('2020-01-01' AS DATE) AS date_week UNION ALL SELECT CAST(DATEADD(WEEK, 1, date_week) AS DATE) FROM _generated_dates WHERE CAST(DATEADD(WEEK, 1, date_week) AS DATE) <= CAST('2020-02-01' AS DATE)) SELECT * FROM (SELECT date_week AS date_week FROM _generated_dates) AS _generated_dates",
},
)
+
+ def test_set_operation_specifiers(self):
+ self.validate_all(
+ "SELECT 1 EXCEPT ALL SELECT 1",
+ write={
+ "": "SELECT 1 EXCEPT ALL SELECT 1",
+ "bigquery": UnsupportedError,
+ "clickhouse": "SELECT 1 EXCEPT SELECT 1",
+ "databricks": "SELECT 1 EXCEPT ALL SELECT 1",
+ "duckdb": "SELECT 1 EXCEPT ALL SELECT 1",
+ "mysql": "SELECT 1 EXCEPT ALL SELECT 1",
+ "oracle": "SELECT 1 EXCEPT ALL SELECT 1",
+ "postgres": "SELECT 1 EXCEPT ALL SELECT 1",
+ "presto": UnsupportedError,
+ "redshift": UnsupportedError,
+ "snowflake": UnsupportedError,
+ "spark": "SELECT 1 EXCEPT ALL SELECT 1",
+ "sqlite": UnsupportedError,
+ "starrocks": UnsupportedError,
+ "trino": UnsupportedError,
+ "tsql": UnsupportedError,
+ },
+ )
+
+ def test_normalize(self):
+ for form in ("", ", nfkc"):
+ with self.subTest(f"Testing NORMALIZE('str'{form}) roundtrip"):
+ self.validate_all(
+ f"SELECT NORMALIZE('str'{form})",
+ read={
+ "presto": f"SELECT NORMALIZE('str'{form})",
+ "trino": f"SELECT NORMALIZE('str'{form})",
+ "bigquery": f"SELECT NORMALIZE('str'{form})",
+ },
+ write={
+ "presto": f"SELECT NORMALIZE('str'{form})",
+ "trino": f"SELECT NORMALIZE('str'{form})",
+ "bigquery": f"SELECT NORMALIZE('str'{form})",
+ },
+ )
+
+ self.assertIsInstance(parse_one("NORMALIZE('str', NFD)").args.get("form"), exp.Var)
+
+ def test_coalesce(self):
+ """
+ Validate that "expressions" is a list for all the exp.Coalesce instances; This is important
+ as some optimizer rules are coalesce specific and will iterate on "expressions"
+ """
+
+ # Check the 2-arg aliases
+ for func in ("COALESCE", "IFNULL", "NVL"):
+ self.assertIsInstance(self.parse_one(f"{func}(1, 2)").expressions, list)
+
+ # Check the varlen case
+ coalesce = self.parse_one("COALESCE(x, y, z)")
+ self.assertIsInstance(coalesce.expressions, list)
+ self.assertIsNone(coalesce.args.get("is_nvl"))
+
+ # Check Oracle's NVL which is decoupled from COALESCE
+ oracle_nvl = parse_one("NVL(x, y)", read="oracle")
+ self.assertIsInstance(oracle_nvl.expressions, list)
+ self.assertTrue(oracle_nvl.args.get("is_nvl"))
+
+ # Check T-SQL's ISNULL which is parsed into exp.Coalesce
+ self.assertIsInstance(parse_one("ISNULL(x, y)", read="tsql").expressions, list)
+
+ def test_trim(self):
+ self.validate_all(
+ "TRIM('abc', 'a')",
+ read={
+ "bigquery": "TRIM('abc', 'a')",
+ "snowflake": "TRIM('abc', 'a')",
+ },
+ write={
+ "bigquery": "TRIM('abc', 'a')",
+ "snowflake": "TRIM('abc', 'a')",
+ },
+ )
+
+ self.validate_all(
+ "LTRIM('Hello World', 'H')",
+ read={
+ "oracle": "LTRIM('Hello World', 'H')",
+ "clickhouse": "TRIM(LEADING 'H' FROM 'Hello World')",
+ "snowflake": "LTRIM('Hello World', 'H')",
+ "bigquery": "LTRIM('Hello World', 'H')",
+ "": "LTRIM('Hello World', 'H')",
+ },
+ write={
+ "clickhouse": "TRIM(LEADING 'H' FROM 'Hello World')",
+ "oracle": "LTRIM('Hello World', 'H')",
+ "snowflake": "LTRIM('Hello World', 'H')",
+ "bigquery": "LTRIM('Hello World', 'H')",
+ },
+ )
+
+ self.validate_all(
+ "RTRIM('Hello World', 'd')",
+ read={
+ "clickhouse": "TRIM(TRAILING 'd' FROM 'Hello World')",
+ "oracle": "RTRIM('Hello World', 'd')",
+ "snowflake": "RTRIM('Hello World', 'd')",
+ "bigquery": "RTRIM('Hello World', 'd')",
+ "": "RTRIM('Hello World', 'd')",
+ },
+ write={
+ "clickhouse": "TRIM(TRAILING 'd' FROM 'Hello World')",
+ "oracle": "RTRIM('Hello World', 'd')",
+ "snowflake": "RTRIM('Hello World', 'd')",
+ "bigquery": "RTRIM('Hello World', 'd')",
+ },
+ )
diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py
index 5d2d044..18a030c 100644
--- a/tests/dialects/test_duckdb.py
+++ b/tests/dialects/test_duckdb.py
@@ -696,11 +696,11 @@ class TestDuckDB(Validator):
},
)
self.validate_all(
- "SELECT CAST('2020-05-06' AS DATE) - INTERVAL 5 DAY",
+ "SELECT CAST('2020-05-06' AS DATE) - INTERVAL '5' DAY",
read={"bigquery": "SELECT DATE_SUB(CAST('2020-05-06' AS DATE), INTERVAL 5 DAY)"},
)
self.validate_all(
- "SELECT CAST('2020-05-06' AS DATE) + INTERVAL 5 DAY",
+ "SELECT CAST('2020-05-06' AS DATE) + INTERVAL '5' DAY",
read={"bigquery": "SELECT DATE_ADD(CAST('2020-05-06' AS DATE), INTERVAL 5 DAY)"},
)
self.validate_identity(
@@ -879,7 +879,7 @@ class TestDuckDB(Validator):
write={"duckdb": "SELECT (90 * INTERVAL '1' DAY)"},
)
self.validate_all(
- "SELECT ((DATE_TRUNC('DAY', CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP) + INTERVAL (0 - ((DAYOFWEEK(CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7) % 7) DAY) + (7 * INTERVAL (-5) DAY))) AS t1",
+ "SELECT ((DATE_TRUNC('DAY', CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP) + INTERVAL (0 - ((ISODOW(CAST(CAST(DATE_TRUNC('DAY', CURRENT_TIMESTAMP) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7) % 7) DAY) + (7 * INTERVAL (-5) DAY))) AS t1",
read={
"presto": "SELECT ((DATE_ADD('week', -5, DATE_TRUNC('DAY', DATE_ADD('day', (0 - MOD((DAY_OF_WEEK(CAST(CAST(DATE_TRUNC('DAY', NOW()) AS DATE) AS TIMESTAMP)) % 7) - 1 + 7, 7)), CAST(CAST(DATE_TRUNC('DAY', NOW()) AS DATE) AS TIMESTAMP)))))) AS t1",
},
@@ -1100,7 +1100,6 @@ class TestDuckDB(Validator):
self.validate_all(
"SELECT CAST('09:05:03' AS TIME) + INTERVAL 2 HOUR",
read={
- "bigquery": "SELECT TIME_ADD(CAST('09:05:03' AS TIME), INTERVAL 2 HOUR)",
"snowflake": "SELECT TIMEADD(HOUR, 2, TO_TIME('09:05:03'))",
},
write={
diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py
index c2768bf..136ea60 100644
--- a/tests/dialects/test_hive.py
+++ b/tests/dialects/test_hive.py
@@ -412,6 +412,7 @@ class TestHive(Validator):
)
def test_hive(self):
+ self.validate_identity("SELECT * FROM t WHERE col IN ('stream')")
self.validate_identity("SET hiveconf:some_var = 5", check_command_warning=True)
self.validate_identity("(VALUES (1 AS a, 2 AS b, 3))")
self.validate_identity("SELECT * FROM my_table TIMESTAMP AS OF DATE_ADD(CURRENT_DATE, -1)")
@@ -715,8 +716,8 @@ class TestHive(Validator):
"presto": "ARRAY_AGG(x)",
},
write={
- "duckdb": "ARRAY_AGG(x)",
- "presto": "ARRAY_AGG(x)",
+ "duckdb": "ARRAY_AGG(x) FILTER(WHERE x IS NOT NULL)",
+ "presto": "ARRAY_AGG(x) FILTER(WHERE x IS NOT NULL)",
"hive": "COLLECT_LIST(x)",
"spark": "COLLECT_LIST(x)",
},
@@ -764,6 +765,24 @@ class TestHive(Validator):
"presto": "SELECT DATE_TRUNC('MONTH', TRY_CAST(ds AS TIMESTAMP)) AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'",
},
)
+ self.validate_all(
+ "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ read={
+ "hive": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "spark2": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "spark": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "databricks": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ },
+ write={
+ "hive": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "spark2": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "spark": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "databricks": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "presto": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 1)",
+ "trino": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 1)",
+ "duckdb": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 1)",
+ },
+ )
def test_escapes(self) -> None:
self.validate_identity("'\n'", "'\\n'")
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index 45b79bf..2fd9ef0 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -241,7 +241,7 @@ class TestMySQL(Validator):
"SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000"
)
self.validate_identity("INTERVAL '1' YEAR")
- self.validate_identity("DATE_ADD(x, INTERVAL 1 YEAR)")
+ self.validate_identity("DATE_ADD(x, INTERVAL '1' YEAR)")
self.validate_identity("CHAR(0)")
self.validate_identity("CHAR(77, 121, 83, 81, '76')")
self.validate_identity("CHAR(77, 77.3, '77.3' USING utf8mb4)")
@@ -539,9 +539,16 @@ class TestMySQL(Validator):
},
)
self.validate_all(
- "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')",
+ "SELECT DATE_FORMAT('2024-08-22 14:53:12', '%a')",
write={
- "mysql": "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')",
+ "mysql": "SELECT DATE_FORMAT('2024-08-22 14:53:12', '%a')",
+ "snowflake": "SELECT TO_CHAR(CAST('2024-08-22 14:53:12' AS TIMESTAMP), 'DY')",
+ },
+ )
+ self.validate_all(
+ "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%a %M %Y')",
+ write={
+ "mysql": "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%a %M %Y')",
"snowflake": "SELECT TO_CHAR(CAST('2009-10-04 22:23:00' AS TIMESTAMP), 'DY mmmm yyyy')",
},
)
@@ -555,7 +562,7 @@ class TestMySQL(Validator):
self.validate_all(
"SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %a %d %m %b')",
write={
- "mysql": "SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %W %d %m %b')",
+ "mysql": "SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %a %d %m %b')",
"snowflake": "SELECT TO_CHAR(CAST('1900-10-04 22:23:00' AS TIMESTAMP), 'DD yy DY DD mm mon')",
},
)
diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py
index 77f46e4..8bdc4af 100644
--- a/tests/dialects/test_oracle.py
+++ b/tests/dialects/test_oracle.py
@@ -99,6 +99,13 @@ class TestOracle(Validator):
)
self.validate_all(
+ "TRUNC(SYSDATE, 'YEAR')",
+ write={
+ "clickhouse": "DATE_TRUNC('YEAR', CURRENT_TIMESTAMP())",
+ "oracle": "TRUNC(SYSDATE, 'YEAR')",
+ },
+ )
+ self.validate_all(
"SELECT * FROM test WHERE MOD(col1, 4) = 3",
read={
"duckdb": "SELECT * FROM test WHERE col1 % 4 = 3",
@@ -260,22 +267,6 @@ class TestOracle(Validator):
},
)
self.validate_all(
- "LTRIM('Hello World', 'H')",
- write={
- "": "LTRIM('Hello World', 'H')",
- "oracle": "LTRIM('Hello World', 'H')",
- "clickhouse": "TRIM(LEADING 'H' FROM 'Hello World')",
- },
- )
- self.validate_all(
- "RTRIM('Hello World', 'd')",
- write={
- "": "RTRIM('Hello World', 'd')",
- "oracle": "RTRIM('Hello World', 'd')",
- "clickhouse": "TRIM(TRAILING 'd' FROM 'Hello World')",
- },
- )
- self.validate_all(
"TRIM(BOTH 'h' FROM 'Hello World')",
write={
"oracle": "TRIM(BOTH 'h' FROM 'Hello World')",
@@ -461,3 +452,93 @@ WHERE
self.validate_identity(
f"CREATE VIEW view AS SELECT * FROM tbl WITH {restriction}{constraint_name}"
)
+
+ def test_multitable_inserts(self):
+ self.maxDiff = None
+ self.validate_identity(
+ "INSERT ALL "
+ "INTO dest_tab1 (id, description) VALUES (id, description) "
+ "INTO dest_tab2 (id, description) VALUES (id, description) "
+ "INTO dest_tab3 (id, description) VALUES (id, description) "
+ "SELECT id, description FROM source_tab"
+ )
+
+ self.validate_identity(
+ "INSERT ALL "
+ "INTO pivot_dest (id, day, val) VALUES (id, 'mon', mon_val) "
+ "INTO pivot_dest (id, day, val) VALUES (id, 'tue', tue_val) "
+ "INTO pivot_dest (id, day, val) VALUES (id, 'wed', wed_val) "
+ "INTO pivot_dest (id, day, val) VALUES (id, 'thu', thu_val) "
+ "INTO pivot_dest (id, day, val) VALUES (id, 'fri', fri_val) "
+ "SELECT * "
+ "FROM pivot_source"
+ )
+
+ self.validate_identity(
+ "INSERT ALL "
+ "WHEN id <= 3 THEN "
+ "INTO dest_tab1 (id, description) VALUES (id, description) "
+ "WHEN id BETWEEN 4 AND 7 THEN "
+ "INTO dest_tab2 (id, description) VALUES (id, description) "
+ "WHEN id >= 8 THEN "
+ "INTO dest_tab3 (id, description) VALUES (id, description) "
+ "SELECT id, description "
+ "FROM source_tab"
+ )
+
+ self.validate_identity(
+ "INSERT ALL "
+ "WHEN id <= 3 THEN "
+ "INTO dest_tab1 (id, description) VALUES (id, description) "
+ "WHEN id BETWEEN 4 AND 7 THEN "
+ "INTO dest_tab2 (id, description) VALUES (id, description) "
+ "WHEN 1 = 1 THEN "
+ "INTO dest_tab3 (id, description) VALUES (id, description) "
+ "SELECT id, description "
+ "FROM source_tab"
+ )
+
+ self.validate_identity(
+ "INSERT FIRST "
+ "WHEN id <= 3 THEN "
+ "INTO dest_tab1 (id, description) VALUES (id, description) "
+ "WHEN id <= 5 THEN "
+ "INTO dest_tab2 (id, description) VALUES (id, description) "
+ "ELSE "
+ "INTO dest_tab3 (id, description) VALUES (id, description) "
+ "SELECT id, description "
+ "FROM source_tab"
+ )
+
+ self.validate_identity(
+ "INSERT FIRST "
+ "WHEN id <= 3 THEN "
+ "INTO dest_tab1 (id, description) VALUES (id, description) "
+ "ELSE "
+ "INTO dest_tab2 (id, description) VALUES (id, description) "
+ "INTO dest_tab3 (id, description) VALUES (id, description) "
+ "SELECT id, description "
+ "FROM source_tab"
+ )
+
+ self.validate_identity(
+ "/* COMMENT */ INSERT FIRST "
+ "WHEN salary > 4000 THEN INTO emp2 "
+ "WHEN salary > 5000 THEN INTO emp3 "
+ "WHEN salary > 6000 THEN INTO emp4 "
+ "SELECT salary FROM employees"
+ )
+
+ def test_json_functions(self):
+ for format_json in ("", " FORMAT JSON"):
+ for on_cond in (
+ "",
+ " TRUE ON ERROR",
+ " NULL ON EMPTY",
+ " DEFAULT 1 ON ERROR TRUE ON EMPTY",
+ ):
+ for passing in ("", " PASSING 'name1' AS \"var1\", 'name2' AS \"var2\""):
+ with self.subTest("Testing JSON_EXISTS()"):
+ self.validate_identity(
+ f"SELECT * FROM t WHERE JSON_EXISTS(name{format_json}, '$[1].middle'{passing}{on_cond})"
+ )
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index c628db4..f3f21a9 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -846,6 +846,9 @@ class TestPostgres(Validator):
self.validate_identity("ALTER TABLE t1 SET TABLESPACE tablespace")
self.validate_identity("ALTER TABLE t1 SET (fillfactor = 5, autovacuum_enabled = TRUE)")
self.validate_identity(
+ "ALTER TABLE tested_table ADD CONSTRAINT unique_example UNIQUE (column_name) NOT VALID"
+ )
+ self.validate_identity(
"CREATE FUNCTION pymax(a INT, b INT) RETURNS INT LANGUAGE plpython3u AS $$\n if a > b:\n return a\n return b\n$$",
)
self.validate_identity(
@@ -1023,6 +1026,10 @@ class TestPostgres(Validator):
self.validate_identity(
"CREATE INDEX CONCURRENTLY IF NOT EXISTS ix_table_id ON tbl USING btree(id)"
)
+ self.validate_identity("DROP INDEX ix_table_id")
+ self.validate_identity("DROP INDEX IF EXISTS ix_table_id")
+ self.validate_identity("DROP INDEX CONCURRENTLY ix_table_id")
+ self.validate_identity("DROP INDEX CONCURRENTLY IF EXISTS ix_table_id")
self.validate_identity(
"""
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index c8e616e..9c61f62 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -329,11 +329,20 @@ class TestPresto(Validator):
},
)
self.validate_all(
- "DAY_OF_WEEK(timestamp '2012-08-08 01:00:00')",
- write={
+ "((DAY_OF_WEEK(CAST(TRY_CAST('2012-08-08 01:00:00' AS TIMESTAMP) AS DATE)) % 7) + 1)",
+ read={
"spark": "DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
+ },
+ )
+ self.validate_all(
+ "DAY_OF_WEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
+ read={
+ "duckdb": "ISODOW(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
+ },
+ write={
+ "spark": "((DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP)) % 7) + 1)",
"presto": "DAY_OF_WEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
- "duckdb": "DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
+ "duckdb": "ISODOW(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
},
)
@@ -522,6 +531,9 @@ class TestPresto(Validator):
},
)
+ self.validate_identity("""CREATE OR REPLACE VIEW v SECURITY DEFINER AS SELECT id FROM t""")
+ self.validate_identity("""CREATE OR REPLACE VIEW v SECURITY INVOKER AS SELECT id FROM t""")
+
def test_quotes(self):
self.validate_all(
"''''",
@@ -1022,6 +1034,25 @@ class TestPresto(Validator):
"spark": "SELECT REGEXP_EXTRACT(TO_JSON(FROM_JSON('[[1, 2, 3]]', SCHEMA_OF_JSON('[[1, 2, 3]]'))), '^.(.*).$', 1)",
},
)
+ self.validate_all(
+ "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ read={
+ "presto": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "trino": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "duckdb": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "snowflake": "REGEXP_SUBSTR('abc', '(a)(b)(c)')",
+ },
+ write={
+ "presto": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "trino": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "duckdb": "REGEXP_EXTRACT('abc', '(a)(b)(c)')",
+ "snowflake": "REGEXP_SUBSTR('abc', '(a)(b)(c)')",
+ "hive": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 0)",
+ "spark2": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 0)",
+ "spark": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 0)",
+ "databricks": "REGEXP_EXTRACT('abc', '(a)(b)(c)', 0)",
+ },
+ )
def test_encode_decode(self):
self.validate_identity("FROM_UTF8(x, y)")
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 7837cc9..3e0d600 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -11,30 +11,12 @@ class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
- self.validate_identity("1 /* /* */")
- self.validate_identity(
- "SELECT * FROM table AT (TIMESTAMP => '2024-07-24') UNPIVOT(a FOR b IN (c)) AS pivot_table"
- )
-
self.assertEqual(
# Ensures we don't fail when generating ParseJSON with the `safe` arg set to `True`
self.validate_identity("""SELECT TRY_PARSE_JSON('{"x: 1}')""").sql(),
"""SELECT PARSE_JSON('{"x: 1}')""",
)
- self.validate_identity(
- "transform(x, a int -> a + a + 1)",
- "TRANSFORM(x, a -> CAST(a AS INT) + CAST(a AS INT) + 1)",
- )
-
- self.validate_all(
- "ARRAY_CONSTRUCT_COMPACT(1, null, 2)",
- write={
- "spark": "ARRAY_COMPACT(ARRAY(1, NULL, 2))",
- "snowflake": "ARRAY_CONSTRUCT_COMPACT(1, NULL, 2)",
- },
- )
-
expr = parse_one("SELECT APPROX_TOP_K(C4, 3, 5) FROM t")
expr.selects[0].assert_is(exp.AggFunc)
self.assertEqual(expr.sql(dialect="snowflake"), "SELECT APPROX_TOP_K(C4, 3, 5) FROM t")
@@ -98,7 +80,6 @@ WHERE
self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT foo FROM IDENTIFIER('x')")
self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT IDENTIFIER('foo') FROM x")
self.validate_identity("INITCAP('iqamqinterestedqinqthisqtopic', 'q')")
- self.validate_identity("CAST(x AS GEOMETRY)")
self.validate_identity("OBJECT_CONSTRUCT(*)")
self.validate_identity("SELECT CAST('2021-01-01' AS DATE) + INTERVAL '1 DAY'")
self.validate_identity("SELECT HLL(*)")
@@ -115,6 +96,10 @@ WHERE
self.validate_identity("ALTER TABLE a SWAP WITH b")
self.validate_identity("SELECT MATCH_CONDITION")
self.validate_identity("SELECT * REPLACE (CAST(col AS TEXT) AS scol) FROM t")
+ self.validate_identity("1 /* /* */")
+ self.validate_identity(
+ "SELECT * FROM table AT (TIMESTAMP => '2024-07-24') UNPIVOT(a FOR b IN (c)) AS pivot_table"
+ )
self.validate_identity(
"SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q3', '2023_Q4', '2024_Q1') DEFAULT ON NULL (0)) ORDER BY empid"
)
@@ -140,6 +125,18 @@ WHERE
"SELECT * FROM DATA AS DATA_L ASOF JOIN DATA AS DATA_R MATCH_CONDITION (DATA_L.VAL > DATA_R.VAL) ON DATA_L.ID = DATA_R.ID"
)
self.validate_identity(
+ "CAST(x AS GEOGRAPHY)",
+ "TO_GEOGRAPHY(x)",
+ )
+ self.validate_identity(
+ "CAST(x AS GEOMETRY)",
+ "TO_GEOMETRY(x)",
+ )
+ self.validate_identity(
+ "transform(x, a int -> a + a + 1)",
+ "TRANSFORM(x, a -> CAST(a AS INT) + CAST(a AS INT) + 1)",
+ )
+ self.validate_identity(
"SELECT * FROM s WHERE c NOT IN (1, 2, 3)",
"SELECT * FROM s WHERE NOT c IN (1, 2, 3)",
)
@@ -309,6 +306,13 @@ WHERE
)
self.validate_all(
+ "ARRAY_CONSTRUCT_COMPACT(1, null, 2)",
+ write={
+ "spark": "ARRAY_COMPACT(ARRAY(1, NULL, 2))",
+ "snowflake": "ARRAY_CONSTRUCT_COMPACT(1, NULL, 2)",
+ },
+ )
+ self.validate_all(
"OBJECT_CONSTRUCT_KEEP_NULL('key_1', 'one', 'key_2', NULL)",
read={
"bigquery": "JSON_OBJECT(['key_1', 'key_2'], ['one', NULL])",
@@ -1419,6 +1423,12 @@ WHERE
},
)
+ self.assertIsNotNone(
+ self.validate_identity("CREATE TABLE foo (bar INT AS (foo))").find(
+ exp.TransformColumnConstraint
+ )
+ )
+
def test_user_defined_functions(self):
self.validate_all(
"CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS $$ SELECT 1 $$",
@@ -1697,16 +1707,27 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene
"REGEXP_SUBSTR(subject, pattern)",
read={
"bigquery": "REGEXP_EXTRACT(subject, pattern)",
+ "snowflake": "REGEXP_EXTRACT(subject, pattern)",
+ },
+ write={
+ "bigquery": "REGEXP_EXTRACT(subject, pattern)",
+ "snowflake": "REGEXP_SUBSTR(subject, pattern)",
+ },
+ )
+ self.validate_all(
+ "REGEXP_SUBSTR(subject, pattern, 1, 1, 'c', 1)",
+ read={
"hive": "REGEXP_EXTRACT(subject, pattern)",
- "presto": "REGEXP_EXTRACT(subject, pattern)",
+ "spark2": "REGEXP_EXTRACT(subject, pattern)",
"spark": "REGEXP_EXTRACT(subject, pattern)",
+ "databricks": "REGEXP_EXTRACT(subject, pattern)",
},
write={
- "bigquery": "REGEXP_EXTRACT(subject, pattern)",
"hive": "REGEXP_EXTRACT(subject, pattern)",
- "presto": "REGEXP_EXTRACT(subject, pattern)",
- "snowflake": "REGEXP_SUBSTR(subject, pattern)",
+ "spark2": "REGEXP_EXTRACT(subject, pattern)",
"spark": "REGEXP_EXTRACT(subject, pattern)",
+ "databricks": "REGEXP_EXTRACT(subject, pattern)",
+ "snowflake": "REGEXP_SUBSTR(subject, pattern, 1, 1, 'c', 1)",
},
)
self.validate_all(
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index cbaa169..4fed68c 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -309,6 +309,13 @@ TBLPROPERTIES (
)
self.validate_all(
+ "SELECT ARRAY_AGG(x) FILTER (WHERE x = 5) FROM (SELECT 1 UNION ALL SELECT NULL) AS t(x)",
+ write={
+ "duckdb": "SELECT ARRAY_AGG(x) FILTER(WHERE x = 5 AND NOT x IS NULL) FROM (SELECT 1 UNION ALL SELECT NULL) AS t(x)",
+ "spark": "SELECT COLLECT_LIST(x) FILTER(WHERE x = 5) FROM (SELECT 1 UNION ALL SELECT NULL) AS t(x)",
+ },
+ )
+ self.validate_all(
"SELECT DATE_FORMAT(DATE '2020-01-01', 'EEEE') AS weekday",
write={
"presto": "SELECT DATE_FORMAT(CAST(CAST('2020-01-01' AS DATE) AS TIMESTAMP), '%W') AS weekday",
diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py
index fa9a2cc..ee4dc90 100644
--- a/tests/dialects/test_starrocks.py
+++ b/tests/dialects/test_starrocks.py
@@ -1,13 +1,41 @@
+from sqlglot.errors import UnsupportedError
from tests.dialects.test_dialect import Validator
class TestStarrocks(Validator):
dialect = "starrocks"
+ def test_ddl(self):
+ ddl_sqls = [
+ "DISTRIBUTED BY HASH (col1) BUCKETS 1",
+ "DISTRIBUTED BY HASH (col1)",
+ "DISTRIBUTED BY RANDOM BUCKETS 1",
+ "DISTRIBUTED BY RANDOM",
+ "DISTRIBUTED BY HASH (col1) ORDER BY (col1)",
+ "DISTRIBUTED BY HASH (col1) PROPERTIES ('replication_num'='1')",
+ "PRIMARY KEY (col1) DISTRIBUTED BY HASH (col1)",
+ "DUPLICATE KEY (col1, col2) DISTRIBUTED BY HASH (col1)",
+ ]
+
+ for properties in ddl_sqls:
+ with self.subTest(f"Testing create scheme: {properties}"):
+ self.validate_identity(f"CREATE TABLE foo (col1 BIGINT, col2 BIGINT) {properties}")
+ self.validate_identity(
+ f"CREATE TABLE foo (col1 BIGINT, col2 BIGINT) ENGINE=OLAP {properties}"
+ )
+
+ # Test the different wider DECIMAL types
+ self.validate_identity(
+ "CREATE TABLE foo (col0 DECIMAL(9, 1), col1 DECIMAL32(9, 1), col2 DECIMAL64(18, 10), col3 DECIMAL128(38, 10)) DISTRIBUTED BY HASH (col1) BUCKETS 1"
+ )
+
def test_identity(self):
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
self.validate_identity("SELECT APPROX_COUNT_DISTINCT(a) FROM x")
self.validate_identity("SELECT [1, 2, 3]")
+ self.validate_identity(
+ """SELECT CAST(PARSE_JSON(fieldvalue) -> '00000000-0000-0000-0000-00000000' AS VARCHAR) AS `code` FROM (SELECT '{"00000000-0000-0000-0000-00000000":"code01"}') AS t(fieldvalue)"""
+ )
def test_time(self):
self.validate_identity("TIMESTAMP('2022-01-01')")
@@ -35,6 +63,43 @@ class TestStarrocks(Validator):
"SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t",
"SELECT student, score, t.unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS t(unnest)",
)
+ self.validate_all(
+ "SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores)",
+ write={
+ "spark": "SELECT student, score, unnest FROM tests LATERAL VIEW EXPLODE(scores) unnest AS unnest",
+ "starrocks": "SELECT student, score, unnest FROM tests CROSS JOIN LATERAL UNNEST(scores) AS unnest(unnest)",
+ },
+ )
+ self.validate_all(
+ r"""SELECT * FROM UNNEST(array['John','Jane','Jim','Jamie'], array[24,25,26,27]) AS t(name, age)""",
+ write={
+ "postgres": "SELECT * FROM UNNEST(ARRAY['John', 'Jane', 'Jim', 'Jamie'], ARRAY[24, 25, 26, 27]) AS t(name, age)",
+ "spark": "SELECT * FROM INLINE(ARRAYS_ZIP(ARRAY('John', 'Jane', 'Jim', 'Jamie'), ARRAY(24, 25, 26, 27))) AS t(name, age)",
+ "starrocks": "SELECT * FROM UNNEST(['John', 'Jane', 'Jim', 'Jamie'], [24, 25, 26, 27]) AS t(name, age)",
+ },
+ )
+
+ # Use UNNEST to convert into multiple columns
+ # see: https://docs.starrocks.io/docs/sql-reference/sql-functions/array-functions/unnest/
+ self.validate_all(
+ r"""SELECT id, t.type, t.scores FROM example_table, unnest(split(type, ";"), scores) AS t(type,scores)""",
+ write={
+ "postgres": "SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)",
+ "spark": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
+ "databricks": r"""SELECT id, t.type, t.scores FROM example_table LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
+ "starrocks": r"""SELECT id, t.type, t.scores FROM example_table, UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)""",
+ "hive": UnsupportedError,
+ },
+ )
+
+ self.validate_all(
+ r"""SELECT id, t.type, t.scores FROM example_table_2 CROSS JOIN LATERAL unnest(split(type, ";"), scores) AS t(type,scores)""",
+ write={
+ "spark": r"""SELECT id, t.type, t.scores FROM example_table_2 LATERAL VIEW INLINE(ARRAYS_ZIP(SPLIT(type, CONCAT('\\Q', ';')), scores)) t AS type, scores""",
+ "starrocks": r"""SELECT id, t.type, t.scores FROM example_table_2 CROSS JOIN LATERAL UNNEST(SPLIT(type, ';'), scores) AS t(type, scores)""",
+ "hive": UnsupportedError,
+ },
+ )
lateral_explode_sqls = [
"SELECT id, t.col FROM tbl, UNNEST(scores) AS t(col)",
diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py
index 598cb53..466f5d5 100644
--- a/tests/dialects/test_teradata.py
+++ b/tests/dialects/test_teradata.py
@@ -155,6 +155,15 @@ class TestTeradata(Validator):
"tsql": "CREATE TABLE a",
},
)
+ self.validate_identity(
+ "CREATE TABLE db.foo (id INT NOT NULL, valid_date DATE FORMAT 'YYYY-MM-DD', measurement INT COMPRESS)"
+ )
+ self.validate_identity(
+ "CREATE TABLE db.foo (id INT NOT NULL, valid_date DATE FORMAT 'YYYY-MM-DD', measurement INT COMPRESS (1, 2, 3))"
+ )
+ self.validate_identity(
+ "CREATE TABLE db.foo (id INT NOT NULL, valid_date DATE FORMAT 'YYYY-MM-DD' COMPRESS (CAST('9999-09-09' AS DATE)), measurement INT)"
+ )
def test_insert(self):
self.validate_all(
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index ecb83da..7114750 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -420,6 +420,11 @@ class TestTSQL(Validator):
"SELECT val FROM (VALUES ((TRUE), (FALSE), (NULL))) AS t(val)",
write_sql="SELECT val FROM (VALUES ((1), (0), (NULL))) AS t(val)",
)
+ self.validate_identity("'a' + 'b'")
+ self.validate_identity(
+ "'a' || 'b'",
+ "'a' + 'b'",
+ )
def test_option(self):
possible_options = [
@@ -1701,7 +1706,7 @@ WHERE
"duckdb": "LAST_DAY(CAST(CURRENT_TIMESTAMP AS DATE) + INTERVAL (-1) MONTH)",
"mysql": "LAST_DAY(DATE_ADD(CURRENT_TIMESTAMP(), INTERVAL -1 MONTH))",
"postgres": "CAST(DATE_TRUNC('MONTH', CAST(CURRENT_TIMESTAMP AS DATE) + INTERVAL '-1 MONTH') + INTERVAL '1 MONTH' - INTERVAL '1 DAY' AS DATE)",
- "presto": "LAST_DAY_OF_MONTH(DATE_ADD('MONTH', CAST(-1 AS BIGINT), CAST(CAST(CURRENT_TIMESTAMP AS TIMESTAMP) AS DATE)))",
+ "presto": "LAST_DAY_OF_MONTH(DATE_ADD('MONTH', -1, CAST(CAST(CURRENT_TIMESTAMP AS TIMESTAMP) AS DATE)))",
"redshift": "LAST_DAY(DATEADD(MONTH, -1, CAST(GETDATE() AS DATE)))",
"snowflake": "LAST_DAY(DATEADD(MONTH, -1, TO_DATE(CURRENT_TIMESTAMP())))",
"spark": "LAST_DAY(ADD_MONTHS(TO_DATE(CURRENT_TIMESTAMP()), -1))",
@@ -1965,3 +1970,31 @@ FROM OPENJSON(@json) WITH (
base_sql = expr.sql()
self.assertEqual(base_sql, f"SCOPE_RESOLUTION({lhs + ', ' if lhs else ''}{rhs})")
self.assertEqual(parse_one(base_sql).sql("tsql"), f"{lhs}::{rhs}")
+
+ def test_count(self):
+ count = annotate_types(self.validate_identity("SELECT COUNT(1) FROM x"))
+ self.assertEqual(count.expressions[0].type.this, exp.DataType.Type.INT)
+
+ count_big = annotate_types(self.validate_identity("SELECT COUNT_BIG(1) FROM x"))
+ self.assertEqual(count_big.expressions[0].type.this, exp.DataType.Type.BIGINT)
+
+ self.validate_all(
+ "SELECT COUNT_BIG(1) FROM x",
+ read={
+ "duckdb": "SELECT COUNT(1) FROM x",
+ "spark": "SELECT COUNT(1) FROM x",
+ },
+ write={
+ "duckdb": "SELECT COUNT(1) FROM x",
+ "spark": "SELECT COUNT(1) FROM x",
+ "tsql": "SELECT COUNT_BIG(1) FROM x",
+ },
+ )
+ self.validate_all(
+ "SELECT COUNT(1) FROM x",
+ write={
+ "duckdb": "SELECT COUNT(1) FROM x",
+ "spark": "SELECT COUNT(1) FROM x",
+ "tsql": "SELECT COUNT(1) FROM x",
+ },
+ )
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 20cbe7f..013eed8 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -596,6 +596,7 @@ CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (CYCLE))
CREATE TABLE customer (period INT NOT NULL)
CREATE TABLE foo (baz_id INT REFERENCES baz (id) DEFERRABLE)
CREATE TABLE foo (baz CHAR(4) CHARACTER SET LATIN UPPERCASE NOT CASESPECIFIC COMPRESS 'a')
+CREATE TABLE db.foo (id INT NOT NULL, valid_date DATE FORMAT 'YYYY-MM-DD', measurement INT COMPRESS)
CREATE TABLE foo (baz DATE FORMAT 'YYYY/MM/DD' TITLE 'title' INLINE LENGTH 1 COMPRESS ('a', 'b'))
CREATE TABLE t (title TEXT)
CREATE TABLE foo (baz INT, inline TEXT)
@@ -877,4 +878,4 @@ SELECT * FROM a STRAIGHT_JOIN b
SELECT COUNT(DISTINCT "foo bar") FROM (SELECT 1 AS "foo bar") AS t
SELECT vector
WITH all AS (SELECT 1 AS count) SELECT all.count FROM all
-SELECT rename \ No newline at end of file
+SELECT rename
diff --git a/tests/fixtures/optimizer/annotate_types.sql b/tests/fixtures/optimizer/annotate_types.sql
index 0a5fc22..f608851 100644
--- a/tests/fixtures/optimizer/annotate_types.sql
+++ b/tests/fixtures/optimizer/annotate_types.sql
@@ -1,13 +1,25 @@
5;
INT;
+-5;
+INT;
+
+~5;
+INT;
+
+(5);
+INT;
+
5.3;
DOUBLE;
'bla';
VARCHAR;
-True;
+true;
+bool;
+
+not true;
bool;
false;
diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql
index c5f8a4f..76fc16d 100644
--- a/tests/fixtures/optimizer/optimizer.sql
+++ b/tests/fixtures/optimizer/optimizer.sql
@@ -838,7 +838,7 @@ SELECT
FROM `bigquery-public-data.GooGle_tReNDs.TOp_TeRmS` AS `TOp_TeRmS`
WHERE
`TOp_TeRmS`.`rank` = 1
- AND `TOp_TeRmS`.`refresh_date` >= DATE_SUB(CURRENT_DATE, INTERVAL 2 WEEK)
+ AND `TOp_TeRmS`.`refresh_date` >= DATE_SUB(CURRENT_DATE, INTERVAL '2' WEEK)
GROUP BY
`day`,
`top_term`,
diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql
index 8e10517..3e5619a 100644
--- a/tests/fixtures/pretty.sql
+++ b/tests/fixtures/pretty.sql
@@ -405,3 +405,16 @@ JOIN b
'eeeeeeeeeeeeeeeeeeeee'
);
+/* COMMENT */
+INSERT FIRST WHEN salary > 4000 THEN INTO emp2
+ WHEN salary > 5000 THEN INTO emp3
+ WHEN salary > 6000 THEN INTO emp4
+SELECT salary FROM employees;
+/* COMMENT */
+INSERT FIRST
+ WHEN salary > 4000 THEN INTO emp2
+ WHEN salary > 5000 THEN INTO emp3
+ WHEN salary > 6000 THEN INTO emp4
+SELECT
+ salary
+FROM employees;
diff --git a/tests/test_build.py b/tests/test_build.py
index 150bb42..e074fea 100644
--- a/tests/test_build.py
+++ b/tests/test_build.py
@@ -361,10 +361,34 @@ class TestBuild(unittest.TestCase):
(
lambda: select("x")
.from_("tbl")
+ .with_("tbl", as_="SELECT x FROM tbl2", materialized=True),
+ "WITH tbl AS MATERIALIZED (SELECT x FROM tbl2) SELECT x FROM tbl",
+ ),
+ (
+ lambda: select("x")
+ .from_("tbl")
+ .with_("tbl", as_="SELECT x FROM tbl2", materialized=False),
+ "WITH tbl AS NOT MATERIALIZED (SELECT x FROM tbl2) SELECT x FROM tbl",
+ ),
+ (
+ lambda: select("x")
+ .from_("tbl")
.with_("tbl", as_="SELECT x FROM tbl2", recursive=True),
"WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
(
+ lambda: select("x")
+ .from_("tbl")
+ .with_("tbl", as_=select("x").from_("tbl2"), recursive=True, materialized=True),
+ "WITH RECURSIVE tbl AS MATERIALIZED (SELECT x FROM tbl2) SELECT x FROM tbl",
+ ),
+ (
+ lambda: select("x")
+ .from_("tbl")
+ .with_("tbl", as_=select("x").from_("tbl2"), recursive=True, materialized=False),
+ "WITH RECURSIVE tbl AS NOT MATERIALIZED (SELECT x FROM tbl2) SELECT x FROM tbl",
+ ),
+ (
lambda: select("x").from_("tbl").with_("tbl", as_=select("x").from_("tbl2")),
"WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl",
),
@@ -677,6 +701,18 @@ class TestBuild(unittest.TestCase):
"WITH cte AS (SELECT x FROM tbl) INSERT INTO t SELECT * FROM cte",
),
(
+ lambda: exp.insert("SELECT * FROM cte", "t").with_(
+ "cte", as_="SELECT x FROM tbl", materialized=True
+ ),
+ "WITH cte AS MATERIALIZED (SELECT x FROM tbl) INSERT INTO t SELECT * FROM cte",
+ ),
+ (
+ lambda: exp.insert("SELECT * FROM cte", "t").with_(
+ "cte", as_="SELECT x FROM tbl", materialized=False
+ ),
+ "WITH cte AS NOT MATERIALIZED (SELECT x FROM tbl) INSERT INTO t SELECT * FROM cte",
+ ),
+ (
lambda: exp.convert((exp.column("x"), exp.column("y"))).isin((1, 2), (3, 4)),
"(x, y) IN ((1, 2), (3, 4))",
"postgres",
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index d6e11a9..fe5a4d7 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -1345,3 +1345,26 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(4, normalization_distance(gen_expr(2), max_=100))
self.assertEqual(18, normalization_distance(gen_expr(3), max_=100))
self.assertEqual(110, normalization_distance(gen_expr(10), max_=100))
+
+ def test_custom_annotators(self):
+ # In Spark hierarchy, SUBSTRING result type is dependent on input expr type
+ for dialect in ("spark2", "spark", "databricks"):
+ for expr_type_pair in (
+ ("col", "STRING"),
+ ("col", "BINARY"),
+ ("'str_literal'", "STRING"),
+ ("CAST('str_literal' AS BINARY)", "BINARY"),
+ ):
+ with self.subTest(
+ f"Testing {dialect}'s SUBSTRING() result type for {expr_type_pair}"
+ ):
+ expr, type = expr_type_pair
+ ast = parse_one(f"SELECT substring({expr}, 2, 3) AS x FROM tbl", read=dialect)
+
+ subst_type = (
+ optimizer.optimize(ast, schema={"tbl": {"col": type}}, dialect=dialect)
+ .expressions[0]
+ .type
+ )
+
+ self.assertEqual(subst_type.sql(dialect), exp.DataType.build(type).sql(dialect))
diff --git a/tests/test_parser.py b/tests/test_parser.py
index ff82e08..9ff8373 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -854,3 +854,14 @@ class TestParser(unittest.TestCase):
).find(exp.Collate)
self.assertIsInstance(collate_node, exp.Collate)
self.assertIsInstance(collate_node.expression, collate_pair[1])
+
+ def test_odbc_date_literals(self):
+ for value, cls in [
+ ("{d'2024-01-01'}", exp.Date),
+ ("{t'12:00:00'}", exp.Time),
+ ("{ts'2024-01-01 12:00:00'}", exp.Timestamp),
+ ]:
+ sql = f"INSERT INTO tab(ds) VALUES ({value})"
+ expr = parse_one(sql)
+ self.assertIsInstance(expr, exp.Insert)
+ self.assertIsInstance(expr.expression.expressions[0].expressions[0], cls)