diff options
Diffstat (limited to 'tests/dialects')
-rw-r--r-- | tests/dialects/test_bigquery.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 18 | ||||
-rw-r--r-- | tests/dialects/test_drill.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 20 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 5 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 21 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_teradata.py | 7 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 123 |
10 files changed, 192 insertions, 26 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 69042be..8d01ebe 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -29,6 +29,10 @@ class TestBigQuery(Validator): with self.assertRaises(ParseError): transpile("SELECT * FROM UNNEST(x) AS x(y)", read="bigquery") + with self.assertRaises(ParseError): + transpile("DATE_ADD(x, day)", read="bigquery") + + self.validate_identity("STRING_AGG(DISTINCT a ORDER BY b DESC, c DESC LIMIT 10)") self.validate_identity("SELECT PARSE_TIMESTAMP('%c', 'Thu Dec 25 07:30:00 2008', 'UTC')") self.validate_identity("SELECT ANY_VALUE(fruit HAVING MAX sold) FROM fruits") self.validate_identity("SELECT ANY_VALUE(fruit HAVING MIN sold) FROM fruits") @@ -389,7 +393,7 @@ class TestBigQuery(Validator): }, ) self.validate_all( - "current_timestamp", + "CURRENT_TIMESTAMP", write={ "bigquery": "CURRENT_TIMESTAMP()", "duckdb": "CURRENT_TIMESTAMP", @@ -400,7 +404,7 @@ class TestBigQuery(Validator): }, ) self.validate_all( - "current_timestamp()", + "CURRENT_TIMESTAMP()", write={ "bigquery": "CURRENT_TIMESTAMP()", "duckdb": "CURRENT_TIMESTAMP", diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index bc82645..16c10fe 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -6,6 +6,12 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): + expr = parse_one("count(x)") + self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)") + self.assertIsNone(expr._meta) + + self.validate_identity("SELECT isNaN(1.0)") + self.validate_identity("SELECT startsWith('Spider-Man', 'Spi')") self.validate_identity("SELECT xor(TRUE, FALSE)") self.validate_identity("ATTACH DATABASE DEFAULT ENGINE = ORDINARY") self.validate_identity("CAST(['hello'], 'Array(Enum8(''hello'' = 1))')") @@ -162,7 +168,7 @@ class TestClickhouse(Validator): ORDER BY loyalty ASC """, write={ - "clickhouse": "SELECT loyalty, COUNT() FROM hits LEFT SEMI JOIN users USING (UserID)" + "clickhouse": "SELECT loyalty, count() FROM hits LEFT SEMI JOIN users USING (UserID)" + " GROUP BY loyalty ORDER BY loyalty" }, ) @@ -247,7 +253,7 @@ class TestClickhouse(Validator): for data_type in data_types: self.validate_all( f"pow(2, 32)::{data_type}", - write={"clickhouse": f"CAST(POWER(2, 32) AS {data_type})"}, + write={"clickhouse": f"CAST(pow(2, 32) AS {data_type})"}, ) def test_ddl(self): @@ -304,8 +310,8 @@ GROUP BY id, toStartOfDay(timestamp) SET - max_hits = MAX(max_hits), - sum_hits = SUM(sum_hits)""", + max_hits = max(max_hits), + sum_hits = sum(sum_hits)""", }, pretty=True, ) @@ -447,8 +453,8 @@ GROUP BY k1, k2 SET - x = MAX(x), - y = MIN(y)""", + x = max(x), + y = min(y)""", }, pretty=True, ) diff --git a/tests/dialects/test_drill.py b/tests/dialects/test_drill.py index a7f609a..41c02fb 100644 --- a/tests/dialects/test_drill.py +++ b/tests/dialects/test_drill.py @@ -66,7 +66,7 @@ class TestDrill(Validator): write={ "drill": "SELECT * FROM (SELECT education_level, salary, marital_status, " "EXTRACT(year FROM age(birth_date)) AS age FROM cp.`employee.json`) " - "PIVOT(AVG(salary) AS avg_salary, AVG(age) AS avg_age FOR marital_status " + "PIVOT(avg(salary) AS avg_salary, avg(age) AS avg_age FOR marital_status " "IN ('M' AS married, 'S' AS single))" }, ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index ae2fa41..d021d62 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -9,6 +9,12 @@ class TestMySQL(Validator): self.validate_identity("CREATE TABLE foo (id BIGINT)") self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10") self.validate_identity("DELETE FROM t WHERE a <= 10 LIMIT 10") + self.validate_identity("CREATE TABLE foo (a BIGINT, INDEX USING BTREE (b))") + self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))") + self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))") + self.validate_identity( + "CREATE TABLE foo (a BIGINT, INDEX b USING HASH (c) COMMENT 'd' VISIBLE ENGINE_ATTRIBUTE = 'e' WITH PARSER foo)" + ) self.validate_identity( "DELETE t1 FROM t1 LEFT JOIN t2 ON t1.id = t2.id WHERE t2.id IS NULL" ) @@ -67,6 +73,12 @@ class TestMySQL(Validator): "mysql": "CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))", }, ) + self.validate_all( + "CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE KEY d (b), KEY e (b))", + write={ + "mysql": "CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE d (b), INDEX e (b))", + }, + ) def test_identity(self): self.validate_identity("SELECT 1 XOR 0") @@ -437,6 +449,13 @@ class TestMySQL(Validator): def test_mysql(self): self.validate_all( + "SELECT * FROM test LIMIT 0 + 1, 0 + 1", + write={ + "mysql": "SELECT * FROM test LIMIT 1 OFFSET 1", + "postgres": "SELECT * FROM test LIMIT 0 + 1 OFFSET 0 + 1", + }, + ) + self.validate_all( "CAST(x AS TEXT)", write={ "mysql": "CAST(x AS CHAR)", @@ -448,6 +467,7 @@ class TestMySQL(Validator): self.validate_all("CAST(x AS SIGNED INTEGER)", write={"mysql": "CAST(x AS SIGNED)"}) self.validate_all("CAST(x AS UNSIGNED)", write={"mysql": "CAST(x AS UNSIGNED)"}) self.validate_all("CAST(x AS UNSIGNED INTEGER)", write={"mysql": "CAST(x AS UNSIGNED)"}) + self.validate_all("TIME_STR_TO_TIME(x)", write={"mysql": "CAST(x AS DATETIME)"}) self.validate_all( """SELECT 17 MEMBER OF('[23, "abc", 17, "ab", 10]')""", write={ diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index f30b38f..0c3b09f 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -23,6 +23,11 @@ class TestOracle(Validator): self.validate_identity( "SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name" ) + self.validate_identity( + "SELECT last_name, department_id, salary, MIN(salary) KEEP (DENSE_RANK FIRST ORDER BY commission_pct) " + 'OVER (PARTITION BY department_id) AS "Worst", MAX(salary) KEEP (DENSE_RANK LAST ORDER BY commission_pct) ' + 'OVER (PARTITION BY department_id) AS "Best" FROM employees ORDER BY department_id, salary, last_name' + ) self.validate_all( "NVL(NULL, 1)", diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index b35665b..be34d8c 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -122,6 +122,10 @@ class TestPostgres(Validator): ) def test_postgres(self): + expr = parse_one("SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)") + unnest = expr.args["joins"][0].this.this + unnest.assert_is(exp.Unnest) + self.validate_identity("CAST(x AS MONEY)") self.validate_identity("CAST(x AS INT4RANGE)") self.validate_identity("CAST(x AS INT4MULTIRANGE)") @@ -414,7 +418,7 @@ class TestPostgres(Validator): }, ) self.validate_all( - "SELECT * FROM r CROSS JOIN LATERAL unnest(array(1)) AS s(location)", + "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", write={ "postgres": "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", }, diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index c0b77a3..a2800bd 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -109,6 +109,8 @@ class TestPresto(Validator): "spark": "CAST(x AS TIMESTAMP)", }, ) + self.validate_identity("CAST(x AS IPADDRESS)") + self.validate_identity("CAST(x AS IPPREFIX)") def test_regex(self): self.validate_all( @@ -459,6 +461,25 @@ class TestPresto(Validator): self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ") self.validate_identity("APPROX_PERCENTILE(a, b, c, d)") + self.validate_all( + "STARTS_WITH('abc', 'a')", + read={"spark": "STARTSWITH('abc', 'a')"}, + write={ + "presto": "STARTS_WITH('abc', 'a')", + "spark": "STARTSWITH('abc', 'a')", + }, + ) + self.validate_all( + "IS_NAN(x)", + read={ + "spark": "ISNAN(x)", + }, + write={ + "presto": "IS_NAN(x)", + "spark": "ISNAN(x)", + "spark2": "ISNAN(x)", + }, + ) self.validate_all("VALUES 1, 2, 3", write={"presto": "VALUES (1), (2), (3)"}) self.validate_all("INTERVAL '1 day'", write={"trino": "INTERVAL '1' day"}) self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' week"}) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 82762e8..a889e1d 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -45,6 +45,14 @@ class TestSnowflake(Validator): self.validate_all("CAST(x AS CHARACTER VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all("CAST(x AS NCHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all( + "SELECT DATE_PART('year', TIMESTAMP '2020-01-01')", + write={ + "hive": "SELECT EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMP))", + "snowflake": "SELECT DATE_PART('year', CAST('2020-01-01' AS TIMESTAMPNTZ))", + "spark": "SELECT EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMP))", + }, + ) + self.validate_all( "SELECT * FROM (VALUES (0) foo(bar))", write={"snowflake": "SELECT * FROM (VALUES (0)) AS foo(bar)"}, ) diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 0df6d0b..4d32241 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -22,6 +22,13 @@ class TestTeradata(Validator): }, ) + def test_statistics(self): + self.validate_identity("COLLECT STATISTICS ON tbl INDEX(col)") + self.validate_identity("COLLECT STATS ON tbl COLUMNS(col)") + self.validate_identity("COLLECT STATS COLUMNS(col) ON tbl") + self.validate_identity("HELP STATISTICS personel.employee") + self.validate_identity("HELP STATISTICS personnel.employee FROM my_qcd") + def test_create(self): self.validate_identity("CREATE TABLE x (y INT) PRIMARY INDEX (y) PARTITION BY y INDEX (y)") self.validate_identity("CREATE TABLE x (y INT) PARTITION BY y INDEX (y)") diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index f0a590f..5266bd4 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -336,7 +336,7 @@ class TestTSQL(Validator): "CAST(x as TIME(4))", write={ "spark": "CAST(x AS TIMESTAMP)", - "tsql": "CAST(x AS TIMESTAMP(4))", + "tsql": "CAST(x AS TIME(4))", }, ) @@ -352,7 +352,7 @@ class TestTSQL(Validator): "CAST(x as DATETIMEOFFSET)", write={ "spark": "CAST(x AS TIMESTAMP)", - "tsql": "CAST(x AS TIMESTAMPTZ)", + "tsql": "CAST(x AS DATETIMEOFFSET)", }, ) @@ -393,7 +393,30 @@ class TestTSQL(Validator): self.validate_all( "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))", write={ - "tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIMESTAMP(4), d FLOAT(24))" + "spark": "CREATE TEMPORARY TABLE mytemp (a INT, b CHAR(2), c TIMESTAMP, d FLOAT)", + "tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))", + }, + ) + self.validate_all( + "CREATE TABLE #mytemptable (a INTEGER)", + read={ + "duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)", + }, + write={ + "tsql": "CREATE TABLE #mytemptable (a INTEGER)", + "snowflake": "CREATE TEMPORARY TABLE mytemptable (a INT)", + "duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)", + "oracle": "CREATE TEMPORARY TABLE mytemptable (a NUMBER)", + }, + ) + self.validate_all( + "CREATE TABLE #mytemptable AS SELECT a FROM Source_Table", + write={ + "duckdb": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table", + "oracle": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table", + "snowflake": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table", + "spark": "CREATE TEMPORARY VIEW mytemptable AS SELECT a FROM Source_Table", + "tsql": "CREATE TABLE #mytemptable AS SELECT a FROM Source_Table", }, ) @@ -535,6 +558,30 @@ WHERE for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): self.assertEqual(expr.sql(dialect="tsql"), expected_sql) + sql = """ + CREATE PROC [dbo].[transform_proc] AS + + DECLARE @CurrentDate VARCHAR(20); + SET @CurrentDate = CONVERT(VARCHAR(20), GETDATE(), 120); + + CREATE TABLE [target_schema].[target_table] + WITH (DISTRIBUTION = REPLICATE, HEAP) + AS + + SELECT + @CurrentDate AS DWCreatedDate + FROM source_schema.sourcetable; + """ + + expected_sqls = [ + 'CREATE PROC "dbo"."transform_proc" AS DECLARE @CurrentDate VARCHAR(20)', + "SET @CurrentDate = CAST(FORMAT(GETDATE(), 'yyyy-MM-dd HH:mm:ss') AS VARCHAR(20))", + 'CREATE TABLE "target_schema"."target_table" WITH (DISTRIBUTION=REPLICATE, HEAP) AS SELECT @CurrentDate AS DWCreatedDate FROM source_schema.sourcetable', + ] + + for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): + self.assertEqual(expr.sql(dialect="tsql"), expected_sql) + def test_charindex(self): self.validate_all( "CHARINDEX(x, y, 9)", @@ -795,31 +842,50 @@ WHERE ) def test_date_diff(self): - self.validate_identity("SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')") - + self.validate_identity("SELECT DATEDIFF(hour, 1.5, '2021-01-01')") + self.validate_identity( + "SELECT DATEDIFF(year, '2020-01-01', '2021-01-01')", + "SELECT DATEDIFF(year, CAST('2020-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", + ) + self.validate_all( + "SELECT DATEDIFF(quarter, 0, '2021-01-01')", + write={ + "tsql": "SELECT DATEDIFF(quarter, CAST('1900-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", + "spark": "SELECT DATEDIFF(quarter, CAST('1900-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + "duckdb": "SELECT DATE_DIFF('quarter', CAST('1900-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + }, + ) + self.validate_all( + "SELECT DATEDIFF(day, 1, '2021-01-01')", + write={ + "tsql": "SELECT DATEDIFF(day, CAST('1900-01-02' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", + "spark": "SELECT DATEDIFF(day, CAST('1900-01-02' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + "duckdb": "SELECT DATE_DIFF('day', CAST('1900-01-02' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + }, + ) self.validate_all( "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", write={ - "tsql": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", - "spark": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", - "spark2": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12", + "tsql": "SELECT DATEDIFF(year, CAST('2020/01/01' AS DATETIME2), CAST('2021/01/01' AS DATETIME2))", + "spark": "SELECT DATEDIFF(year, CAST('2020/01/01' AS TIMESTAMP), CAST('2021/01/01' AS TIMESTAMP))", + "spark2": "SELECT MONTHS_BETWEEN(CAST('2021/01/01' AS TIMESTAMP), CAST('2020/01/01' AS TIMESTAMP)) / 12", }, ) self.validate_all( - "SELECT DATEDIFF(mm, 'start','end')", + "SELECT DATEDIFF(mm, 'start', 'end')", write={ - "databricks": "SELECT DATEDIFF(month, 'start', 'end')", - "spark2": "SELECT MONTHS_BETWEEN('end', 'start')", - "tsql": "SELECT DATEDIFF(month, 'start', 'end')", + "databricks": "SELECT DATEDIFF(month, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", + "spark2": "SELECT MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))", + "tsql": "SELECT DATEDIFF(month, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))", }, ) self.validate_all( "SELECT DATEDIFF(quarter, 'start', 'end')", write={ - "databricks": "SELECT DATEDIFF(quarter, 'start', 'end')", - "spark": "SELECT DATEDIFF(quarter, 'start', 'end')", - "spark2": "SELECT MONTHS_BETWEEN('end', 'start') / 3", - "tsql": "SELECT DATEDIFF(quarter, 'start', 'end')", + "databricks": "SELECT DATEDIFF(quarter, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", + "spark": "SELECT DATEDIFF(quarter, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", + "spark2": "SELECT MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) / 3", + "tsql": "SELECT DATEDIFF(quarter, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))", }, ) @@ -943,8 +1009,15 @@ WHERE expr = parse_one("#x", read="tsql") self.assertIsInstance(expr, exp.Column) self.assertIsInstance(expr.this, exp.Identifier) + self.assertTrue(expr.this.args.get("temporary")) self.assertEqual(expr.sql("tsql"), "#x") + expr = parse_one("##x", read="tsql") + self.assertIsInstance(expr, exp.Column) + self.assertIsInstance(expr.this, exp.Identifier) + self.assertTrue(expr.this.args.get("global")) + self.assertEqual(expr.sql("tsql"), "##x") + expr = parse_one("@x", read="tsql") self.assertIsInstance(expr, exp.Parameter) self.assertIsInstance(expr.this, exp.Var) @@ -955,6 +1028,24 @@ WHERE self.assertIsInstance(table.this, exp.Parameter) self.assertIsInstance(table.this.this, exp.Var) + def test_temp_table(self): + self.validate_all( + "SELECT * FROM #mytemptable", + write={ + "duckdb": "SELECT * FROM mytemptable", + "spark": "SELECT * FROM mytemptable", + "tsql": "SELECT * FROM #mytemptable", + }, + ) + self.validate_all( + "SELECT * FROM ##mytemptable", + write={ + "duckdb": "SELECT * FROM mytemptable", + "spark": "SELECT * FROM mytemptable", + "tsql": "SELECT * FROM ##mytemptable", + }, + ) + def test_system_time(self): self.validate_all( "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo'", |