From 1a60bbae98d3b530924a6807a55f8250de19ea86 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 2 Dec 2022 10:16:29 +0100 Subject: Adding upstream version 10.1.3. Signed-off-by: Daniel Baumann --- tests/dataframe/unit/test_functions.py | 2 +- tests/dialects/test_clickhouse.py | 4 + tests/dialects/test_dialect.py | 11 +- tests/dialects/test_duckdb.py | 6 ++ tests/dialects/test_hive.py | 5 +- tests/dialects/test_mysql.py | 19 +++- tests/dialects/test_postgres.py | 21 ++++ tests/dialects/test_presto.py | 7 ++ tests/dialects/test_redshift.py | 16 +++ tests/dialects/test_snowflake.py | 20 +++- tests/dialects/test_spark.py | 11 +- tests/dialects/test_sqlite.py | 4 + tests/dialects/test_tsql.py | 28 ++++- tests/fixtures/identity.sql | 7 ++ tests/fixtures/optimizer/eliminate_subqueries.sql | 12 +++ tests/fixtures/optimizer/lower_identities.sql | 41 ++++++++ tests/fixtures/optimizer/optimizer.sql | 15 +++ tests/fixtures/optimizer/simplify.sql | 9 ++ tests/fixtures/optimizer/tpc-h/tpc-h.sql | 57 +++++----- tests/fixtures/optimizer/unnest_subqueries.sql | 84 ++++++++------- tests/test_executor.py | 71 ++++++++++++- tests/test_expressions.py | 28 ++--- tests/test_optimizer.py | 6 +- tests/test_parser.py | 67 ++++++++++-- tests/test_tokens.py | 14 +-- tests/test_transforms.py | 29 +++++- tests/test_transpile.py | 120 ++++++++++++++++++++-- 27 files changed, 589 insertions(+), 125 deletions(-) create mode 100644 tests/fixtures/optimizer/lower_identities.sql (limited to 'tests') diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 8e5e5cd..99b140d 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -1276,7 +1276,7 @@ class TestFunctions(unittest.TestCase): col = SF.concat(SF.col("cola"), SF.col("colb")) self.assertEqual("CONCAT(cola, colb)", col.sql()) col_single = SF.concat("cola") - self.assertEqual("CONCAT(cola)", col_single.sql()) + self.assertEqual("cola", col_single.sql()) def test_array_position(self): col_str = SF.array_position("cola", SF.col("colb")) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index efb41bb..c95c967 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -10,6 +10,10 @@ class TestClickhouse(Validator): self.validate_identity("SELECT * FROM x AS y FINAL") self.validate_identity("'a' IN mapKeys(map('a', 1, 'b', 2))") self.validate_identity("CAST((1, 2) AS Tuple(a Int8, b Int16))") + self.validate_identity("SELECT * FROM foo LEFT ANY JOIN bla") + self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla") + self.validate_identity("SELECT * FROM foo ASOF JOIN bla") + self.validate_identity("SELECT * FROM foo ANY JOIN bla") self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 1b2f9c1..6033570 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -997,6 +997,13 @@ class TestDialect(Validator): "spark": "CONCAT_WS('-', x)", }, ) + self.validate_all( + "CONCAT(a)", + write={ + "mysql": "a", + "tsql": "a", + }, + ) self.validate_all( "IF(x > 1, 1, 0)", write={ @@ -1263,8 +1270,8 @@ class TestDialect(Validator): self.validate_all( """/* comment1 */ SELECT - x, -- comment2 - y -- comment3""", + x, /* comment2 */ + y /* comment3 */""", read={ "mysql": """SELECT # comment1 x, # comment2 diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 625156b..99b0493 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -89,6 +89,8 @@ class TestDuckDB(Validator): "presto": "CAST(COL AS ARRAY(BIGINT))", "hive": "CAST(COL AS ARRAY)", "spark": "CAST(COL AS ARRAY)", + "postgres": "CAST(COL AS BIGINT[])", + "snowflake": "CAST(COL AS ARRAY)", }, ) @@ -104,6 +106,10 @@ class TestDuckDB(Validator): "spark": "ARRAY(0, 1, 2)", }, ) + self.validate_all( + "SELECT ARRAY_LENGTH([0], 1) AS x", + write={"duckdb": "SELECT ARRAY_LENGTH(LIST_VALUE(0), 1) AS x"}, + ) self.validate_all( "REGEXP_MATCHES(x, y)", write={ diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 69c7630..22d7bce 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -139,7 +139,7 @@ class TestHive(Validator): "CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", write={ "duckdb": "CREATE TABLE test AS SELECT 1", - "presto": "CREATE TABLE test WITH (FORMAT='parquet', x='1', Z='2') AS SELECT 1", + "presto": "CREATE TABLE test WITH (FORMAT='PARQUET', x='1', Z='2') AS SELECT 1", "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", "spark": "CREATE TABLE test USING PARQUET TBLPROPERTIES ('x'='1', 'Z'='2') AS SELECT 1", }, @@ -459,6 +459,7 @@ class TestHive(Validator): "hive": "MAP(a, b, c, d)", "presto": "MAP(ARRAY[a, c], ARRAY[b, d])", "spark": "MAP(a, b, c, d)", + "snowflake": "OBJECT_CONSTRUCT(a, b, c, d)", }, write={ "": "MAP(ARRAY(a, c), ARRAY(b, d))", @@ -467,6 +468,7 @@ class TestHive(Validator): "presto": "MAP(ARRAY[a, c], ARRAY[b, d])", "hive": "MAP(a, b, c, d)", "spark": "MAP(a, b, c, d)", + "snowflake": "OBJECT_CONSTRUCT(a, b, c, d)", }, ) self.validate_all( @@ -476,6 +478,7 @@ class TestHive(Validator): "presto": "MAP(ARRAY[a], ARRAY[b])", "hive": "MAP(a, b)", "spark": "MAP(a, b)", + "snowflake": "OBJECT_CONSTRUCT(a, b)", }, ) self.validate_all( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index af98249..5064dbe 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -23,6 +23,8 @@ class TestMySQL(Validator): self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')") self.validate_identity("@@GLOBAL.max_connections") + self.validate_identity("CREATE TABLE A LIKE B") + # SET Commands self.validate_identity("SET @var_name = expr") self.validate_identity("SET @name = 43") @@ -177,14 +179,27 @@ class TestMySQL(Validator): "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", write={ "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR ',')", - "sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC)", + "sqlite": "GROUP_CONCAT(DISTINCT x)", + "tsql": "STRING_AGG(x, ',') WITHIN GROUP (ORDER BY y DESC)", + "postgres": "STRING_AGG(DISTINCT x, ',' ORDER BY y DESC NULLS LAST)", + }, + ) + self.validate_all( + "GROUP_CONCAT(x ORDER BY y SEPARATOR z)", + write={ + "mysql": "GROUP_CONCAT(x ORDER BY y SEPARATOR z)", + "sqlite": "GROUP_CONCAT(x, z)", + "tsql": "STRING_AGG(x, z) WITHIN GROUP (ORDER BY y)", + "postgres": "STRING_AGG(x, z ORDER BY y NULLS FIRST)", }, ) self.validate_all( "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')", write={ "mysql": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC SEPARATOR '')", - "sqlite": "GROUP_CONCAT(DISTINCT x ORDER BY y DESC, '')", + "sqlite": "GROUP_CONCAT(DISTINCT x, '')", + "tsql": "STRING_AGG(x, '') WITHIN GROUP (ORDER BY y DESC)", + "postgres": "STRING_AGG(DISTINCT x, '' ORDER BY y DESC NULLS LAST)", }, ) self.validate_identity( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 8294eea..cd6117c 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -6,6 +6,9 @@ class TestPostgres(Validator): dialect = "postgres" def test_ddl(self): + self.validate_identity("CREATE TABLE test (foo HSTORE)") + self.validate_identity("CREATE TABLE test (foo JSONB)") + self.validate_identity("CREATE TABLE test (foo VARCHAR(64)[])") self.validate_all( "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", write={ @@ -60,6 +63,12 @@ class TestPostgres(Validator): ) def test_postgres(self): + self.validate_identity("SELECT ARRAY[1, 2, 3]") + self.validate_identity("SELECT ARRAY_LENGTH(ARRAY[1, 2, 3], 1)") + self.validate_identity("STRING_AGG(x, y)") + self.validate_identity("STRING_AGG(x, ',' ORDER BY y)") + self.validate_identity("STRING_AGG(x, ',' ORDER BY y DESC)") + self.validate_identity("STRING_AGG(DISTINCT x, ',' ORDER BY y DESC)") self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END") self.validate_identity( "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END" @@ -86,6 +95,14 @@ class TestPostgres(Validator): self.validate_identity("SELECT e'\\xDEADBEEF'") self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") + self.validate_all( + "END WORK AND NO CHAIN", + write={"postgres": "COMMIT AND NO CHAIN"}, + ) + self.validate_all( + "END AND CHAIN", + write={"postgres": "COMMIT AND CHAIN"}, + ) self.validate_all( "CREATE TABLE x (a UUID, b BYTEA)", write={ @@ -95,6 +112,10 @@ class TestPostgres(Validator): "spark": "CREATE TABLE x (a UUID, b BINARY)", }, ) + + self.validate_identity( + "CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)" + ) self.validate_all( "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)", write={ diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 8179cf7..70e1059 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -13,6 +13,7 @@ class TestPresto(Validator): "duckdb": "CAST(a AS INT[])", "presto": "CAST(a AS ARRAY(INTEGER))", "spark": "CAST(a AS ARRAY)", + "snowflake": "CAST(a AS ARRAY)", }, ) self.validate_all( @@ -31,6 +32,7 @@ class TestPresto(Validator): "duckdb": "CAST(LIST_VALUE(1, 2) AS BIGINT[])", "presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))", "spark": "CAST(ARRAY(1, 2) AS ARRAY)", + "snowflake": "CAST([1, 2] AS ARRAY)", }, ) self.validate_all( @@ -41,6 +43,7 @@ class TestPresto(Validator): "presto": "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INTEGER, INTEGER))", "hive": "CAST(MAP(1, 1) AS MAP)", "spark": "CAST(MAP_FROM_ARRAYS(ARRAY(1), ARRAY(1)) AS MAP)", + "snowflake": "CAST(OBJECT_CONSTRUCT(1, 1) AS OBJECT)", }, ) self.validate_all( @@ -51,6 +54,7 @@ class TestPresto(Validator): "presto": "CAST(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INTEGER)))", "hive": "CAST(MAP('a', ARRAY(1), 'b', ARRAY(2), 'c', ARRAY(3)) AS MAP>)", "spark": "CAST(MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))) AS MAP>)", + "snowflake": "CAST(OBJECT_CONSTRUCT('a', [1], 'b', [2], 'c', [3]) AS OBJECT)", }, ) self.validate_all( @@ -393,6 +397,7 @@ class TestPresto(Validator): write={ "hive": UnsupportedError, "spark": "MAP_FROM_ARRAYS(a, b)", + "snowflake": UnsupportedError, }, ) self.validate_all( @@ -401,6 +406,7 @@ class TestPresto(Validator): "hive": "MAP(a, c, b, d)", "presto": "MAP(ARRAY[a, b], ARRAY[c, d])", "spark": "MAP_FROM_ARRAYS(ARRAY(a, b), ARRAY(c, d))", + "snowflake": "OBJECT_CONSTRUCT(a, c, b, d)", }, ) self.validate_all( @@ -409,6 +415,7 @@ class TestPresto(Validator): "hive": "MAP('a', 'b')", "presto": "MAP(ARRAY['a'], ARRAY['b'])", "spark": "MAP_FROM_ARRAYS(ARRAY('a'), ARRAY('b'))", + "snowflake": "OBJECT_CONSTRUCT('a', 'b')", }, ) self.validate_all( diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 5309a34..1943ee3 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -50,6 +50,12 @@ class TestRedshift(Validator): "redshift": 'SELECT tablename, "column" FROM pg_table_def WHERE "column" LIKE \'%start\\\\_%\' LIMIT 5' }, ) + self.validate_all( + "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", + write={ + "redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', + }, + ) def test_identity(self): self.validate_identity("CAST('bla' AS SUPER)") @@ -64,3 +70,13 @@ class TestRedshift(Validator): self.validate_identity( "SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'" ) + self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE AUTO") + self.validate_identity( + "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid)" + ) + self.validate_identity( + "COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'" + ) + self.validate_identity( + "UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'" + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 0e69f4e..baca269 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -172,13 +172,28 @@ class TestSnowflake(Validator): self.validate_all( "trim(date_column, 'UTC')", write={ + "bigquery": "TRIM(date_column, 'UTC')", "snowflake": "TRIM(date_column, 'UTC')", "postgres": "TRIM('UTC' FROM date_column)", }, ) self.validate_all( "trim(date_column)", - write={"snowflake": "TRIM(date_column)"}, + write={ + "snowflake": "TRIM(date_column)", + "bigquery": "TRIM(date_column)", + }, + ) + self.validate_all( + "DECODE(x, a, b, c, d)", + read={ + "": "MATCHES(x, a, b, c, d)", + }, + write={ + "": "MATCHES(x, a, b, c, d)", + "oracle": "DECODE(x, a, b, c, d)", + "snowflake": "DECODE(x, a, b, c, d)", + }, ) def test_null_treatment(self): @@ -370,7 +385,8 @@ class TestSnowflake(Validator): ) self.validate_all( - r"""SELECT * FROM TABLE(?)""", write={"snowflake": r"""SELECT * FROM TABLE(?)"""} + r"""SELECT * FROM TABLE(?)""", + write={"snowflake": r"""SELECT * FROM TABLE(?)"""}, ) self.validate_all( diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 4470722..3a9f918 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -32,13 +32,14 @@ class TestSpark(Validator): "presto": "CREATE TABLE db.example_table (col_a ARRAY(INTEGER), col_b ARRAY(ARRAY(INTEGER)))", "hive": "CREATE TABLE db.example_table (col_a ARRAY, col_b ARRAY>)", "spark": "CREATE TABLE db.example_table (col_a ARRAY, col_b ARRAY>)", + "snowflake": "CREATE TABLE db.example_table (col_a ARRAY, col_b ARRAY)", }, ) self.validate_all( "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", write={ "duckdb": "CREATE TABLE x", - "presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])", + "presto": "CREATE TABLE x WITH (TABLE_FORMAT='ICEBERG', PARTITIONED_BY=ARRAY['MONTHS'])", "hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", "spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", }, @@ -94,6 +95,13 @@ TBLPROPERTIES ( pretty=True, ) + self.validate_all( + "CACHE TABLE testCache OPTIONS ('storageLevel' 'DISK_ONLY') SELECT * FROM testData", + write={ + "spark": "CACHE TABLE testCache OPTIONS('storageLevel' = 'DISK_ONLY') AS SELECT * FROM testData" + }, + ) + def test_to_date(self): self.validate_all( "TO_DATE(x, 'yyyy-MM-dd')", @@ -271,6 +279,7 @@ TBLPROPERTIES ( "presto": "MAP(ARRAY[1], c)", "hive": "MAP(ARRAY(1), c)", "spark": "MAP_FROM_ARRAYS(ARRAY(1), c)", + "snowflake": "OBJECT_CONSTRUCT([1], c)", }, ) self.validate_all( diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index 3cc974c..e54a4bc 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -5,6 +5,10 @@ class TestSQLite(Validator): dialect = "sqlite" def test_ddl(self): + self.validate_all( + "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)", + write={"sqlite": "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)"}, + ) self.validate_all( """ CREATE TABLE "Track" diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index a60f48d..afdd48a 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -17,7 +17,6 @@ class TestTSQL(Validator): "spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo", }, ) - self.validate_all( "CONVERT(INT, CONVERT(NUMERIC, '444.75'))", write={ @@ -25,6 +24,33 @@ class TestTSQL(Validator): "tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)", }, ) + self.validate_all( + "STRING_AGG(x, y) WITHIN GROUP (ORDER BY z DESC)", + write={ + "tsql": "STRING_AGG(x, y) WITHIN GROUP (ORDER BY z DESC)", + "mysql": "GROUP_CONCAT(x ORDER BY z DESC SEPARATOR y)", + "sqlite": "GROUP_CONCAT(x, y)", + "postgres": "STRING_AGG(x, y ORDER BY z DESC NULLS LAST)", + }, + ) + self.validate_all( + "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z ASC)", + write={ + "tsql": "STRING_AGG(x, '|') WITHIN GROUP (ORDER BY z)", + "mysql": "GROUP_CONCAT(x ORDER BY z SEPARATOR '|')", + "sqlite": "GROUP_CONCAT(x, '|')", + "postgres": "STRING_AGG(x, '|' ORDER BY z NULLS FIRST)", + }, + ) + self.validate_all( + "STRING_AGG(x, '|')", + write={ + "tsql": "STRING_AGG(x, '|')", + "mysql": "GROUP_CONCAT(x SEPARATOR '|')", + "sqlite": "GROUP_CONCAT(x, '|')", + "postgres": "STRING_AGG(x, '|')", + }, + ) def test_types(self): self.validate_identity("CAST(x AS XML)") diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 75bd25d..06ab96d 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -34,6 +34,7 @@ x >> 1 x >> 1 | 1 & 1 ^ 1 x || y 1 - -1 +- -5 dec.x + y a.filter a.b.c @@ -438,6 +439,7 @@ SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b) SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b) SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score) SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score) +CREATE TABLE foo (id INT PRIMARY KEY ASC) CREATE TABLE a.b AS SELECT 1 CREATE TABLE a.b AS SELECT a FROM a.c CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d @@ -579,6 +581,7 @@ SELECT * FROM (tbl1 JOIN (tbl2 JOIN tbl3) ON bla = foo) SELECT * FROM (tbl1 JOIN LATERAL (SELECT * FROM bla) AS tbl) SELECT CAST(x AS INT) /* comment */ FROM foo SELECT a /* x */, b /* x */ +SELECT a /* x */ /* y */ /* z */, b /* k */ /* m */ SELECT * FROM foo /* x */, bla /* x */ SELECT 1 /* comment */ + 1 SELECT 1 /* c1 */ + 2 /* c2 */ @@ -588,3 +591,7 @@ SELECT x FROM a.b.c /* x */, e.f.g /* x */ SELECT FOO(x /* c */) /* FOO */, b /* b */ SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */ SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b' +SELECT x AS INTO FROM bla +SELECT * INTO newevent FROM event +SELECT * INTO TEMPORARY newevent FROM event +SELECT * INTO UNLOGGED newevent FROM event diff --git a/tests/fixtures/optimizer/eliminate_subqueries.sql b/tests/fixtures/optimizer/eliminate_subqueries.sql index f395c0a..c566657 100644 --- a/tests/fixtures/optimizer/eliminate_subqueries.sql +++ b/tests/fixtures/optimizer/eliminate_subqueries.sql @@ -77,3 +77,15 @@ WITH x_2 AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT x.id FROM x -- Existing duplicate CTE WITH y AS (SELECT a FROM x) SELECT a FROM (SELECT a FROM x) AS y JOIN y AS z; WITH y AS (SELECT a FROM x) SELECT a FROM y AS y JOIN y AS z; + +-- Nested CTE +WITH cte1 AS (SELECT a FROM x) SELECT a FROM (WITH cte2 AS (SELECT a FROM cte1) SELECT a FROM cte2); +WITH cte1 AS (SELECT a FROM x), cte2 AS (SELECT a FROM cte1), cte AS (SELECT a FROM cte2 AS cte2) SELECT a FROM cte AS cte; + +-- Nested CTE inside CTE +WITH cte1 AS (WITH cte2 AS (SELECT a FROM x) SELECT t.a FROM cte2 AS t) SELECT a FROM cte1; +WITH cte2 AS (SELECT a FROM x), cte1 AS (SELECT t.a FROM cte2 AS t) SELECT a FROM cte1; + +-- Duplicate CTE nested in CTE +WITH cte1 AS (SELECT a FROM x), cte2 AS (WITH cte3 AS (SELECT a FROM x) SELECT a FROM cte3) SELECT a FROM cte2; +WITH cte1 AS (SELECT a FROM x), cte2 AS (SELECT a FROM cte1 AS cte3) SELECT a FROM cte2; diff --git a/tests/fixtures/optimizer/lower_identities.sql b/tests/fixtures/optimizer/lower_identities.sql new file mode 100644 index 0000000..cea346f --- /dev/null +++ b/tests/fixtures/optimizer/lower_identities.sql @@ -0,0 +1,41 @@ +SELECT a FROM x; +SELECT a FROM x; + +SELECT "A" FROM "X"; +SELECT "A" FROM "X"; + +SELECT a AS A FROM x; +SELECT a AS A FROM x; + +SELECT * FROM x; +SELECT * FROM x; + +SELECT A FROM x; +SELECT a FROM x; + +SELECT a FROM X; +SELECT a FROM x; + +SELECT A AS A FROM (SELECT a AS A FROM x); +SELECT a AS A FROM (SELECT a AS a FROM x); + +SELECT a AS B FROM x ORDER BY B; +SELECT a AS B FROM x ORDER BY B; + +SELECT A FROM x ORDER BY A; +SELECT a FROM x ORDER BY a; + +SELECT A AS B FROM X GROUP BY A HAVING SUM(B) > 0; +SELECT a AS B FROM x GROUP BY a HAVING SUM(b) > 0; + +SELECT A AS B, SUM(B) AS C FROM X GROUP BY A HAVING C > 0; +SELECT a AS B, SUM(b) AS C FROM x GROUP BY a HAVING C > 0; + +SELECT A FROM X UNION SELECT A FROM X; +SELECT a FROM x UNION SELECT a FROM x; + +SELECT A AS A FROM X UNION SELECT A AS A FROM X; +SELECT a AS A FROM x UNION SELECT a AS A FROM x; + +(SELECT A AS A FROM X); +(SELECT a AS A FROM x); diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index a1e531b..a692c7d 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -276,3 +276,18 @@ SELECT /*+ COALESCE(3), FROM `x` AS `x` JOIN `y` AS `y` ON `x`.`b` = `y`.`b`; + +WITH cte1 AS ( + WITH cte2 AS ( + SELECT a, b FROM x + ) + SELECT a1 + FROM ( + WITH cte3 AS (SELECT 1) + SELECT a AS a1, b AS b1 FROM cte2 + ) +) +SELECT a1 FROM cte1; +SELECT + "x"."a" AS "a1" +FROM "x" AS "x"; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 7207ba2..d9c7779 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -274,6 +274,15 @@ TRUE; -(-1); 1; +- -+1; +1; + ++-1; +-1; + +++1; +1; + 0.06 - 0.01; 0.05; diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index 8138b11..4893743 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -666,11 +666,20 @@ WITH "supplier_2" AS ( FROM "nation" AS "nation" WHERE "nation"."n_name" = 'GERMANY' +), "_u_0" AS ( + SELECT + SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0" + FROM "partsupp" AS "partsupp" + JOIN "supplier_2" AS "supplier" + ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" + JOIN "nation_2" AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" ) SELECT "partsupp"."ps_partkey" AS "ps_partkey", SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value" FROM "partsupp" AS "partsupp" +CROSS JOIN "_u_0" AS "_u_0" JOIN "supplier_2" AS "supplier" ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" JOIN "nation_2" AS "nation" @@ -678,15 +687,7 @@ JOIN "nation_2" AS "nation" GROUP BY "partsupp"."ps_partkey" HAVING - SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > ( - SELECT - SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0" - FROM "partsupp" AS "partsupp" - JOIN "supplier_2" AS "supplier" - ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" - JOIN "nation_2" AS "nation" - ON "supplier"."s_nationkey" = "nation"."n_nationkey" - ) + SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > MAX("_u_0"."_col_0") ORDER BY "value" DESC; @@ -880,6 +881,10 @@ WITH "revenue" AS ( AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE) GROUP BY "lineitem"."l_suppkey" +), "_u_0" AS ( + SELECT + MAX("revenue"."total_revenue") AS "_col_0" + FROM "revenue" ) SELECT "supplier"."s_suppkey" AS "s_suppkey", @@ -889,12 +894,9 @@ SELECT "revenue"."total_revenue" AS "total_revenue" FROM "supplier" AS "supplier" JOIN "revenue" - ON "revenue"."total_revenue" = ( - SELECT - MAX("revenue"."total_revenue") AS "_col_0" - FROM "revenue" - ) - AND "supplier"."s_suppkey" = "revenue"."supplier_no" + ON "supplier"."s_suppkey" = "revenue"."supplier_no" +JOIN "_u_0" AS "_u_0" + ON "revenue"."total_revenue" = "_u_0"."_col_0" ORDER BY "s_suppkey"; @@ -1395,7 +1397,14 @@ order by cntrycode; WITH "_u_0" AS ( SELECT - "orders"."o_custkey" AS "_u_1" + AVG("customer"."c_acctbal") AS "_col_0" + FROM "customer" AS "customer" + WHERE + "customer"."c_acctbal" > 0.00 + AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') +), "_u_1" AS ( + SELECT + "orders"."o_custkey" AS "_u_2" FROM "orders" AS "orders" GROUP BY "orders"."o_custkey" @@ -1405,18 +1414,12 @@ SELECT COUNT(*) AS "numcust", SUM("customer"."c_acctbal") AS "totacctbal" FROM "customer" AS "customer" -LEFT JOIN "_u_0" AS "_u_0" - ON "_u_0"."_u_1" = "customer"."c_custkey" +JOIN "_u_0" AS "_u_0" + ON "customer"."c_acctbal" > "_u_0"."_col_0" +LEFT JOIN "_u_1" AS "_u_1" + ON "_u_1"."_u_2" = "customer"."c_custkey" WHERE - "_u_0"."_u_1" IS NULL - AND "customer"."c_acctbal" > ( - SELECT - AVG("customer"."c_acctbal") AS "_col_0" - FROM "customer" AS "customer" - WHERE - "customer"."c_acctbal" > 0.00 - AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') - ) + "_u_1"."_u_2" IS NULL AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') GROUP BY SUBSTRING("customer"."c_phone", 1, 2) diff --git a/tests/fixtures/optimizer/unnest_subqueries.sql b/tests/fixtures/optimizer/unnest_subqueries.sql index f53121a..dc373a0 100644 --- a/tests/fixtures/optimizer/unnest_subqueries.sql +++ b/tests/fixtures/optimizer/unnest_subqueries.sql @@ -1,10 +1,12 @@ +--SELECT x.a > (SELECT SUM(y.a) AS b FROM y) FROM x; -------------------------------------- -- Unnest Subqueries -------------------------------------- SELECT * FROM x AS x WHERE - x.a IN (SELECT y.a AS a FROM y) + x.a = (SELECT SUM(y.a) AS a FROM y) + AND x.a IN (SELECT y.a AS a FROM y) AND x.a IN (SELECT y.b AS b FROM y) AND x.a = ANY (SELECT y.a AS a FROM y) AND x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a) @@ -24,52 +26,57 @@ WHERE SELECT * FROM x AS x +CROSS JOIN ( + SELECT + SUM(y.a) AS a + FROM y +) AS "_u_0" LEFT JOIN ( SELECT y.a AS a FROM y GROUP BY y.a -) AS "_u_0" - ON x.a = "_u_0"."a" +) AS "_u_1" + ON x.a = "_u_1"."a" LEFT JOIN ( SELECT y.b AS b FROM y GROUP BY y.b -) AS "_u_1" - ON x.a = "_u_1"."b" +) AS "_u_2" + ON x.a = "_u_2"."b" LEFT JOIN ( SELECT y.a AS a FROM y GROUP BY y.a -) AS "_u_2" - ON x.a = "_u_2"."a" +) AS "_u_3" + ON x.a = "_u_3"."a" LEFT JOIN ( SELECT SUM(y.b) AS b, - y.a AS _u_4 + y.a AS _u_5 FROM y WHERE TRUE GROUP BY y.a -) AS "_u_3" - ON x.a = "_u_3"."_u_4" +) AS "_u_4" + ON x.a = "_u_4"."_u_5" LEFT JOIN ( SELECT SUM(y.b) AS b, - y.a AS _u_6 + y.a AS _u_7 FROM y WHERE TRUE GROUP BY y.a -) AS "_u_5" - ON x.a = "_u_5"."_u_6" +) AS "_u_6" + ON x.a = "_u_6"."_u_7" LEFT JOIN ( SELECT y.a AS a @@ -78,8 +85,8 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_7" - ON "_u_7".a = x.a +) AS "_u_8" + ON "_u_8".a = x.a LEFT JOIN ( SELECT y.a AS a @@ -88,31 +95,31 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_8" - ON "_u_8".a = x.a +) AS "_u_9" + ON "_u_9".a = x.a LEFT JOIN ( SELECT ARRAY_AGG(y.a) AS a, - y.b AS _u_10 + y.b AS _u_11 FROM y WHERE TRUE GROUP BY y.b -) AS "_u_9" - ON "_u_9"."_u_10" = x.a +) AS "_u_10" + ON "_u_10"."_u_11" = x.a LEFT JOIN ( SELECT SUM(y.a) AS a, - y.a AS _u_12, - ARRAY_AGG(y.b) AS _u_13 + y.a AS _u_13, + ARRAY_AGG(y.b) AS _u_14 FROM y WHERE TRUE AND TRUE AND TRUE GROUP BY y.a -) AS "_u_11" - ON "_u_11"."_u_12" = x.a AND "_u_11"."_u_12" = x.b +) AS "_u_12" + ON "_u_12"."_u_13" = x.a AND "_u_12"."_u_13" = x.b LEFT JOIN ( SELECT y.a AS a @@ -121,37 +128,38 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_14" - ON x.a = "_u_14".a +) AS "_u_15" + ON x.a = "_u_15".a WHERE - NOT "_u_0"."a" IS NULL - AND NOT "_u_1"."b" IS NULL - AND NOT "_u_2"."a" IS NULL + x.a = "_u_0".a + AND NOT "_u_1"."a" IS NULL + AND NOT "_u_2"."b" IS NULL + AND NOT "_u_3"."a" IS NULL AND ( - x.a = "_u_3".b AND NOT "_u_3"."_u_4" IS NULL + x.a = "_u_4".b AND NOT "_u_4"."_u_5" IS NULL ) AND ( - x.a > "_u_5".b AND NOT "_u_5"."_u_6" IS NULL + x.a > "_u_6".b AND NOT "_u_6"."_u_7" IS NULL ) AND ( - None = "_u_7".a AND NOT "_u_7".a IS NULL + None = "_u_8".a AND NOT "_u_8".a IS NULL ) AND NOT ( - x.a = "_u_8".a AND NOT "_u_8".a IS NULL + x.a = "_u_9".a AND NOT "_u_9".a IS NULL ) AND ( - ARRAY_ANY("_u_9".a, _x -> _x = x.a) AND NOT "_u_9"."_u_10" IS NULL + ARRAY_ANY("_u_10".a, _x -> _x = x.a) AND NOT "_u_10"."_u_11" IS NULL ) AND ( ( ( - x.a < "_u_11".a AND NOT "_u_11"."_u_12" IS NULL - ) AND NOT "_u_11"."_u_12" IS NULL + x.a < "_u_12".a AND NOT "_u_12"."_u_13" IS NULL + ) AND NOT "_u_12"."_u_13" IS NULL ) - AND ARRAY_ANY("_u_11"."_u_13", "_x" -> "_x" <> x.d) + AND ARRAY_ANY("_u_12"."_u_14", "_x" -> "_x" <> x.d) ) AND ( - NOT "_u_14".a IS NULL AND NOT "_u_14".a IS NULL + NOT "_u_15".a IS NULL AND NOT "_u_15".a IS NULL ) AND x.a IN ( SELECT diff --git a/tests/test_executor.py b/tests/test_executor.py index 2c4d7cd..9d452e4 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -68,13 +68,13 @@ class TestExecutor(unittest.TestCase): def test_execute_tpch(self): def to_csv(expression): - if isinstance(expression, exp.Table): + if isinstance(expression, exp.Table) and expression.name not in ("revenue"): return parse_one( f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}" ) return expression - for i, (sql, _) in enumerate(self.sqls[0:7]): + for i, (sql, _) in enumerate(self.sqls[0:16]): with self.subTest(f"tpch-h {i + 1}"): a = self.cached_execute(sql) sql = parse_one(sql).transform(to_csv).sql(pretty=True) @@ -165,6 +165,39 @@ class TestExecutor(unittest.TestCase): ["a"], [("a",)], ), + ( + "SELECT DISTINCT a FROM (SELECT 1 AS a UNION ALL SELECT 1 AS a)", + ["a"], + [(1,)], + ), + ( + "SELECT DISTINCT a, SUM(b) AS b " + "FROM (SELECT 'a' AS a, 1 AS b UNION ALL SELECT 'a' AS a, 2 AS b UNION ALL SELECT 'b' AS a, 1 AS b) " + "GROUP BY a " + "LIMIT 1", + ["a", "b"], + [("a", 3)], + ), + ( + "SELECT COUNT(1) AS a FROM (SELECT 1)", + ["a"], + [(1,)], + ), + ( + "SELECT COUNT(1) AS a FROM (SELECT 1) LIMIT 0", + ["a"], + [], + ), + ( + "SELECT a FROM x GROUP BY a LIMIT 0", + ["a"], + [], + ), + ( + "SELECT a FROM x LIMIT 0", + ["a"], + [], + ), ]: with self.subTest(sql): result = execute(sql, schema=schema, tables=tables) @@ -346,6 +379,28 @@ class TestExecutor(unittest.TestCase): ], ) + def test_execute_subqueries(self): + tables = { + "table": [ + {"a": 1, "b": 1}, + {"a": 2, "b": 2}, + ], + } + + self.assertEqual( + execute( + """ + SELECT * + FROM table + WHERE a = (SELECT MAX(a) FROM table) + """, + tables=tables, + ).rows, + [ + (2, 2), + ], + ) + def test_table_depth_mismatch(self): tables = {"table": []} schema = {"db": {"table": {"col": "VARCHAR"}}} @@ -401,6 +456,7 @@ class TestExecutor(unittest.TestCase): ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]), ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]), ("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]), + ("SELECT SUM(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0"], [(0,)]), ]: result = execute(sql) self.assertEqual(result.columns, tuple(cols)) @@ -462,7 +518,18 @@ class TestExecutor(unittest.TestCase): ("IF(false, 1, 0)", 0), ("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"), ("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)), + ("1 IN (1, 2, 3)", True), + ("1 IN (2, 3)", False), + ("NULL IS NULL", True), + ("NULL IS NOT NULL", False), + ("NULL = NULL", None), + ("NULL <> NULL", None), ]: with self.subTest(sql): result = execute(f"SELECT {sql}") self.assertEqual(result.rows, [(expected,)]) + + def test_case_sensitivity(self): + result = execute("SELECT A AS A FROM X", tables={"x": [{"a": 1}]}) + self.assertEqual(result.columns, ("A",)) + self.assertEqual(result.rows, [(1,)]) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index c0927ad..0e13ade 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -525,24 +525,14 @@ class TestExpressions(unittest.TestCase): ), exp.Properties( expressions=[ - exp.FileFormatProperty( - this=exp.Literal.string("FORMAT"), value=exp.Literal.string("parquet") - ), + exp.FileFormatProperty(this=exp.Literal.string("parquet")), exp.PartitionedByProperty( - this=exp.Literal.string("PARTITIONED_BY"), - value=exp.Tuple( - expressions=[exp.to_identifier("a"), exp.to_identifier("b")] - ), - ), - exp.AnonymousProperty( - this=exp.Literal.string("custom"), value=exp.Literal.number(1) - ), - exp.TableFormatProperty( - this=exp.Literal.string("TABLE_FORMAT"), - value=exp.to_identifier("test_format"), + this=exp.Tuple(expressions=[exp.to_identifier("a"), exp.to_identifier("b")]) ), - exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.null()), - exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.true()), + exp.Property(this=exp.Literal.string("custom"), value=exp.Literal.number(1)), + exp.TableFormatProperty(this=exp.to_identifier("test_format")), + exp.EngineProperty(this=exp.null()), + exp.CollateProperty(this=exp.true()), ] ), ) @@ -609,9 +599,9 @@ FROM foo""", """SELECT a, b AS B, - c, -- comment - d AS D, -- another comment - CAST(x AS INT) -- final comment + c, /* comment */ + d AS D, /* another comment */ + CAST(x AS INT) /* final comment */ FROM foo""", ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 6637a1d..ecf581d 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -85,9 +85,8 @@ class TestOptimizer(unittest.TestCase): if leave_tables_isolated is not None: func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated) - optimized = func(parse_one(sql, read=dialect), **func_kwargs) - with self.subTest(title): + optimized = func(parse_one(sql, read=dialect), **func_kwargs) self.assertEqual( expected, optimized.sql(pretty=pretty, dialect=dialect), @@ -168,6 +167,9 @@ class TestOptimizer(unittest.TestCase): def test_quote_identities(self): self.check_file("quote_identities", optimizer.quote_identities.quote_identities) + def test_lower_identities(self): + self.check_file("lower_identities", optimizer.lower_identities.lower_identities) + def test_pushdown_projection(self): def pushdown_projections(expression, **kwargs): expression = optimizer.qualify_tables.qualify_tables(expression) diff --git a/tests/test_parser.py b/tests/test_parser.py index c747ea3..fa7b589 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -15,6 +15,51 @@ class TestParser(unittest.TestCase): self.assertIsInstance(parse_one("int", into=exp.DataType), exp.DataType) self.assertIsInstance(parse_one("array", into=exp.DataType), exp.DataType) + def test_parse_into_error(self): + expected_message = "Failed to parse into []" + expected_errors = [ + { + "description": "Invalid expression / Unexpected token", + "line": 1, + "col": 1, + "start_context": "", + "highlight": "SELECT", + "end_context": " 1;", + "into_expression": exp.From, + } + ] + with self.assertRaises(ParseError) as ctx: + parse_one("SELECT 1;", "sqlite", [exp.From]) + self.assertEqual(str(ctx.exception), expected_message) + self.assertEqual(ctx.exception.errors, expected_errors) + + def test_parse_into_errors(self): + expected_message = "Failed to parse into [, ]" + expected_errors = [ + { + "description": "Invalid expression / Unexpected token", + "line": 1, + "col": 1, + "start_context": "", + "highlight": "SELECT", + "end_context": " 1;", + "into_expression": exp.From, + }, + { + "description": "Invalid expression / Unexpected token", + "line": 1, + "col": 1, + "start_context": "", + "highlight": "SELECT", + "end_context": " 1;", + "into_expression": exp.Join, + }, + ] + with self.assertRaises(ParseError) as ctx: + parse_one("SELECT 1;", "sqlite", [exp.From, exp.Join]) + self.assertEqual(str(ctx.exception), expected_message) + self.assertEqual(ctx.exception.errors, expected_errors) + def test_column(self): columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all(exp.Column) assert len(list(columns)) == 1 @@ -24,6 +69,9 @@ class TestParser(unittest.TestCase): def test_float(self): self.assertEqual(parse_one(".2"), parse_one("0.2")) + def test_unary_plus(self): + self.assertEqual(parse_one("+15"), exp.Literal.number(15)) + def test_table(self): tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)] self.assertEqual(tables, ["a", "b.c", "d"]) @@ -157,8 +205,9 @@ class TestParser(unittest.TestCase): def test_comments(self): expression = parse_one( """ - --comment1 - SELECT /* this won't be used */ + --comment1.1 + --comment1.2 + SELECT /*comment1.3*/ a, --comment2 b as B, --comment3:testing "test--annotation", @@ -169,13 +218,13 @@ class TestParser(unittest.TestCase): """ ) - self.assertEqual(expression.comment, "comment1") - self.assertEqual(expression.expressions[0].comment, "comment2") - self.assertEqual(expression.expressions[1].comment, "comment3:testing") - self.assertEqual(expression.expressions[2].comment, None) - self.assertEqual(expression.expressions[3].comment, "comment4 --foo") - self.assertEqual(expression.expressions[4].comment, "") - self.assertEqual(expression.expressions[5].comment, " space") + self.assertEqual(expression.comments, ["comment1.1", "comment1.2", "comment1.3"]) + self.assertEqual(expression.expressions[0].comments, ["comment2"]) + self.assertEqual(expression.expressions[1].comments, ["comment3:testing"]) + self.assertEqual(expression.expressions[2].comments, None) + self.assertEqual(expression.expressions[3].comments, ["comment4 --foo"]) + self.assertEqual(expression.expressions[4].comments, [""]) + self.assertEqual(expression.expressions[5].comments, [" space"]) def test_type_literals(self): self.assertEqual(parse_one("int 1"), parse_one("CAST(1 AS INT)")) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index d4772ba..1d1b966 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -7,13 +7,13 @@ class TestTokens(unittest.TestCase): def test_comment_attachment(self): tokenizer = Tokenizer() sql_comment = [ - ("/*comment*/ foo", "comment"), - ("/*comment*/ foo --test", "comment"), - ("--comment\nfoo --test", "comment"), - ("foo --comment", "comment"), - ("foo", None), - ("foo /*comment 1*/ /*comment 2*/", "comment 1"), + ("/*comment*/ foo", ["comment"]), + ("/*comment*/ foo --test", ["comment", "test"]), + ("--comment\nfoo --test", ["comment", "test"]), + ("foo --comment", ["comment"]), + ("foo", []), + ("foo /*comment 1*/ /*comment 2*/", ["comment 1", "comment 2"]), ] for sql, comment in sql_comment: - self.assertEqual(tokenizer.tokenize(sql)[0].comment, comment) + self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 1928d2c..0bcd2ca 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,7 +1,7 @@ import unittest from sqlglot import parse_one -from sqlglot.transforms import unalias_group +from sqlglot.transforms import eliminate_distinct_on, unalias_group class TestTime(unittest.TestCase): @@ -35,3 +35,30 @@ class TestTime(unittest.TestCase): "SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date", "SELECT the_date AS the_date, COUNT(*) AS the_count FROM x GROUP BY the_date", ) + + def test_eliminate_distinct_on(self): + self.validate( + eliminate_distinct_on, + "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", + 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', + ) + self.validate( + eliminate_distinct_on, + "SELECT DISTINCT ON (a) a, b FROM x", + 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a) AS "_row_number" FROM x) WHERE "_row_number" = 1', + ) + self.validate( + eliminate_distinct_on, + "SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC", + 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', + ) + self.validate( + eliminate_distinct_on, + "SELECT DISTINCT a, b FROM x ORDER BY c DESC", + "SELECT DISTINCT a, b FROM x ORDER BY c DESC", + ) + self.validate( + eliminate_distinct_on, + "SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC", + 'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS "_row_number_2" FROM x) WHERE "_row_number_2" = 1', + ) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 1bd2527..7bf53e5 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -26,6 +26,7 @@ class TestTranspile(unittest.TestCase): ) self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date") self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime") + self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row") for key in ("union", "filter", "over", "from", "join"): with self.subTest(f"alias {key}"): @@ -38,6 +39,11 @@ class TestTranspile(unittest.TestCase): def test_asc(self): self.validate("SELECT x FROM y ORDER BY x ASC", "SELECT x FROM y ORDER BY x") + def test_unary(self): + self.validate("+++1", "1") + self.validate("+-1", "-1") + self.validate("+- - -1", "- - -1") + def test_paren(self): with self.assertRaises(ParseError): transpile("1 + (2 + 3") @@ -58,7 +64,7 @@ class TestTranspile(unittest.TestCase): ) self.validate( "SELECT FOO, /*x*/\nBAR, /*y*/\nBAZ", - "SELECT\n FOO -- x\n , BAR -- y\n , BAZ", + "SELECT\n FOO /* x */\n , BAR /* y */\n , BAZ", leading_comma=True, pretty=True, ) @@ -78,7 +84,8 @@ class TestTranspile(unittest.TestCase): def test_comments(self): self.validate("SELECT */*comment*/", "SELECT * /* comment */") self.validate( - "SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */" + "SELECT * FROM table /*comment 1*/ /*comment 2*/", + "SELECT * FROM table /* comment 1 */ /* comment 2 */", ) self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */") self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo") @@ -112,6 +119,53 @@ class TestTranspile(unittest.TestCase): ) self.validate( """ +-- comment 1 +-- comment 2 +-- comment 3 +SELECT * FROM foo + """, + "/* comment 1 */ /* comment 2 */ /* comment 3 */ SELECT * FROM foo", + ) + self.validate( + """ +-- comment 1 +-- comment 2 +-- comment 3 +SELECT * FROM foo""", + """/* comment 1 */ +/* comment 2 */ +/* comment 3 */ +SELECT + * +FROM foo""", + pretty=True, + ) + self.validate( + """ +SELECT * FROM tbl /*line1 +line2 +line3*/ /*another comment*/ where 1=1 -- comment at the end""", + """SELECT * FROM tbl /* line1 +line2 +line3 */ /* another comment */ WHERE 1 = 1 /* comment at the end */""", + ) + self.validate( + """ +SELECT * FROM tbl /*line1 +line2 +line3*/ /*another comment*/ where 1=1 -- comment at the end""", + """SELECT + * +FROM tbl /* line1 +line2 +line3 */ +/* another comment */ +WHERE + 1 = 1 /* comment at the end */""", + pretty=True, + ) + self.validate( + """ /* multi line comment @@ -130,8 +184,8 @@ class TestTranspile(unittest.TestCase): */ SELECT tbl.cola /* comment 1 */ + tbl.colb /* comment 2 */, - CAST(x AS INT), -- comment 3 - y -- comment 4 + CAST(x AS INT), /* comment 3 */ + y /* comment 4 */ FROM bar /* comment 5 */, tbl /* comment 6 */""", read="mysql", pretty=True, @@ -364,33 +418,79 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", @mock.patch("sqlglot.parser.logger") def test_error_level(self, logger): invalid = "x + 1. (" - errors = [ + expected_messages = [ "Required keyword: 'expressions' missing for . Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", "Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", ] + expected_errors = [ + { + "description": "Required keyword: 'expressions' missing for ", + "line": 1, + "col": 8, + "start_context": "x + 1. ", + "highlight": "(", + "end_context": "", + "into_expression": None, + }, + { + "description": "Expecting )", + "line": 1, + "col": 8, + "start_context": "x + 1. ", + "highlight": "(", + "end_context": "", + "into_expression": None, + }, + ] transpile(invalid, error_level=ErrorLevel.WARN) - for error in errors: + for error in expected_messages: assert_logger_contains(error, logger) with self.assertRaises(ParseError) as ctx: transpile(invalid, error_level=ErrorLevel.IMMEDIATE) - self.assertEqual(str(ctx.exception), errors[0]) + self.assertEqual(str(ctx.exception), expected_messages[0]) + self.assertEqual(ctx.exception.errors[0], expected_errors[0]) with self.assertRaises(ParseError) as ctx: transpile(invalid, error_level=ErrorLevel.RAISE) - self.assertEqual(str(ctx.exception), "\n\n".join(errors)) + self.assertEqual(str(ctx.exception), "\n\n".join(expected_messages)) + self.assertEqual(ctx.exception.errors, expected_errors) more_than_max_errors = "((((" - expected = ( + expected_messages = ( "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" "Required keyword: 'this' missing for . Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" "... and 2 more" ) + expected_errors = [ + { + "description": "Expecting )", + "line": 1, + "col": 4, + "start_context": "(((", + "highlight": "(", + "end_context": "", + "into_expression": None, + }, + { + "description": "Required keyword: 'this' missing for ", + "line": 1, + "col": 4, + "start_context": "(((", + "highlight": "(", + "end_context": "", + "into_expression": None, + }, + ] + # Also expect three trailing structured errors that match the first + expected_errors += [expected_errors[0]] * 3 + with self.assertRaises(ParseError) as ctx: transpile(more_than_max_errors, error_level=ErrorLevel.RAISE) - self.assertEqual(str(ctx.exception), expected) + self.assertEqual(str(ctx.exception), expected_messages) + self.assertEqual(ctx.exception.errors, expected_errors) @mock.patch("sqlglot.generator.logger") def test_unsupported_level(self, logger): -- cgit v1.2.3