summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_dialect.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects/test_dialect.py')
-rw-r--r--tests/dialects/test_dialect.py156
1 files changed, 127 insertions, 29 deletions
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 3b837df..1913f53 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -1,20 +1,18 @@
import unittest
-from sqlglot import (
- Dialect,
- Dialects,
- ErrorLevel,
- UnsupportedError,
- parse_one,
- transpile,
-)
+from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
class Validator(unittest.TestCase):
dialect = None
- def validate_identity(self, sql):
- self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql)
+ def parse_one(self, sql):
+ return parse_one(sql, read=self.dialect)
+
+ def validate_identity(self, sql, write_sql=None):
+ expression = self.parse_one(sql)
+ self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect))
+ return expression
def validate_all(self, sql, read=None, write=None, pretty=False):
"""
@@ -28,12 +26,14 @@ class Validator(unittest.TestCase):
read (dict): Mapping of dialect -> SQL
write (dict): Mapping of dialect -> SQL
"""
- expression = parse_one(sql, read=self.dialect)
+ expression = self.parse_one(sql)
for read_dialect, read_sql in (read or {}).items():
with self.subTest(f"{read_dialect} -> {sql}"):
self.assertEqual(
- parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE),
+ parse_one(read_sql, read_dialect).sql(
+ self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty
+ ),
sql,
)
@@ -83,10 +83,6 @@ class TestDialect(Validator):
)
self.validate_all(
"CAST(a AS BINARY(4))",
- read={
- "presto": "CAST(a AS VARBINARY(4))",
- "sqlite": "CAST(a AS VARBINARY(4))",
- },
write={
"bigquery": "CAST(a AS BINARY(4))",
"clickhouse": "CAST(a AS BINARY(4))",
@@ -104,6 +100,24 @@ class TestDialect(Validator):
},
)
self.validate_all(
+ "CAST(a AS VARBINARY(4))",
+ write={
+ "bigquery": "CAST(a AS VARBINARY(4))",
+ "clickhouse": "CAST(a AS VARBINARY(4))",
+ "duckdb": "CAST(a AS VARBINARY(4))",
+ "mysql": "CAST(a AS VARBINARY(4))",
+ "hive": "CAST(a AS BINARY(4))",
+ "oracle": "CAST(a AS BLOB(4))",
+ "postgres": "CAST(a AS BYTEA(4))",
+ "presto": "CAST(a AS VARBINARY(4))",
+ "redshift": "CAST(a AS VARBYTE(4))",
+ "snowflake": "CAST(a AS VARBINARY(4))",
+ "sqlite": "CAST(a AS BLOB(4))",
+ "spark": "CAST(a AS BINARY(4))",
+ "starrocks": "CAST(a AS VARBINARY(4))",
+ },
+ )
+ self.validate_all(
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
write={
"clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))",
@@ -472,45 +486,57 @@ class TestDialect(Validator):
},
)
self.validate_all(
- "DATE_TRUNC(x, 'day')",
+ "DATE_TRUNC('day', x)",
write={
"mysql": "DATE(x)",
- "starrocks": "DATE(x)",
},
)
self.validate_all(
- "DATE_TRUNC(x, 'week')",
+ "DATE_TRUNC('week', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
- "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')",
},
)
self.validate_all(
- "DATE_TRUNC(x, 'month')",
+ "DATE_TRUNC('month', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
- "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')",
},
)
self.validate_all(
- "DATE_TRUNC(x, 'quarter')",
+ "DATE_TRUNC('quarter', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
- "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')",
},
)
self.validate_all(
- "DATE_TRUNC(x, 'year')",
+ "DATE_TRUNC('year', x)",
write={
"mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
- "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')",
},
)
self.validate_all(
- "DATE_TRUNC(x, 'millenium')",
+ "DATE_TRUNC('millenium', x)",
write={
"mysql": UnsupportedError,
- "starrocks": UnsupportedError,
+ },
+ )
+ self.validate_all(
+ "DATE_TRUNC('year', x)",
+ read={
+ "starrocks": "DATE_TRUNC('year', x)",
+ },
+ write={
+ "starrocks": "DATE_TRUNC('year', x)",
+ },
+ )
+ self.validate_all(
+ "DATE_TRUNC(x, year)",
+ read={
+ "bigquery": "DATE_TRUNC(x, year)",
+ },
+ write={
+ "bigquery": "DATE_TRUNC(x, year)",
},
)
self.validate_all(
@@ -564,6 +590,22 @@ class TestDialect(Validator):
"spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
},
)
+ self.validate_all(
+ "TIMESTAMP '2022-01-01'",
+ write={
+ "mysql": "CAST('2022-01-01' AS TIMESTAMP)",
+ "starrocks": "CAST('2022-01-01' AS DATETIME)",
+ "hive": "CAST('2022-01-01' AS TIMESTAMP)",
+ },
+ )
+ self.validate_all(
+ "TIMESTAMP('2022-01-01')",
+ write={
+ "mysql": "TIMESTAMP('2022-01-01')",
+ "starrocks": "TIMESTAMP('2022-01-01')",
+ "hive": "TIMESTAMP('2022-01-01')",
+ },
+ )
for unit in ("DAY", "MONTH", "YEAR"):
self.validate_all(
@@ -1002,7 +1044,10 @@ class TestDialect(Validator):
)
def test_limit(self):
- self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"})
+ self.validate_all(
+ "SELECT * FROM data LIMIT 10, 20",
+ write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"},
+ )
self.validate_all(
"SELECT x FROM y LIMIT 10",
write={
@@ -1132,3 +1177,56 @@ class TestDialect(Validator):
"sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c",
},
)
+
+ def test_nullsafe_eq(self):
+ self.validate_all(
+ "SELECT a IS NOT DISTINCT FROM b",
+ read={
+ "mysql": "SELECT a <=> b",
+ "postgres": "SELECT a IS NOT DISTINCT FROM b",
+ },
+ write={
+ "mysql": "SELECT a <=> b",
+ "postgres": "SELECT a IS NOT DISTINCT FROM b",
+ },
+ )
+
+ def test_nullsafe_neq(self):
+ self.validate_all(
+ "SELECT a IS DISTINCT FROM b",
+ read={
+ "postgres": "SELECT a IS DISTINCT FROM b",
+ },
+ write={
+ "mysql": "SELECT NOT a <=> b",
+ "postgres": "SELECT a IS DISTINCT FROM b",
+ },
+ )
+
+ def test_hash_comments(self):
+ self.validate_all(
+ "SELECT 1 /* arbitrary content,,, until end-of-line */",
+ read={
+ "mysql": "SELECT 1 # arbitrary content,,, until end-of-line",
+ "bigquery": "SELECT 1 # arbitrary content,,, until end-of-line",
+ "clickhouse": "SELECT 1 #! arbitrary content,,, until end-of-line",
+ },
+ )
+ self.validate_all(
+ """/* comment1 */
+SELECT
+ x, -- comment2
+ y -- comment3""",
+ read={
+ "mysql": """SELECT # comment1
+ x, # comment2
+ y # comment3""",
+ "bigquery": """SELECT # comment1
+ x, # comment2
+ y # comment3""",
+ "clickhouse": """SELECT # comment1
+ x, # comment2
+ y # comment3""",
+ },
+ pretty=True,
+ )