summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_dialect.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects/test_dialect.py')
-rw-r--r--tests/dialects/test_dialect.py117
1 files changed, 111 insertions, 6 deletions
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 5faed51..76ab94b 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -17,8 +17,8 @@ from sqlglot.parser import logger as parser_logger
class Validator(unittest.TestCase):
dialect = None
- def parse_one(self, sql):
- return parse_one(sql, read=self.dialect)
+ def parse_one(self, sql, **kwargs):
+ return parse_one(sql, read=self.dialect, **kwargs)
def validate_identity(self, sql, write_sql=None, pretty=False, check_command_warning=False):
if check_command_warning:
@@ -611,7 +611,7 @@ class TestDialect(Validator):
write={
"duckdb": "EPOCH(STRPTIME('2020-01-01', '%Y-%m-%d'))",
"hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
- "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d'))",
+ "presto": "TO_UNIXTIME(COALESCE(TRY(DATE_PARSE(CAST('2020-01-01' AS VARCHAR), '%Y-%m-%d')), PARSE_DATETIME(CAST('2020-01-01' AS VARCHAR), 'yyyy-MM-dd')))",
"starrocks": "UNIX_TIMESTAMP('2020-01-01', '%Y-%m-%d')",
"doris": "UNIX_TIMESTAMP('2020-01-01', '%Y-%m-%d')",
},
@@ -700,7 +700,7 @@ class TestDialect(Validator):
"hive": "TO_DATE(x)",
"postgres": "CAST(x AS DATE)",
"presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)",
- "snowflake": "CAST(x AS DATE)",
+ "snowflake": "TO_DATE(x)",
"doris": "TO_DATE(x)",
"mysql": "DATE(x)",
},
@@ -961,6 +961,7 @@ class TestDialect(Validator):
"presto": "CAST(x AS DATE)",
"spark": "CAST(x AS DATE)",
"sqlite": "x",
+ "tsql": "CAST(x AS DATE)",
},
)
self.validate_all(
@@ -1509,7 +1510,7 @@ class TestDialect(Validator):
"POSITION(needle, haystack, pos)",
write={
"drill": "STRPOS(SUBSTR(haystack, pos), needle) + pos - 1",
- "presto": "STRPOS(haystack, needle, pos)",
+ "presto": "STRPOS(SUBSTR(haystack, pos), needle) + pos - 1",
"spark": "LOCATE(needle, haystack, pos)",
"clickhouse": "position(haystack, needle, pos)",
"snowflake": "POSITION(needle, haystack, pos)",
@@ -1719,6 +1720,11 @@ class TestDialect(Validator):
with self.subTest(f"{expression.__class__.__name__} {dialect} -> {expected}"):
self.assertEqual(expected, expression.sql(dialect=dialect))
+ self.assertEqual(
+ parse_one("CAST(x AS DECIMAL) / y", read="mysql").sql(dialect="postgres"),
+ "CAST(x AS DECIMAL) / NULLIF(y, 0)",
+ )
+
def test_limit(self):
self.validate_all(
"SELECT * FROM data LIMIT 10, 20",
@@ -2054,6 +2060,44 @@ SELECT
)
def test_logarithm(self):
+ for base in (2, 10):
+ with self.subTest(f"Transpiling LOG base {base}"):
+ self.validate_all(
+ f"LOG({base}, a)",
+ read={
+ "": f"LOG{base}(a)",
+ "bigquery": f"LOG{base}(a)",
+ "clickhouse": f"LOG{base}(a)",
+ "databricks": f"LOG{base}(a)",
+ "duckdb": f"LOG{base}(a)",
+ "mysql": f"LOG{base}(a)",
+ "postgres": f"LOG{base}(a)",
+ "presto": f"LOG{base}(a)",
+ "spark": f"LOG{base}(a)",
+ "sqlite": f"LOG{base}(a)",
+ "trino": f"LOG{base}(a)",
+ "tsql": f"LOG{base}(a)",
+ },
+ write={
+ "bigquery": f"LOG(a, {base})",
+ "clickhouse": f"LOG{base}(a)",
+ "duckdb": f"LOG({base}, a)",
+ "mysql": f"LOG({base}, a)",
+ "oracle": f"LOG({base}, a)",
+ "postgres": f"LOG({base}, a)",
+ "presto": f"LOG{base}(a)",
+ "redshift": f"LOG({base}, a)",
+ "snowflake": f"LOG({base}, a)",
+ "spark2": f"LOG({base}, a)",
+ "spark": f"LOG({base}, a)",
+ "sqlite": f"LOG({base}, a)",
+ "starrocks": f"LOG({base}, a)",
+ "tableau": f"LOG(a, {base})",
+ "trino": f"LOG({base}, a)",
+ "tsql": f"LOG(a, {base})",
+ },
+ )
+
self.validate_all(
"LOG(x)",
read={
@@ -2082,6 +2126,7 @@ SELECT
"bigquery": "LOG(n, b)",
"databricks": "LOG(b, n)",
"drill": "LOG(b, n)",
+ "duckdb": "LOG(b, n)",
"hive": "LOG(b, n)",
"mysql": "LOG(b, n)",
"oracle": "LOG(b, n)",
@@ -2089,8 +2134,13 @@ SELECT
"snowflake": "LOG(b, n)",
"spark": "LOG(b, n)",
"sqlite": "LOG(b, n)",
+ "trino": "LOG(b, n)",
"tsql": "LOG(n, b)",
},
+ write={
+ "clickhouse": UnsupportedError,
+ "presto": UnsupportedError,
+ },
)
def test_count_if(self):
@@ -2190,7 +2240,28 @@ SELECT
"WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq",
write={
"duckdb": "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq",
- "tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y AS y FROM t2) AS subq",
+ "tsql": "WITH t2(y) AS (SELECT 2), t1(x) AS (SELECT 1) SELECT * FROM (SELECT y AS y FROM t2) AS subq",
+ },
+ )
+ self.validate_all(
+ """
+WITH c AS (
+ WITH b AS (
+ WITH a1 AS (
+ SELECT 1
+ ), a2 AS (
+ SELECT 2
+ )
+ SELECT * FROM a1, a2
+ )
+ SELECT *
+ FROM b
+)
+SELECT *
+FROM c""",
+ write={
+ "duckdb": "WITH c AS (WITH b AS (WITH a1 AS (SELECT 1), a2 AS (SELECT 2) SELECT * FROM a1, a2) SELECT * FROM b) SELECT * FROM c",
+ "hive": "WITH a1 AS (SELECT 1), a2 AS (SELECT 2), b AS (SELECT * FROM a1, a2), c AS (SELECT * FROM b) SELECT * FROM c",
},
)
@@ -2312,3 +2383,37 @@ SELECT
self.validate_identity("TRUNCATE TABLE db.schema.test")
self.validate_identity("TRUNCATE TABLE IF EXISTS db.schema.test")
self.validate_identity("TRUNCATE TABLE t1, t2, t3")
+
+ def test_create_sequence(self):
+ self.validate_identity("CREATE SEQUENCE seq")
+ self.validate_identity(
+ "CREATE TEMPORARY SEQUENCE seq AS SMALLINT START WITH 3 INCREMENT BY 2 MINVALUE 1 MAXVALUE 10 CACHE 1 NO CYCLE OWNED BY table.col"
+ )
+ self.validate_identity(
+ "CREATE SEQUENCE seq START WITH 1 NO MINVALUE NO MAXVALUE CYCLE NO CACHE"
+ )
+ self.validate_identity("CREATE OR REPLACE TEMPORARY SEQUENCE seq INCREMENT BY 1 NO CYCLE")
+ self.validate_identity(
+ "CREATE OR REPLACE SEQUENCE IF NOT EXISTS seq COMMENT='test comment' ORDER"
+ )
+ self.validate_identity(
+ "CREATE SEQUENCE schema.seq SHARING=METADATA NOORDER NOKEEP SCALE EXTEND SHARD EXTEND SESSION"
+ )
+ self.validate_identity(
+ "CREATE SEQUENCE schema.seq SHARING=DATA ORDER KEEP NOSCALE NOSHARD GLOBAL"
+ )
+ self.validate_identity(
+ "CREATE SEQUENCE schema.seq SHARING=DATA NOCACHE NOCYCLE SCALE NOEXTEND"
+ )
+ self.validate_identity(
+ """CREATE TEMPORARY SEQUENCE seq AS BIGINT INCREMENT BY 2 MINVALUE 1 CACHE 1 NOMAXVALUE NO CYCLE OWNED BY NONE""",
+ """CREATE TEMPORARY SEQUENCE seq AS BIGINT INCREMENT BY 2 MINVALUE 1 CACHE 1 NOMAXVALUE NO CYCLE""",
+ )
+ self.validate_identity(
+ """CREATE TEMPORARY SEQUENCE seq START 1""",
+ """CREATE TEMPORARY SEQUENCE seq START WITH 1""",
+ )
+ self.validate_identity(
+ """CREATE TEMPORARY SEQUENCE seq START WITH = 1 INCREMENT BY = 2""",
+ """CREATE TEMPORARY SEQUENCE seq START WITH 1 INCREMENT BY 2""",
+ )