diff options
Diffstat (limited to 'tests/dialects/test_dialect.py')
-rw-r--r-- | tests/dialects/test_dialect.py | 156 |
1 files changed, 127 insertions, 29 deletions
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 3b837df..1913f53 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1,20 +1,18 @@ import unittest -from sqlglot import ( - Dialect, - Dialects, - ErrorLevel, - UnsupportedError, - parse_one, - transpile, -) +from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one class Validator(unittest.TestCase): dialect = None - def validate_identity(self, sql): - self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql) + def parse_one(self, sql): + return parse_one(sql, read=self.dialect) + + def validate_identity(self, sql, write_sql=None): + expression = self.parse_one(sql) + self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect)) + return expression def validate_all(self, sql, read=None, write=None, pretty=False): """ @@ -28,12 +26,14 @@ class Validator(unittest.TestCase): read (dict): Mapping of dialect -> SQL write (dict): Mapping of dialect -> SQL """ - expression = parse_one(sql, read=self.dialect) + expression = self.parse_one(sql) for read_dialect, read_sql in (read or {}).items(): with self.subTest(f"{read_dialect} -> {sql}"): self.assertEqual( - parse_one(read_sql, read_dialect).sql(self.dialect, unsupported_level=ErrorLevel.IGNORE), + parse_one(read_sql, read_dialect).sql( + self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty + ), sql, ) @@ -83,10 +83,6 @@ class TestDialect(Validator): ) self.validate_all( "CAST(a AS BINARY(4))", - read={ - "presto": "CAST(a AS VARBINARY(4))", - "sqlite": "CAST(a AS VARBINARY(4))", - }, write={ "bigquery": "CAST(a AS BINARY(4))", "clickhouse": "CAST(a AS BINARY(4))", @@ -104,6 +100,24 @@ class TestDialect(Validator): }, ) self.validate_all( + "CAST(a AS VARBINARY(4))", + write={ + "bigquery": "CAST(a AS VARBINARY(4))", + "clickhouse": "CAST(a AS VARBINARY(4))", + "duckdb": "CAST(a AS VARBINARY(4))", + "mysql": "CAST(a AS VARBINARY(4))", + "hive": "CAST(a AS BINARY(4))", + "oracle": "CAST(a AS BLOB(4))", + "postgres": "CAST(a AS BYTEA(4))", + "presto": "CAST(a AS VARBINARY(4))", + "redshift": "CAST(a AS VARBYTE(4))", + "snowflake": "CAST(a AS VARBINARY(4))", + "sqlite": "CAST(a AS BLOB(4))", + "spark": "CAST(a AS BINARY(4))", + "starrocks": "CAST(a AS VARBINARY(4))", + }, + ) + self.validate_all( "CAST(MAP('a', '1') AS MAP(TEXT, TEXT))", write={ "clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))", @@ -472,45 +486,57 @@ class TestDialect(Validator): }, ) self.validate_all( - "DATE_TRUNC(x, 'day')", + "DATE_TRUNC('day', x)", write={ "mysql": "DATE(x)", - "starrocks": "DATE(x)", }, ) self.validate_all( - "DATE_TRUNC(x, 'week')", + "DATE_TRUNC('week', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", }, ) self.validate_all( - "DATE_TRUNC(x, 'month')", + "DATE_TRUNC('month', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')", }, ) self.validate_all( - "DATE_TRUNC(x, 'quarter')", + "DATE_TRUNC('quarter', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')", }, ) self.validate_all( - "DATE_TRUNC(x, 'year')", + "DATE_TRUNC('year', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", - "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", }, ) self.validate_all( - "DATE_TRUNC(x, 'millenium')", + "DATE_TRUNC('millenium', x)", write={ "mysql": UnsupportedError, - "starrocks": UnsupportedError, + }, + ) + self.validate_all( + "DATE_TRUNC('year', x)", + read={ + "starrocks": "DATE_TRUNC('year', x)", + }, + write={ + "starrocks": "DATE_TRUNC('year', x)", + }, + ) + self.validate_all( + "DATE_TRUNC(x, year)", + read={ + "bigquery": "DATE_TRUNC(x, year)", + }, + write={ + "bigquery": "DATE_TRUNC(x, year)", }, ) self.validate_all( @@ -564,6 +590,22 @@ class TestDialect(Validator): "spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", }, ) + self.validate_all( + "TIMESTAMP '2022-01-01'", + write={ + "mysql": "CAST('2022-01-01' AS TIMESTAMP)", + "starrocks": "CAST('2022-01-01' AS DATETIME)", + "hive": "CAST('2022-01-01' AS TIMESTAMP)", + }, + ) + self.validate_all( + "TIMESTAMP('2022-01-01')", + write={ + "mysql": "TIMESTAMP('2022-01-01')", + "starrocks": "TIMESTAMP('2022-01-01')", + "hive": "TIMESTAMP('2022-01-01')", + }, + ) for unit in ("DAY", "MONTH", "YEAR"): self.validate_all( @@ -1002,7 +1044,10 @@ class TestDialect(Validator): ) def test_limit(self): - self.validate_all("SELECT * FROM data LIMIT 10, 20", write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"}) + self.validate_all( + "SELECT * FROM data LIMIT 10, 20", + write={"sqlite": "SELECT * FROM data LIMIT 10 OFFSET 20"}, + ) self.validate_all( "SELECT x FROM y LIMIT 10", write={ @@ -1132,3 +1177,56 @@ class TestDialect(Validator): "sqlite": "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, e AS d FROM table2) SELECT b, d AS dd FROM cte1 AS t JOIN cte2 WHERE cte1.a = cte2.c", }, ) + + def test_nullsafe_eq(self): + self.validate_all( + "SELECT a IS NOT DISTINCT FROM b", + read={ + "mysql": "SELECT a <=> b", + "postgres": "SELECT a IS NOT DISTINCT FROM b", + }, + write={ + "mysql": "SELECT a <=> b", + "postgres": "SELECT a IS NOT DISTINCT FROM b", + }, + ) + + def test_nullsafe_neq(self): + self.validate_all( + "SELECT a IS DISTINCT FROM b", + read={ + "postgres": "SELECT a IS DISTINCT FROM b", + }, + write={ + "mysql": "SELECT NOT a <=> b", + "postgres": "SELECT a IS DISTINCT FROM b", + }, + ) + + def test_hash_comments(self): + self.validate_all( + "SELECT 1 /* arbitrary content,,, until end-of-line */", + read={ + "mysql": "SELECT 1 # arbitrary content,,, until end-of-line", + "bigquery": "SELECT 1 # arbitrary content,,, until end-of-line", + "clickhouse": "SELECT 1 #! arbitrary content,,, until end-of-line", + }, + ) + self.validate_all( + """/* comment1 */ +SELECT + x, -- comment2 + y -- comment3""", + read={ + "mysql": """SELECT # comment1 + x, # comment2 + y # comment3""", + "bigquery": """SELECT # comment1 + x, # comment2 + y # comment3""", + "clickhouse": """SELECT # comment1 + x, # comment2 + y # comment3""", + }, + pretty=True, + ) |