From 5d0ea770947ae1da51537ff75b14b48218d729aa Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 3 Mar 2024 15:11:03 +0100 Subject: Adding upstream version 22.2.0. Signed-off-by: Daniel Baumann --- tests/dataframe/unit/test_functions.py | 4 +- tests/dialects/test_bigquery.py | 164 ++++++++++++++++---- tests/dialects/test_clickhouse.py | 20 +++ tests/dialects/test_databricks.py | 5 + tests/dialects/test_dialect.py | 13 +- tests/dialects/test_doris.py | 10 ++ tests/dialects/test_duckdb.py | 61 +++++++- tests/dialects/test_hive.py | 9 +- tests/dialects/test_mysql.py | 53 +++---- tests/dialects/test_oracle.py | 22 ++- tests/dialects/test_postgres.py | 45 ++++++ tests/dialects/test_presto.py | 16 ++ tests/dialects/test_redshift.py | 13 +- tests/dialects/test_snowflake.py | 167 +++++++++++++++------ tests/dialects/test_spark.py | 1 + tests/dialects/test_sqlite.py | 12 ++ tests/dialects/test_tsql.py | 142 ++++++++++++++---- tests/fixtures/identity.sql | 15 +- tests/fixtures/optimizer/optimizer.sql | 2 +- tests/fixtures/optimizer/qualify_columns.sql | 12 ++ tests/fixtures/optimizer/qualify_tables.sql | 23 ++- tests/fixtures/optimizer/quote_identifiers.sql | 4 + tests/fixtures/optimizer/tpc-ds/call_center.csv.gz | Bin 421 -> 425 bytes .../fixtures/optimizer/tpc-ds/catalog_page.csv.gz | Bin 463753 -> 460883 bytes .../optimizer/tpc-ds/catalog_returns.csv.gz | Bin 157676 -> 158215 bytes .../fixtures/optimizer/tpc-ds/catalog_sales.csv.gz | Bin 1803802 -> 1814673 bytes tests/fixtures/optimizer/tpc-ds/customer.csv.gz | Bin 107615 -> 107573 bytes .../optimizer/tpc-ds/customer_address.csv.gz | Bin 28336 -> 28719 bytes .../optimizer/tpc-ds/customer_demographics.csv.gz | Bin 126457 -> 126715 bytes tests/fixtures/optimizer/tpc-ds/date_dim.csv.gz | Bin 1531293 -> 1575448 bytes .../optimizer/tpc-ds/household_demographics.csv.gz | Bin 23425 -> 23544 bytes tests/fixtures/optimizer/tpc-ds/income_band.csv.gz | Bin 188 -> 191 bytes tests/fixtures/optimizer/tpc-ds/inventory.csv.gz | Bin 206882 -> 202661 bytes tests/fixtures/optimizer/tpc-ds/item.csv.gz | Bin 31392 -> 31336 bytes tests/fixtures/optimizer/tpc-ds/promotion.csv.gz | Bin 497 -> 501 bytes tests/fixtures/optimizer/tpc-ds/reason.csv.gz | Bin 81 -> 83 bytes tests/fixtures/optimizer/tpc-ds/ship_mode.csv.gz | Bin 617 -> 633 bytes tests/fixtures/optimizer/tpc-ds/store.csv.gz | Bin 396 -> 397 bytes .../fixtures/optimizer/tpc-ds/store_returns.csv.gz | Bin 254858 -> 255650 bytes tests/fixtures/optimizer/tpc-ds/store_sales.csv.gz | Bin 2417178 -> 2436694 bytes tests/fixtures/optimizer/tpc-ds/time_dim.csv.gz | Bin 668972 -> 680588 bytes tests/fixtures/optimizer/tpc-ds/tpc-ds.sql | 25 +++ tests/fixtures/optimizer/tpc-ds/warehouse.csv.gz | Bin 218 -> 221 bytes tests/fixtures/optimizer/tpc-ds/web_page.csv.gz | Bin 208 -> 212 bytes tests/fixtures/optimizer/tpc-ds/web_returns.csv.gz | Bin 67542 -> 67833 bytes tests/fixtures/optimizer/tpc-ds/web_sales.csv.gz | Bin 864379 -> 867887 bytes tests/fixtures/optimizer/tpc-ds/web_site.csv.gz | Bin 404 -> 406 bytes tests/fixtures/optimizer/tpc-h/tpc-h.sql | 1 - tests/fixtures/optimizer/unnest_subqueries.sql | 16 ++ tests/test_build.py | 20 +++ tests/test_diff.py | 43 ++++-- tests/test_executor.py | 81 +++++++--- tests/test_expressions.py | 15 +- tests/test_lineage.py | 65 +++++--- tests/test_optimizer.py | 48 +++++- tests/test_parser.py | 3 + tests/test_serde.py | 3 +- tests/test_transpile.py | 2 - 58 files changed, 895 insertions(+), 240 deletions(-) (limited to 'tests') diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 24904b7..e40d50d 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -280,9 +280,9 @@ class TestFunctions(unittest.TestCase): def test_signum(self): col_str = SF.signum("cola") - self.assertEqual("SIGNUM(cola)", col_str.sql()) + self.assertEqual("SIGN(cola)", col_str.sql()) col = SF.signum(SF.col("cola")) - self.assertEqual("SIGNUM(cola)", col.sql()) + self.assertEqual("SIGN(cola)", col.sql()) def test_sin(self): col_str = SF.sin("cola") diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 8c47948..0d94d19 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -5,7 +5,9 @@ from sqlglot import ( ParseError, TokenError, UnsupportedError, + exp, parse, + parse_one, transpile, ) from sqlglot.helper import logger as helper_logger @@ -18,6 +20,51 @@ class TestBigQuery(Validator): maxDiff = None def test_bigquery(self): + self.validate_all( + "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 as a, 'abc' AS b), STRUCT(str_col AS abc)", + write={ + "bigquery": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 AS a, 'abc' AS b), STRUCT(str_col AS abc)", + "duckdb": "SELECT {'_0': 1, '_1': 2, '_2': 3}, {}, {'_0': 'abc'}, {'_0': 1, '_1': t.str_col}, {'a': 1, 'b': 'abc'}, {'abc': str_col}", + "hive": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1, 'abc'), STRUCT(str_col)", + "spark2": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 AS a, 'abc' AS b), STRUCT(str_col AS abc)", + "spark": "SELECT STRUCT(1, 2, 3), STRUCT(), STRUCT('abc'), STRUCT(1, t.str_col), STRUCT(1 AS a, 'abc' AS b), STRUCT(str_col AS abc)", + "snowflake": "SELECT OBJECT_CONSTRUCT('_0', 1, '_1', 2, '_2', 3), OBJECT_CONSTRUCT(), OBJECT_CONSTRUCT('_0', 'abc'), OBJECT_CONSTRUCT('_0', 1, '_1', t.str_col), OBJECT_CONSTRUCT('a', 1, 'b', 'abc'), OBJECT_CONSTRUCT('abc', str_col)", + # fallback to unnamed without type inference + "trino": "SELECT ROW(1, 2, 3), ROW(), ROW('abc'), ROW(1, t.str_col), CAST(ROW(1, 'abc') AS ROW(a INTEGER, b VARCHAR)), ROW(str_col)", + }, + ) + self.validate_all( + "PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)", + write={ + "bigquery": "PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)", + "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S.%f%z')", + }, + ) + + table = parse_one("x-0._y.z", dialect="bigquery", into=exp.Table) + self.assertEqual(table.catalog, "x-0") + self.assertEqual(table.db, "_y") + self.assertEqual(table.name, "z") + + table = parse_one("x-0._y", dialect="bigquery", into=exp.Table) + self.assertEqual(table.db, "x-0") + self.assertEqual(table.name, "_y") + + self.validate_identity("SELECT * FROM x-0.y") + self.assertEqual(exp.to_table("`x.y.z`", dialect="bigquery").sql(), '"x"."y"."z"') + self.assertEqual(exp.to_table("`x.y.z`", dialect="bigquery").sql("bigquery"), "`x.y.z`") + self.assertEqual(exp.to_table("`x`.`y`", dialect="bigquery").sql("bigquery"), "`x`.`y`") + + select_with_quoted_udf = self.validate_identity("SELECT `p.d.UdF`(data) FROM `p.d.t`") + self.assertEqual(select_with_quoted_udf.selects[0].name, "p.d.UdF") + + self.validate_identity("SELECT `p.d.UdF`(data).* FROM `p.d.t`") + self.validate_identity("SELECT * FROM `my-project.my-dataset.my-table`") + self.validate_identity("CREATE OR REPLACE TABLE `a.b.c` CLONE `a.b.d`") + self.validate_identity("SELECT x, 1 AS y GROUP BY 1 ORDER BY 1") + self.validate_identity("SELECT * FROM x.*") + self.validate_identity("SELECT * FROM x.y*") + self.validate_identity("CASE A WHEN 90 THEN 'red' WHEN 50 THEN 'blue' ELSE 'green' END") self.validate_identity("CREATE SCHEMA x DEFAULT COLLATE 'en'") self.validate_identity("CREATE TABLE x (y INT64) DEFAULT COLLATE 'en'") self.validate_identity("PARSE_JSON('{}', wide_number_mode => 'exact')") @@ -90,6 +137,16 @@ class TestBigQuery(Validator): self.validate_identity("LOG(n, b)") self.validate_identity("SELECT COUNT(x RESPECT NULLS)") self.validate_identity("SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x") + self.validate_identity("SELECT ARRAY((SELECT AS STRUCT 1 AS a, 2 AS b))") + self.validate_identity("SELECT ARRAY((SELECT AS STRUCT 1 AS a, 2 AS b) LIMIT 10)") + self.validate_identity("CAST(x AS CHAR)", "CAST(x AS STRING)") + self.validate_identity("CAST(x AS NCHAR)", "CAST(x AS STRING)") + self.validate_identity("CAST(x AS NVARCHAR)", "CAST(x AS STRING)") + self.validate_identity("CAST(x AS TIMESTAMPTZ)", "CAST(x AS TIMESTAMP)") + self.validate_identity("CAST(x AS RECORD)", "CAST(x AS STRUCT)") + self.validate_identity( + "SELECT * FROM `SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW`" + ) self.validate_identity( "SELECT * FROM test QUALIFY a IS DISTINCT FROM b WINDOW c AS (PARTITION BY d)" ) @@ -120,6 +177,10 @@ class TestBigQuery(Validator): self.validate_identity( """SELECT JSON_EXTRACT_SCALAR('5')""", """SELECT JSON_EXTRACT_SCALAR('5', '$')""" ) + self.validate_identity( + "SELECT ARRAY(SELECT AS STRUCT 1 a, 2 b)", + "SELECT ARRAY(SELECT AS STRUCT 1 AS a, 2 AS b)", + ) self.validate_identity( "select array_contains([1, 2, 3], 1)", "SELECT EXISTS(SELECT 1 FROM UNNEST([1, 2, 3]) AS _col WHERE _col = 1)", @@ -168,10 +229,6 @@ class TestBigQuery(Validator): """SELECT JSON '"foo"' AS json_data""", """SELECT PARSE_JSON('"foo"') AS json_data""", ) - self.validate_identity( - "CREATE OR REPLACE TABLE `a.b.c` CLONE `a.b.d`", - "CREATE OR REPLACE TABLE a.b.c CLONE a.b.d", - ) self.validate_identity( "SELECT * FROM UNNEST(x) WITH OFFSET EXCEPT DISTINCT SELECT * FROM UNNEST(y) WITH OFFSET", "SELECT * FROM UNNEST(x) WITH OFFSET AS offset EXCEPT DISTINCT SELECT * FROM UNNEST(y) WITH OFFSET AS offset", @@ -185,6 +242,39 @@ class TestBigQuery(Validator): r"REGEXP_EXTRACT(svc_plugin_output, '\\\\\\((.*)')", ) + self.validate_all( + "PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)", + write={ + "bigquery": "PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E6S%z', x)", + "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S.%f%z')", + }, + ) + self.validate_all( + "SELECT results FROM Coordinates, Coordinates.position AS results", + write={ + "bigquery": "SELECT results FROM Coordinates, UNNEST(Coordinates.position) AS results", + "presto": "SELECT results FROM Coordinates, UNNEST(Coordinates.position) AS _t(results)", + }, + ) + self.validate_all( + "SELECT results FROM Coordinates, `Coordinates.position` AS results", + write={ + "bigquery": "SELECT results FROM Coordinates, `Coordinates.position` AS results", + "presto": 'SELECT results FROM Coordinates, "Coordinates"."position" AS results', + }, + ) + self.validate_all( + "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS results", + read={ + "presto": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS _t(results)", + "redshift": "SELECT results FROM Coordinates AS c, c.position AS results", + }, + write={ + "bigquery": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS results", + "presto": "SELECT results FROM Coordinates AS c, UNNEST(c.position) AS _t(results)", + "redshift": "SELECT results FROM Coordinates AS c, c.position AS results", + }, + ) self.validate_all( "TIMESTAMP(x)", write={ @@ -434,8 +524,8 @@ class TestBigQuery(Validator): self.validate_all( "CREATE OR REPLACE TABLE `a.b.c` COPY `a.b.d`", write={ - "bigquery": "CREATE OR REPLACE TABLE a.b.c COPY a.b.d", - "snowflake": "CREATE OR REPLACE TABLE a.b.c CLONE a.b.d", + "bigquery": "CREATE OR REPLACE TABLE `a.b.c` COPY `a.b.d`", + "snowflake": 'CREATE OR REPLACE TABLE "a"."b"."c" CLONE "a"."b"."d"', }, ) ( @@ -475,11 +565,6 @@ class TestBigQuery(Validator): ), ) self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"}) - self.validate_all("CAST(x AS CHAR)", write={"bigquery": "CAST(x AS STRING)"}) - self.validate_all("CAST(x AS NCHAR)", write={"bigquery": "CAST(x AS STRING)"}) - self.validate_all("CAST(x AS NVARCHAR)", write={"bigquery": "CAST(x AS STRING)"}) - self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"bigquery": "CAST(x AS TIMESTAMP)"}) - self.validate_all("CAST(x AS RECORD)", write={"bigquery": "CAST(x AS STRUCT)"}) self.validate_all( 'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)', write={ @@ -566,11 +651,11 @@ class TestBigQuery(Validator): read={"spark": "select posexplode_outer([])"}, ) self.validate_all( - "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z", + "SELECT AS STRUCT ARRAY(SELECT AS STRUCT 1 AS b FROM x) AS y FROM z", write={ - "": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z", - "bigquery": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z", - "duckdb": "SELECT {'y': ARRAY(SELECT {'b': b} FROM x)} FROM z", + "": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT 1 AS b FROM x) AS y FROM z", + "bigquery": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT 1 AS b FROM x) AS y FROM z", + "duckdb": "SELECT {'y': ARRAY(SELECT {'b': 1} FROM x)} FROM z", }, ) self.validate_all( @@ -585,25 +670,9 @@ class TestBigQuery(Validator): "bigquery": "PARSE_TIMESTAMP('%Y.%m.%d %I:%M:%S%z', x)", }, ) - self.validate_all( + self.validate_identity( "CREATE TEMP TABLE foo AS SELECT 1", - write={"bigquery": "CREATE TEMPORARY TABLE foo AS SELECT 1"}, - ) - self.validate_all( - "SELECT * FROM `SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW`", - write={ - "bigquery": "SELECT * FROM SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW", - }, - ) - self.validate_all( - "SELECT * FROM `my-project.my-dataset.my-table`", - write={"bigquery": "SELECT * FROM `my-project`.`my-dataset`.`my-table`"}, - ) - self.validate_all( - "SELECT ARRAY(SELECT AS STRUCT 1 a, 2 b)", - write={ - "bigquery": "SELECT ARRAY(SELECT AS STRUCT 1 AS a, 2 AS b)", - }, + "CREATE TEMPORARY TABLE foo AS SELECT 1", ) self.validate_all( "REGEXP_CONTAINS('foo', '.*')", @@ -1088,6 +1157,35 @@ WHERE self.assertIn("unsupported syntax", cm.output[0]) + with self.assertLogs(helper_logger): + statements = parse( + """ + BEGIN + DECLARE MY_VAR INT64 DEFAULT 1; + SET MY_VAR = (SELECT 0); + + IF MY_VAR = 1 THEN SELECT 'TRUE'; + ELSEIF MY_VAR = 0 THEN SELECT 'FALSE'; + ELSE SELECT 'NULL'; + END IF; + END + """, + read="bigquery", + ) + + expected_statements = ( + "BEGIN DECLARE MY_VAR INT64 DEFAULT 1", + "SET MY_VAR = (SELECT 0)", + "IF MY_VAR = 1 THEN SELECT 'TRUE'", + "ELSEIF MY_VAR = 0 THEN SELECT 'FALSE'", + "ELSE SELECT 'NULL'", + "END IF", + "END", + ) + + for actual, expected in zip(statements, expected_statements): + self.assertEqual(actual.sql(dialect="bigquery"), expected) + with self.assertLogs(helper_logger) as cm: self.validate_identity( "SELECT * FROM t AS t(c1, c2)", diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 0148812..edf3da1 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -6,6 +6,21 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): + self.validate_all( + "SELECT * FROM x PREWHERE y = 1 WHERE z = 2", + write={ + "": "SELECT * FROM x WHERE z = 2", + "clickhouse": "SELECT * FROM x PREWHERE y = 1 WHERE z = 2", + }, + ) + self.validate_all( + "SELECT * FROM x AS prewhere", + read={ + "clickhouse": "SELECT * FROM x AS prewhere", + "duckdb": "SELECT * FROM x prewhere", + }, + ) + self.validate_identity("SELECT * FROM x LIMIT 1 UNION ALL SELECT * FROM y") string_types = [ @@ -77,6 +92,7 @@ class TestClickhouse(Validator): self.validate_identity("""SELECT JSONExtractString('{"x": {"y": 1}}', 'x', 'y')""") self.validate_identity("SELECT * FROM table LIMIT 1 BY a, b") self.validate_identity("SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b") + self.validate_identity( "SELECT $1$foo$1$", "SELECT 'foo'", @@ -134,6 +150,9 @@ class TestClickhouse(Validator): self.validate_identity( "CREATE MATERIALIZED VIEW test_view (id UInt8) TO db.table1 AS SELECT * FROM test_data" ) + self.validate_identity("TRUNCATE TABLE t1 ON CLUSTER test_cluster") + self.validate_identity("TRUNCATE DATABASE db") + self.validate_identity("TRUNCATE DATABASE db ON CLUSTER test_cluster") self.validate_all( "SELECT arrayJoin([1,2,3])", @@ -373,6 +392,7 @@ class TestClickhouse(Validator): def test_cte(self): self.validate_identity("WITH 'x' AS foo SELECT foo") + self.validate_identity("WITH ['c'] AS field_names SELECT field_names") self.validate_identity("WITH SUM(bytes) AS foo SELECT foo FROM system.parts") self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5") self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1") diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 8222170..94f2dc2 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -38,6 +38,11 @@ class TestDatabricks(Validator): "CREATE FUNCTION add_one(x INT) RETURNS INT LANGUAGE PYTHON AS $FOO$def add_one(x):\n return x+1$FOO$" ) + self.validate_identity("TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', address)") + self.validate_identity( + "TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', city LIKE 'LA')" + ) + self.validate_all( "CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))", write={ diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index b50fec8..5faed51 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1108,6 +1108,11 @@ class TestDialect(Validator): ) def test_order_by(self): + self.validate_identity( + "SELECT c FROM t ORDER BY a, b,", + "SELECT c FROM t ORDER BY a, b", + ) + self.validate_all( "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", write={ @@ -1777,7 +1782,7 @@ class TestDialect(Validator): "CREATE TABLE t (c CHAR, nc NCHAR, v1 VARCHAR, v2 VARCHAR2, nv NVARCHAR, nv2 NVARCHAR2)", write={ "duckdb": "CREATE TABLE t (c TEXT, nc TEXT, v1 TEXT, v2 TEXT, nv TEXT, nv2 TEXT)", - "hive": "CREATE TABLE t (c CHAR, nc CHAR, v1 STRING, v2 STRING, nv STRING, nv2 STRING)", + "hive": "CREATE TABLE t (c STRING, nc STRING, v1 STRING, v2 STRING, nv STRING, nv2 STRING)", "oracle": "CREATE TABLE t (c CHAR, nc NCHAR, v1 VARCHAR2, v2 VARCHAR2, nv NVARCHAR2, nv2 NVARCHAR2)", "postgres": "CREATE TABLE t (c CHAR, nc CHAR, v1 VARCHAR, v2 VARCHAR, nv VARCHAR, nv2 VARCHAR)", "sqlite": "CREATE TABLE t (c TEXT, nc TEXT, v1 TEXT, v2 TEXT, nv TEXT, nv2 TEXT)", @@ -2301,3 +2306,9 @@ SELECT "tsql": UnsupportedError, }, ) + + def test_truncate(self): + self.validate_identity("TRUNCATE TABLE table") + self.validate_identity("TRUNCATE TABLE db.schema.test") + self.validate_identity("TRUNCATE TABLE IF EXISTS db.schema.test") + self.validate_identity("TRUNCATE TABLE t1, t2, t3") diff --git a/tests/dialects/test_doris.py b/tests/dialects/test_doris.py index 5ae23ad..035289b 100644 --- a/tests/dialects/test_doris.py +++ b/tests/dialects/test_doris.py @@ -26,6 +26,16 @@ class TestDoris(Validator): "doris": "SELECT ARRAY_SUM(x -> x * x, ARRAY(2, 3))", }, ) + self.validate_all( + "MONTHS_ADD(d, n)", + read={ + "oracle": "ADD_MONTHS(d, n)", + }, + write={ + "doris": "MONTHS_ADD(d, n)", + "oracle": "ADD_MONTHS(d, n)", + }, + ) def test_identity(self): self.validate_identity("COALECSE(a, b, c, d)") diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 9c48f69..58d1f06 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -7,9 +7,14 @@ class TestDuckDB(Validator): dialect = "duckdb" def test_duckdb(self): - struct_pack = parse_one('STRUCT_PACK("a b" := 1)', read="duckdb") - self.assertIsInstance(struct_pack.expressions[0].this, exp.Identifier) - self.assertEqual(struct_pack.sql(dialect="duckdb"), "{'a b': 1}") + self.validate_all( + 'STRUCT_PACK("a b" := 1)', + write={ + "duckdb": "{'a b': 1}", + "spark": "STRUCT(1 AS `a b`)", + "snowflake": "OBJECT_CONSTRUCT('a b', 1)", + }, + ) self.validate_all( "SELECT SUM(X) OVER (ORDER BY x)", @@ -52,8 +57,21 @@ class TestDuckDB(Validator): exp.select("*").from_("t").offset(exp.select("5").subquery()).sql(dialect="duckdb"), ) - for struct_value in ("{'a': 1}", "struct_pack(a := 1)"): - self.validate_all(struct_value, write={"presto": UnsupportedError}) + self.validate_all( + "{'a': 1, 'b': '2'}", write={"presto": "CAST(ROW(1, '2') AS ROW(a INTEGER, b VARCHAR))"} + ) + self.validate_all( + "struct_pack(a := 1, b := 2)", + write={"presto": "CAST(ROW(1, 2) AS ROW(a INTEGER, b INTEGER))"}, + ) + + self.validate_all( + "struct_pack(a := 1, b := x)", + write={ + "duckdb": "{'a': 1, 'b': x}", + "presto": UnsupportedError, + }, + ) for join_type in ("SEMI", "ANTI"): exists = "EXISTS" if join_type == "SEMI" else "NOT EXISTS" @@ -171,7 +189,6 @@ class TestDuckDB(Validator): }, ) - self.validate_identity("SELECT i FROM RANGE(5) AS _(i) ORDER BY i ASC") self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y") self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x") self.validate_identity("SELECT SUM(x) FILTER (x = 1)", "SELECT SUM(x) FILTER(WHERE x = 1)") @@ -209,6 +226,10 @@ class TestDuckDB(Validator): self.validate_identity("FROM (FROM tbl)", "SELECT * FROM (SELECT * FROM tbl)") self.validate_identity("FROM tbl", "SELECT * FROM tbl") self.validate_identity("x -> '$.family'") + self.validate_identity("CREATE TABLE color (name ENUM('RED', 'GREEN', 'BLUE'))") + self.validate_identity( + "SELECT * FROM x LEFT JOIN UNNEST(y)", "SELECT * FROM x LEFT JOIN UNNEST(y) ON TRUE" + ) self.validate_identity( """SELECT '{"foo": [1, 2, 3]}' -> 'foo' -> 0""", """SELECT '{"foo": [1, 2, 3]}' -> '$.foo' -> '$[0]'""", @@ -623,6 +644,27 @@ class TestDuckDB(Validator): }, ) + self.validate_identity("SELECT * FROM RANGE(1, 5, 10)") + self.validate_identity("SELECT * FROM GENERATE_SERIES(2, 13, 4)") + + self.validate_all( + "WITH t AS (SELECT i, i * i * i * i * i AS i5 FROM RANGE(1, 5) t(i)) SELECT * FROM t", + write={ + "duckdb": "WITH t AS (SELECT i, i * i * i * i * i AS i5 FROM RANGE(1, 5) AS t(i)) SELECT * FROM t", + "sqlite": "WITH t AS (SELECT i, i * i * i * i * i AS i5 FROM (SELECT value AS i FROM GENERATE_SERIES(1, 5)) AS t) SELECT * FROM t", + }, + ) + + self.validate_identity( + """SELECT i FROM RANGE(5) AS _(i) ORDER BY i ASC""", + """SELECT i FROM RANGE(0, 5) AS _(i) ORDER BY i ASC""", + ) + + self.validate_identity( + """SELECT i FROM GENERATE_SERIES(12) AS _(i) ORDER BY i ASC""", + """SELECT i FROM GENERATE_SERIES(0, 12) AS _(i) ORDER BY i ASC""", + ) + def test_array_index(self): with self.assertLogs(helper_logger) as cm: self.validate_all( @@ -994,3 +1036,10 @@ class TestDuckDB(Validator): read={"bigquery": "IS_INF(x)"}, write={"bigquery": "IS_INF(x)", "duckdb": "ISINF(x)"}, ) + + def test_parameter_token(self): + self.validate_all( + "SELECT $foo", + read={"bigquery": "SELECT @foo"}, + write={"bigquery": "SELECT @foo", "duckdb": "SELECT $foo"}, + ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index ea28f29..b892dd6 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -440,6 +440,9 @@ class TestHive(Validator): self.validate_identity( "SELECT key, value, GROUPING__ID, COUNT(*) FROM T1 GROUP BY key, value WITH ROLLUP" ) + self.validate_identity( + "TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', address = 'abc')" + ) self.validate_all( "SELECT ${hiveconf:some_var}", @@ -611,12 +614,6 @@ class TestHive(Validator): "spark": "GET_JSON_OBJECT(x, '$.name')", }, ) - self.validate_all( - "STRUCT(a = b, c = d)", - read={ - "snowflake": "OBJECT_CONSTRUCT(a, b, c, d)", - }, - ) self.validate_all( "MAP(a, b, c, d)", read={ diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index fd27a1e..5f23c44 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -29,6 +29,7 @@ class TestMySQL(Validator): self.validate_identity("CREATE TABLE foo (a BIGINT, INDEX USING BTREE (b))") self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))") self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))") + self.validate_identity("ALTER TABLE t1 ADD COLUMN x INT, ALGORITHM=INPLACE, LOCK=EXCLUSIVE") self.validate_identity( "CREATE TABLE `oauth_consumer` (`key` VARCHAR(32) NOT NULL, UNIQUE `OAUTH_CONSUMER_KEY` (`key`))" ) @@ -68,6 +69,26 @@ class TestMySQL(Validator): self.validate_identity( "CREATE OR REPLACE VIEW my_view AS SELECT column1 AS `boo`, column2 AS `foo` FROM my_table WHERE column3 = 'some_value' UNION SELECT q.* FROM fruits_table, JSON_TABLE(Fruits, '$[*]' COLUMNS(id VARCHAR(255) PATH '$.$id', value VARCHAR(255) PATH '$.value')) AS q", ) + self.validate_identity( + "CREATE TABLE `foo` (`id` char(36) NOT NULL DEFAULT (uuid()), PRIMARY KEY (`id`), UNIQUE KEY `id` (`id`))", + "CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))", + ) + self.validate_identity( + "CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE KEY d (b), KEY e (b))", + "CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE d (b), INDEX e (b))", + ) + self.validate_identity( + "CREATE TABLE test (ts TIMESTAMP, ts_tz TIMESTAMPTZ, ts_ltz TIMESTAMPLTZ)", + "CREATE TABLE test (ts DATETIME, ts_tz TIMESTAMP, ts_ltz TIMESTAMP)", + ) + self.validate_identity( + "ALTER TABLE test_table ALTER COLUMN test_column SET DATA TYPE LONGTEXT", + "ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT", + ) + self.validate_identity( + "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP) DEFAULT CHARSET=utf8 ROW_FORMAT=DYNAMIC", + "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC", + ) self.validate_all( "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'", @@ -78,12 +99,6 @@ class TestMySQL(Validator): "sqlite": "CREATE TABLE z (a INTEGER)", }, ) - self.validate_all( - "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP) DEFAULT CHARSET=utf8 ROW_FORMAT=DYNAMIC", - write={ - "mysql": "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC", - }, - ) self.validate_all( "CREATE TABLE x (id int not null auto_increment, primary key (id))", write={ @@ -96,33 +111,9 @@ class TestMySQL(Validator): "sqlite": "CREATE TABLE x (id INTEGER NOT NULL)", }, ) - self.validate_all( - "CREATE TABLE `foo` (`id` char(36) NOT NULL DEFAULT (uuid()), PRIMARY KEY (`id`), UNIQUE KEY `id` (`id`))", - write={ - "mysql": "CREATE TABLE `foo` (`id` CHAR(36) NOT NULL DEFAULT (UUID()), PRIMARY KEY (`id`), UNIQUE `id` (`id`))", - }, - ) - self.validate_all( - "CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE KEY d (b), KEY e (b))", - write={ - "mysql": "CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE d (b), INDEX e (b))", - }, - ) - self.validate_all( - "CREATE TABLE test (ts TIMESTAMP, ts_tz TIMESTAMPTZ, ts_ltz TIMESTAMPLTZ)", - write={ - "mysql": "CREATE TABLE test (ts DATETIME, ts_tz TIMESTAMP, ts_ltz TIMESTAMP)", - }, - ) - self.validate_all( - "ALTER TABLE test_table ALTER COLUMN test_column SET DATA TYPE LONGTEXT", - write={ - "mysql": "ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT", - }, - ) - self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1") def test_identity(self): + self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1") self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')") self.validate_identity("SELECT @var1 := 1, @var2") self.validate_identity("UNLOCK TABLES") diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index bc8f8bb..9438507 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -1,4 +1,4 @@ -from sqlglot import exp, parse_one +from sqlglot import exp from sqlglot.errors import UnsupportedError from tests.dialects.test_dialect import Validator @@ -7,11 +7,18 @@ class TestOracle(Validator): dialect = "oracle" def test_oracle(self): - self.validate_identity("REGEXP_REPLACE('source', 'search')") - parse_one("ALTER TABLE tbl_name DROP FOREIGN KEY fk_symbol", dialect="oracle").assert_is( - exp.AlterTable + self.validate_all( + "SELECT CONNECT_BY_ROOT x y", + write={ + "": "SELECT CONNECT_BY_ROOT(x) AS y", + "oracle": "SELECT CONNECT_BY_ROOT x AS y", + }, ) + self.parse_one("ALTER TABLE tbl_name DROP FOREIGN KEY fk_symbol").assert_is(exp.AlterTable) + self.validate_identity("CREATE GLOBAL TEMPORARY TABLE t AS SELECT * FROM orders") + self.validate_identity("CREATE PRIVATE TEMPORARY TABLE t AS SELECT * FROM orders") + self.validate_identity("REGEXP_REPLACE('source', 'search')") self.validate_identity("TIMESTAMP(3) WITH TIME ZONE") self.validate_identity("CURRENT_TIMESTAMP(precision)") self.validate_identity("ALTER TABLE tbl_name DROP FOREIGN KEY fk_symbol") @@ -88,6 +95,13 @@ class TestOracle(Validator): ) self.validate_identity("SELECT TO_CHAR(-100, 'L99', 'NL_CURRENCY = '' AusDollars '' ')") + self.validate_all( + "TO_CHAR(x)", + write={ + "doris": "CAST(x AS STRING)", + "oracle": "TO_CHAR(x)", + }, + ) self.validate_all( "SELECT TO_CHAR(TIMESTAMP '1999-12-01 10:00:00')", write={ diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index d1ecb2a..1d0ea8b 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -8,8 +8,10 @@ class TestPostgres(Validator): dialect = "postgres" def test_postgres(self): + self.validate_identity("1.x", "1. AS x") self.validate_identity("|/ x", "SQRT(x)") self.validate_identity("||/ x", "CBRT(x)") + expr = parse_one( "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", read="postgres" ) @@ -82,6 +84,7 @@ class TestPostgres(Validator): self.validate_identity("CAST(1 AS DECIMAL) / CAST(2 AS DECIMAL) * -100") self.validate_identity("EXEC AS myfunc @id = 123", check_command_warning=True) self.validate_identity("SELECT CURRENT_USER") + self.validate_identity("SELECT * FROM ONLY t1") self.validate_identity( """LAST_VALUE("col1") OVER (ORDER BY "col2" RANGE BETWEEN INTERVAL '1 DAY' PRECEDING AND '1 month' FOLLOWING)""" ) @@ -163,6 +166,9 @@ class TestPostgres(Validator): "SELECT $$Dianne's horse$$", "SELECT 'Dianne''s horse'", ) + self.validate_identity( + "COMMENT ON TABLE mytable IS $$doc this$$", "COMMENT ON TABLE mytable IS 'doc this'" + ) self.validate_identity( "UPDATE MYTABLE T1 SET T1.COL = 13", "UPDATE MYTABLE AS T1 SET T1.COL = 13", @@ -320,6 +326,7 @@ class TestPostgres(Validator): "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)", "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)", ) + self.validate_identity("SELECT * FROM t1*", "SELECT * FROM t1") self.validate_all( "SELECT JSON_EXTRACT_PATH_TEXT(x, k1, k2, k3) FROM t", @@ -653,6 +660,12 @@ class TestPostgres(Validator): self.validate_identity("CREATE TABLE t (c CHAR(2) UNIQUE NOT NULL) INHERITS (t1)") self.validate_identity("CREATE TABLE s.t (c CHAR(2) UNIQUE NOT NULL) INHERITS (s.t1, s.t2)") self.validate_identity("CREATE FUNCTION x(INT) RETURNS INT SET search_path = 'public'") + self.validate_identity("TRUNCATE TABLE t1 CONTINUE IDENTITY") + self.validate_identity("TRUNCATE TABLE t1 RESTART IDENTITY") + self.validate_identity("TRUNCATE TABLE t1 CASCADE") + self.validate_identity("TRUNCATE TABLE t1 RESTRICT") + self.validate_identity("TRUNCATE TABLE t1 CONTINUE IDENTITY CASCADE") + self.validate_identity("TRUNCATE TABLE t1 RESTART IDENTITY RESTRICT") self.validate_identity( "CREATE TABLE cust_part3 PARTITION OF customers FOR VALUES WITH (MODULUS 3, REMAINDER 2)" ) @@ -785,6 +798,10 @@ class TestPostgres(Validator): self.validate_identity( "CREATE INDEX index_ci_pipelines_on_project_idandrefandiddesc ON public.ci_pipelines USING btree(project_id, ref, id DESC)" ) + self.validate_identity( + "TRUNCATE TABLE ONLY t1, t2*, ONLY t3, t4, t5* RESTART IDENTITY CASCADE", + "TRUNCATE TABLE ONLY t1, t2, ONLY t3, t4, t5 RESTART IDENTITY CASCADE", + ) with self.assertRaises(ParseError): transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres") @@ -911,3 +928,31 @@ class TestPostgres(Validator): """See https://github.com/tobymao/sqlglot/pull/2404 for details.""" self.assertIsInstance(parse_one("'thomas' ~ '.*thomas.*'", read="postgres"), exp.Binary) self.assertIsInstance(parse_one("'thomas' ~* '.*thomas.*'", read="postgres"), exp.Binary) + + def test_unnest_json_array(self): + trino_input = """ + WITH t(boxcrate) AS ( + SELECT JSON '[{"boxes": [{"name": "f1", "type": "plant", "color": "red"}]}]' + ) + SELECT + JSON_EXTRACT_SCALAR(boxes,'$.name') AS name, + JSON_EXTRACT_SCALAR(boxes,'$.type') AS type, + JSON_EXTRACT_SCALAR(boxes,'$.color') AS color + FROM t + CROSS JOIN UNNEST(CAST(boxcrate AS array(json))) AS x(tbox) + CROSS JOIN UNNEST(CAST(json_extract(tbox, '$.boxes') AS array(json))) AS y(boxes) + """ + + expected_postgres = """WITH t(boxcrate) AS ( + SELECT + CAST('[{"boxes": [{"name": "f1", "type": "plant", "color": "red"}]}]' AS JSON) +) +SELECT + JSON_EXTRACT_PATH_TEXT(boxes, 'name') AS name, + JSON_EXTRACT_PATH_TEXT(boxes, 'type') AS type, + JSON_EXTRACT_PATH_TEXT(boxes, 'color') AS color +FROM t +CROSS JOIN JSON_ARRAY_ELEMENTS(CAST(boxcrate AS JSON)) AS x(tbox) +CROSS JOIN JSON_ARRAY_ELEMENTS(CAST(JSON_EXTRACT_PATH(tbox, 'boxes') AS JSON)) AS y(boxes)""" + + self.validate_all(expected_postgres, read={"trino": trino_input}, pretty=True) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index d3d1a76..2ea595e 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -647,6 +647,7 @@ class TestPresto(Validator): """JSON '"foo"'""", write={ "bigquery": """PARSE_JSON('"foo"')""", + "postgres": """CAST('"foo"' AS JSON)""", "presto": """JSON_PARSE('"foo"')""", "snowflake": """PARSE_JSON('"foo"')""", }, @@ -1142,3 +1143,18 @@ MATCH_RECOGNIZE ( "presto": "DATE_FORMAT(ts, '%y')", }, ) + + def test_signum(self): + self.validate_all( + "SIGN(x)", + read={ + "presto": "SIGN(x)", + "spark": "SIGNUM(x)", + "starrocks": "SIGN(x)", + }, + write={ + "presto": "SIGN(x)", + "spark": "SIGN(x)", + "starrocks": "SIGN(x)", + }, + ) diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 33cfa0c..506f429 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -515,6 +515,11 @@ FROM ( ) def test_column_unnesting(self): + self.validate_identity("SELECT c.*, o FROM bloo AS c, c.c_orders AS o") + self.validate_identity( + "SELECT c.*, o, l FROM bloo AS c, c.c_orders AS o, o.o_lineitems AS l" + ) + ast = parse_one("SELECT * FROM t.t JOIN t.c1 ON c1.c2 = t.c3", read="redshift") ast.args["from"].this.assert_is(exp.Table) ast.args["joins"][0].this.assert_is(exp.Table) @@ -522,7 +527,7 @@ FROM ( ast = parse_one("SELECT * FROM t AS t CROSS JOIN t.c1", read="redshift") ast.args["from"].this.assert_is(exp.Table) - ast.args["joins"][0].this.assert_is(exp.Column) + ast.args["joins"][0].this.assert_is(exp.Unnest) self.assertEqual(ast.sql("redshift"), "SELECT * FROM t AS t CROSS JOIN t.c1") ast = parse_one( @@ -530,9 +535,9 @@ FROM ( ) joins = ast.args["joins"] ast.args["from"].this.assert_is(exp.Table) - joins[0].this.this.assert_is(exp.Column) - joins[1].this.this.assert_is(exp.Column) - joins[2].this.this.assert_is(exp.Dot) + joins[0].this.assert_is(exp.Unnest) + joins[1].this.assert_is(exp.Unnest) + joins[2].this.assert_is(exp.Unnest).expressions[0].assert_is(exp.Dot) self.assertEqual( ast.sql("redshift"), "SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l" ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 4e4feb3..e48f811 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -40,6 +40,7 @@ WHERE )""", ) + self.validate_identity("ALTER TABLE authors ADD CONSTRAINT c1 UNIQUE (id, email)") self.validate_identity("RM @parquet_stage", check_command_warning=True) self.validate_identity("REMOVE @parquet_stage", check_command_warning=True) self.validate_identity("SELECT TIMESTAMP_FROM_PARTS(d, t)") @@ -84,6 +85,7 @@ WHERE self.validate_identity( "SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE (0.1)" ) + self.validate_identity("x:from", "GET_PATH(x, 'from')") self.validate_identity( "value:values::string", "CAST(GET_PATH(value, 'values') AS TEXT)", @@ -371,15 +373,17 @@ WHERE write={"snowflake": "SELECT * FROM (VALUES (0)) AS foo(bar)"}, ) self.validate_all( - "OBJECT_CONSTRUCT(a, b, c, d)", + "OBJECT_CONSTRUCT('a', b, 'c', d)", read={ - "": "STRUCT(a as b, c as d)", + "": "STRUCT(b as a, d as c)", }, write={ "duckdb": "{'a': b, 'c': d}", - "snowflake": "OBJECT_CONSTRUCT(a, b, c, d)", + "snowflake": "OBJECT_CONSTRUCT('a', b, 'c', d)", }, ) + self.validate_identity("OBJECT_CONSTRUCT(a, b, c, d)") + self.validate_all( "SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1", write={ @@ -1461,26 +1465,22 @@ MATCH_RECOGNIZE ( pretty=True, ) - def test_show(self): - # Parsed as Show - self.validate_identity("SHOW PRIMARY KEYS") - self.validate_identity("SHOW PRIMARY KEYS IN ACCOUNT") - self.validate_identity("SHOW PRIMARY KEYS IN DATABASE") - self.validate_identity("SHOW PRIMARY KEYS IN DATABASE foo") - self.validate_identity("SHOW PRIMARY KEYS IN TABLE") - self.validate_identity("SHOW PRIMARY KEYS IN TABLE foo") - self.validate_identity( - 'SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."customers"', - 'SHOW PRIMARY KEYS IN TABLE "TEST"."PUBLIC"."customers"', - ) - self.validate_identity( - 'SHOW TERSE PRIMARY KEYS IN "TEST"."PUBLIC"."customers"', - 'SHOW PRIMARY KEYS IN TABLE "TEST"."PUBLIC"."customers"', - ) + def test_show_users(self): + self.validate_identity("SHOW USERS") + self.validate_identity("SHOW TERSE USERS") + self.validate_identity("SHOW USERS LIKE '_foo%' STARTS WITH 'bar' LIMIT 5 FROM 'baz'") + + def test_show_schemas(self): self.validate_identity( "show terse schemas in database db1 starts with 'a' limit 10 from 'b'", "SHOW TERSE SCHEMAS IN DATABASE db1 STARTS WITH 'a' LIMIT 10 FROM 'b'", ) + + ast = parse_one("SHOW SCHEMAS IN DATABASE db1", read="snowflake") + self.assertEqual(ast.args.get("scope_kind"), "DATABASE") + self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), "db1") + + def test_show_objects(self): self.validate_identity( "show terse objects in schema db1.schema1 starts with 'a' limit 10 from 'b'", "SHOW TERSE OBJECTS IN SCHEMA db1.schema1 STARTS WITH 'a' LIMIT 10 FROM 'b'", @@ -1489,6 +1489,23 @@ MATCH_RECOGNIZE ( "show terse objects in db1.schema1 starts with 'a' limit 10 from 'b'", "SHOW TERSE OBJECTS IN SCHEMA db1.schema1 STARTS WITH 'a' LIMIT 10 FROM 'b'", ) + + ast = parse_one("SHOW OBJECTS IN db1.schema1", read="snowflake") + self.assertEqual(ast.args.get("scope_kind"), "SCHEMA") + self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), "db1.schema1") + + def test_show_columns(self): + self.validate_identity("SHOW COLUMNS") + self.validate_identity("SHOW COLUMNS IN TABLE dt_test") + self.validate_identity("SHOW COLUMNS LIKE '_foo%' IN TABLE dt_test") + self.validate_identity("SHOW COLUMNS IN VIEW") + self.validate_identity("SHOW COLUMNS LIKE '_foo%' IN VIEW dt_test") + + ast = parse_one("SHOW COLUMNS LIKE '_testing%' IN dt_test", read="snowflake") + self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), "dt_test") + self.assertEqual(ast.find(exp.Literal).sql(dialect="snowflake"), "'_testing%'") + + def test_show_tables(self): self.validate_identity( "SHOW TABLES LIKE 'line%' IN tpch.public", "SHOW TABLES LIKE 'line%' IN SCHEMA tpch.public", @@ -1506,47 +1523,97 @@ MATCH_RECOGNIZE ( "SHOW TERSE TABLES IN SCHEMA db1.schema1 STARTS WITH 'a' LIMIT 10 FROM 'b'", ) - ast = parse_one('SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."customers"', read="snowflake") - table = ast.find(exp.Table) + ast = parse_one("SHOW TABLES IN db1.schema1", read="snowflake") + self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), "db1.schema1") + + def test_show_primary_keys(self): + self.validate_identity("SHOW PRIMARY KEYS") + self.validate_identity("SHOW PRIMARY KEYS IN ACCOUNT") + self.validate_identity("SHOW PRIMARY KEYS IN DATABASE") + self.validate_identity("SHOW PRIMARY KEYS IN DATABASE foo") + self.validate_identity("SHOW PRIMARY KEYS IN TABLE") + self.validate_identity("SHOW PRIMARY KEYS IN TABLE foo") + self.validate_identity( + 'SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."foo"', + 'SHOW PRIMARY KEYS IN TABLE "TEST"."PUBLIC"."foo"', + ) + self.validate_identity( + 'SHOW TERSE PRIMARY KEYS IN "TEST"."PUBLIC"."foo"', + 'SHOW PRIMARY KEYS IN TABLE "TEST"."PUBLIC"."foo"', + ) - self.assertEqual(table.sql(dialect="snowflake"), '"TEST"."PUBLIC"."customers"') + ast = parse_one('SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."foo"', read="snowflake") + self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), '"TEST"."PUBLIC"."foo"') - self.validate_identity("SHOW COLUMNS") - self.validate_identity("SHOW COLUMNS IN TABLE dt_test") - self.validate_identity("SHOW COLUMNS LIKE '_foo%' IN TABLE dt_test") - self.validate_identity("SHOW COLUMNS IN VIEW") - self.validate_identity("SHOW COLUMNS LIKE '_foo%' IN VIEW dt_test") + def test_show_views(self): + self.validate_identity("SHOW TERSE VIEWS") + self.validate_identity("SHOW VIEWS") + self.validate_identity("SHOW VIEWS LIKE 'foo%'") + self.validate_identity("SHOW VIEWS IN ACCOUNT") + self.validate_identity("SHOW VIEWS IN DATABASE") + self.validate_identity("SHOW VIEWS IN DATABASE foo") + self.validate_identity("SHOW VIEWS IN SCHEMA foo") + self.validate_identity( + "SHOW VIEWS IN foo", + "SHOW VIEWS IN SCHEMA foo", + ) - self.validate_identity("SHOW USERS") - self.validate_identity("SHOW TERSE USERS") - self.validate_identity("SHOW USERS LIKE '_foo%' STARTS WITH 'bar' LIMIT 5 FROM 'baz'") + ast = parse_one("SHOW VIEWS IN db1.schema1", read="snowflake") + self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), "db1.schema1") - ast = parse_one("SHOW COLUMNS LIKE '_testing%' IN dt_test", read="snowflake") - table = ast.find(exp.Table) - literal = ast.find(exp.Literal) + def test_show_unique_keys(self): + self.validate_identity("SHOW UNIQUE KEYS") + self.validate_identity("SHOW UNIQUE KEYS IN ACCOUNT") + self.validate_identity("SHOW UNIQUE KEYS IN DATABASE") + self.validate_identity("SHOW UNIQUE KEYS IN DATABASE foo") + self.validate_identity("SHOW UNIQUE KEYS IN TABLE") + self.validate_identity("SHOW UNIQUE KEYS IN TABLE foo") + self.validate_identity( + 'SHOW UNIQUE KEYS IN "TEST"."PUBLIC"."foo"', + 'SHOW UNIQUE KEYS IN SCHEMA "TEST"."PUBLIC"."foo"', + ) + self.validate_identity( + 'SHOW TERSE UNIQUE KEYS IN "TEST"."PUBLIC"."foo"', + 'SHOW UNIQUE KEYS IN SCHEMA "TEST"."PUBLIC"."foo"', + ) - self.assertEqual(table.sql(dialect="snowflake"), "dt_test") + ast = parse_one('SHOW UNIQUE KEYS IN "TEST"."PUBLIC"."foo"', read="snowflake") + self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), '"TEST"."PUBLIC"."foo"') - self.assertEqual(literal.sql(dialect="snowflake"), "'_testing%'") + def test_show_imported_keys(self): + self.validate_identity("SHOW IMPORTED KEYS") + self.validate_identity("SHOW IMPORTED KEYS IN ACCOUNT") + self.validate_identity("SHOW IMPORTED KEYS IN DATABASE") + self.validate_identity("SHOW IMPORTED KEYS IN DATABASE foo") + self.validate_identity("SHOW IMPORTED KEYS IN TABLE") + self.validate_identity("SHOW IMPORTED KEYS IN TABLE foo") + self.validate_identity( + 'SHOW IMPORTED KEYS IN "TEST"."PUBLIC"."foo"', + 'SHOW IMPORTED KEYS IN SCHEMA "TEST"."PUBLIC"."foo"', + ) + self.validate_identity( + 'SHOW TERSE IMPORTED KEYS IN "TEST"."PUBLIC"."foo"', + 'SHOW IMPORTED KEYS IN SCHEMA "TEST"."PUBLIC"."foo"', + ) - ast = parse_one("SHOW SCHEMAS IN DATABASE db1", read="snowflake") - self.assertEqual(ast.args.get("scope_kind"), "DATABASE") - table = ast.find(exp.Table) - self.assertEqual(table.sql(dialect="snowflake"), "db1") + ast = parse_one('SHOW IMPORTED KEYS IN "TEST"."PUBLIC"."foo"', read="snowflake") + self.assertEqual(ast.find(exp.Table).sql(dialect="snowflake"), '"TEST"."PUBLIC"."foo"') - ast = parse_one("SHOW OBJECTS IN db1.schema1", read="snowflake") - self.assertEqual(ast.args.get("scope_kind"), "SCHEMA") - table = ast.find(exp.Table) - self.assertEqual(table.sql(dialect="snowflake"), "db1.schema1") + def test_show_sequences(self): + self.validate_identity("SHOW TERSE SEQUENCES") + self.validate_identity("SHOW SEQUENCES") + self.validate_identity("SHOW SEQUENCES LIKE '_foo%' IN ACCOUNT") + self.validate_identity("SHOW SEQUENCES LIKE '_foo%' IN DATABASE") + self.validate_identity("SHOW SEQUENCES LIKE '_foo%' IN DATABASE foo") + self.validate_identity("SHOW SEQUENCES LIKE '_foo%' IN SCHEMA") + self.validate_identity("SHOW SEQUENCES LIKE '_foo%' IN SCHEMA foo") + self.validate_identity( + "SHOW SEQUENCES LIKE '_foo%' IN foo", + "SHOW SEQUENCES LIKE '_foo%' IN SCHEMA foo", + ) - ast = parse_one("SHOW TABLES IN db1.schema1", read="snowflake") + ast = parse_one("SHOW SEQUENCES IN dt_test", read="snowflake") self.assertEqual(ast.args.get("scope_kind"), "SCHEMA") - table = ast.find(exp.Table) - self.assertEqual(table.sql(dialect="snowflake"), "db1.schema1") - - users_exp = self.validate_identity("SHOW USERS") - self.assertTrue(isinstance(users_exp, exp.Show)) - self.assertEqual(users_exp.this, "USERS") def test_storage_integration(self): self.validate_identity( diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 196735b..1cf1ede 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -16,6 +16,7 @@ class TestSpark(Validator): self.validate_identity( "CREATE TABLE foo (col STRING) CLUSTERED BY (col) SORTED BY (col) INTO 10 BUCKETS" ) + self.validate_identity("TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', address)") self.validate_all( "CREATE TABLE db.example_table (col_a struct)", diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index f7a3dd7..2421987 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -1,5 +1,7 @@ from tests.dialects.test_dialect import Validator +from sqlglot.helper import logger as helper_logger + class TestSQLite(Validator): dialect = "sqlite" @@ -76,6 +78,7 @@ class TestSQLite(Validator): self.validate_identity( """SELECT item AS "item", some AS "some" FROM data WHERE (item = 'value_1' COLLATE NOCASE) AND (some = 't' COLLATE NOCASE) ORDER BY item ASC LIMIT 1 OFFSET 0""" ) + self.validate_identity("SELECT * FROM GENERATE_SERIES(1, 5)") self.validate_all("SELECT LIKE(y, x)", write={"sqlite": "SELECT x LIKE y"}) self.validate_all("SELECT GLOB('*y*', 'xyz')", write={"sqlite": "SELECT 'xyz' GLOB '*y*'"}) @@ -178,3 +181,12 @@ class TestSQLite(Validator): "CREATE TABLE foo (bar LONGVARCHAR)", write={"sqlite": "CREATE TABLE foo (bar TEXT)"}, ) + + def test_warnings(self): + with self.assertLogs(helper_logger) as cm: + self.validate_identity( + "SELECT * FROM t AS t(c1, c2)", + "SELECT * FROM t AS t", + ) + + self.assertIn("Named columns are not supported in table alias.", cm.output[0]) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index a304a9e..ed474fd 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1,6 +1,7 @@ from sqlglot import exp, parse, parse_one from sqlglot.parser import logger as parser_logger from tests.dialects.test_dialect import Validator +from sqlglot.errors import ParseError class TestTSQL(Validator): @@ -27,6 +28,7 @@ class TestTSQL(Validator): self.validate_identity("SELECT * FROM t WHERE NOT c", "SELECT * FROM t WHERE NOT c <> 0") self.validate_identity("1 AND true", "1 <> 0 AND (1 = 1)") self.validate_identity("CAST(x AS int) OR y", "CAST(x AS INTEGER) <> 0 OR y <> 0") + self.validate_identity("TRUNCATE TABLE t1 WITH (PARTITIONS(1, 2 TO 5, 10 TO 20, 84))") self.validate_all( "SELECT IIF(cond <> 0, 'True', 'False')", @@ -142,7 +144,7 @@ class TestTSQL(Validator): "tsql": "CREATE TABLE #mytemptable (a INTEGER)", "snowflake": "CREATE TEMPORARY TABLE mytemptable (a INT)", "duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)", - "oracle": "CREATE TEMPORARY TABLE mytemptable (a NUMBER)", + "oracle": "CREATE GLOBAL TEMPORARY TABLE mytemptable (a NUMBER)", "hive": "CREATE TEMPORARY TABLE mytemptable (a INT)", "spark2": "CREATE TEMPORARY TABLE mytemptable (a INT) USING PARQUET", "spark": "CREATE TEMPORARY TABLE mytemptable (a INT) USING PARQUET", @@ -281,7 +283,7 @@ class TestTSQL(Validator): "CONVERT(INT, CONVERT(NUMERIC, '444.75'))", write={ "mysql": "CAST(CAST('444.75' AS DECIMAL) AS SIGNED)", - "tsql": "CAST(CAST('444.75' AS NUMERIC) AS INTEGER)", + "tsql": "CONVERT(INTEGER, CONVERT(NUMERIC, '444.75'))", }, ) self.validate_all( @@ -356,6 +358,76 @@ class TestTSQL(Validator): self.validate_identity("HASHBYTES('MD2', 'x')") self.validate_identity("LOG(n, b)") + def test_option(self): + possible_options = [ + "HASH GROUP", + "ORDER GROUP", + "CONCAT UNION", + "HASH UNION", + "MERGE UNION", + "LOOP JOIN", + "MERGE JOIN", + "HASH JOIN", + "DISABLE_OPTIMIZED_PLAN_FORCING", + "EXPAND VIEWS", + "FAST 15", + "FORCE ORDER", + "FORCE EXTERNALPUSHDOWN", + "DISABLE EXTERNALPUSHDOWN", + "FORCE SCALEOUTEXECUTION", + "DISABLE SCALEOUTEXECUTION", + "IGNORE_NONCLUSTERED_COLUMNSTORE_INDEX", + "KEEP PLAN", + "KEEPFIXED PLAN", + "MAX_GRANT_PERCENT = 5", + "MIN_GRANT_PERCENT = 10", + "MAXDOP 13", + "MAXRECURSION 8", + "NO_PERFORMANCE_SPOOL", + "OPTIMIZE FOR UNKNOWN", + "PARAMETERIZATION SIMPLE", + "PARAMETERIZATION FORCED", + "QUERYTRACEON 99", + "RECOMPILE", + "ROBUST PLAN", + "USE PLAN N''", + "LABEL = 'MyLabel'", + ] + + possible_statements = [ + # These should be un-commented once support for the OPTION clause is added for DELETE, MERGE and UPDATE + # "DELETE FROM Table1", + # "MERGE INTO Locations AS T USING locations_stage AS S ON T.LocationID = S.LocationID WHEN MATCHED THEN UPDATE SET LocationName = S.LocationName", + # "UPDATE Customers SET ContactName = 'Alfred Schmidt', City = 'Frankfurt' WHERE CustomerID = 1", + "SELECT * FROM Table1", + "SELECT * FROM Table1 WHERE id = 2", + ] + + for statement in possible_statements: + for option in possible_options: + query = f"{statement} OPTION({option})" + result = self.validate_identity(query) + options = result.args.get("options") + self.assertIsInstance(options, list, f"When parsing query {query}") + is_query_options = map(lambda o: isinstance(o, exp.QueryOption), options) + self.assertTrue(all(is_query_options), f"When parsing query {query}") + + self.validate_identity( + f"{statement} OPTION(RECOMPILE, USE PLAN N'', MAX_GRANT_PERCENT = 5)" + ) + + raising_queries = [ + # Missing parentheses + "SELECT * FROM Table1 OPTION HASH GROUP", + # Must be followed by 'PLAN" + "SELECT * FROM Table1 OPTION(KEEPFIXED)", + # Missing commas + "SELECT * FROM Table1 OPTION(HASH GROUP HASH GROUP)", + ] + for query in raising_queries: + with self.assertRaises(ParseError, msg=f"When running '{query}'"): + self.parse_one(query) + def test_types(self): self.validate_identity("CAST(x AS XML)") self.validate_identity("CAST(x AS UNIQUEIDENTIFIER)") @@ -525,7 +597,7 @@ class TestTSQL(Validator): "CAST(x as NCHAR(1))", write={ "spark": "CAST(x AS CHAR(1))", - "tsql": "CAST(x AS CHAR(1))", + "tsql": "CAST(x AS NCHAR(1))", }, ) @@ -533,7 +605,7 @@ class TestTSQL(Validator): "CAST(x as NVARCHAR(2))", write={ "spark": "CAST(x AS VARCHAR(2))", - "tsql": "CAST(x AS VARCHAR(2))", + "tsql": "CAST(x AS NVARCHAR(2))", }, ) @@ -692,12 +764,7 @@ class TestTSQL(Validator): "SELECT * INTO foo.bar.baz FROM (SELECT * FROM a.b.c) AS temp", read={ "": "CREATE TABLE foo.bar.baz AS SELECT * FROM a.b.c", - }, - ) - self.validate_all( - "SELECT * INTO foo.bar.baz FROM (SELECT * FROM a.b.c) AS temp", - read={ - "": "CREATE TABLE foo.bar.baz AS (SELECT * FROM a.b.c)", + "duckdb": "CREATE TABLE foo.bar.baz AS (SELECT * FROM a.b.c)", }, ) self.validate_all( @@ -759,11 +826,6 @@ class TestTSQL(Validator): ) def test_transaction(self): - # BEGIN { TRAN | TRANSACTION } - # [ { transaction_name | @tran_name_variable } - # [ WITH MARK [ 'description' ] ] - # ] - # [ ; ] self.validate_identity("BEGIN TRANSACTION") self.validate_all("BEGIN TRAN", write={"tsql": "BEGIN TRANSACTION"}) self.validate_identity("BEGIN TRANSACTION transaction_name") @@ -771,8 +833,6 @@ class TestTSQL(Validator): self.validate_identity("BEGIN TRANSACTION transaction_name WITH MARK 'description'") def test_commit(self): - # COMMIT [ { TRAN | TRANSACTION } [ transaction_name | @tran_name_variable ] ] [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] [ ; ] - self.validate_all("COMMIT", write={"tsql": "COMMIT TRANSACTION"}) self.validate_all("COMMIT TRAN", write={"tsql": "COMMIT TRANSACTION"}) self.validate_identity("COMMIT TRANSACTION") @@ -787,11 +847,6 @@ class TestTSQL(Validator): ) def test_rollback(self): - # Applies to SQL Server and Azure SQL Database - # ROLLBACK { TRAN | TRANSACTION } - # [ transaction_name | @tran_name_variable - # | savepoint_name | @savepoint_variable ] - # [ ; ] self.validate_all("ROLLBACK", write={"tsql": "ROLLBACK TRANSACTION"}) self.validate_all("ROLLBACK TRAN", write={"tsql": "ROLLBACK TRANSACTION"}) self.validate_identity("ROLLBACK TRANSACTION") @@ -911,7 +966,7 @@ WHERE expected_sqls = [ "CREATE PROC [dbo].[transform_proc] AS DECLARE @CurrentDate VARCHAR(20)", - "SET @CurrentDate = CAST(FORMAT(GETDATE(), 'yyyy-MM-dd HH:mm:ss') AS VARCHAR(20))", + "SET @CurrentDate = CONVERT(VARCHAR(20), GETDATE(), 120)", "CREATE TABLE [target_schema].[target_table] (a INTEGER) WITH (DISTRIBUTION=REPLICATE, HEAP)", ] @@ -1090,155 +1145,173 @@ WHERE }, ) - def test_convert_date_format(self): + def test_convert(self): self.validate_all( "CONVERT(NVARCHAR(200), x)", write={ "spark": "CAST(x AS VARCHAR(200))", + "tsql": "CONVERT(NVARCHAR(200), x)", }, ) self.validate_all( "CONVERT(NVARCHAR, x)", write={ "spark": "CAST(x AS VARCHAR(30))", + "tsql": "CONVERT(NVARCHAR, x)", }, ) self.validate_all( "CONVERT(NVARCHAR(MAX), x)", write={ "spark": "CAST(x AS STRING)", + "tsql": "CONVERT(NVARCHAR(MAX), x)", }, ) self.validate_all( "CONVERT(VARCHAR(200), x)", write={ "spark": "CAST(x AS VARCHAR(200))", + "tsql": "CONVERT(VARCHAR(200), x)", }, ) self.validate_all( "CONVERT(VARCHAR, x)", write={ "spark": "CAST(x AS VARCHAR(30))", + "tsql": "CONVERT(VARCHAR, x)", }, ) self.validate_all( "CONVERT(VARCHAR(MAX), x)", write={ "spark": "CAST(x AS STRING)", + "tsql": "CONVERT(VARCHAR(MAX), x)", }, ) self.validate_all( "CONVERT(CHAR(40), x)", write={ "spark": "CAST(x AS CHAR(40))", + "tsql": "CONVERT(CHAR(40), x)", }, ) self.validate_all( "CONVERT(CHAR, x)", write={ "spark": "CAST(x AS CHAR(30))", + "tsql": "CONVERT(CHAR, x)", }, ) self.validate_all( "CONVERT(NCHAR(40), x)", write={ "spark": "CAST(x AS CHAR(40))", + "tsql": "CONVERT(NCHAR(40), x)", }, ) self.validate_all( "CONVERT(NCHAR, x)", write={ "spark": "CAST(x AS CHAR(30))", + "tsql": "CONVERT(NCHAR, x)", }, ) self.validate_all( "CONVERT(VARCHAR, x, 121)", write={ "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))", + "tsql": "CONVERT(VARCHAR, x, 121)", }, ) self.validate_all( "CONVERT(VARCHAR(40), x, 121)", write={ "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(40))", + "tsql": "CONVERT(VARCHAR(40), x, 121)", }, ) self.validate_all( "CONVERT(VARCHAR(MAX), x, 121)", write={ - "spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS STRING)", + "tsql": "CONVERT(VARCHAR(MAX), x, 121)", }, ) self.validate_all( "CONVERT(NVARCHAR, x, 121)", write={ "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))", + "tsql": "CONVERT(NVARCHAR, x, 121)", }, ) self.validate_all( "CONVERT(NVARCHAR(40), x, 121)", write={ "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(40))", + "tsql": "CONVERT(NVARCHAR(40), x, 121)", }, ) self.validate_all( "CONVERT(NVARCHAR(MAX), x, 121)", write={ - "spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + "spark": "CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS STRING)", + "tsql": "CONVERT(NVARCHAR(MAX), x, 121)", }, ) self.validate_all( "CONVERT(DATE, x, 121)", write={ "spark": "TO_DATE(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + "tsql": "CONVERT(DATE, x, 121)", }, ) self.validate_all( "CONVERT(DATETIME, x, 121)", write={ "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + "tsql": "CONVERT(DATETIME2, x, 121)", }, ) self.validate_all( "CONVERT(DATETIME2, x, 121)", write={ "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS')", + "tsql": "CONVERT(DATETIME2, x, 121)", }, ) self.validate_all( "CONVERT(INT, x)", write={ "spark": "CAST(x AS INT)", + "tsql": "CONVERT(INTEGER, x)", }, ) self.validate_all( "CONVERT(INT, x, 121)", write={ "spark": "CAST(x AS INT)", + "tsql": "CONVERT(INTEGER, x, 121)", }, ) self.validate_all( "TRY_CONVERT(NVARCHAR, x, 121)", write={ "spark": "TRY_CAST(DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss.SSSSSS') AS VARCHAR(30))", + "tsql": "TRY_CONVERT(NVARCHAR, x, 121)", }, ) self.validate_all( "TRY_CONVERT(INT, x)", write={ "spark": "TRY_CAST(x AS INT)", + "tsql": "TRY_CONVERT(INTEGER, x)", }, ) self.validate_all( "TRY_CAST(x AS INT)", write={ "spark": "TRY_CAST(x AS INT)", - }, - ) - self.validate_all( - "CAST(x AS INT)", - write={ - "spark": "CAST(x AS INT)", + "tsql": "TRY_CAST(x AS INTEGER)", }, ) self.validate_all( @@ -1246,6 +1319,7 @@ WHERE write={ "mysql": "SELECT CAST(DATE_FORMAT(testdb.dbo.test.x, '%Y-%m-%d %T') AS CHAR(10)) AS y FROM testdb.dbo.test", "spark": "SELECT CAST(DATE_FORMAT(testdb.dbo.test.x, 'yyyy-MM-dd HH:mm:ss') AS VARCHAR(10)) AS y FROM testdb.dbo.test", + "tsql": "SELECT CONVERT(VARCHAR(10), testdb.dbo.test.x, 120) AS y FROM testdb.dbo.test", }, ) self.validate_all( @@ -1253,12 +1327,14 @@ WHERE write={ "mysql": "SELECT CAST(y.x AS CHAR(10)) AS z FROM testdb.dbo.test AS y", "spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y", + "tsql": "SELECT CONVERT(VARCHAR(10), y.x) AS z FROM testdb.dbo.test AS y", }, ) self.validate_all( "SELECT CAST((SELECT x FROM y) AS VARCHAR) AS test", write={ "spark": "SELECT CAST((SELECT x FROM y) AS STRING) AS test", + "tsql": "SELECT CAST((SELECT x FROM y) AS VARCHAR) AS test", }, ) @@ -1654,7 +1730,7 @@ FROM OPENJSON(@json) WITH ( Date DATETIME2 '$.Order.Date', Customer VARCHAR(200) '$.AccountNumber', Quantity INTEGER '$.Item.Quantity', - [Order] VARCHAR(MAX) AS JSON + [Order] NVARCHAR(MAX) AS JSON )""" }, pretty=True, diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index d9efc57..6d3bb07 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -196,10 +196,10 @@ SET LOCAL variable = value @"x" COMMIT USE db -USE role x -USE warehouse x -USE database x -USE schema x.y +USE ROLE x +USE WAREHOUSE x +USE DATABASE x +USE SCHEMA x.y NOT 1 NOT NOT 1 SELECT * FROM test @@ -643,6 +643,7 @@ DROP MATERIALIZED VIEW x.y.z CACHE TABLE x CACHE LAZY TABLE x CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') +CACHE LAZY TABLE x OPTIONS(N'storageLevel' = 'value') CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1 CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS WITH a AS (SELECT 1) SELECT a.* FROM a CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a @@ -708,6 +709,7 @@ COMMENT ON COLUMN my_schema.my_table.my_column IS 'Employee ID number' COMMENT ON DATABASE my_database IS 'Development Database' COMMENT ON PROCEDURE my_proc(integer, integer) IS 'Runs a report' COMMENT ON TABLE my_schema.my_table IS 'Employee Information' +COMMENT ON TABLE my_schema.my_table IS N'National String' WITH a AS (SELECT 1) INSERT INTO b SELECT * FROM a WITH a AS (SELECT * FROM b) UPDATE a SET col = 1 WITH a AS (SELECT * FROM b) CREATE TABLE b AS SELECT * FROM a @@ -785,6 +787,7 @@ ALTER TABLE baa ADD CONSTRAINT boo PRIMARY KEY (x, y) NOT ENFORCED DEFERRABLE IN ALTER TABLE baa ADD CONSTRAINT boo FOREIGN KEY (x, y) REFERENCES persons ON UPDATE NO ACTION ON DELETE NO ACTION MATCH FULL ALTER TABLE a ADD PRIMARY KEY (x, y) NOT ENFORCED ALTER TABLE a ADD FOREIGN KEY (x, y) REFERENCES bla +ALTER TABLE s_ut ADD CONSTRAINT s_ut_uq UNIQUE hajo SELECT partition FROM a SELECT end FROM a SELECT id FROM b.a AS a QUALIFY ROW_NUMBER() OVER (PARTITION BY br ORDER BY sadf DESC) = 1 @@ -850,3 +853,7 @@ CAST(foo AS BPCHAR) values SELECT values SELECT values AS values FROM t WHERE values + 1 > 3 +SELECT truncate +SELECT only +TRUNCATE(a, b) +SELECT enum diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index a33c81b..990453b 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -820,7 +820,7 @@ SELECT `TOp_TeRmS`.`refresh_date` AS `day`, `TOp_TeRmS`.`term` AS `top_term`, `TOp_TeRmS`.`rank` AS `rank` -FROM `bigquery-public-data`.`GooGle_tReNDs`.`TOp_TeRmS` AS `TOp_TeRmS` +FROM `bigquery-public-data.GooGle_tReNDs.TOp_TeRmS` AS `TOp_TeRmS` WHERE `TOp_TeRmS`.`rank` = 1 AND CAST(`TOp_TeRmS`.`refresh_date` AS DATE) >= DATE_SUB(CURRENT_DATE, INTERVAL 2 WEEK) diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index df8c1a5..71c6f45 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -172,6 +172,10 @@ SELECT _q_0._col_0 AS _col_0, _q_0._col_1 AS _col_1 FROM (VALUES (1, 2)) AS _q_0 select * from (values (1, 2)) x; SELECT x._col_0 AS _col_0, x._col_1 AS _col_1 FROM (VALUES (1, 2)) AS x(_col_0, _col_1); +# execute: false +SELECT SOME_UDF(data).* FROM t; +SELECT SOME_UDF(t.data).* FROM t AS t; + -------------------------------------- -- Derived tables -------------------------------------- @@ -333,6 +337,10 @@ WITH cte AS (SELECT 1 AS x) SELECT cte.a AS a FROM cte AS cte(a); WITH cte(x, y) AS (SELECT 1, 2) SELECT cte.* FROM cte AS cte(a); WITH cte AS (SELECT 1 AS x, 2 AS y) SELECT cte.a AS a, cte.y AS y FROM cte AS cte(a); +-- Cannot pop table column aliases for recursive ctes (redshift). +WITH RECURSIVE cte(x) AS (SELECT 1), cte2(y) AS (SELECT 2) SELECT * FROM cte, cte2; +WITH RECURSIVE cte(x) AS (SELECT 1 AS x), cte2(y) AS (SELECT 2 AS y) SELECT cte.x AS x, cte2.y AS y FROM cte AS cte, cte2 AS cte2; + # execute: false WITH player AS (SELECT player.name, player.asset.info FROM players) SELECT * FROM player; WITH player AS (SELECT players.player.name AS name, players.player.asset.info AS info FROM players AS players) SELECT player.name AS name, player.info AS info FROM player AS player; @@ -549,6 +557,10 @@ SELECT x.a + x.b AS f, (x.a + x.b) * x.b AS _col_1 FROM x AS x; SELECT x.a + x.b AS f, f, f + 5 FROM x; SELECT x.a + x.b AS f, x.a + x.b AS _col_1, x.a + x.b + 5 AS _col_2 FROM x AS x; +# title: expand double agg if window func +SELECT a, SUM(b) AS c, SUM(c) OVER(PARTITION BY a) AS d from x group by 1 ORDER BY a; +SELECT x.a AS a, SUM(x.b) AS c, SUM(SUM(x.b)) OVER (PARTITION BY x.a) AS d FROM x AS x GROUP BY x.a ORDER BY a; + -------------------------------------- -- Wrapped tables / join constructs -------------------------------------- diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql index 1426aa7..99b5153 100644 --- a/tests/fixtures/optimizer/qualify_tables.sql +++ b/tests/fixtures/optimizer/qualify_tables.sql @@ -19,6 +19,21 @@ SELECT 1 FROM x.y.z AS z; SELECT 1 FROM y.z AS z, z.a; SELECT 1 FROM c.y.z AS z, z.a; +# title: bigquery implicit unnest syntax, coordinates.position should be a column, not a table +# dialect: bigquery +SELECT results FROM Coordinates, coordinates.position AS results; +SELECT results FROM c.db.Coordinates AS Coordinates, UNNEST(coordinates.position) AS results; + +# title: bigquery implicit unnest syntax, table is already qualified +# dialect: bigquery +SELECT results FROM db.coordinates, Coordinates.position AS results; +SELECT results FROM c.db.coordinates AS coordinates, UNNEST(Coordinates.position) AS results; + +# title: bigquery schema name clashes with CTE name - this is a join, not an implicit unnest +# dialect: bigquery +WITH Coordinates AS (SELECT [1, 2] AS position) SELECT results FROM Coordinates, `Coordinates.position` AS results; +WITH Coordinates AS (SELECT [1, 2] AS position) SELECT results FROM Coordinates AS Coordinates, `c.Coordinates.position` AS results; + # title: single cte WITH a AS (SELECT 1 FROM z) SELECT 1 FROM a; WITH a AS (SELECT 1 FROM c.db.z AS z) SELECT 1 FROM a AS a; @@ -83,7 +98,7 @@ SELECT * FROM ((c.db.a AS foo CROSS JOIN c.db.b AS bar) CROSS JOIN c.db.c AS baz SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1); SELECT * FROM (c.db.tbl1 AS tbl1 CROSS JOIN (SELECT * FROM c.db.tbl2 AS tbl2) AS t1); -# title: wrapped join with subquery with alias, parentheses can't be omitted because of alias +# title: wrapped join with subquery with alias, parentheses cant be omitted because of alias SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1) AS t2; SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 CROSS JOIN (SELECT * FROM c.db.tbl2 AS tbl2) AS t1) AS t2; @@ -95,7 +110,7 @@ SELECT * FROM c.db.a AS a LEFT JOIN (c.db.b AS b INNER JOIN c.db.c AS c ON c.id SELECT * FROM a LEFT JOIN b INNER JOIN c ON c.id = b.id ON b.id = a.id; SELECT * FROM c.db.a AS a LEFT JOIN c.db.b AS b INNER JOIN c.db.c AS c ON c.id = b.id ON b.id = a.id; -# title: parentheses can't be omitted because alias shadows inner table names +# title: parentheses cant be omitted because alias shadows inner table names SELECT t.a FROM (tbl AS tbl) AS t; SELECT t.a FROM (SELECT * FROM c.db.tbl AS tbl) AS t; @@ -146,3 +161,7 @@ CREATE TABLE c.db.t1 AS (WITH cte AS (SELECT x FROM c.db.t2 AS t2) SELECT * FROM # title: insert statement with cte WITH cte AS (SELECT b FROM y) INSERT INTO s SELECT * FROM cte; WITH cte AS (SELECT b FROM c.db.y AS y) INSERT INTO c.db.s SELECT * FROM cte AS cte; + +# title: qualify wrapped query +(SELECT x FROM t); +(SELECT x FROM c.db.t AS t); diff --git a/tests/fixtures/optimizer/quote_identifiers.sql b/tests/fixtures/optimizer/quote_identifiers.sql index 21181f7..34500c4 100644 --- a/tests/fixtures/optimizer/quote_identifiers.sql +++ b/tests/fixtures/optimizer/quote_identifiers.sql @@ -29,3 +29,7 @@ SELECT "dual" FROM "t"; # dialect: snowflake SELECT * FROM t AS dual; SELECT * FROM "t" AS "dual"; + +# dialect: bigquery +SELECT `p.d.udf`(data).* FROM `p.d.t`; +SELECT `p.d.udf`(`data`).* FROM `p.d.t`; diff --git a/tests/fixtures/optimizer/tpc-ds/call_center.csv.gz b/tests/fixtures/optimizer/tpc-ds/call_center.csv.gz index f36e23b..ad5043f 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/call_center.csv.gz and b/tests/fixtures/optimizer/tpc-ds/call_center.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/catalog_page.csv.gz b/tests/fixtures/optimizer/tpc-ds/catalog_page.csv.gz index 702242c..eed1508 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/catalog_page.csv.gz and b/tests/fixtures/optimizer/tpc-ds/catalog_page.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/catalog_returns.csv.gz b/tests/fixtures/optimizer/tpc-ds/catalog_returns.csv.gz index e87a0ec..e160514 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/catalog_returns.csv.gz and b/tests/fixtures/optimizer/tpc-ds/catalog_returns.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/catalog_sales.csv.gz b/tests/fixtures/optimizer/tpc-ds/catalog_sales.csv.gz index a40b0da..1828149 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/catalog_sales.csv.gz and b/tests/fixtures/optimizer/tpc-ds/catalog_sales.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/customer.csv.gz b/tests/fixtures/optimizer/tpc-ds/customer.csv.gz index f4af4f7..2277f72 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/customer.csv.gz and b/tests/fixtures/optimizer/tpc-ds/customer.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/customer_address.csv.gz b/tests/fixtures/optimizer/tpc-ds/customer_address.csv.gz index 8698e39..c553721 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/customer_address.csv.gz and b/tests/fixtures/optimizer/tpc-ds/customer_address.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/customer_demographics.csv.gz b/tests/fixtures/optimizer/tpc-ds/customer_demographics.csv.gz index c4b7b68..dfc65a0 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/customer_demographics.csv.gz and b/tests/fixtures/optimizer/tpc-ds/customer_demographics.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/date_dim.csv.gz b/tests/fixtures/optimizer/tpc-ds/date_dim.csv.gz index 35be08f..26280bf 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/date_dim.csv.gz and b/tests/fixtures/optimizer/tpc-ds/date_dim.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/household_demographics.csv.gz b/tests/fixtures/optimizer/tpc-ds/household_demographics.csv.gz index b8addb7..f0cde03 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/household_demographics.csv.gz and b/tests/fixtures/optimizer/tpc-ds/household_demographics.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/income_band.csv.gz b/tests/fixtures/optimizer/tpc-ds/income_band.csv.gz index d34d870..4374587 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/income_band.csv.gz and b/tests/fixtures/optimizer/tpc-ds/income_band.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/inventory.csv.gz b/tests/fixtures/optimizer/tpc-ds/inventory.csv.gz index c6f0d47..5afaaf6 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/inventory.csv.gz and b/tests/fixtures/optimizer/tpc-ds/inventory.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/item.csv.gz b/tests/fixtures/optimizer/tpc-ds/item.csv.gz index 4a316cd..9f65d87 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/item.csv.gz and b/tests/fixtures/optimizer/tpc-ds/item.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/promotion.csv.gz b/tests/fixtures/optimizer/tpc-ds/promotion.csv.gz index 339666c..e8692c2 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/promotion.csv.gz and b/tests/fixtures/optimizer/tpc-ds/promotion.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/reason.csv.gz b/tests/fixtures/optimizer/tpc-ds/reason.csv.gz index 0094849..de1f50f 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/reason.csv.gz and b/tests/fixtures/optimizer/tpc-ds/reason.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/ship_mode.csv.gz b/tests/fixtures/optimizer/tpc-ds/ship_mode.csv.gz index 8dec386..14465e8 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/ship_mode.csv.gz and b/tests/fixtures/optimizer/tpc-ds/ship_mode.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/store.csv.gz b/tests/fixtures/optimizer/tpc-ds/store.csv.gz index b4e8de0..8d04078 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/store.csv.gz and b/tests/fixtures/optimizer/tpc-ds/store.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/store_returns.csv.gz b/tests/fixtures/optimizer/tpc-ds/store_returns.csv.gz index 8469492..cba1300 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/store_returns.csv.gz and b/tests/fixtures/optimizer/tpc-ds/store_returns.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/store_sales.csv.gz b/tests/fixtures/optimizer/tpc-ds/store_sales.csv.gz index 3dd22e1..68caa83 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/store_sales.csv.gz and b/tests/fixtures/optimizer/tpc-ds/store_sales.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/time_dim.csv.gz b/tests/fixtures/optimizer/tpc-ds/time_dim.csv.gz index bf4fcaf..3e0fa35 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/time_dim.csv.gz and b/tests/fixtures/optimizer/tpc-ds/time_dim.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index 5ea51e0..76e6431 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -1,6 +1,7 @@ -------------------------------------- -- TPC-DS 1 -------------------------------------- +# execute: true WITH customer_total_return AS (SELECT sr_customer_sk AS ctr_customer_sk, sr_store_sk AS ctr_store_sk, @@ -219,6 +220,7 @@ ORDER BY -------------------------------------- -- TPC-DS 3 -------------------------------------- +# execute: true SELECT dt.d_year, item.i_brand_id brand_id, item.i_brand brand, @@ -859,6 +861,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 6 -------------------------------------- +# execute: true SELECT a.ca_state state, Count(*) cnt FROM customer_address a, @@ -924,6 +927,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 7 -------------------------------------- +# execute: true SELECT i_item_id, Avg(ss_quantity) agg1, Avg(ss_list_price) agg2, @@ -1247,6 +1251,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 9 -------------------------------------- +# execute: true SELECT CASE WHEN (SELECT Count(*) FROM store_sales @@ -1448,6 +1453,7 @@ WHERE -------------------------------------- -- TPC-DS 10 -------------------------------------- +# execute: true SELECT cd_gender, cd_marital_status, cd_education_status, @@ -3056,6 +3062,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 24 -------------------------------------- +# execute: true WITH ssales AS (SELECT c_last_name, c_first_name, @@ -3158,6 +3165,7 @@ HAVING -------------------------------------- -- TPC-DS 25 -------------------------------------- +# execute: true SELECT i_item_id, i_item_desc, s_store_id, @@ -3247,6 +3255,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 26 -------------------------------------- +# execute: true SELECT i_item_id, Avg(cs_quantity) agg1, Avg(cs_list_price) agg2, @@ -3527,6 +3536,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 29 -------------------------------------- +# execute: true SELECT i_item_id, i_item_desc, s_store_id, @@ -3726,6 +3736,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 31 -------------------------------------- +# execute: true WITH ss AS (SELECT ca_county, d_qoy, @@ -3948,6 +3959,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 33 -------------------------------------- +# execute: true WITH ss AS (SELECT i_manufact_id, Sum(ss_ext_sales_price) total_sales @@ -5014,6 +5026,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 43 -------------------------------------- +# execute: true SELECT s_store_name, s_store_id, Sum(CASE @@ -6194,6 +6207,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 52 -------------------------------------- +# execute: true SELECT dt.d_year, item.i_brand_id brand_id, item.i_brand brand, @@ -6357,6 +6371,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 54 -------------------------------------- +# execute: true WITH my_customers AS (SELECT DISTINCT c_customer_sk, c_current_addr_sk @@ -6493,6 +6508,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 55 -------------------------------------- +# execute: true SELECT i_brand_id brand_id, i_brand brand, Sum(ss_ext_sales_price) ext_price @@ -6531,6 +6547,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 56 -------------------------------------- +# execute: true WITH ss AS (SELECT i_item_id, Sum(ss_ext_sales_price) total_sales @@ -7231,6 +7248,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 60 -------------------------------------- +# execute: true WITH ss AS (SELECT i_item_id, Sum(ss_ext_sales_price) total_sales @@ -8012,6 +8030,7 @@ ORDER BY -------------------------------------- -- TPC-DS 65 -------------------------------------- +# execute: true SELECT s_store_name, i_item_desc, sc.revenue, @@ -9113,6 +9132,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 69 -------------------------------------- +# execute: true SELECT cd_gender, cd_marital_status, cd_education_status, @@ -9355,6 +9375,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 71 -------------------------------------- +# execute: true SELECT i_brand_id brand_id, i_brand brand, t_hour, @@ -11064,6 +11085,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 83 -------------------------------------- +# execute: true WITH sr_items AS (SELECT i_item_id item_id, Sum(sr_return_quantity) sr_item_qty @@ -11262,6 +11284,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 84 -------------------------------------- +# execute: true SELECT c_customer_id AS customer_id, c_last_name || ', ' @@ -11563,6 +11586,7 @@ FROM "cool_cust" AS "cool_cust"; -------------------------------------- -- TPC-DS 88 -------------------------------------- +# execute: true select * from (select count(*) h8_30_to_9 @@ -12140,6 +12164,7 @@ LIMIT 100; -------------------------------------- -- TPC-DS 93 -------------------------------------- +# execute: true SELECT ss_customer_sk, Sum(act_sales) sumsales FROM (SELECT ss_item_sk, diff --git a/tests/fixtures/optimizer/tpc-ds/warehouse.csv.gz b/tests/fixtures/optimizer/tpc-ds/warehouse.csv.gz index 1dd95a0..cf64636 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/warehouse.csv.gz and b/tests/fixtures/optimizer/tpc-ds/warehouse.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/web_page.csv.gz b/tests/fixtures/optimizer/tpc-ds/web_page.csv.gz index 10a06a2..894ce3b 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/web_page.csv.gz and b/tests/fixtures/optimizer/tpc-ds/web_page.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/web_returns.csv.gz b/tests/fixtures/optimizer/tpc-ds/web_returns.csv.gz index 811e079..21f7040 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/web_returns.csv.gz and b/tests/fixtures/optimizer/tpc-ds/web_returns.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/web_sales.csv.gz b/tests/fixtures/optimizer/tpc-ds/web_sales.csv.gz index b1ac3b8..b384c78 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/web_sales.csv.gz and b/tests/fixtures/optimizer/tpc-ds/web_sales.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-ds/web_site.csv.gz b/tests/fixtures/optimizer/tpc-ds/web_site.csv.gz index ccedce2..b9b5f72 100644 Binary files a/tests/fixtures/optimizer/tpc-ds/web_site.csv.gz and b/tests/fixtures/optimizer/tpc-ds/web_site.csv.gz differ diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index a99abcd..39b5ffa 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -1047,7 +1047,6 @@ WITH "_u_0" AS ( "lineitem"."l_orderkey" AS "l_orderkey" FROM "lineitem" AS "lineitem" GROUP BY - "lineitem"."l_orderkey", "lineitem"."l_orderkey" HAVING SUM("lineitem"."l_quantity") > 300 diff --git a/tests/fixtures/optimizer/unnest_subqueries.sql b/tests/fixtures/optimizer/unnest_subqueries.sql index 3caeef6..45e462b 100644 --- a/tests/fixtures/optimizer/unnest_subqueries.sql +++ b/tests/fixtures/optimizer/unnest_subqueries.sql @@ -25,6 +25,7 @@ WHERE AND x.a > ALL (SELECT y.c FROM y WHERE y.a = x.a) AND x.a > (SELECT COUNT(*) as d FROM y WHERE y.a = x.a) AND x.a = SUM(SELECT 1) -- invalid statement left alone + AND x.a IN (SELECT max(y.b) AS b FROM y GROUP BY y.a) ; SELECT * @@ -155,6 +156,20 @@ LEFT JOIN ( y.a ) AS _u_21 ON _u_21._u_22 = x.a +LEFT JOIN ( + SELECT + _q.b + FROM ( + SELECT + MAX(y.b) AS b + FROM y + GROUP BY + y.a + ) AS _q + GROUP BY + _q.b +) AS _u_24 + ON x.a = _u_24.b WHERE x.a = _u_0.a AND NOT _u_1.a IS NULL @@ -212,6 +227,7 @@ WHERE AND x.a > COALESCE(_u_21.d, 0) AND x.a = SUM(SELECT 1) /* invalid statement left alone */ + AND NOT _u_24.b IS NULL ; SELECT CAST(( diff --git a/tests/test_build.py b/tests/test_build.py index f0c631f..cdddd4f 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -94,6 +94,7 @@ class TestBuild(unittest.TestCase): (lambda: select("x").from_("tbl"), "SELECT x FROM tbl"), (lambda: select("x", "y").from_("tbl"), "SELECT x, y FROM tbl"), (lambda: select("x").select("y").from_("tbl"), "SELECT x, y FROM tbl"), + (lambda: select("comment", "begin"), "SELECT comment, begin"), ( lambda: select("x").select("y", append=False).from_("tbl"), "SELECT y FROM tbl", @@ -501,6 +502,25 @@ class TestBuild(unittest.TestCase): ), "SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned", ), + (lambda: parse_one("(SELECT 1)").select("2"), "(SELECT 1, 2)"), + ( + lambda: parse_one("(SELECT 1)").limit(1), + "SELECT * FROM ((SELECT 1)) AS _l_0 LIMIT 1", + ), + ( + lambda: parse_one("WITH t AS (SELECT 1) (SELECT 1)").limit(1), + "SELECT * FROM (WITH t AS (SELECT 1) (SELECT 1)) AS _l_0 LIMIT 1", + ), + ( + lambda: parse_one("(SELECT 1 LIMIT 2)").limit(1), + "SELECT * FROM ((SELECT 1 LIMIT 2)) AS _l_0 LIMIT 1", + ), + (lambda: parse_one("(SELECT 1)").subquery(), "((SELECT 1))"), + (lambda: parse_one("(SELECT 1)").subquery("alias"), "((SELECT 1)) AS alias"), + ( + lambda: parse_one("(select * from foo)").with_("foo", "select 1 as c"), + "WITH foo AS (SELECT 1 AS c) (SELECT * FROM foo)", + ), ( lambda: exp.update("tbl", {"x": None, "y": {"x": 1}}), "UPDATE tbl SET x = NULL, y = MAP(ARRAY('x'), ARRAY(1))", diff --git a/tests/test_diff.py b/tests/test_diff.py index d5fa163..fa012a8 100644 --- a/tests/test_diff.py +++ b/tests/test_diff.py @@ -2,7 +2,7 @@ import unittest from sqlglot import exp, parse_one from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff -from sqlglot.expressions import Join, to_identifier +from sqlglot.expressions import Join, to_table class TestDiff(unittest.TestCase): @@ -18,7 +18,6 @@ class TestDiff(unittest.TestCase): self._validate_delta_only( diff(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")), [ - Remove(to_identifier("b", quoted=False)), # the Identifier node Remove(parse_one("b")), # the Column node ], ) @@ -26,7 +25,6 @@ class TestDiff(unittest.TestCase): self._validate_delta_only( diff(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")), [ - Insert(to_identifier("c", quoted=False)), # the Identifier node Insert(parse_one("c")), # the Column node ], ) @@ -38,9 +36,39 @@ class TestDiff(unittest.TestCase): ), [ Update( - to_identifier("table_one", quoted=False), - to_identifier("table_two", quoted=False), - ), # the Identifier node + to_table("table_one", quoted=False), + to_table("table_two", quoted=False), + ), # the Table node + ], + ) + + def test_lambda(self): + self._validate_delta_only( + diff(parse_one("SELECT a, b, c, x(a -> a)"), parse_one("SELECT a, b, c, x(b -> b)")), + [ + Update( + exp.Lambda(this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")]), + exp.Lambda(this=exp.to_identifier("b"), expressions=[exp.to_identifier("b")]), + ), + ], + ) + + def test_udf(self): + self._validate_delta_only( + diff(parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()')), + [ + Insert(parse_one('"my.udf2"()')), + Remove(parse_one('"my.udf1"()')), + ], + ) + self._validate_delta_only( + diff( + parse_one('SELECT a, b, "my.udf"(x, y, z)'), + parse_one('SELECT a, b, "my.udf"(x, y, w)'), + ), + [ + Insert(exp.column("w")), + Remove(exp.column("z")), ], ) @@ -95,7 +123,6 @@ class TestDiff(unittest.TestCase): diff(parse_one(expr_src), parse_one(expr_tgt)), [ Remove(parse_one("LOWER(c) AS c")), # the Alias node - Remove(to_identifier("c", quoted=False)), # the Identifier node Remove(parse_one("LOWER(c)")), # the Lower node Remove(parse_one("'filter'")), # the Literal node Insert(parse_one("'different_filter'")), # the Literal node @@ -162,9 +189,7 @@ class TestDiff(unittest.TestCase): self._validate_delta_only( diff(expr_src, expr_tgt), [ - Insert(expression=exp.to_identifier("b")), Insert(expression=exp.to_column("tbl.b")), - Insert(expression=exp.to_identifier("tbl")), ], ) diff --git a/tests/test_executor.py b/tests/test_executor.py index 9a2b46b..981c1d4 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -1,3 +1,4 @@ +import os import datetime import unittest from datetime import date @@ -17,40 +18,53 @@ from tests.helpers import ( FIXTURES_DIR, SKIP_INTEGRATION, TPCH_SCHEMA, + TPCDS_SCHEMA, load_sql_fixture_pairs, + string_to_bool, ) -DIR = FIXTURES_DIR + "/optimizer/tpc-h/" +DIR_TPCH = FIXTURES_DIR + "/optimizer/tpc-h/" +DIR_TPCDS = FIXTURES_DIR + "/optimizer/tpc-ds/" @unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set") class TestExecutor(unittest.TestCase): @classmethod def setUpClass(cls): - cls.conn = duckdb.connect() + cls.tpch_conn = duckdb.connect() + cls.tpcds_conn = duckdb.connect() for table, columns in TPCH_SCHEMA.items(): - cls.conn.execute( + cls.tpch_conn.execute( f""" CREATE VIEW {table} AS SELECT * - FROM READ_CSV('{DIR}{table}.csv.gz', delim='|', header=True, columns={columns}) + FROM READ_CSV('{DIR_TPCH}{table}.csv.gz', delim='|', header=True, columns={columns}) + """ + ) + + for table, columns in TPCDS_SCHEMA.items(): + cls.tpcds_conn.execute( + f""" + CREATE VIEW {table} AS + SELECT * + FROM READ_CSV('{DIR_TPCDS}{table}.csv.gz', delim='|', header=True, columns={columns}) """ ) cls.cache = {} - cls.sqls = [ - (sql, expected) - for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql") - ] + cls.tpch_sqls = list(load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql")) + cls.tpcds_sqls = list(load_sql_fixture_pairs("optimizer/tpc-ds/tpc-ds.sql")) @classmethod def tearDownClass(cls): - cls.conn.close() + cls.tpch_conn.close() + cls.tpcds_conn.close() - def cached_execute(self, sql): + def cached_execute(self, sql, tpch=True): + conn = self.tpch_conn if tpch else self.tpcds_conn if sql not in self.cache: - self.cache[sql] = self.conn.execute(transpile(sql, write="duckdb")[0]).fetchdf() + self.cache[sql] = conn.execute(transpile(sql, write="duckdb")[0]).fetchdf() return self.cache[sql] def rename_anonymous(self, source, target): @@ -66,18 +80,28 @@ class TestExecutor(unittest.TestCase): self.assertEqual(generate(parse_one("x is null")), "scope[None][x] is None") def test_optimized_tpch(self): - for i, (sql, optimized) in enumerate(self.sqls, start=1): + for i, (_, sql, optimized) in enumerate(self.tpch_sqls, start=1): with self.subTest(f"{i}, {sql}"): - a = self.cached_execute(sql) - b = self.conn.execute(transpile(optimized, write="duckdb")[0]).fetchdf() + a = self.cached_execute(sql, tpch=True) + b = self.tpch_conn.execute(transpile(optimized, write="duckdb")[0]).fetchdf() self.rename_anonymous(b, a) assert_frame_equal(a, b) + def subtestHelper(self, i, table, tpch=True): + with self.subTest(f"{'tpc-h' if tpch else 'tpc-ds'} {i + 1}"): + _, sql, _ = self.tpch_sqls[i] if tpch else self.tpcds_sqls[i] + a = self.cached_execute(sql, tpch=tpch) + b = pd.DataFrame( + ((np.nan if c is None else c for c in r) for r in table.rows), + columns=table.columns, + ) + assert_frame_equal(a, b, check_dtype=False, check_index_type=False) + def test_execute_tpch(self): def to_csv(expression): if isinstance(expression, exp.Table) and expression.name not in ("revenue"): return parse_one( - f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}" + f"READ_CSV('{DIR_TPCH}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}" ) return expression @@ -87,19 +111,26 @@ class TestExecutor(unittest.TestCase): execute, ( (parse_one(sql).transform(to_csv).sql(pretty=True), TPCH_SCHEMA) - for sql, _ in self.sqls + for _, sql, _ in self.tpch_sqls ), ) ): - with self.subTest(f"tpch-h {i + 1}"): - sql, _ = self.sqls[i] - a = self.cached_execute(sql) - b = pd.DataFrame( - ((np.nan if c is None else c for c in r) for r in table.rows), - columns=table.columns, - ) - - assert_frame_equal(a, b, check_dtype=False, check_index_type=False) + self.subtestHelper(i, table, tpch=True) + + def test_execute_tpcds(self): + def to_csv(expression): + if isinstance(expression, exp.Table) and os.path.exists( + f"{DIR_TPCDS}{expression.name}.csv.gz" + ): + return parse_one( + f"READ_CSV('{DIR_TPCDS}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}" + ) + return expression + + for i, (meta, sql, _) in enumerate(self.tpcds_sqls): + if string_to_bool(meta.get("execute")): + table = execute(parse_one(sql).transform(to_csv).sql(pretty=True), TPCDS_SCHEMA) + self.subtestHelper(i, table, tpch=False) def test_execute_callable(self): tables = { diff --git a/tests/test_expressions.py b/tests/test_expressions.py index d42eeca..11f8fd3 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -249,7 +249,7 @@ class TestExpressions(unittest.TestCase): {"example.table": "`my-project.example.table`"}, dialect="bigquery", ).sql(), - 'SELECT * FROM "my-project".example.table /* example.table */', + 'SELECT * FROM "my-project"."example"."table" /* example.table */', ) def test_expand(self): @@ -313,6 +313,18 @@ class TestExpressions(unittest.TestCase): ).sql(), "SELECT * FROM (SELECT a FROM tbl1) WHERE b > 100", ) + self.assertEqual( + exp.replace_placeholders( + parse_one("select * from foo WHERE x > ? AND y IS ?"), 0, False + ).sql(), + "SELECT * FROM foo WHERE x > 0 AND y IS FALSE", + ) + self.assertEqual( + exp.replace_placeholders( + parse_one("select * from foo WHERE x > :int1 AND y IS :bool1"), int1=0, bool1=False + ).sql(), + "SELECT * FROM foo WHERE x > 0 AND y IS FALSE", + ) def test_function_building(self): self.assertEqual(exp.func("max", 1).sql(), "MAX(1)") @@ -645,6 +657,7 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("TO_HEX(foo)", read="bigquery"), exp.Hex) self.assertIsInstance(parse_one("TO_HEX(MD5(foo))", read="bigquery"), exp.MD5) self.assertIsInstance(parse_one("TRANSFORM(a, b)", read="spark"), exp.Transform) + self.assertIsInstance(parse_one("ADD_MONTHS(a, b)"), exp.AddMonths) def test_column(self): column = parse_one("a.b.c.d") diff --git a/tests/test_lineage.py b/tests/test_lineage.py index 922edcb..ed1a448 100644 --- a/tests/test_lineage.py +++ b/tests/test_lineage.py @@ -25,21 +25,21 @@ class TestLineage(unittest.TestCase): node.source.sql(), "SELECT z.a AS a FROM (SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */) AS z /* source: z */", ) - self.assertEqual(node.alias, "") + self.assertEqual(node.source_name, "") downstream = node.downstream[0] self.assertEqual( downstream.source.sql(), "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */", ) - self.assertEqual(downstream.alias, "z") + self.assertEqual(downstream.source_name, "z") downstream = downstream.downstream[0] self.assertEqual( downstream.source.sql(), "SELECT x.a AS a FROM x AS x", ) - self.assertEqual(downstream.alias, "y") + self.assertEqual(downstream.source_name, "y") self.assertGreater(len(node.to_html()._repr_html_()), 1000) def test_lineage_sql_with_cte(self) -> None: @@ -53,7 +53,8 @@ class TestLineage(unittest.TestCase): node.source.sql(), "WITH z AS (SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */) SELECT z.a AS a FROM z AS z", ) - self.assertEqual(node.alias, "") + self.assertEqual(node.source_name, "") + self.assertEqual(node.reference_node_name, "") # Node containing expanded CTE expression downstream = node.downstream[0] @@ -61,14 +62,16 @@ class TestLineage(unittest.TestCase): downstream.source.sql(), "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */", ) - self.assertEqual(downstream.alias, "") + self.assertEqual(downstream.source_name, "") + self.assertEqual(downstream.reference_node_name, "z") downstream = downstream.downstream[0] self.assertEqual( downstream.source.sql(), "SELECT x.a AS a FROM x AS x", ) - self.assertEqual(downstream.alias, "y") + self.assertEqual(downstream.source_name, "y") + self.assertEqual(downstream.reference_node_name, "") def test_lineage_source_with_cte(self) -> None: node = lineage( @@ -81,21 +84,24 @@ class TestLineage(unittest.TestCase): node.source.sql(), "SELECT z.a AS a FROM (WITH y AS (SELECT x.a AS a FROM x AS x) SELECT y.a AS a FROM y AS y) AS z /* source: z */", ) - self.assertEqual(node.alias, "") + self.assertEqual(node.source_name, "") + self.assertEqual(node.reference_node_name, "") downstream = node.downstream[0] self.assertEqual( downstream.source.sql(), "WITH y AS (SELECT x.a AS a FROM x AS x) SELECT y.a AS a FROM y AS y", ) - self.assertEqual(downstream.alias, "z") + self.assertEqual(downstream.source_name, "z") + self.assertEqual(downstream.reference_node_name, "") downstream = downstream.downstream[0] self.assertEqual( downstream.source.sql(), "SELECT x.a AS a FROM x AS x", ) - self.assertEqual(downstream.alias, "z") + self.assertEqual(downstream.source_name, "z") + self.assertEqual(downstream.reference_node_name, "y") def test_lineage_source_with_star(self) -> None: node = lineage( @@ -106,14 +112,16 @@ class TestLineage(unittest.TestCase): node.source.sql(), "WITH y AS (SELECT * FROM x AS x) SELECT y.a AS a FROM y AS y", ) - self.assertEqual(node.alias, "") + self.assertEqual(node.source_name, "") + self.assertEqual(node.reference_node_name, "") downstream = node.downstream[0] self.assertEqual( downstream.source.sql(), "SELECT * FROM x AS x", ) - self.assertEqual(downstream.alias, "") + self.assertEqual(downstream.source_name, "") + self.assertEqual(downstream.reference_node_name, "y") def test_lineage_external_col(self) -> None: node = lineage( @@ -124,14 +132,16 @@ class TestLineage(unittest.TestCase): node.source.sql(), "WITH y AS (SELECT * FROM x AS x) SELECT a AS a FROM y AS y JOIN z AS z ON y.uid = z.uid", ) - self.assertEqual(node.alias, "") + self.assertEqual(node.source_name, "") + self.assertEqual(node.reference_node_name, "") downstream = node.downstream[0] self.assertEqual( downstream.source.sql(), "?", ) - self.assertEqual(downstream.alias, "") + self.assertEqual(downstream.source_name, "") + self.assertEqual(downstream.reference_node_name, "") def test_lineage_values(self) -> None: node = lineage( @@ -143,17 +153,17 @@ class TestLineage(unittest.TestCase): node.source.sql(), "SELECT y.a AS a FROM (SELECT t.a AS a FROM (VALUES (1), (2)) AS t(a)) AS y /* source: y */", ) - self.assertEqual(node.alias, "") + self.assertEqual(node.source_name, "") downstream = node.downstream[0] self.assertEqual(downstream.source.sql(), "SELECT t.a AS a FROM (VALUES (1), (2)) AS t(a)") self.assertEqual(downstream.expression.sql(), "t.a AS a") - self.assertEqual(downstream.alias, "y") + self.assertEqual(downstream.source_name, "y") downstream = downstream.downstream[0] self.assertEqual(downstream.source.sql(), "(VALUES (1), (2)) AS t(a)") self.assertEqual(downstream.expression.sql(), "a") - self.assertEqual(downstream.alias, "y") + self.assertEqual(downstream.source_name, "y") def test_lineage_cte_name_appears_in_schema(self) -> None: schema = {"a": {"b": {"t1": {"c1": "int"}, "t2": {"c2": "int"}}}} @@ -168,22 +178,22 @@ class TestLineage(unittest.TestCase): node.source.sql(), "WITH t1 AS (SELECT t2.c2 AS c2 FROM a.b.t2 AS t2), inter AS (SELECT t1.c2 AS c2 FROM t1 AS t1) SELECT inter.c2 AS c2 FROM inter AS inter", ) - self.assertEqual(node.alias, "") + self.assertEqual(node.source_name, "") downstream = node.downstream[0] self.assertEqual(downstream.source.sql(), "SELECT t1.c2 AS c2 FROM t1 AS t1") self.assertEqual(downstream.expression.sql(), "t1.c2 AS c2") - self.assertEqual(downstream.alias, "") + self.assertEqual(downstream.source_name, "") downstream = downstream.downstream[0] self.assertEqual(downstream.source.sql(), "SELECT t2.c2 AS c2 FROM a.b.t2 AS t2") self.assertEqual(downstream.expression.sql(), "t2.c2 AS c2") - self.assertEqual(downstream.alias, "") + self.assertEqual(downstream.source_name, "") downstream = downstream.downstream[0] self.assertEqual(downstream.source.sql(), "a.b.t2 AS t2") self.assertEqual(downstream.expression.sql(), "a.b.t2 AS t2") - self.assertEqual(downstream.alias, "") + self.assertEqual(downstream.source_name, "") self.assertEqual(downstream.downstream, []) @@ -280,9 +290,11 @@ class TestLineage(unittest.TestCase): downstream_a = node.downstream[0] self.assertEqual(downstream_a.name, "0") self.assertEqual(downstream_a.source.sql(), "SELECT * FROM catalog.db.table_a AS table_a") + self.assertEqual(downstream_a.reference_node_name, "dataset") downstream_b = node.downstream[1] self.assertEqual(downstream_b.name, "0") self.assertEqual(downstream_b.source.sql(), "SELECT * FROM catalog.db.table_b AS table_b") + self.assertEqual(downstream_b.reference_node_name, "dataset") def test_lineage_source_union(self) -> None: query = "SELECT x, created_at FROM dataset;" @@ -306,12 +318,14 @@ class TestLineage(unittest.TestCase): downstream_a = node.downstream[0] self.assertEqual(downstream_a.name, "0") - self.assertEqual(downstream_a.alias, "dataset") + self.assertEqual(downstream_a.source_name, "dataset") self.assertEqual(downstream_a.source.sql(), "SELECT * FROM catalog.db.table_a AS table_a") + self.assertEqual(downstream_a.reference_node_name, "") downstream_b = node.downstream[1] self.assertEqual(downstream_b.name, "0") - self.assertEqual(downstream_b.alias, "dataset") + self.assertEqual(downstream_b.source_name, "dataset") self.assertEqual(downstream_b.source.sql(), "SELECT * FROM catalog.db.table_b AS table_b") + self.assertEqual(downstream_b.reference_node_name, "") def test_select_star(self) -> None: node = lineage("x", "SELECT x from (SELECT * from table_a)") @@ -332,3 +346,10 @@ class TestLineage(unittest.TestCase): "with _data as (select [struct(1 as a, 2 as b)] as col) select b from _data cross join unnest(col)", ) self.assertEqual(node.name, "b") + + def test_lineage_normalize(self) -> None: + node = lineage("a", "WITH x AS (SELECT 1 a) SELECT a FROM x", dialect="snowflake") + self.assertEqual(node.name, "A") + + with self.assertRaises(sqlglot.errors.SqlglotError): + lineage('"a"', "WITH x AS (SELECT 1 a) SELECT a FROM x", dialect="snowflake") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index af8c3cd..046e5a6 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -205,6 +205,7 @@ class TestOptimizer(unittest.TestCase): optimizer.qualify_tables.qualify_tables, db="db", catalog="c", + set_dialect=True, ) def test_normalize(self): @@ -285,6 +286,15 @@ class TestOptimizer(unittest.TestCase): "SELECT `test`.`bar_bazfoo_$id` AS `bar_bazfoo_$id` FROM `test` AS `test`", ) + qualified = optimizer.qualify.qualify( + parse_one("WITH t AS (SELECT 1 AS c) (SELECT c FROM t)") + ) + self.assertIs(qualified.selects[0].parent, qualified.this) + self.assertEqual( + qualified.sql(), + 'WITH "t" AS (SELECT 1 AS "c") (SELECT "t"."c" AS "c" FROM "t" AS "t")', + ) + self.check_file( "qualify_columns", qualify_columns, execute=True, schema=self.schema, set_dialect=True ) @@ -348,6 +358,23 @@ class TestOptimizer(unittest.TestCase): self.assertEqual("CONCAT('a', x, 'bc')", simplified_concat.sql(dialect="presto")) self.assertEqual("CONCAT('a', x, 'bc')", simplified_safe_concat.sql()) + anon_unquoted_str = parse_one("anonymous(x, y)") + self.assertEqual(optimizer.simplify.gen(anon_unquoted_str), "ANONYMOUS x,y") + + anon_unquoted_identifier = exp.Anonymous( + this=exp.to_identifier("anonymous"), expressions=[exp.column("x"), exp.column("y")] + ) + self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS x,y") + + anon_quoted = parse_one('"anonymous"(x, y)') + self.assertEqual(optimizer.simplify.gen(anon_quoted), '"anonymous" x,y') + + with self.assertRaises(ValueError) as e: + anon_invalid = exp.Anonymous(this=5) + optimizer.simplify.gen(anon_invalid) + + self.assertIn("Anonymous.this expects a str or an Identifier, got 'int'.", str(e.exception)) + def test_unnest_subqueries(self): self.check_file( "unnest_subqueries", @@ -982,9 +1009,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.selects[0].type.sql(), "ARRAY") schema = MappingSchema({"t": {"c": "STRUCT<`f` STRING>"}}, dialect="bigquery") - expression = annotate_types(parse_one("SELECT t.c FROM t"), schema=schema) + expression = annotate_types(parse_one("SELECT t.c, [t.c] FROM t"), schema=schema) self.assertEqual(expression.selects[0].type.sql(dialect="bigquery"), "STRUCT<`f` STRING>") + self.assertEqual( + expression.selects[1].type.sql(dialect="bigquery"), "ARRAY>" + ) expression = annotate_types( parse_one("SELECT unnest(t.x) FROM t AS t", dialect="postgres"), @@ -1010,6 +1040,22 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(exp.DataType.Type.USERDEFINED, expression.selects[0].type.this) self.assertEqual(expression.selects[0].type.sql(dialect="postgres"), "IPADDRESS") + def test_unnest_annotation(self): + expression = annotate_types( + optimizer.qualify.qualify( + parse_one( + """ + SELECT a, a.b, a.b.c FROM x, UNNEST(x.a) AS a + """, + read="bigquery", + ) + ), + schema={"x": {"a": "ARRAY>>"}}, + ) + self.assertEqual(expression.selects[0].type, exp.DataType.build("STRUCT>")) + self.assertEqual(expression.selects[1].type, exp.DataType.build("STRUCT")) + self.assertEqual(expression.selects[2].type, exp.DataType.build("int")) + def test_recursive_cte(self): query = parse_one( """ diff --git a/tests/test_parser.py b/tests/test_parser.py index 035b5de..791d352 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -852,3 +852,6 @@ class TestParser(unittest.TestCase): ): with self.subTest(dialect): self.assertEqual(parse_one(sql, dialect=dialect).sql(dialect=dialect), sql) + + def test_distinct_from(self): + self.assertIsInstance(parse_one("a IS DISTINCT FROM b OR c IS DISTINCT FROM d"), exp.Or) diff --git a/tests/test_serde.py b/tests/test_serde.py index 1043fcf..40d6134 100644 --- a/tests/test_serde.py +++ b/tests/test_serde.py @@ -6,8 +6,7 @@ from sqlglot.optimizer.annotate_types import annotate_types from tests.helpers import load_sql_fixtures -class CustomExpression(exp.Expression): - ... +class CustomExpression(exp.Expression): ... class TestSerDe(unittest.TestCase): diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 99b3fac..49deda9 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -747,7 +747,6 @@ FROM base""", "ALTER SEQUENCE IF EXISTS baz RESTART WITH boo", "ALTER SESSION SET STATEMENT_TIMEOUT_IN_SECONDS=3", "ALTER TABLE integers DROP PRIMARY KEY", - "ALTER TABLE s_ut ADD CONSTRAINT s_ut_uq UNIQUE hajo", "ALTER TABLE table1 MODIFY COLUMN name1 SET TAG foo='bar'", "ALTER TABLE table1 RENAME COLUMN c1 AS c2", "ALTER TABLE table1 RENAME COLUMN c1 TO c2, c2 TO c3", @@ -769,7 +768,6 @@ FROM base""", "SET -v", "SET @user OFF", "SHOW TABLES", - "TRUNCATE TABLE x", "VACUUM FREEZE my_table", ): with self.subTest(sql): -- cgit v1.2.3