From 8d36f5966675e23bee7026ba37ae0647fbf47300 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 8 Apr 2024 10:11:53 +0200 Subject: Merging upstream version 23.7.0. Signed-off-by: Daniel Baumann --- tests/dialects/test_dialect.py | 117 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 111 insertions(+), 6 deletions(-) (limited to 'tests/dialects/test_dialect.py') diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 5faed51..76ab94b 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -17,8 +17,8 @@ from sqlglot.parser import logger as parser_logger class Validator(unittest.TestCase): dialect = None - def parse_one(self, sql): - return parse_one(sql, read=self.dialect) + def parse_one(self, sql, **kwargs): + return parse_one(sql, read=self.dialect, **kwargs) def validate_identity(self, sql, write_sql=None, pretty=False, check_command_warning=False): if check_command_warning: @@ -611,7 +611,7 @@ class TestDialect(Validator): write={ "duckdb": "EPOCH(STRPTIME('2020-01-01', '%Y-%m-%d'))", "hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')", - "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d'))", + "presto": "TO_UNIXTIME(COALESCE(TRY(DATE_PARSE(CAST('2020-01-01' AS VARCHAR), '%Y-%m-%d')), PARSE_DATETIME(CAST('2020-01-01' AS VARCHAR), 'yyyy-MM-dd')))", "starrocks": "UNIX_TIMESTAMP('2020-01-01', '%Y-%m-%d')", "doris": "UNIX_TIMESTAMP('2020-01-01', '%Y-%m-%d')", }, @@ -700,7 +700,7 @@ class TestDialect(Validator): "hive": "TO_DATE(x)", "postgres": "CAST(x AS DATE)", "presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)", - "snowflake": "CAST(x AS DATE)", + "snowflake": "TO_DATE(x)", "doris": "TO_DATE(x)", "mysql": "DATE(x)", }, @@ -961,6 +961,7 @@ class TestDialect(Validator): "presto": "CAST(x AS DATE)", "spark": "CAST(x AS DATE)", "sqlite": "x", + "tsql": "CAST(x AS DATE)", }, ) self.validate_all( @@ -1509,7 +1510,7 @@ class TestDialect(Validator): "POSITION(needle, haystack, pos)", write={ "drill": "STRPOS(SUBSTR(haystack, pos), needle) + pos - 1", - "presto": "STRPOS(haystack, needle, pos)", + "presto": "STRPOS(SUBSTR(haystack, pos), needle) + pos - 1", "spark": "LOCATE(needle, haystack, pos)", "clickhouse": "position(haystack, needle, pos)", "snowflake": "POSITION(needle, haystack, pos)", @@ -1719,6 +1720,11 @@ class TestDialect(Validator): with self.subTest(f"{expression.__class__.__name__} {dialect} -> {expected}"): self.assertEqual(expected, expression.sql(dialect=dialect)) + self.assertEqual( + parse_one("CAST(x AS DECIMAL) / y", read="mysql").sql(dialect="postgres"), + "CAST(x AS DECIMAL) / NULLIF(y, 0)", + ) + def test_limit(self): self.validate_all( "SELECT * FROM data LIMIT 10, 20", @@ -2054,6 +2060,44 @@ SELECT ) def test_logarithm(self): + for base in (2, 10): + with self.subTest(f"Transpiling LOG base {base}"): + self.validate_all( + f"LOG({base}, a)", + read={ + "": f"LOG{base}(a)", + "bigquery": f"LOG{base}(a)", + "clickhouse": f"LOG{base}(a)", + "databricks": f"LOG{base}(a)", + "duckdb": f"LOG{base}(a)", + "mysql": f"LOG{base}(a)", + "postgres": f"LOG{base}(a)", + "presto": f"LOG{base}(a)", + "spark": f"LOG{base}(a)", + "sqlite": f"LOG{base}(a)", + "trino": f"LOG{base}(a)", + "tsql": f"LOG{base}(a)", + }, + write={ + "bigquery": f"LOG(a, {base})", + "clickhouse": f"LOG{base}(a)", + "duckdb": f"LOG({base}, a)", + "mysql": f"LOG({base}, a)", + "oracle": f"LOG({base}, a)", + "postgres": f"LOG({base}, a)", + "presto": f"LOG{base}(a)", + "redshift": f"LOG({base}, a)", + "snowflake": f"LOG({base}, a)", + "spark2": f"LOG({base}, a)", + "spark": f"LOG({base}, a)", + "sqlite": f"LOG({base}, a)", + "starrocks": f"LOG({base}, a)", + "tableau": f"LOG(a, {base})", + "trino": f"LOG({base}, a)", + "tsql": f"LOG(a, {base})", + }, + ) + self.validate_all( "LOG(x)", read={ @@ -2082,6 +2126,7 @@ SELECT "bigquery": "LOG(n, b)", "databricks": "LOG(b, n)", "drill": "LOG(b, n)", + "duckdb": "LOG(b, n)", "hive": "LOG(b, n)", "mysql": "LOG(b, n)", "oracle": "LOG(b, n)", @@ -2089,8 +2134,13 @@ SELECT "snowflake": "LOG(b, n)", "spark": "LOG(b, n)", "sqlite": "LOG(b, n)", + "trino": "LOG(b, n)", "tsql": "LOG(n, b)", }, + write={ + "clickhouse": UnsupportedError, + "presto": UnsupportedError, + }, ) def test_count_if(self): @@ -2190,7 +2240,28 @@ SELECT "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq", write={ "duckdb": "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq", - "tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y AS y FROM t2) AS subq", + "tsql": "WITH t2(y) AS (SELECT 2), t1(x) AS (SELECT 1) SELECT * FROM (SELECT y AS y FROM t2) AS subq", + }, + ) + self.validate_all( + """ +WITH c AS ( + WITH b AS ( + WITH a1 AS ( + SELECT 1 + ), a2 AS ( + SELECT 2 + ) + SELECT * FROM a1, a2 + ) + SELECT * + FROM b +) +SELECT * +FROM c""", + write={ + "duckdb": "WITH c AS (WITH b AS (WITH a1 AS (SELECT 1), a2 AS (SELECT 2) SELECT * FROM a1, a2) SELECT * FROM b) SELECT * FROM c", + "hive": "WITH a1 AS (SELECT 1), a2 AS (SELECT 2), b AS (SELECT * FROM a1, a2), c AS (SELECT * FROM b) SELECT * FROM c", }, ) @@ -2312,3 +2383,37 @@ SELECT self.validate_identity("TRUNCATE TABLE db.schema.test") self.validate_identity("TRUNCATE TABLE IF EXISTS db.schema.test") self.validate_identity("TRUNCATE TABLE t1, t2, t3") + + def test_create_sequence(self): + self.validate_identity("CREATE SEQUENCE seq") + self.validate_identity( + "CREATE TEMPORARY SEQUENCE seq AS SMALLINT START WITH 3 INCREMENT BY 2 MINVALUE 1 MAXVALUE 10 CACHE 1 NO CYCLE OWNED BY table.col" + ) + self.validate_identity( + "CREATE SEQUENCE seq START WITH 1 NO MINVALUE NO MAXVALUE CYCLE NO CACHE" + ) + self.validate_identity("CREATE OR REPLACE TEMPORARY SEQUENCE seq INCREMENT BY 1 NO CYCLE") + self.validate_identity( + "CREATE OR REPLACE SEQUENCE IF NOT EXISTS seq COMMENT='test comment' ORDER" + ) + self.validate_identity( + "CREATE SEQUENCE schema.seq SHARING=METADATA NOORDER NOKEEP SCALE EXTEND SHARD EXTEND SESSION" + ) + self.validate_identity( + "CREATE SEQUENCE schema.seq SHARING=DATA ORDER KEEP NOSCALE NOSHARD GLOBAL" + ) + self.validate_identity( + "CREATE SEQUENCE schema.seq SHARING=DATA NOCACHE NOCYCLE SCALE NOEXTEND" + ) + self.validate_identity( + """CREATE TEMPORARY SEQUENCE seq AS BIGINT INCREMENT BY 2 MINVALUE 1 CACHE 1 NOMAXVALUE NO CYCLE OWNED BY NONE""", + """CREATE TEMPORARY SEQUENCE seq AS BIGINT INCREMENT BY 2 MINVALUE 1 CACHE 1 NOMAXVALUE NO CYCLE""", + ) + self.validate_identity( + """CREATE TEMPORARY SEQUENCE seq START 1""", + """CREATE TEMPORARY SEQUENCE seq START WITH 1""", + ) + self.validate_identity( + """CREATE TEMPORARY SEQUENCE seq START WITH = 1 INCREMENT BY = 2""", + """CREATE TEMPORARY SEQUENCE seq START WITH 1 INCREMENT BY 2""", + ) -- cgit v1.2.3