summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_dialect.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects/test_dialect.py')
-rw-r--r--tests/dialects/test_dialect.py116
1 files changed, 97 insertions, 19 deletions
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 63f789f..6a41218 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -1,6 +1,13 @@
import unittest
-from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
+from sqlglot import (
+ Dialect,
+ Dialects,
+ ErrorLevel,
+ ParseError,
+ UnsupportedError,
+ parse_one,
+)
from sqlglot.dialects import Hive
@@ -23,9 +30,10 @@ class Validator(unittest.TestCase):
Args:
sql (str): Main SQL expression
- dialect (str): dialect of `sql`
read (dict): Mapping of dialect -> SQL
write (dict): Mapping of dialect -> SQL
+ pretty (bool): prettify both read and write
+ identify (bool): quote identifiers in both read and write
"""
expression = self.parse_one(sql)
@@ -78,7 +86,7 @@ class TestDialect(Validator):
"CAST(a AS TEXT)",
write={
"bigquery": "CAST(a AS STRING)",
- "clickhouse": "CAST(a AS TEXT)",
+ "clickhouse": "CAST(a AS String)",
"drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS CHAR)",
@@ -116,7 +124,7 @@ class TestDialect(Validator):
"CAST(a AS VARBINARY(4))",
write={
"bigquery": "CAST(a AS BYTES)",
- "clickhouse": "CAST(a AS VARBINARY(4))",
+ "clickhouse": "CAST(a AS String)",
"duckdb": "CAST(a AS BLOB(4))",
"mysql": "CAST(a AS VARBINARY(4))",
"hive": "CAST(a AS BINARY(4))",
@@ -133,7 +141,7 @@ class TestDialect(Validator):
self.validate_all(
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
write={
- "clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))",
+ "clickhouse": "CAST(map('a', '1') AS Map(String, String))",
},
)
self.validate_all(
@@ -367,6 +375,60 @@ class TestDialect(Validator):
},
)
+ def test_nvl2(self):
+ self.validate_all(
+ "SELECT NVL2(a, b, c)",
+ write={
+ "": "SELECT NVL2(a, b, c)",
+ "bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "databricks": "SELECT NVL2(a, b, c)",
+ "doris": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "drill": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "duckdb": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "hive": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "mysql": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "oracle": "SELECT NVL2(a, b, c)",
+ "postgres": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "presto": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "redshift": "SELECT NVL2(a, b, c)",
+ "snowflake": "SELECT NVL2(a, b, c)",
+ "spark": "SELECT NVL2(a, b, c)",
+ "spark2": "SELECT NVL2(a, b, c)",
+ "sqlite": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "starrocks": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "teradata": "SELECT NVL2(a, b, c)",
+ "trino": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "tsql": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ },
+ )
+ self.validate_all(
+ "SELECT NVL2(a, b)",
+ write={
+ "": "SELECT NVL2(a, b)",
+ "bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "databricks": "SELECT NVL2(a, b)",
+ "doris": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "drill": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "duckdb": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "hive": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "mysql": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "oracle": "SELECT NVL2(a, b)",
+ "postgres": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "presto": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "redshift": "SELECT NVL2(a, b)",
+ "snowflake": "SELECT NVL2(a, b)",
+ "spark": "SELECT NVL2(a, b)",
+ "spark2": "SELECT NVL2(a, b)",
+ "sqlite": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "starrocks": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "teradata": "SELECT NVL2(a, b)",
+ "trino": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "tsql": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ },
+ )
+
def test_time(self):
self.validate_all(
"STR_TO_TIME(x, '%Y-%m-%dT%H:%M:%S')",
@@ -860,7 +922,7 @@ class TestDialect(Validator):
"ARRAY(0, 1, 2)",
write={
"bigquery": "[0, 1, 2]",
- "duckdb": "LIST_VALUE(0, 1, 2)",
+ "duckdb": "[0, 1, 2]",
"presto": "ARRAY[0, 1, 2]",
"spark": "ARRAY(0, 1, 2)",
},
@@ -879,7 +941,7 @@ class TestDialect(Validator):
"ARRAY_SUM(ARRAY(1, 2))",
write={
"trino": "REDUCE(ARRAY[1, 2], 0, (acc, x) -> acc + x, acc -> acc)",
- "duckdb": "LIST_SUM(LIST_VALUE(1, 2))",
+ "duckdb": "LIST_SUM([1, 2])",
"hive": "ARRAY_SUM(ARRAY(1, 2))",
"presto": "ARRAY_SUM(ARRAY[1, 2])",
"spark": "AGGREGATE(ARRAY(1, 2), 0, (acc, x) -> acc + x, acc -> acc)",
@@ -1403,27 +1465,27 @@ class TestDialect(Validator):
},
)
self.validate_all(
- "CREATE INDEX my_idx ON tbl (a, b)",
+ "CREATE INDEX my_idx ON tbl(a, b)",
read={
- "hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)",
- "sqlite": "CREATE INDEX my_idx ON tbl (a, b)",
+ "hive": "CREATE INDEX my_idx ON TABLE tbl(a, b)",
+ "sqlite": "CREATE INDEX my_idx ON tbl(a, b)",
},
write={
- "hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)",
- "postgres": "CREATE INDEX my_idx ON tbl (a NULLS FIRST, b NULLS FIRST)",
- "sqlite": "CREATE INDEX my_idx ON tbl (a, b)",
+ "hive": "CREATE INDEX my_idx ON TABLE tbl(a, b)",
+ "postgres": "CREATE INDEX my_idx ON tbl(a NULLS FIRST, b NULLS FIRST)",
+ "sqlite": "CREATE INDEX my_idx ON tbl(a, b)",
},
)
self.validate_all(
- "CREATE UNIQUE INDEX my_idx ON tbl (a, b)",
+ "CREATE UNIQUE INDEX my_idx ON tbl(a, b)",
read={
- "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)",
- "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)",
+ "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl(a, b)",
+ "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl(a, b)",
},
write={
- "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)",
- "postgres": "CREATE UNIQUE INDEX my_idx ON tbl (a NULLS FIRST, b NULLS FIRST)",
- "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)",
+ "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl(a, b)",
+ "postgres": "CREATE UNIQUE INDEX my_idx ON tbl(a NULLS FIRST, b NULLS FIRST)",
+ "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl(a, b)",
},
)
self.validate_all(
@@ -1710,3 +1772,19 @@ SELECT
"tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
},
)
+
+ def test_cast_to_user_defined_type(self):
+ self.validate_all(
+ "CAST(x AS some_udt)",
+ write={
+ "": "CAST(x AS some_udt)",
+ "oracle": "CAST(x AS some_udt)",
+ "postgres": "CAST(x AS some_udt)",
+ "presto": "CAST(x AS some_udt)",
+ "teradata": "CAST(x AS some_udt)",
+ "tsql": "CAST(x AS some_udt)",
+ },
+ )
+
+ with self.assertRaises(ParseError):
+ parse_one("CAST(x AS some_udt)", read="bigquery")