diff options
Diffstat (limited to 'tests')
27 files changed, 447 insertions, 61 deletions
diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 37ea2e1..8b44b9f 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -1152,17 +1152,17 @@ class TestFunctions(unittest.TestCase): def test_regexp_extract(self): col_str = SF.regexp_extract("cola", r"(\d+)-(\d+)", 1) - self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col_str.sql()) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col_str.sql()) col = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)", 1) - self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)', 1)", col.sql()) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)', 1)", col.sql()) col_no_idx = SF.regexp_extract(SF.col("cola"), r"(\d+)-(\d+)") - self.assertEqual("REGEXP_EXTRACT(cola, '(\\\d+)-(\\\d+)')", col_no_idx.sql()) + self.assertEqual("REGEXP_EXTRACT(cola, '(\\\\d+)-(\\\\d+)')", col_no_idx.sql()) def test_regexp_replace(self): col_str = SF.regexp_replace("cola", r"(\d+)", "--") - self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col_str.sql()) + self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col_str.sql()) col = SF.regexp_replace(SF.col("cola"), r"(\d+)", "--") - self.assertEqual("REGEXP_REPLACE(cola, '(\\\d+)', '--')", col.sql()) + self.assertEqual("REGEXP_REPLACE(cola, '(\\\\d+)', '--')", col.sql()) def test_initcap(self): col_str = SF.initcap("cola") diff --git a/tests/dataframe/unit/test_window.py b/tests/dataframe/unit/test_window.py index 70a868a..45d736f 100644 --- a/tests/dataframe/unit/test_window.py +++ b/tests/dataframe/unit/test_window.py @@ -15,11 +15,11 @@ class TestDataframeWindow(unittest.TestCase): def test_window_spec_rows_between(self): rows_between = WindowSpec().rowsBetween(3, 5) - self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) + self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) def test_window_spec_range_between(self): range_between = WindowSpec().rangeBetween(3, 5) - self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) + self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) def test_window_partition_by(self): partition_by = Window.partitionBy(F.col("cola"), F.col("colb")) @@ -31,46 +31,46 @@ class TestDataframeWindow(unittest.TestCase): def test_window_rows_between(self): rows_between = Window.rowsBetween(3, 5) - self.assertEqual("OVER ( ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) + self.assertEqual("OVER (ROWS BETWEEN 3 PRECEDING AND 5 FOLLOWING)", rows_between.sql()) def test_window_range_between(self): range_between = Window.rangeBetween(3, 5) - self.assertEqual("OVER ( RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) + self.assertEqual("OVER (RANGE BETWEEN 3 PRECEDING AND 5 FOLLOWING)", range_between.sql()) def test_window_rows_unbounded(self): rows_between_unbounded_start = Window.rowsBetween(Window.unboundedPreceding, 2) self.assertEqual( - "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", + "OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", rows_between_unbounded_start.sql(), ) rows_between_unbounded_end = Window.rowsBetween(1, Window.unboundedFollowing) self.assertEqual( - "OVER ( ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + "OVER (ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_end.sql(), ) rows_between_unbounded_both = Window.rowsBetween( Window.unboundedPreceding, Window.unboundedFollowing ) self.assertEqual( - "OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", + "OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", rows_between_unbounded_both.sql(), ) def test_window_range_unbounded(self): range_between_unbounded_start = Window.rangeBetween(Window.unboundedPreceding, 2) self.assertEqual( - "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", + "OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING)", range_between_unbounded_start.sql(), ) range_between_unbounded_end = Window.rangeBetween(1, Window.unboundedFollowing) self.assertEqual( - "OVER ( RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", + "OVER (RANGE BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_end.sql(), ) range_between_unbounded_both = Window.rangeBetween( Window.unboundedPreceding, Window.unboundedFollowing ) self.assertEqual( - "OVER ( RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", + "OVER (RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", range_between_unbounded_both.sql(), ) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 258e47f..c61a2f3 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -125,7 +125,7 @@ class TestBigQuery(Validator): }, ) self.validate_all( - "CURRENT_DATE", + "CURRENT_TIMESTAMP()", read={ "tsql": "GETDATE()", }, @@ -300,6 +300,14 @@ class TestBigQuery(Validator): }, ) self.validate_all( + "SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)", + write={ + "spark": "SELECT cola, colb, colc FROM VALUES (1, 'test', NULL) AS tab(cola, colb, colc)", + "bigquery": "SELECT cola, colb, colc FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb, NULL AS colc)])", + "snowflake": "SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)", + }, + ) + self.validate_all( "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))", write={ "bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))", @@ -324,3 +332,35 @@ class TestBigQuery(Validator): "SELECT a, GROUP_CONCAT(b) FROM table GROUP BY a", write={"bigquery": "SELECT a, STRING_AGG(b) FROM table GROUP BY a"}, ) + + def test_remove_precision_parameterized_types(self): + self.validate_all( + "SELECT CAST(1 AS NUMERIC(10, 2))", + write={ + "bigquery": "SELECT CAST(1 AS NUMERIC)", + }, + ) + self.validate_all( + "CREATE TABLE test (a NUMERIC(10, 2))", + write={ + "bigquery": "CREATE TABLE test (a NUMERIC(10, 2))", + }, + ) + self.validate_all( + "SELECT CAST('1' AS STRING(10)) UNION ALL SELECT CAST('2' AS STRING(10))", + write={ + "bigquery": "SELECT CAST('1' AS STRING) UNION ALL SELECT CAST('2' AS STRING)", + }, + ) + self.validate_all( + "SELECT cola FROM (SELECT CAST('1' AS STRING(10)) AS cola UNION ALL SELECT CAST('2' AS STRING(10)) AS cola)", + write={ + "bigquery": "SELECT cola FROM (SELECT CAST('1' AS STRING) AS cola UNION ALL SELECT CAST('2' AS STRING) AS cola)", + }, + ) + self.validate_all( + "INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING(10)), CAST(14 AS STRING(10)))", + write={ + "bigquery": "INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING), CAST(14 AS STRING))", + }, + ) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index c95c967..109e9f3 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -14,6 +14,9 @@ class TestClickhouse(Validator): self.validate_identity("SELECT * FROM foo LEFT ASOF JOIN bla") self.validate_identity("SELECT * FROM foo ASOF JOIN bla") self.validate_identity("SELECT * FROM foo ANY JOIN bla") + self.validate_identity("SELECT quantile(0.5)(a)") + self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t") + self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)") self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", @@ -38,3 +41,9 @@ class TestClickhouse(Validator): "SELECT x #! comment", write={"": "SELECT x /* comment */"}, ) + self.validate_all( + "SELECT quantileIf(0.5)(a, true)", + write={ + "clickhouse": "SELECT quantileIf(0.5)(a, TRUE)", + }, + ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index ced7102..284a30d 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -85,7 +85,7 @@ class TestDialect(Validator): self.validate_all( "CAST(a AS BINARY(4))", write={ - "bigquery": "CAST(a AS BINARY(4))", + "bigquery": "CAST(a AS BINARY)", "clickhouse": "CAST(a AS BINARY(4))", "drill": "CAST(a AS VARBINARY(4))", "duckdb": "CAST(a AS BINARY(4))", @@ -104,7 +104,7 @@ class TestDialect(Validator): self.validate_all( "CAST(a AS VARBINARY(4))", write={ - "bigquery": "CAST(a AS VARBINARY(4))", + "bigquery": "CAST(a AS VARBINARY)", "clickhouse": "CAST(a AS VARBINARY(4))", "duckdb": "CAST(a AS VARBINARY(4))", "mysql": "CAST(a AS VARBINARY(4))", @@ -181,7 +181,7 @@ class TestDialect(Validator): self.validate_all( "CAST(a AS VARCHAR(3))", write={ - "bigquery": "CAST(a AS STRING(3))", + "bigquery": "CAST(a AS STRING)", "drill": "CAST(a AS VARCHAR(3))", "duckdb": "CAST(a AS TEXT(3))", "mysql": "CAST(a AS VARCHAR(3))", diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index a7f3b8f..bbf00b1 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -339,6 +339,24 @@ class TestHive(Validator): def test_hive(self): self.validate_all( + "SELECT A.1a AS b FROM test_a AS A", + write={ + "spark": "SELECT A.1a AS b FROM test_a AS A", + }, + ) + self.validate_all( + "SELECT 1_a AS a FROM test_table", + write={ + "spark": "SELECT 1_a AS a FROM test_table", + }, + ) + self.validate_all( + "SELECT a_b AS 1_a FROM test_table", + write={ + "spark": "SELECT a_b AS 1_a FROM test_table", + }, + ) + self.validate_all( "PERCENTILE(x, 0.5)", write={ "duckdb": "QUANTILE(x, 0.5)", @@ -411,7 +429,7 @@ class TestHive(Validator): "INITCAP('new york')", write={ "duckdb": "INITCAP('new york')", - "presto": "REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))", + "presto": r"REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))", "hive": "INITCAP('new york')", "spark": "INITCAP('new york')", }, diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 1e048d5..583d349 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -122,6 +122,10 @@ class TestPostgres(Validator): "TO_TIMESTAMP(123::DOUBLE PRECISION)", write={"postgres": "TO_TIMESTAMP(CAST(123 AS DOUBLE PRECISION))"}, ) + self.validate_all( + "SELECT to_timestamp(123)::time without time zone", + write={"postgres": "SELECT CAST(TO_TIMESTAMP(123) AS TIME)"}, + ) self.validate_identity( "CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)" diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 70e1059..ee535e9 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -60,11 +60,11 @@ class TestPresto(Validator): self.validate_all( "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", write={ - "bigquery": "CAST(x AS TIMESTAMPTZ(9))", + "bigquery": "CAST(x AS TIMESTAMPTZ)", "duckdb": "CAST(x AS TIMESTAMPTZ(9))", "presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", - "hive": "CAST(x AS TIMESTAMPTZ(9))", - "spark": "CAST(x AS TIMESTAMPTZ(9))", + "hive": "CAST(x AS TIMESTAMPTZ)", + "spark": "CAST(x AS TIMESTAMPTZ)", }, ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index df62c6c..0e9ce9b 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -523,3 +523,33 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA "spark": "SELECT `c0`, `c1` FROM (VALUES (1, 2), (3, 4)) AS `t0`(`c0`, `c1`)", }, ) + + def test_describe_table(self): + self.validate_all( + "DESCRIBE TABLE db.table", + write={ + "snowflake": "DESCRIBE TABLE db.table", + "spark": "DESCRIBE db.table", + }, + ) + self.validate_all( + "DESCRIBE db.table", + write={ + "snowflake": "DESCRIBE TABLE db.table", + "spark": "DESCRIBE db.table", + }, + ) + self.validate_all( + "DESC TABLE db.table", + write={ + "snowflake": "DESCRIBE TABLE db.table", + "spark": "DESCRIBE db.table", + }, + ) + self.validate_all( + "DESC VIEW db.table", + write={ + "snowflake": "DESCRIBE VIEW db.table", + "spark": "DESCRIBE db.table", + }, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 7395e72..f287a89 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -207,6 +207,7 @@ TBLPROPERTIES ( ) def test_spark(self): + self.validate_identity("SELECT UNIX_TIMESTAMP()") self.validate_all( "ARRAY_SORT(x, (left, right) -> -1)", write={ diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index b4ac094..b74c05f 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -6,6 +6,8 @@ class TestTSQL(Validator): def test_tsql(self): self.validate_identity('SELECT "x"."y" FROM foo') + self.validate_identity("SELECT * FROM #foo") + self.validate_identity("SELECT * FROM ##foo") self.validate_identity( "SELECT DISTINCT DepartmentName, PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY BaseRate) OVER (PARTITION BY DepartmentName) AS MedianCont FROM dbo.DimEmployee" ) @@ -71,6 +73,12 @@ class TestTSQL(Validator): "tsql": "CAST(x AS DATETIME2)", }, ) + self.validate_all( + "CAST(x AS DATETIME2(6))", + write={ + "hive": "CAST(x AS TIMESTAMP)", + }, + ) def test_charindex(self): self.validate_all( @@ -300,6 +308,12 @@ class TestTSQL(Validator): "spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y", }, ) + self.validate_all( + "SELECT CAST((SELECT x FROM y) AS VARCHAR) AS test", + write={ + "spark": "SELECT CAST((SELECT x FROM y) AS STRING) AS test", + }, + ) def test_add_date(self): self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')") @@ -441,3 +455,13 @@ class TestTSQL(Validator): "SELECT '''test'''", write={"spark": r"SELECT '\'test\''"}, ) + + def test_eomonth(self): + self.validate_all( + "EOMONTH(GETDATE())", + write={"spark": "LAST_DAY(CURRENT_TIMESTAMP())"}, + ) + self.validate_all( + "EOMONTH(GETDATE(), -1)", + write={"spark": "LAST_DAY(ADD_MONTHS(CURRENT_TIMESTAMP(), -1))"}, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index e6a6e6b..beb5703 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -89,6 +89,7 @@ POSEXPLODE("x") AS ("a", "b") POSEXPLODE("x") AS ("a", "b", "c") STR_POSITION(x, 'a') STR_POSITION(x, 'a', 3) +LEVENSHTEIN('gumbo', 'gambol', 2, 1, 1) SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)] x[ORDINAL(1)][SAFE_OFFSET(2)] x LIKE SUBSTR('abc', 1, 1) @@ -425,6 +426,7 @@ SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 AND 3) SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND 3) SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING) +SELECT AVG(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC) SELECT SUM(x) FILTER(WHERE x > 1) @@ -450,14 +452,24 @@ SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b) SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b) SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score) SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score) +SELECT * FROM t WITH (TABLOCK, INDEX(myindex)) +SELECT * FROM t WITH (NOWAIT) +CREATE TABLE foo AS (SELECT 1) UNION ALL (SELECT 2) CREATE TABLE foo (id INT PRIMARY KEY ASC) CREATE TABLE a.b AS SELECT 1 +CREATE TABLE a.b AS SELECT 1 WITH DATA AND STATISTICS +CREATE TABLE a.b AS SELECT 1 WITH NO DATA AND NO STATISTICS +CREATE TABLE a.b AS (SELECT 1) NO PRIMARY INDEX +CREATE TABLE a.b AS (SELECT 1) UNIQUE PRIMARY INDEX index1 (a) UNIQUE INDEX index2 (b) +CREATE TABLE a.b AS (SELECT 1) PRIMARY AMP INDEX index1 (a) UNIQUE INDEX index2 (b) CREATE TABLE a.b AS SELECT a FROM a.c CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d CREATE TEMPORARY TABLE x AS SELECT a FROM d CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d CREATE VIEW x AS SELECT a FROM b CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b +CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d +CREATE VIEW IF NOT EXISTS z (a, b COMMENT 'b', c COMMENT 'c') AS SELECT a, b, c FROM d CREATE OR REPLACE VIEW x AS SELECT * CREATE OR REPLACE TEMPORARY VIEW x AS SELECT * CREATE TEMPORARY VIEW x AS SELECT a FROM d @@ -490,6 +502,8 @@ CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT) CREATE TABLE z (a INT REFERENCES parent(b, c)) CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id)) CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c)) +CREATE VIEW z (a, b) +CREATE VIEW z (a, b COMMENT 'b', c COMMENT 'c') CREATE TEMPORARY FUNCTION f CREATE TEMPORARY FUNCTION f AS 'g' CREATE FUNCTION f @@ -559,6 +573,7 @@ INSERT INTO x.z IF EXISTS SELECT * FROM y INSERT INTO x VALUES (1, 'a', 2.0) INSERT INTO x VALUES (1, 'a', 2.0), (1, 'a', 3.0), (X(), y[1], z.x) INSERT INTO y (a, b, c) SELECT a, b, c FROM x +INSERT INTO y (SELECT 1) UNION (SELECT 2) INSERT OVERWRITE TABLE x IF EXISTS SELECT * FROM y INSERT OVERWRITE TABLE a.b IF EXISTS SELECT * FROM y INSERT OVERWRITE DIRECTORY 'x' SELECT 1 @@ -627,3 +642,4 @@ ALTER TABLE integers ALTER COLUMN i SET DEFAULT 10 ALTER TABLE integers ALTER COLUMN i DROP DEFAULT ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT +SELECT div.a FROM test_table AS div diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql index 4a3ad4b..4c06e42 100644 --- a/tests/fixtures/optimizer/merge_subqueries.sql +++ b/tests/fixtures/optimizer/merge_subqueries.sql @@ -311,3 +311,42 @@ FROM ON t1.cola = t2.cola; SELECT /*+ BROADCAST(a2) */ a1.cola AS cola, a2.cola AS cola FROM VALUES (1) AS a1(cola) JOIN VALUES (1) AS a2(cola) ON a1.cola = a2.cola; + +# title: Nested subquery selects from same table as another subquery +WITH i AS ( + SELECT + x.a AS a + FROM x AS x +), j AS ( + SELECT + x.a, + x.b + FROM x AS x +), k AS ( + SELECT + j.a, + j.b + FROM j AS j +) +SELECT + i.a, + k.b +FROM i AS i +LEFT JOIN k AS k +ON i.a = k.a; +SELECT x.a AS a, x_2.b AS b FROM x AS x LEFT JOIN x AS x_2 ON x.a = x_2.a; + +# title: Outer select joins on inner select join +WITH i AS ( + SELECT + x.a AS a + FROM y AS y + JOIN x AS x + ON y.b = x.b +) +SELECT + x.a AS a +FROM x AS x +LEFT JOIN i AS i + ON x.a = i.a; +WITH i AS (SELECT x.a AS a FROM y AS y JOIN x AS x ON y.b = x.b) SELECT x.a AS a FROM x AS x LEFT JOIN i AS i ON x.a = i.a; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index b502d81..664b3c7 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -105,7 +105,7 @@ LEFT JOIN "_u_0" AS "_u_0" JOIN "y" AS "y" ON "x"."b" = "y"."b" WHERE - "_u_0"."_col_0" >= 0 AND "x"."a" > 1 AND NOT "_u_0"."_u_1" IS NULL + "_u_0"."_col_0" >= 0 AND "x"."a" > 1 GROUP BY "x"."a"; diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql index 2a21f65..b9f6c3f 100644 --- a/tests/fixtures/optimizer/pushdown_projections.sql +++ b/tests/fixtures/optimizer/pushdown_projections.sql @@ -54,3 +54,6 @@ WITH t1 AS (SELECT q.cola AS cola FROM UNNEST(ARRAY(STRUCT(1 AS cola, 'test' AS SELECT x FROM VALUES(1, 2) AS q(x, y); SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y); + +SELECT i.a FROM x AS i LEFT JOIN (SELECT a, b FROM (SELECT a, b FROM x)) AS j ON i.a = j.a; +SELECT i.a AS a FROM x AS i LEFT JOIN (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) AS j ON i.a = j.a; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index cf4195d..4e9e70c 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -375,6 +375,18 @@ CAST('1998-12-01' AS DATE) - INTERVAL '90' foo; date '1998-12-01' + interval '90' foo; CAST('1998-12-01' AS DATE) + INTERVAL '90' foo; +CAST(x AS DATE) + interval '1' week; +CAST(x AS DATE) + INTERVAL '1' week; + +CAST('2008-11-11' AS DATETIME) + INTERVAL '5' MONTH; +CAST('2009-04-11 00:00:00' AS DATETIME); + +datetime '1998-12-01' - interval '90' day; +CAST('1998-09-02 00:00:00' AS DATETIME); + +CAST(x AS DATETIME) + interval '1' week; +CAST(x AS DATETIME) + INTERVAL '1' week; + -------------------------------------- -- Comparisons -------------------------------------- diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index 9c1f138..272fb26 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -150,7 +150,6 @@ WHERE "part"."p_size" = 15 AND "part"."p_type" LIKE '%BRASS' AND "partsupp"."ps_supplycost" = "_u_0"."_col_0" - AND NOT "_u_0"."_u_1" IS NULL ORDER BY "s_acctbal" DESC, "n_name", @@ -1008,7 +1007,7 @@ JOIN "part" AS "part" LEFT JOIN "_u_0" AS "_u_0" ON "_u_0"."_u_1" = "part"."p_partkey" WHERE - "lineitem"."l_quantity" < "_u_0"."_col_0" AND NOT "_u_0"."_u_1" IS NULL; + "lineitem"."l_quantity" < "_u_0"."_col_0"; -------------------------------------- -- TPC-H 18 @@ -1253,10 +1252,7 @@ WITH "_u_0" AS ( LEFT JOIN "_u_3" AS "_u_3" ON "partsupp"."ps_partkey" = "_u_3"."p_partkey" WHERE - "partsupp"."ps_availqty" > "_u_0"."_col_0" - AND NOT "_u_0"."_u_1" IS NULL - AND NOT "_u_0"."_u_2" IS NULL - AND NOT "_u_3"."p_partkey" IS NULL + "partsupp"."ps_availqty" > "_u_0"."_col_0" AND NOT "_u_3"."p_partkey" IS NULL GROUP BY "partsupp"."ps_suppkey" ) diff --git a/tests/fixtures/optimizer/unnest_subqueries.sql b/tests/fixtures/optimizer/unnest_subqueries.sql index a444945..9d760e0 100644 --- a/tests/fixtures/optimizer/unnest_subqueries.sql +++ b/tests/fixtures/optimizer/unnest_subqueries.sql @@ -22,6 +22,8 @@ WHERE AND x.a > ANY (SELECT y.a FROM y) AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10) AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10) + AND x.a > ALL (SELECT y.c FROM y WHERE y.a = x.a) + AND x.a > (SELECT COUNT(*) as d FROM y WHERE y.a = x.a) ; SELECT * @@ -130,37 +132,42 @@ LEFT JOIN ( y.a ) AS _u_15 ON x.a = _u_15.a +LEFT JOIN ( + SELECT + ARRAY_AGG(c), + y.a AS _u_20 + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS _u_19 + ON _u_19._u_20 = x.a +LEFT JOIN ( + SELECT + COUNT(*) AS d, + y.a AS _u_22 + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS _u_21 + ON _u_21._u_22 = x.a WHERE x.a = _u_0.a AND NOT "_u_1"."a" IS NULL AND NOT "_u_2"."b" IS NULL AND NOT "_u_3"."a" IS NULL + AND x.a = _u_4.b + AND x.a > _u_6.b + AND x.a = _u_8.a + AND NOT x.a = _u_9.a + AND ARRAY_ANY(_u_10.a, _x -> _x = x.a) AND ( - x.a = _u_4.b AND NOT _u_4._u_5 IS NULL - ) - AND ( - x.a > _u_6.b AND NOT _u_6._u_7 IS NULL - ) - AND ( - None = _u_8.a AND NOT _u_8.a IS NULL - ) - AND NOT ( - x.a = _u_9.a AND NOT _u_9.a IS NULL - ) - AND ( - ARRAY_ANY(_u_10.a, _x -> _x = x.a) AND NOT _u_10._u_11 IS NULL - ) - AND ( - ( - ( - x.a < _u_12.a AND NOT _u_12._u_13 IS NULL - ) AND NOT _u_12._u_13 IS NULL - ) - AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d) - ) - AND ( - NOT _u_15.a IS NULL AND NOT _u_15.a IS NULL + x.a < _u_12.a AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d) ) + AND NOT _u_15.a IS NULL AND x.a IN ( SELECT y.a AS a @@ -199,4 +206,6 @@ WHERE WHERE y.a = x.a OFFSET 10 - ); + ) + AND ARRAY_ALL(_u_19."", _x -> _x = x.a) + AND x.a > COALESCE(_u_21.d, 0); diff --git a/tests/helpers.py b/tests/helpers.py index 9abdaae..bab4da0 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -27,8 +27,7 @@ def assert_logger_contains(message, logger, level="error"): def load_sql_fixtures(filename): with open(os.path.join(FIXTURES_DIR, filename), encoding="utf-8") as f: - for sql in _filter_comments(f.read()).splitlines(): - yield sql + yield from _filter_comments(f.read()).splitlines() def load_sql_fixture_pairs(filename): diff --git a/tests/test_executor.py b/tests/test_executor.py index b705551..f45a5d4 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -401,6 +401,36 @@ class TestExecutor(unittest.TestCase): ], ) + def test_correlated_count(self): + tables = { + "parts": [{"pnum": 0, "qoh": 1}], + "supplies": [], + } + + schema = { + "parts": {"pnum": "int", "qoh": "int"}, + "supplies": {"pnum": "int", "shipdate": "int"}, + } + + self.assertEqual( + execute( + """ + select * + from parts + where parts.qoh >= ( + select count(supplies.shipdate) + 1 + from supplies + where supplies.pnum = parts.pnum and supplies.shipdate < 10 + ) + """, + tables=tables, + schema=schema, + ).rows, + [ + (0, 1), + ], + ) + def test_table_depth_mismatch(self): tables = {"table": []} schema = {"db": {"table": {"col": "VARCHAR"}}} diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 1e23983..906e08c 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -646,3 +646,72 @@ FROM foo""", exp.Column(this=exp.to_identifier("colb")), ], ) + + def test_values(self): + self.assertEqual( + exp.values([(1, 2), (3, 4)], "t", ["a", "b"]).sql(), + "(VALUES (1, 2), (3, 4)) AS t(a, b)", + ) + self.assertEqual( + exp.values( + [(1, 2), (3, 4)], + "t", + {"a": exp.DataType.build("TEXT"), "b": exp.DataType.build("TEXT")}, + ).sql(), + "(VALUES (CAST(1 AS TEXT), CAST(2 AS TEXT)), (3, 4)) AS t(a, b)", + ) + with self.assertRaises(ValueError): + exp.values([(1, 2), (3, 4)], columns=["a"]) + + def test_data_type_builder(self): + self.assertEqual(exp.DataType.build("TEXT").sql(), "TEXT") + self.assertEqual(exp.DataType.build("DECIMAL(10, 2)").sql(), "DECIMAL(10, 2)") + self.assertEqual(exp.DataType.build("VARCHAR(255)").sql(), "VARCHAR(255)") + self.assertEqual(exp.DataType.build("ARRAY<INT>").sql(), "ARRAY<INT>") + self.assertEqual(exp.DataType.build("CHAR").sql(), "CHAR") + self.assertEqual(exp.DataType.build("NCHAR").sql(), "CHAR") + self.assertEqual(exp.DataType.build("VARCHAR").sql(), "VARCHAR") + self.assertEqual(exp.DataType.build("NVARCHAR").sql(), "VARCHAR") + self.assertEqual(exp.DataType.build("TEXT").sql(), "TEXT") + self.assertEqual(exp.DataType.build("BINARY").sql(), "BINARY") + self.assertEqual(exp.DataType.build("VARBINARY").sql(), "VARBINARY") + self.assertEqual(exp.DataType.build("INT").sql(), "INT") + self.assertEqual(exp.DataType.build("TINYINT").sql(), "TINYINT") + self.assertEqual(exp.DataType.build("SMALLINT").sql(), "SMALLINT") + self.assertEqual(exp.DataType.build("BIGINT").sql(), "BIGINT") + self.assertEqual(exp.DataType.build("FLOAT").sql(), "FLOAT") + self.assertEqual(exp.DataType.build("DOUBLE").sql(), "DOUBLE") + self.assertEqual(exp.DataType.build("DECIMAL").sql(), "DECIMAL") + self.assertEqual(exp.DataType.build("BOOLEAN").sql(), "BOOLEAN") + self.assertEqual(exp.DataType.build("JSON").sql(), "JSON") + self.assertEqual(exp.DataType.build("JSONB").sql(), "JSONB") + self.assertEqual(exp.DataType.build("INTERVAL").sql(), "INTERVAL") + self.assertEqual(exp.DataType.build("TIME").sql(), "TIME") + self.assertEqual(exp.DataType.build("TIMESTAMP").sql(), "TIMESTAMP") + self.assertEqual(exp.DataType.build("TIMESTAMPTZ").sql(), "TIMESTAMPTZ") + self.assertEqual(exp.DataType.build("TIMESTAMPLTZ").sql(), "TIMESTAMPLTZ") + self.assertEqual(exp.DataType.build("DATE").sql(), "DATE") + self.assertEqual(exp.DataType.build("DATETIME").sql(), "DATETIME") + self.assertEqual(exp.DataType.build("ARRAY").sql(), "ARRAY") + self.assertEqual(exp.DataType.build("MAP").sql(), "MAP") + self.assertEqual(exp.DataType.build("UUID").sql(), "UUID") + self.assertEqual(exp.DataType.build("GEOGRAPHY").sql(), "GEOGRAPHY") + self.assertEqual(exp.DataType.build("GEOMETRY").sql(), "GEOMETRY") + self.assertEqual(exp.DataType.build("STRUCT").sql(), "STRUCT") + self.assertEqual(exp.DataType.build("NULLABLE").sql(), "NULLABLE") + self.assertEqual(exp.DataType.build("HLLSKETCH").sql(), "HLLSKETCH") + self.assertEqual(exp.DataType.build("HSTORE").sql(), "HSTORE") + self.assertEqual(exp.DataType.build("SUPER").sql(), "SUPER") + self.assertEqual(exp.DataType.build("SERIAL").sql(), "SERIAL") + self.assertEqual(exp.DataType.build("SMALLSERIAL").sql(), "SMALLSERIAL") + self.assertEqual(exp.DataType.build("BIGSERIAL").sql(), "BIGSERIAL") + self.assertEqual(exp.DataType.build("XML").sql(), "XML") + self.assertEqual(exp.DataType.build("UNIQUEIDENTIFIER").sql(), "UNIQUEIDENTIFIER") + self.assertEqual(exp.DataType.build("MONEY").sql(), "MONEY") + self.assertEqual(exp.DataType.build("SMALLMONEY").sql(), "SMALLMONEY") + self.assertEqual(exp.DataType.build("ROWVERSION").sql(), "ROWVERSION") + self.assertEqual(exp.DataType.build("IMAGE").sql(), "IMAGE") + self.assertEqual(exp.DataType.build("VARIANT").sql(), "VARIANT") + self.assertEqual(exp.DataType.build("OBJECT").sql(), "OBJECT") + self.assertEqual(exp.DataType.build("NULL").sql(), "NULL") + self.assertEqual(exp.DataType.build("UNKNOWN").sql(), "UNKNOWN") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 1c97be7..887f427 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -299,10 +299,10 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(set(scopes[6].sources), {"q", "z", "r", "s"}) self.assertEqual(len(scopes[6].columns), 6) - self.assertEqual(set(c.table for c in scopes[6].columns), {"r", "s"}) + self.assertEqual({c.table for c in scopes[6].columns}, {"r", "s"}) self.assertEqual(scopes[6].source_columns("q"), []) self.assertEqual(len(scopes[6].source_columns("r")), 2) - self.assertEqual(set(c.table for c in scopes[6].source_columns("r")), {"r"}) + self.assertEqual({c.table for c in scopes[6].source_columns("r")}, {"r"}) self.assertEqual({c.sql() for c in scopes[-1].find_all(exp.Column)}, {"r.b", "s.b"}) self.assertEqual(scopes[-1].find(exp.Column).sql(), "r.b") @@ -578,3 +578,16 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') scope_t, scope_y = build_scope(query).cte_scopes self.assertEqual(set(scope_t.cte_sources), {"t"}) self.assertEqual(set(scope_y.cte_sources), {"t", "y"}) + + def test_schema_with_spaces(self): + schema = { + "a": { + "b c": "text", + '"d e"': "text", + } + } + + self.assertEqual( + optimizer.optimize(parse_one("SELECT * FROM a"), schema=schema), + parse_one('SELECT "a"."b c" AS "b c", "a"."d e" AS "d e" FROM "a" AS "a"'), + ) diff --git a/tests/test_parser.py b/tests/test_parser.py index ae2e4cd..03b801b 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -8,7 +8,8 @@ from tests.helpers import assert_logger_contains class TestParser(unittest.TestCase): def test_parse_empty(self): - self.assertIsNone(parse_one("")) + with self.assertRaises(ParseError) as ctx: + parse_one("") def test_parse_into(self): self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join) @@ -90,6 +91,9 @@ class TestParser(unittest.TestCase): parse_one("""SELECT * FROM x CROSS JOIN y, z LATERAL VIEW EXPLODE(y)""").sql(), """SELECT * FROM x, z CROSS JOIN y LATERAL VIEW EXPLODE(y)""", ) + self.assertIsNone( + parse_one("create table a as (select b from c) index").find(exp.TableAlias) + ) def test_command(self): expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive") @@ -155,6 +159,11 @@ class TestParser(unittest.TestCase): assert expressions[0].args["from"].expressions[0].this.name == "a" assert expressions[1].args["from"].expressions[0].this.name == "b" + expressions = parse("SELECT 1; ; SELECT 2") + + assert len(expressions) == 3 + assert expressions[1] is None + def test_expression(self): ignore = Parser(error_level=ErrorLevel.IGNORE) self.assertIsInstance(ignore.expression(exp.Hint, expressions=[""]), exp.Hint) diff --git a/tests/test_schema.py b/tests/test_schema.py index 6c1ca9c..3dd9103 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -184,3 +184,19 @@ class TestSchema(unittest.TestCase): schema = MappingSchema({"foo": {"bar": parse_one("INT", into=exp.DataType)}}) self.assertEqual(schema.get_column_type("foo", "bar").this, exp.DataType.Type.INT) + + def test_schema_normalization(self): + schema = MappingSchema( + schema={"x": {"`y`": {"Z": {"a": "INT", "`B`": "VARCHAR"}, "w": {"C": "INT"}}}}, + dialect="spark", + ) + + table_z = exp.Table(this="z", db="y", catalog="x") + table_w = exp.Table(this="w", db="y", catalog="x") + + self.assertEqual(schema.column_names(table_z), ["a", "B"]) + self.assertEqual(schema.column_names(table_w), ["c"]) + + # Clickhouse supports both `` and "" for identifier quotes; sqlglot uses "" when generating sql + schema = MappingSchema(schema={"x": {"`y`": "INT"}}, dialect="clickhouse") + self.assertEqual(schema.column_names(exp.Table(this="x")), ["y"]) diff --git a/tests/test_serde.py b/tests/test_serde.py new file mode 100644 index 0000000..603a155 --- /dev/null +++ b/tests/test_serde.py @@ -0,0 +1,33 @@ +import json +import unittest + +from sqlglot import exp, parse_one +from sqlglot.optimizer.annotate_types import annotate_types +from tests.helpers import load_sql_fixtures + + +class CustomExpression(exp.Expression): + ... + + +class TestSerDe(unittest.TestCase): + def dump_load(self, expression): + return exp.Expression.load(json.loads(json.dumps(expression.dump()))) + + def test_serde(self): + for sql in load_sql_fixtures("identity.sql"): + with self.subTest(sql): + before = parse_one(sql) + after = self.dump_load(before) + self.assertEqual(before, after) + + def test_custom_expression(self): + before = CustomExpression() + after = self.dump_load(before) + self.assertEqual(before, after) + + def test_type_annotations(self): + before = annotate_types(parse_one("CAST('1' AS INT)")) + after = self.dump_load(before) + self.assertEqual(before.type, after.type) + self.assertEqual(before.this.type, after.this.type) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index cfb8d2b..cc9af7e 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1,7 +1,11 @@ import unittest from sqlglot import parse_one -from sqlglot.transforms import eliminate_distinct_on, unalias_group +from sqlglot.transforms import ( + eliminate_distinct_on, + remove_precision_parameterized_types, + unalias_group, +) class TestTime(unittest.TestCase): @@ -62,3 +66,10 @@ class TestTime(unittest.TestCase): "SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC", 'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1', ) + + def test_remove_precision_parameterized_types(self): + self.validate( + remove_precision_parameterized_types, + "SELECT CAST(1 AS DECIMAL(10, 2)), CAST('13' AS VARCHAR(10))", + "SELECT CAST(1 AS DECIMAL), CAST('13' AS VARCHAR)", + ) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 9253ded..3a7fea4 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -118,6 +118,11 @@ class TestTranspile(unittest.TestCase): "SELECT x FROM foo /* x */", ) self.validate( + """select x, -- + from foo""", + "SELECT x FROM foo", + ) + self.validate( """ -- comment 1 -- comment 2 |