From 379c6d1f52e1d311867c4f789dc389da1d9af898 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 6 Aug 2023 09:48:11 +0200 Subject: Merging upstream version 17.9.1. Signed-off-by: Daniel Baumann --- tests/dialects/test_bigquery.py | 8 +- tests/dialects/test_clickhouse.py | 18 ++- tests/dialects/test_drill.py | 2 +- tests/dialects/test_mysql.py | 20 ++++ tests/dialects/test_oracle.py | 5 + tests/dialects/test_postgres.py | 6 +- tests/dialects/test_presto.py | 21 ++++ tests/dialects/test_snowflake.py | 8 ++ tests/dialects/test_teradata.py | 7 ++ tests/dialects/test_tsql.py | 123 ++++++++++++++++++--- tests/fixtures/identity.sql | 4 + tests/fixtures/optimizer/normalize_identifiers.sql | 7 ++ tests/fixtures/optimizer/optimizer.sql | 19 +++- tests/fixtures/optimizer/qualify_columns.sql | 10 ++ tests/fixtures/optimizer/qualify_columns_ddl.sql | 35 ++++++ tests/fixtures/optimizer/tpc-ds/tpc-ds.sql | 30 ++++- tests/test_optimizer.py | 22 +++- tests/test_parser.py | 118 +++++++++++++++++++- tests/test_schema.py | 34 ++++++ tests/test_transforms.py | 2 +- 20 files changed, 463 insertions(+), 36 deletions(-) create mode 100644 tests/fixtures/optimizer/qualify_columns_ddl.sql (limited to 'tests') diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 69042be..8d01ebe 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -29,6 +29,10 @@ class TestBigQuery(Validator): with self.assertRaises(ParseError): transpile("SELECT * FROM UNNEST(x) AS x(y)", read="bigquery") + with self.assertRaises(ParseError): + transpile("DATE_ADD(x, day)", read="bigquery") + + self.validate_identity("STRING_AGG(DISTINCT a ORDER BY b DESC, c DESC LIMIT 10)") self.validate_identity("SELECT PARSE_TIMESTAMP('%c', 'Thu Dec 25 07:30:00 2008', 'UTC')") self.validate_identity("SELECT ANY_VALUE(fruit HAVING MAX sold) FROM fruits") self.validate_identity("SELECT ANY_VALUE(fruit HAVING MIN sold) FROM fruits") @@ -389,7 +393,7 @@ class TestBigQuery(Validator): }, ) self.validate_all( - "current_timestamp", + "CURRENT_TIMESTAMP", write={ "bigquery": "CURRENT_TIMESTAMP()", "duckdb": "CURRENT_TIMESTAMP", @@ -400,7 +404,7 @@ class TestBigQuery(Validator): }, ) self.validate_all( - "current_timestamp()", + "CURRENT_TIMESTAMP()", write={ "bigquery": "CURRENT_TIMESTAMP()", "duckdb": "CURRENT_TIMESTAMP", diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index bc82645..16c10fe 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -6,6 +6,12 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): + expr = parse_one("count(x)") + self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)") + self.assertIsNone(expr._meta) + + self.validate_identity("SELECT isNaN(1.0)") + self.validate_identity("SELECT startsWith('Spider-Man', 'Spi')") self.validate_identity("SELECT xor(TRUE, FALSE)") self.validate_identity("ATTACH DATABASE DEFAULT ENGINE = ORDINARY") self.validate_identity("CAST(['hello'], 'Array(Enum8(''hello'' = 1))')") @@ -162,7 +168,7 @@ class TestClickhouse(Validator): ORDER BY loyalty ASC """, write={ - "clickhouse": "SELECT loyalty, COUNT() FROM hits LEFT SEMI JOIN users USING (UserID)" + "clickhouse": "SELECT loyalty, count() FROM hits LEFT SEMI JOIN users USING (UserID)" + " GROUP BY loyalty ORDER BY loyalty" }, ) @@ -247,7 +253,7 @@ class TestClickhouse(Validator): for data_type in data_types: self.validate_all( f"pow(2, 32)::{data_type}", - write={"clickhouse": f"CAST(POWER(2, 32) AS {data_type})"}, + write={"clickhouse": f"CAST(pow(2, 32) AS {data_type})"}, ) def test_ddl(self): @@ -304,8 +310,8 @@ GROUP BY id, toStartOfDay(timestamp) SET - max_hits = MAX(max_hits), - sum_hits = SUM(sum_hits)""", + max_hits = max(max_hits), + sum_hits = sum(sum_hits)""", }, pretty=True, ) @@ -447,8 +453,8 @@ GROUP BY k1, k2 SET - x = MAX(x), - y = MIN(y)""", + x = max(x), + y = min(y)""", }, pretty=True, ) diff --git a/tests/dialects/test_drill.py b/tests/dialects/test_drill.py index a7f609a..41c02fb 100644 --- a/tests/dialects/test_drill.py +++ b/tests/dialects/test_drill.py @@ -66,7 +66,7 @@ class TestDrill(Validator): write={ "drill": "SELECT * FROM (SELECT education_level, salary, marital_status, " "EXTRACT(year FROM age(birth_date)) AS age FROM cp.`employee.json`) " - "PIVOT(AVG(salary) AS avg_salary, AVG(age) AS avg_age FOR marital_status " + "PIVOT(avg(salary) AS avg_salary, avg(age) AS avg_age FOR marital_status " "IN ('M' AS married, 'S' AS single))" }, ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index ae2fa41..d021d62 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -9,6 +9,12 @@ class TestMySQL(Validator): self.validate_identity("CREATE TABLE foo (id BIGINT)") self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10") self.validate_identity("DELETE FROM t WHERE a <= 10 LIMIT 10") + self.validate_identity("CREATE TABLE foo (a BIGINT, INDEX USING BTREE (b))") + self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))") + self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))") + self.validate_identity( + "CREATE TABLE foo (a BIGINT, INDEX b USING HASH (c) COMMENT 'd' VISIBLE ENGINE_ATTRIBUTE = 'e' WITH PARSER foo)" + ) self.validate_identity( "DELETE t1 FROM t1 LEFT JOIN t2 ON t1.id = t2.id WHERE t2.id IS NULL" ) @@ -67,6 +73,12 @@ class TestMySQL(Validator): "mysql": "CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))", }, ) + self.validate_all( + "CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE KEY d (b), KEY e (b))", + write={ + "mysql": "CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE d (b), INDEX e (b))", + }, + ) def test_identity(self): self.validate_identity("SELECT 1 XOR 0") @@ -436,6 +448,13 @@ class TestMySQL(Validator): self.validate_identity("TIME_STR_TO_UNIX(x)", "UNIX_TIMESTAMP(x)") def test_mysql(self): + self.validate_all( + "SELECT * FROM test LIMIT 0 + 1, 0 + 1", + write={ + "mysql": "SELECT * FROM test LIMIT 1 OFFSET 1", + "postgres": "SELECT * FROM test LIMIT 0 + 1 OFFSET 0 + 1", + }, + ) self.validate_all( "CAST(x AS TEXT)", write={ @@ -448,6 +467,7 @@ class TestMySQL(Validator): self.validate_all("CAST(x AS SIGNED INTEGER)", write={"mysql": "CAST(x AS SIGNED)"}) self.validate_all("CAST(x AS UNSIGNED)", write={"mysql": "CAST(x AS UNSIGNED)"}) self.validate_all("CAST(x AS UNSIGNED INTEGER)", write={"mysql": "CAST(x AS UNSIGNED)"}) + self.validate_all("TIME_STR_TO_TIME(x)", write={"mysql": "CAST(x AS DATETIME)"}) self.validate_all( """SELECT 17 MEMBER OF('[23, "abc", 17, "ab", 10]')""", write={ diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index f30b38f..0c3b09f 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -23,6 +23,11 @@ class TestOracle(Validator): self.validate_identity( "SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name" ) + self.validate_identity( + "SELECT last_name, department_id, salary, MIN(salary) KEEP (DENSE_RANK FIRST ORDER BY commission_pct) " + 'OVER (PARTITION BY department_id) AS "Worst", MAX(salary) KEEP (DENSE_RANK LAST ORDER BY commission_pct) ' + 'OVER (PARTITION BY department_id) AS "Best" FROM employees ORDER BY department_id, salary, last_name' + ) self.validate_all( "NVL(NULL, 1)", diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index b35665b..be34d8c 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -122,6 +122,10 @@ class TestPostgres(Validator): ) def test_postgres(self): + expr = parse_one("SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)") + unnest = expr.args["joins"][0].this.this + unnest.assert_is(exp.Unnest) + self.validate_identity("CAST(x AS MONEY)") self.validate_identity("CAST(x AS INT4RANGE)") self.validate_identity("CAST(x AS INT4MULTIRANGE)") @@ -414,7 +418,7 @@ class TestPostgres(Validator): }, ) self.validate_all( - "SELECT * FROM r CROSS JOIN LATERAL unnest(array(1)) AS s(location)", + "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", write={ "postgres": "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", }, diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index c0b77a3..a2800bd 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -109,6 +109,8 @@ class TestPresto(Validator): "spark": "CAST(x AS TIMESTAMP)", }, ) + self.validate_identity("CAST(x AS IPADDRESS)") + self.validate_identity("CAST(x AS IPPREFIX)") def test_regex(self): self.validate_all( @@ -459,6 +461,25 @@ class TestPresto(Validator): self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ") self.validate_identity("APPROX_PERCENTILE(a, b, c, d)") + self.validate_all( + "STARTS_WITH('abc', 'a')", + read={"spark": "STARTSWITH('abc', 'a')"}, + write={ + "presto": "STARTS_WITH('abc', 'a')", + "spark": "STARTSWITH('abc', 'a')", + }, + ) + self.validate_all( + "IS_NAN(x)", + read={ + "spark": "ISNAN(x)", + }, + write={ + "presto": "IS_NAN(x)", + "spark": "ISNAN(x)", + "spark2": "ISNAN(x)", + }, + ) self.validate_all("VALUES 1, 2, 3", write={"presto": "VALUES (1), (2), (3)"}) self.validate_all("INTERVAL '1 day'", write={"trino": "INTERVAL '1' day"}) self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' week"}) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 82762e8..a889e1d 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -44,6 +44,14 @@ class TestSnowflake(Validator): self.validate_all("CAST(x AS CHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all("CAST(x AS CHARACTER VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all("CAST(x AS NCHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) + self.validate_all( + "SELECT DATE_PART('year', TIMESTAMP '2020-01-01')", + write={ + "hive": "SELECT EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMP))", + "snowflake": "SELECT DATE_PART('year', CAST('2020-01-01' AS TIMESTAMPNTZ))", + "spark": "SELECT EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMP))", + }, + ) self.validate_all( "SELECT * FROM (VALUES (0) foo(bar))", write={"snowflake": "SELECT * FROM (VALUES (0)) AS foo(bar)"}, diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 0df6d0b..4d32241 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -22,6 +22,13 @@ class TestTeradata(Validator): }, ) + def test_statistics(self): + self.validate_identity("COLLECT STATISTICS ON tbl INDEX(col)") + self.validate_identity("COLLECT STATS ON tbl COLUMNS(col)") + self.validate_identity("COLLECT STATS COLUMNS(col) ON tbl") + self.validate_identity("HELP STATISTICS personel.employee") + self.validate_identity("HELP STATISTICS personnel.employee FROM my_qcd") + def test_create(self): self.validate_identity("CREATE TABLE x (y INT) PRIMARY INDEX (y) PARTITION BY y INDEX (y)") self.validate_identity("CREATE TABLE x (y INT) PARTITION BY y INDEX (y)") diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index f0a590f..5266bd4 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -336,7 +336,7 @@ class TestTSQL(Validator): "CAST(x as TIME(4))", write={ "spark": "CAST(x AS TIMESTAMP)", - "tsql": "CAST(x AS TIMESTAMP(4))", + "tsql": "CAST(x AS TIME(4))", }, ) @@ -352,7 +352,7 @@ class TestTSQL(Validator): "CAST(x as DATETIMEOFFSET)", write={ "spark": "CAST(x AS TIMESTAMP)", - "tsql": "CAST(x AS TIMESTAMPTZ)", + "tsql": "CAST(x AS DATETIMEOFFSET)", }, ) @@ -393,7 +393,30 @@ class TestTSQL(Validator): self.validate_all( "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))", write={ - "tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIMESTAMP(4), d FLOAT(24))" + "spark": "CREATE TEMPORARY TABLE mytemp (a INT, b CHAR(2), c TIMESTAMP, d FLOAT)", + "tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))", + }, + ) + self.validate_all( + "CREATE TABLE #mytemptable (a INTEGER)", + read={ + "duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)", + }, + write={ + "tsql": "CREATE TABLE #mytemptable (a INTEGER)", + "snowflake": "CREATE TEMPORARY TABLE mytemptable (a INT)", + "duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)", + "oracle": "CREATE TEMPORARY TABLE mytemptable (a NUMBER)", + }, + ) + self.validate_all( + "CREATE TABLE #mytemptable AS SELECT a FROM Source_Table", + write={ + "duckdb": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table", + "oracle": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table", + "snowflake": "CREATE TEMPORARY TABLE mytemptable AS SELECT a FROM Source_Table", + "spark": "CREATE TEMPORARY VIEW mytemptable AS SELECT a FROM Source_Table", + "tsql": "CREATE TABLE #mytemptable AS SELECT a FROM Source_Table", }, ) @@ -535,6 +558,30 @@ WHERE for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): self.assertEqual(expr.sql(dialect="tsql"), expected_sql) + sql = """ + CREATE PROC [dbo].[transform_proc] AS + + DECLARE @CurrentDate VARCHAR(20); + SET @CurrentDate = CONVERT(VARCHAR(20), GETDATE(), 120); + + CREATE TABLE [target_schema].[target_table] + WITH (DISTRIBUTION = REPLICATE, HEAP) + AS + + SELECT + @CurrentDate AS DWCreatedDate + FROM source_schema.sourcetable; + """ + + expected_sqls = [ + 'CREATE PROC "dbo"."transform_proc" AS DECLARE @CurrentDate VARCHAR(20)', + "SET @CurrentDate = CAST(FORMAT(GETDATE(), 'yyyy-MM-dd HH:mm:ss') AS VARCHAR(20))", + 'CREATE TABLE "target_schema"."target_table" WITH (DISTRIBUTION=REPLICATE, HEAP) AS SELECT @CurrentDate AS DWCreatedDate FROM source_schema.sourcetable', + ] + + for expr, expected_sql in zip(parse(sql, read="tsql"), expected_sqls): + self.assertEqual(expr.sql(dialect="tsql"), expected_sql) + def test_charindex(self): self.validate_all( "CHARINDEX(x, y, 9)", @@ -795,31 +842,50 @@ WHERE ) def test_date_diff(self): - self.validate_identity("SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')") - + self.validate_identity("SELECT DATEDIFF(hour, 1.5, '2021-01-01')") + self.validate_identity( + "SELECT DATEDIFF(year, '2020-01-01', '2021-01-01')", + "SELECT DATEDIFF(year, CAST('2020-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", + ) + self.validate_all( + "SELECT DATEDIFF(quarter, 0, '2021-01-01')", + write={ + "tsql": "SELECT DATEDIFF(quarter, CAST('1900-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", + "spark": "SELECT DATEDIFF(quarter, CAST('1900-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + "duckdb": "SELECT DATE_DIFF('quarter', CAST('1900-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + }, + ) + self.validate_all( + "SELECT DATEDIFF(day, 1, '2021-01-01')", + write={ + "tsql": "SELECT DATEDIFF(day, CAST('1900-01-02' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", + "spark": "SELECT DATEDIFF(day, CAST('1900-01-02' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + "duckdb": "SELECT DATE_DIFF('day', CAST('1900-01-02' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + }, + ) self.validate_all( "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", write={ - "tsql": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", - "spark": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", - "spark2": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12", + "tsql": "SELECT DATEDIFF(year, CAST('2020/01/01' AS DATETIME2), CAST('2021/01/01' AS DATETIME2))", + "spark": "SELECT DATEDIFF(year, CAST('2020/01/01' AS TIMESTAMP), CAST('2021/01/01' AS TIMESTAMP))", + "spark2": "SELECT MONTHS_BETWEEN(CAST('2021/01/01' AS TIMESTAMP), CAST('2020/01/01' AS TIMESTAMP)) / 12", }, ) self.validate_all( - "SELECT DATEDIFF(mm, 'start','end')", + "SELECT DATEDIFF(mm, 'start', 'end')", write={ - "databricks": "SELECT DATEDIFF(month, 'start', 'end')", - "spark2": "SELECT MONTHS_BETWEEN('end', 'start')", - "tsql": "SELECT DATEDIFF(month, 'start', 'end')", + "databricks": "SELECT DATEDIFF(month, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", + "spark2": "SELECT MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))", + "tsql": "SELECT DATEDIFF(month, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))", }, ) self.validate_all( "SELECT DATEDIFF(quarter, 'start', 'end')", write={ - "databricks": "SELECT DATEDIFF(quarter, 'start', 'end')", - "spark": "SELECT DATEDIFF(quarter, 'start', 'end')", - "spark2": "SELECT MONTHS_BETWEEN('end', 'start') / 3", - "tsql": "SELECT DATEDIFF(quarter, 'start', 'end')", + "databricks": "SELECT DATEDIFF(quarter, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", + "spark": "SELECT DATEDIFF(quarter, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", + "spark2": "SELECT MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) / 3", + "tsql": "SELECT DATEDIFF(quarter, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))", }, ) @@ -943,8 +1009,15 @@ WHERE expr = parse_one("#x", read="tsql") self.assertIsInstance(expr, exp.Column) self.assertIsInstance(expr.this, exp.Identifier) + self.assertTrue(expr.this.args.get("temporary")) self.assertEqual(expr.sql("tsql"), "#x") + expr = parse_one("##x", read="tsql") + self.assertIsInstance(expr, exp.Column) + self.assertIsInstance(expr.this, exp.Identifier) + self.assertTrue(expr.this.args.get("global")) + self.assertEqual(expr.sql("tsql"), "##x") + expr = parse_one("@x", read="tsql") self.assertIsInstance(expr, exp.Parameter) self.assertIsInstance(expr.this, exp.Var) @@ -955,6 +1028,24 @@ WHERE self.assertIsInstance(table.this, exp.Parameter) self.assertIsInstance(table.this.this, exp.Var) + def test_temp_table(self): + self.validate_all( + "SELECT * FROM #mytemptable", + write={ + "duckdb": "SELECT * FROM mytemptable", + "spark": "SELECT * FROM mytemptable", + "tsql": "SELECT * FROM #mytemptable", + }, + ) + self.validate_all( + "SELECT * FROM ##mytemptable", + write={ + "duckdb": "SELECT * FROM mytemptable", + "spark": "SELECT * FROM mytemptable", + "tsql": "SELECT * FROM ##mytemptable", + }, + ) + def test_system_time(self): self.validate_all( "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo'", diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index b460c15..10f77ac 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -730,6 +730,7 @@ WITH a AS (SELECT * FROM b) DELETE FROM a WITH a AS (SELECT * FROM b) CACHE TABLE a SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ? SELECT :hello, ? FROM x LIMIT :my_limit +SELECT a FROM b WHERE c IS ? SELECT * FROM x OFFSET @skip FETCH NEXT @take ROWS ONLY WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a @@ -848,3 +849,6 @@ SELECT * FROM current_date SELECT * FROM schema.current_date SELECT /*+ SOME_HINT(foo) */ 1 SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1) +/* comment1 */ INSERT INTO x /* comment2 */ VALUES (1, 2, 3) +/* comment1 */ UPDATE tbl /* comment2 */ SET x = 2 WHERE x < 2 +/* comment1 */ DELETE FROM x /* comment2 */ WHERE y > 1 diff --git a/tests/fixtures/optimizer/normalize_identifiers.sql b/tests/fixtures/optimizer/normalize_identifiers.sql index ddb755f..2ab4778 100644 --- a/tests/fixtures/optimizer/normalize_identifiers.sql +++ b/tests/fixtures/optimizer/normalize_identifiers.sql @@ -1,3 +1,10 @@ +foo; +foo; + +# dialect: snowflake +foo + "bar".baz; +FOO + "bar".BAZ; + SELECT a FROM x; SELECT a FROM x; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index 14f5cfe..981e052 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -638,7 +638,7 @@ SELECT FROM "users" AS "u" CROSS JOIN LATERAL ( SELECT - "l"."log_date" + "l"."log_date" AS "log_date" FROM "logs" AS "l" WHERE "l"."log_date" <= 100 AND "l"."user_id" = "u"."user_id" @@ -890,3 +890,20 @@ FROM ( JOIN "y" AS "y" ON "x"."a" = "y"."c" ); + +# title: replace scalar subquery, wrap resulting column in a MAX +SELECT a, SUM(c) / (SELECT SUM(c) FROM y) * 100 AS foo FROM y INNER JOIN x ON y.b = x.b GROUP BY a; +WITH "_u_0" AS ( + SELECT + SUM("y"."c") AS "_col_0" + FROM "y" AS "y" +) +SELECT + "x"."a" AS "a", + SUM("y"."c") / MAX("_u_0"."_col_0") * 100 AS "foo" +FROM "y" AS "y" +CROSS JOIN "_u_0" AS "_u_0" +JOIN "x" AS "x" + ON "y"."b" = "x"."b" +GROUP BY + "x"."a"; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 90505ac..8a2519e 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -93,6 +93,16 @@ SELECT 2 AS "2" FROM x AS x GROUP BY 1; SELECT 'a' AS a FROM x GROUP BY 1; SELECT 'a' AS a FROM x AS x GROUP BY 1; +# execute: false +# dialect: oracle +SELECT t."col" FROM tbl t; +SELECT T."col" AS "col" FROM TBL T; + +# execute: false +# dialect: oracle +WITH base AS (SELECT x.dummy AS COL_1 FROM dual x) SELECT b."COL_1" FROM base b; +WITH BASE AS (SELECT X.DUMMY AS COL_1 FROM DUAL X) SELECT B.COL_1 AS COL_1 FROM BASE B; + # execute: false -- this query seems to be invalid in postgres and duckdb but valid in bigquery SELECT 2 a FROM x GROUP BY 1 HAVING a > 1; diff --git a/tests/fixtures/optimizer/qualify_columns_ddl.sql b/tests/fixtures/optimizer/qualify_columns_ddl.sql new file mode 100644 index 0000000..87e0f6d --- /dev/null +++ b/tests/fixtures/optimizer/qualify_columns_ddl.sql @@ -0,0 +1,35 @@ +# title: Create with CTE +WITH cte AS (SELECT b FROM y) CREATE TABLE s AS SELECT * FROM cte; +WITH cte AS (SELECT y.b AS b FROM y AS y) CREATE TABLE s AS SELECT cte.b AS b FROM cte; + +# title: Create without CTE +CREATE TABLE foo AS SELECT a FROM tbl; +CREATE TABLE foo AS SELECT tbl.a AS a FROM tbl AS tbl; + +# title: Create with complex CTE with derived table +WITH cte AS (SELECT a FROM (SELECT a from x)) CREATE TABLE s AS SELECT * FROM cte; +WITH cte AS (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) CREATE TABLE s AS SELECT cte.a AS a FROM cte; + +# title: Create wtih multiple CTEs +WITH cte1 AS (SELECT b FROM y), cte2 AS (SELECT b FROM cte1) CREATE TABLE s AS SELECT * FROM cte2; +WITH cte1 AS (SELECT y.b AS b FROM y AS y), cte2 AS (SELECT cte1.b AS b FROM cte1) CREATE TABLE s AS SELECT cte2.b AS b FROM cte2; + +# title: Create with multiple CTEs, selecting only from the first CTE (unnecessary code) +WITH cte1 AS (SELECT b FROM y), cte2 AS (SELECT b FROM cte1) CREATE TABLE s AS SELECT * FROM cte1; +WITH cte1 AS (SELECT y.b AS b FROM y AS y), cte2 AS (SELECT cte1.b AS b FROM cte1) CREATE TABLE s AS SELECT cte1.b AS b FROM cte1; + +# title: Create with multiple derived tables +CREATE TABLE s AS SELECT * FROM (SELECT b FROM (SELECT b FROM y)); +CREATE TABLE s AS SELECT _q_1.b AS b FROM (SELECT _q_0.b AS b FROM (SELECT y.b AS b FROM y AS y) AS _q_0) AS _q_1; + +# title: Create with a CTE and a derived table +WITH cte AS (SELECT b FROM y) CREATE TABLE s AS SELECT * FROM (SELECT b FROM (SELECT b FROM cte)); +WITH cte AS (SELECT y.b AS b FROM y AS y) CREATE TABLE s AS SELECT _q_1.b AS b FROM (SELECT _q_0.b AS b FROM (SELECT cte.b AS b FROM cte) AS _q_0) AS _q_1; + +# title: Insert with CTE +WITH cte AS (SELECT b FROM y) INSERT INTO s SELECT * FROM cte; +WITH cte AS (SELECT y.b AS b FROM y AS y) INSERT INTO s SELECT cte.b AS b FROM cte; + +# title: Insert without CTE +INSERT INTO foo SELECT a FROM tbl; +INSERT INTO foo SELECT tbl.a AS a FROM tbl AS tbl; diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index 8aaf50c..1205c33 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -1449,11 +1449,31 @@ WITH "_u_0" AS ( "store_sales"."ss_quantity" <= 80 AND "store_sales"."ss_quantity" >= 61 ) SELECT - CASE WHEN "_u_0"."_col_0" > 3672 THEN "_u_1"."_col_0" ELSE "_u_2"."_col_0" END AS "bucket1", - CASE WHEN "_u_3"."_col_0" > 3392 THEN "_u_4"."_col_0" ELSE "_u_5"."_col_0" END AS "bucket2", - CASE WHEN "_u_6"."_col_0" > 32784 THEN "_u_7"."_col_0" ELSE "_u_8"."_col_0" END AS "bucket3", - CASE WHEN "_u_9"."_col_0" > 26032 THEN "_u_10"."_col_0" ELSE "_u_11"."_col_0" END AS "bucket4", - CASE WHEN "_u_12"."_col_0" > 23982 THEN "_u_13"."_col_0" ELSE "_u_14"."_col_0" END AS "bucket5" + CASE + WHEN MAX("_u_0"."_col_0") > 3672 + THEN MAX("_u_1"."_col_0") + ELSE MAX("_u_2"."_col_0") + END AS "bucket1", + CASE + WHEN MAX("_u_3"."_col_0") > 3392 + THEN MAX("_u_4"."_col_0") + ELSE MAX("_u_5"."_col_0") + END AS "bucket2", + CASE + WHEN MAX("_u_6"."_col_0") > 32784 + THEN MAX("_u_7"."_col_0") + ELSE MAX("_u_8"."_col_0") + END AS "bucket3", + CASE + WHEN MAX("_u_9"."_col_0") > 26032 + THEN MAX("_u_10"."_col_0") + ELSE MAX("_u_11"."_col_0") + END AS "bucket4", + CASE + WHEN MAX("_u_12"."_col_0") > 23982 + THEN MAX("_u_13"."_col_0") + ELSE MAX("_u_14"."_col_0") + END AS "bucket5" FROM "reason" AS "reason" CROSS JOIN "_u_0" AS "_u_0" CROSS JOIN "_u_1" AS "_u_1" diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index cd0b9b1..64d7db7 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -242,7 +242,10 @@ class TestOptimizer(unittest.TestCase): "CREATE FUNCTION `udfs`.`myTest`(`x` FLOAT64) AS (1)", ) - self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema) + self.check_file( + "qualify_columns", qualify_columns, execute=True, schema=self.schema, set_dialect=True + ) + self.check_file("qualify_columns_ddl", qualify_columns, schema=self.schema) def test_qualify_columns__with_invisible(self): schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}}) @@ -448,6 +451,23 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') ) self.assertEqual(set(scopes[3].sources), {""}) + inner_query = "SELECT bar FROM baz" + for udtf in (f"UNNEST(({inner_query}))", f"LATERAL ({inner_query})"): + sql = f"SELECT a FROM foo CROSS JOIN {udtf}" + expression = parse_one(sql) + + for scopes in traverse_scope(expression), list(build_scope(expression).traverse()): + self.assertEqual(len(scopes), 3) + + self.assertEqual(scopes[0].expression.sql(), inner_query) + self.assertEqual(set(scopes[0].sources), {"baz"}) + + self.assertEqual(scopes[1].expression.sql(), udtf) + self.assertEqual(set(scopes[1].sources), {"", "foo"}) # foo is a lateral source + + self.assertEqual(scopes[2].expression.sql(), f"SELECT a FROM foo CROSS JOIN {udtf}") + self.assertEqual(set(scopes[2].sources), {"", "foo"}) + @patch("sqlglot.optimizer.scope.logger") def test_scope_warning(self, logger): self.assertEqual(len(traverse_scope(parse_one("WITH q AS (@y) SELECT * FROM q"))), 1) diff --git a/tests/test_parser.py b/tests/test_parser.py index 07686af..027a9ca 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -253,7 +253,7 @@ class TestParser(unittest.TestCase): self.assertIsInstance(parse_one("INTERVAL '1' DAY").args["unit"], exp.Var) self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'") - def test_comments(self): + def test_comments_select(self): expression = parse_one( """ --comment1.1 @@ -277,6 +277,120 @@ class TestParser(unittest.TestCase): self.assertEqual(expression.expressions[4].comments, [""]) self.assertEqual(expression.expressions[5].comments, [" space"]) + def test_comments_select_cte(self): + expression = parse_one( + """ + /*comment1.1*/ + /*comment1.2*/ + WITH a AS (SELECT 1) + SELECT /*comment2*/ + a.* + FROM /*comment3*/ + a + """ + ) + + self.assertEqual(expression.comments, ["comment2"]) + self.assertEqual(expression.args.get("from").comments, ["comment3"]) + self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + + def test_comments_insert(self): + expression = parse_one( + """ + --comment1.1 + --comment1.2 + INSERT INTO /*comment1.3*/ + x /*comment2*/ + VALUES /*comment3*/ + (1, 'a', 2.0) + """ + ) + + self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"]) + self.assertEqual(expression.this.comments, ["comment2"]) + + def test_comments_insert_cte(self): + expression = parse_one( + """ + /*comment1.1*/ + /*comment1.2*/ + WITH a AS (SELECT 1) + INSERT INTO /*comment2*/ + b /*comment3*/ + SELECT * FROM a + """ + ) + + self.assertEqual(expression.comments, ["comment2"]) + self.assertEqual(expression.this.comments, ["comment3"]) + self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + + def test_comments_update(self): + expression = parse_one( + """ + --comment1.1 + --comment1.2 + UPDATE /*comment1.3*/ + tbl /*comment2*/ + SET /*comment3*/ + x = 2 + WHERE /*comment4*/ + x <> 2 + """ + ) + + self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"]) + self.assertEqual(expression.this.comments, ["comment2"]) + self.assertEqual(expression.args.get("where").comments, ["comment4"]) + + def test_comments_update_cte(self): + expression = parse_one( + """ + /*comment1.1*/ + /*comment1.2*/ + WITH a AS (SELECT * FROM b) + UPDATE /*comment2*/ + a /*comment3*/ + SET col = 1 + """ + ) + + self.assertEqual(expression.comments, ["comment2"]) + self.assertEqual(expression.this.comments, ["comment3"]) + self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + + def test_comments_delete(self): + expression = parse_one( + """ + --comment1.1 + --comment1.2 + DELETE /*comment1.3*/ + FROM /*comment2*/ + x /*comment3*/ + WHERE /*comment4*/ + y > 1 + """ + ) + + self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"]) + self.assertEqual(expression.this.comments, ["comment3"]) + self.assertEqual(expression.args.get("where").comments, ["comment4"]) + + def test_comments_delete_cte(self): + expression = parse_one( + """ + /*comment1.1*/ + /*comment1.2*/ + WITH a AS (SELECT * FROM b) + --comment2 + DELETE FROM a /*comment3*/ + """ + ) + + self.assertEqual(expression.comments, ["comment2"]) + self.assertEqual(expression.this.comments, ["comment3"]) + self.assertEqual(expression.args.get("with").comments, ["comment1.1", "comment1.2"]) + def test_type_literals(self): self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)")) self.assertEqual(parse_one("int.5"), parse_one("CAST(0.5 AS INT)")) @@ -528,7 +642,7 @@ class TestParser(unittest.TestCase): now = time.time() query = parse_one( """ - select * + SELECT * FROM a LEFT JOIN b ON a.id = b.id LEFT JOIN b ON a.id = b.id diff --git a/tests/test_schema.py b/tests/test_schema.py index b89754f..626fa11 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -238,3 +238,37 @@ class TestSchema(unittest.TestCase): schema = MappingSchema(schema={"Foo": {"`BaR`": "int"}}, dialect="bigquery") self.assertEqual(schema.column_names("Foo"), ["bar"]) self.assertEqual(schema.column_names("foo"), []) + + # Check that the schema's normalization setting can be overridden + schema = MappingSchema(schema={"X": {"y": "int"}}, normalize=False, dialect="snowflake") + self.assertEqual(schema.column_names("x", normalize=True), ["y"]) + + def test_same_number_of_qualifiers(self): + schema = MappingSchema({"x": {"y": {"c1": "int"}}}) + + with self.assertRaises(SchemaError) as ctx: + schema.add_table("z", {"c2": "int"}) + + self.assertEqual( + str(ctx.exception), + "Table z must match the schema's nesting level: 2.", + ) + + schema = MappingSchema() + schema.add_table("x.y", {"c1": "int"}) + + with self.assertRaises(SchemaError) as ctx: + schema.add_table("z", {"c2": "int"}) + + self.assertEqual( + str(ctx.exception), + "Table z must match the schema's nesting level: 2.", + ) + + with self.assertRaises(SchemaError) as ctx: + MappingSchema({"x": {"y": {"c1": "int"}}, "z": {"c2": "int"}}) + + self.assertEqual( + str(ctx.exception), + "Table z must match the schema's nesting level: 2.", + ) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 24d8c30..80d12ac 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -99,7 +99,7 @@ class TestTransforms(unittest.TestCase): self.validate( eliminate_qualify, "SELECT * FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1", - "SELECT * FROM (SELECT *, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w, p, o FROM qt) AS _t WHERE _w = 1", + "SELECT * FROM (SELECT *, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w FROM qt) AS _t WHERE _w = 1", ) self.validate( eliminate_qualify, -- cgit v1.2.3