From 042432fc9a1f7c3d5d552f12449fe45109fbcd57 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 4 May 2024 18:12:58 +0200 Subject: Adding upstream version 23.13.1. Signed-off-by: Daniel Baumann --- tests/dialects/test_databricks.py | 6 ++ tests/dialects/test_dialect.py | 2 +- tests/dialects/test_duckdb.py | 10 ++ tests/dialects/test_hive.py | 4 +- tests/dialects/test_mysql.py | 118 +++++++++++++++-------- tests/dialects/test_postgres.py | 39 +++++++- tests/dialects/test_presto.py | 12 ++- tests/dialects/test_prql.py | 13 +++ tests/dialects/test_redshift.py | 10 +- tests/dialects/test_snowflake.py | 51 ++++++---- tests/dialects/test_spark.py | 21 +++- tests/dialects/test_trino.py | 18 ++++ tests/dialects/test_tsql.py | 11 +++ tests/fixtures/identity.sql | 4 +- tests/fixtures/optimizer/qualify_columns.sql | 4 + tests/fixtures/optimizer/qualify_columns_ddl.sql | 14 +-- tests/fixtures/optimizer/qualify_tables.sql | 4 + tests/fixtures/optimizer/simplify.sql | 11 +++ tests/test_build.py | 4 + tests/test_lineage.py | 64 +++++++++++- 20 files changed, 341 insertions(+), 79 deletions(-) create mode 100644 tests/dialects/test_trino.py (limited to 'tests') diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index c15cf09..14a6bf3 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -25,6 +25,9 @@ class TestDatabricks(Validator): self.validate_identity("CREATE FUNCTION a AS b") self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1") self.validate_identity("CREATE TABLE foo (x DATE GENERATED ALWAYS AS (CAST(y AS DATE)))") + self.validate_identity( + "CREATE TABLE IF NOT EXISTS db.table (a TIMESTAMP, b BOOLEAN GENERATED ALWAYS AS (NOT a IS NULL)) USING DELTA" + ) self.validate_identity( "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t" ) @@ -47,6 +50,9 @@ class TestDatabricks(Validator): self.validate_identity( "TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', city LIKE 'LA')" ) + self.validate_identity( + "COPY INTO target FROM `s3://link` FILEFORMAT = AVRO VALIDATE = ALL FILES = ('file1', 'file2') FORMAT_OPTIONS(opt1 = TRUE, opt2 = 'test') COPY_OPTIONS(opt3 = 5)" + ) self.validate_all( "CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index ea38521..dda0eb2 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -293,7 +293,7 @@ class TestDialect(Validator): "bigquery": "CAST(a AS INT64)", "drill": "CAST(a AS INTEGER)", "duckdb": "CAST(a AS SMALLINT)", - "mysql": "CAST(a AS SMALLINT)", + "mysql": "CAST(a AS SIGNED)", "hive": "CAST(a AS SMALLINT)", "oracle": "CAST(a AS NUMBER)", "postgres": "CAST(a AS SMALLINT)", diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 9105a49..2d0af13 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -324,6 +324,8 @@ class TestDuckDB(Validator): self.validate_identity( "SELECT * FROM (PIVOT Cities ON Year USING SUM(Population) GROUP BY Country) AS pivot_alias" ) + self.validate_identity("DATE_SUB('YEAR', col, '2020-01-01')").assert_is(exp.Anonymous) + self.validate_identity("DATESUB('YEAR', col, '2020-01-01')").assert_is(exp.Anonymous) self.validate_all("0b1010", write={"": "0 AS b1010"}) self.validate_all("0x1010", write={"": "0 AS x1010"}) @@ -724,6 +726,14 @@ class TestDuckDB(Validator): """SELECT i FROM GENERATE_SERIES(0, 12) AS _(i) ORDER BY i ASC""", ) + self.validate_identity( + "COPY lineitem FROM 'lineitem.ndjson' WITH (FORMAT JSON, DELIMITER ',', AUTO_DETECT TRUE, COMPRESSION SNAPPY, CODEC ZSTD, FORCE_NOT_NULL(col1, col2))" + ) + self.validate_identity( + "COPY (SELECT 42 AS a, 'hello' AS b) TO 'query.json' WITH (FORMAT JSON, ARRAY TRUE)" + ) + self.validate_identity("COPY lineitem (l_orderkey) TO 'orderkey.tbl' WITH (DELIMITER '|')") + def test_array_index(self): with self.assertLogs(helper_logger) as cm: self.validate_all( diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 9215f05..dfce446 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -334,7 +334,7 @@ class TestHive(Validator): "hive": "DATE_ADD('2020-01-01', 1)", "presto": "DATE_ADD('DAY', 1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", "redshift": "DATEADD(DAY, 1, '2020-01-01')", - "snowflake": "DATEADD(DAY, 1, CAST(CAST('2020-01-01' AS TIMESTAMPNTZ) AS DATE))", + "snowflake": "DATEADD(DAY, 1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", "spark": "DATE_ADD('2020-01-01', 1)", "tsql": "DATEADD(DAY, 1, CAST(CAST('2020-01-01' AS DATETIME2) AS DATE))", }, @@ -348,7 +348,7 @@ class TestHive(Validator): "hive": "DATE_ADD('2020-01-01', 1 * -1)", "presto": "DATE_ADD('DAY', 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", "redshift": "DATEADD(DAY, 1 * -1, '2020-01-01')", - "snowflake": "DATEADD(DAY, 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMPNTZ) AS DATE))", + "snowflake": "DATEADD(DAY, 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", "spark": "DATE_ADD('2020-01-01', 1 * -1)", "tsql": "DATEADD(DAY, 1 * -1, CAST(CAST('2020-01-01' AS DATETIME2) AS DATE))", }, diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index e8af5c6..53e2dab 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -1,4 +1,5 @@ from sqlglot import expressions as exp +from sqlglot.dialects.mysql import MySQL from tests.dialects.test_dialect import Validator @@ -6,21 +7,11 @@ class TestMySQL(Validator): dialect = "mysql" def test_ddl(self): - int_types = {"BIGINT", "INT", "MEDIUMINT", "SMALLINT", "TINYINT"} - - for t in int_types: + for t in ("BIGINT", "INT", "MEDIUMINT", "SMALLINT", "TINYINT"): self.validate_identity(f"CREATE TABLE t (id {t} UNSIGNED)") self.validate_identity(f"CREATE TABLE t (id {t}(10) UNSIGNED)") self.validate_identity("CREATE TABLE t (id DECIMAL(20, 4) UNSIGNED)") - - self.validate_all( - "CREATE TABLE t (id INT UNSIGNED)", - write={ - "duckdb": "CREATE TABLE t (id UINTEGER)", - }, - ) - self.validate_identity("CREATE TABLE foo (a BIGINT, UNIQUE (b) USING BTREE)") self.validate_identity("CREATE TABLE foo (id BIGINT)") self.validate_identity("CREATE TABLE 00f (1d BIGINT)") @@ -97,6 +88,13 @@ class TestMySQL(Validator): "CREATE TABLE `foo` (a VARCHAR(10), INDEX idx_a (a DESC))", ) + self.validate_all( + "CREATE TABLE t (id INT UNSIGNED)", + write={ + "duckdb": "CREATE TABLE t (id UINTEGER)", + "mysql": "CREATE TABLE t (id INT UNSIGNED)", + }, + ) self.validate_all( "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'", write={ @@ -109,15 +107,10 @@ class TestMySQL(Validator): self.validate_all( "CREATE TABLE x (id int not null auto_increment, primary key (id))", write={ + "mysql": "CREATE TABLE x (id INT NOT NULL AUTO_INCREMENT, PRIMARY KEY (id))", "sqlite": "CREATE TABLE x (id INTEGER NOT NULL AUTOINCREMENT PRIMARY KEY)", }, ) - self.validate_all( - "CREATE TABLE x (id int not null auto_increment)", - write={ - "sqlite": "CREATE TABLE x (id INTEGER NOT NULL)", - }, - ) def test_identity(self): self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1") @@ -135,8 +128,6 @@ class TestMySQL(Validator): self.validate_identity("SELECT CAST('[4,5]' AS JSON) MEMBER OF('[[3,4],[4,5]]')") self.validate_identity("""SELECT 'ab' MEMBER OF('[23, "abc", 17, "ab", 10]')""") self.validate_identity("""SELECT * FROM foo WHERE 'ab' MEMBER OF(content)""") - self.validate_identity("CAST(x AS ENUM('a', 'b'))") - self.validate_identity("CAST(x AS SET('a', 'b'))") self.validate_identity("SELECT CURRENT_TIMESTAMP(6)") self.validate_identity("x ->> '$.name'") self.validate_identity("SELECT CAST(`a`.`b` AS CHAR) FROM foo") @@ -226,29 +217,47 @@ class TestMySQL(Validator): self.validate_identity("SELECT * FROM t1 PARTITION(p0)") def test_types(self): - self.validate_identity("CAST(x AS MEDIUMINT) + CAST(y AS YEAR(4))") + for char_type in MySQL.Generator.CHAR_CAST_MAPPING: + with self.subTest(f"MySQL cast into {char_type}"): + self.validate_identity(f"CAST(x AS {char_type.value})", "CAST(x AS CHAR)") + + for signed_type in MySQL.Generator.SIGNED_CAST_MAPPING: + with self.subTest(f"MySQL cast into {signed_type}"): + self.validate_identity(f"CAST(x AS {signed_type.value})", "CAST(x AS SIGNED)") + + self.validate_identity("CAST(x AS ENUM('a', 'b'))") + self.validate_identity("CAST(x AS SET('a', 'b'))") + self.validate_identity( + "CAST(x AS MEDIUMINT) + CAST(y AS YEAR(4))", + "CAST(x AS SIGNED) + CAST(y AS YEAR(4))", + ) + self.validate_identity( + "CAST(x AS TIMESTAMP)", + "CAST(x AS DATETIME)", + ) + self.validate_identity( + "CAST(x AS TIMESTAMPTZ)", + "TIMESTAMP(x)", + ) + self.validate_identity( + "CAST(x AS TIMESTAMPLTZ)", + "TIMESTAMP(x)", + ) self.validate_all( "CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT) + CAST(z AS TINYTEXT)", - read={ - "mysql": "CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT) + CAST(z AS TINYTEXT)", - }, write={ + "mysql": "CAST(x AS CHAR) + CAST(y AS CHAR) + CAST(z AS CHAR)", "spark": "CAST(x AS TEXT) + CAST(y AS TEXT) + CAST(z AS TEXT)", }, ) self.validate_all( "CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB) + CAST(z AS TINYBLOB)", - read={ - "mysql": "CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB) + CAST(z AS TINYBLOB)", - }, write={ + "mysql": "CAST(x AS CHAR) + CAST(y AS CHAR) + CAST(z AS CHAR)", "spark": "CAST(x AS BLOB) + CAST(y AS BLOB) + CAST(z AS BLOB)", }, ) - self.validate_all("CAST(x AS TIMESTAMP)", write={"mysql": "CAST(x AS DATETIME)"}) - self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"mysql": "TIMESTAMP(x)"}) - self.validate_all("CAST(x AS TIMESTAMPLTZ)", write={"mysql": "TIMESTAMP(x)"}) def test_canonical_functions(self): self.validate_identity("SELECT LEFT('str', 2)", "SELECT LEFT('str', 2)") @@ -457,63 +466,63 @@ class TestMySQL(Validator): "SELECT DATE_FORMAT('2017-06-15', '%Y')", write={ "mysql": "SELECT DATE_FORMAT('2017-06-15', '%Y')", - "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'yyyy')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'yyyy')", }, ) self.validate_all( "SELECT DATE_FORMAT('2017-06-15', '%m')", write={ "mysql": "SELECT DATE_FORMAT('2017-06-15', '%m')", - "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'mm')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'mm')", }, ) self.validate_all( "SELECT DATE_FORMAT('2017-06-15', '%d')", write={ "mysql": "SELECT DATE_FORMAT('2017-06-15', '%d')", - "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'DD')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'DD')", }, ) self.validate_all( "SELECT DATE_FORMAT('2017-06-15', '%Y-%m-%d')", write={ "mysql": "SELECT DATE_FORMAT('2017-06-15', '%Y-%m-%d')", - "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'yyyy-mm-DD')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'yyyy-mm-DD')", }, ) self.validate_all( "SELECT DATE_FORMAT('2017-06-15 22:23:34', '%H')", write={ "mysql": "SELECT DATE_FORMAT('2017-06-15 22:23:34', '%H')", - "snowflake": "SELECT TO_CHAR(CAST('2017-06-15 22:23:34' AS TIMESTAMPNTZ), 'hh24')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15 22:23:34' AS TIMESTAMP), 'hh24')", }, ) self.validate_all( "SELECT DATE_FORMAT('2017-06-15', '%w')", write={ "mysql": "SELECT DATE_FORMAT('2017-06-15', '%w')", - "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'dy')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'dy')", }, ) self.validate_all( "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')", write={ "mysql": "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')", - "snowflake": "SELECT TO_CHAR(CAST('2009-10-04 22:23:00' AS TIMESTAMPNTZ), 'DY mmmm yyyy')", + "snowflake": "SELECT TO_CHAR(CAST('2009-10-04 22:23:00' AS TIMESTAMP), 'DY mmmm yyyy')", }, ) self.validate_all( "SELECT DATE_FORMAT('2007-10-04 22:23:00', '%H:%i:%s')", write={ "mysql": "SELECT DATE_FORMAT('2007-10-04 22:23:00', '%T')", - "snowflake": "SELECT TO_CHAR(CAST('2007-10-04 22:23:00' AS TIMESTAMPNTZ), 'hh24:mi:ss')", + "snowflake": "SELECT TO_CHAR(CAST('2007-10-04 22:23:00' AS TIMESTAMP), 'hh24:mi:ss')", }, ) self.validate_all( "SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %a %d %m %b')", write={ "mysql": "SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %W %d %m %b')", - "snowflake": "SELECT TO_CHAR(CAST('1900-10-04 22:23:00' AS TIMESTAMPNTZ), 'DD yy DY DD mm mon')", + "snowflake": "SELECT TO_CHAR(CAST('1900-10-04 22:23:00' AS TIMESTAMP), 'DD yy DY DD mm mon')", }, ) @@ -598,6 +607,19 @@ class TestMySQL(Validator): ) def test_mysql(self): + self.validate_all( + "SELECT department, GROUP_CONCAT(name) AS employee_names FROM data GROUP BY department", + read={ + "postgres": "SELECT department, array_agg(name) AS employee_names FROM data GROUP BY department", + }, + ) + self.validate_all( + "SELECT UNIX_TIMESTAMP(CAST('2024-04-29 12:00:00' AS DATETIME))", + read={ + "mysql": "SELECT UNIX_TIMESTAMP(CAST('2024-04-29 12:00:00' AS DATETIME))", + "postgres": "SELECT EXTRACT(epoch FROM TIMESTAMP '2024-04-29 12:00:00')", + }, + ) self.validate_all( "SELECT JSON_EXTRACT('[10, 20, [30, 40]]', '$[1]')", read={ @@ -1109,3 +1131,23 @@ COMMENT='客户账户表'""" "tsql": "CAST(a AS FLOAT) / NULLIF(b, 0)", }, ) + + def test_timestamp_trunc(self): + for dialect in ("postgres", "snowflake", "duckdb", "spark", "databricks"): + for unit in ( + "MILLISECOND", + "SECOND", + "DAY", + "MONTH", + "YEAR", + ): + with self.subTest(f"MySQL -> {dialect} Timestamp Trunc with unit {unit}: "): + self.validate_all( + f"DATE_ADD('0000-01-01 00:00:00', INTERVAL (TIMESTAMPDIFF({unit}, '0000-01-01 00:00:00', CAST('2001-02-16 20:38:40' AS DATETIME))) {unit})", + read={ + dialect: f"DATE_TRUNC({unit}, TIMESTAMP '2001-02-16 20:38:40')", + }, + write={ + "mysql": f"DATE_ADD('0000-01-01 00:00:00', INTERVAL (TIMESTAMPDIFF({unit}, '0000-01-01 00:00:00', CAST('2001-02-16 20:38:40' AS DATETIME))) {unit})", + }, + ) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 5a55a7d..6b6117e 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -312,8 +312,32 @@ class TestPostgres(Validator): "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)", "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)", ) - self.validate_identity("SELECT * FROM t1*", "SELECT * FROM t1") + self.validate_identity( + "SELECT * FROM t1*", + "SELECT * FROM t1", + ) + self.validate_identity( + "SELECT SUBSTRING('afafa' for 1)", + "SELECT SUBSTRING('afafa' FROM 1 FOR 1)", + ) + self.validate_identity( + "CAST(x AS INT8)", + "CAST(x AS BIGINT)", + ) + self.validate_all( + "SELECT REGEXP_REPLACE('mr .', '[^a-zA-Z]', '', 'g')", + write={ + "duckdb": "SELECT REGEXP_REPLACE('mr .', '[^a-zA-Z]', '', 'g')", + "postgres": "SELECT REGEXP_REPLACE('mr .', '[^a-zA-Z]', '', 'g')", + }, + ) + self.validate_all( + "CREATE TABLE t (c INT)", + read={ + "mysql": "CREATE TABLE t (c INT COMMENT 'comment 1') COMMENT = 'comment 2'", + }, + ) self.validate_all( 'SELECT * FROM "test_table" ORDER BY RANDOM() LIMIT 5', write={ @@ -449,7 +473,7 @@ class TestPostgres(Validator): write={ "postgres": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", "redshift": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", - "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))", + "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", }, ) self.validate_all( @@ -660,6 +684,16 @@ class TestPostgres(Validator): ) self.assertIsInstance(self.parse_one("id::UUID"), exp.Cast) + self.validate_identity( + "COPY tbl (col1, col2) FROM 'file' WITH (FORMAT format, HEADER MATCH, FREEZE TRUE)" + ) + self.validate_identity( + "COPY tbl (col1, col2) TO 'file' WITH (FORMAT format, HEADER MATCH, FREEZE TRUE)" + ) + self.validate_identity( + "COPY (SELECT * FROM t) TO 'file' WITH (FORMAT format, HEADER MATCH, FREEZE TRUE)" + ) + def test_ddl(self): # Checks that user-defined types are parsed into DataType instead of Identifier self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is( @@ -676,6 +710,7 @@ class TestPostgres(Validator): cdef.args["kind"].assert_is(exp.DataType) self.assertEqual(expr.sql(dialect="postgres"), "CREATE TABLE t (x INTERVAL DAY)") + self.validate_identity("CREATE INDEX IF NOT EXISTS ON t(c)") self.validate_identity("CREATE INDEX et_vid_idx ON et(vid) INCLUDE (fid)") self.validate_identity("CREATE INDEX idx_x ON x USING BTREE(x, y) WHERE (NOT y IS NULL)") self.validate_identity("CREATE TABLE test (elems JSONB[])") diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 4bafc08..108e916 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -10,6 +10,8 @@ class TestPresto(Validator): self.validate_identity("SELECT * FROM x qualify", "SELECT * FROM x AS qualify") self.validate_identity("CAST(x AS IPADDRESS)") self.validate_identity("CAST(x AS IPPREFIX)") + self.validate_identity("CAST(TDIGEST_AGG(1) AS TDIGEST)") + self.validate_identity("CAST(x AS HYPERLOGLOG)") self.validate_all( "CAST(x AS INTERVAL YEAR TO MONTH)", @@ -1059,6 +1061,15 @@ class TestPresto(Validator): ) def test_json(self): + with self.assertLogs(helper_logger): + self.validate_all( + """SELECT JSON_EXTRACT_SCALAR(TRY(FILTER(CAST(JSON_EXTRACT('{"k1": [{"k2": "{\\"k3\\": 1}", "k4": "v"}]}', '$.k1') AS ARRAY(MAP(VARCHAR, VARCHAR))), x -> x['k4'] = 'v')[1]['k2']), '$.k3')""", + write={ + "presto": """SELECT JSON_EXTRACT_SCALAR(TRY(FILTER(CAST(JSON_EXTRACT('{"k1": [{"k2": "{\\"k3\\": 1}", "k4": "v"}]}', '$.k1') AS ARRAY(MAP(VARCHAR, VARCHAR))), x -> x['k4'] = 'v')[1]['k2']), '$.k3')""", + "spark": """SELECT GET_JSON_OBJECT(FILTER(FROM_JSON(GET_JSON_OBJECT('{"k1": [{"k2": "{\\\\"k3\\\\": 1}", "k4": "v"}]}', '$.k1'), 'ARRAY>'), x -> x['k4'] = 'v')[0]['k2'], '$.k3')""", + }, + ) + self.validate_all( "SELECT CAST(JSON '[1,23,456]' AS ARRAY(INTEGER))", write={ @@ -1073,7 +1084,6 @@ class TestPresto(Validator): "presto": 'SELECT CAST(JSON_PARSE(\'{"k1":1,"k2":23,"k3":456}\') AS MAP(VARCHAR, INTEGER))', }, ) - self.validate_all( "SELECT CAST(ARRAY [1, 23, 456] AS JSON)", write={ diff --git a/tests/dialects/test_prql.py b/tests/dialects/test_prql.py index 1a0eec2..5b438f1 100644 --- a/tests/dialects/test_prql.py +++ b/tests/dialects/test_prql.py @@ -66,3 +66,16 @@ class TestPRQL(Validator): "from x filter (a > 1 || null != b || c != null)", "SELECT * FROM x WHERE (a > 1 OR NOT b IS NULL OR NOT c IS NULL)", ) + self.validate_identity("from a aggregate { average x }", "SELECT AVG(x) FROM a") + self.validate_identity( + "from a aggregate { average x, min y, ct = sum z }", + "SELECT AVG(x), MIN(y), COALESCE(SUM(z), 0) AS ct FROM a", + ) + self.validate_identity( + "from a aggregate { average x, min y, sum z }", + "SELECT AVG(x), MIN(y), COALESCE(SUM(z), 0) FROM a", + ) + self.validate_identity( + "from a aggregate { min y, b = stddev x, max z }", + "SELECT MIN(y), STDDEV(x) AS b, MAX(z) FROM a", + ) diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index a91f4f9..e227ea9 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -162,7 +162,7 @@ class TestRedshift(Validator): write={ "postgres": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", "redshift": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", - "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))", + "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", }, ) self.validate_all( @@ -271,7 +271,7 @@ class TestRedshift(Validator): "postgres": "SELECT CAST('2008-02-28' AS TIMESTAMP) + INTERVAL '18 MONTH'", "presto": "SELECT DATE_ADD('MONTH', 18, CAST('2008-02-28' AS TIMESTAMP))", "redshift": "SELECT DATEADD(MONTH, 18, '2008-02-28')", - "snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMPNTZ))", + "snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMP))", "tsql": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS DATETIME2))", }, ) @@ -362,8 +362,10 @@ class TestRedshift(Validator): "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO" ) self.validate_identity( - "COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'", - check_command_warning=True, + "COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole' REGION 'us-east-1' FORMAT orc", + ) + self.validate_identity( + "COPY customer FROM 's3://mybucket/mydata' CREDENTIALS 'aws_iam_role=arn:aws:iam:::role/;master_symmetric_key=' emptyasnull blanksasnull timeformat 'YYYY-MM-DD HH:MI:SS'" ) self.validate_identity( "UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'", diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 1cbf68c..ed33366 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -10,6 +10,9 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + self.validate_identity( + "MERGE INTO my_db AS ids USING (SELECT new_id FROM my_model WHERE NOT col IS NULL) AS new_ids ON ids.type = new_ids.type AND ids.source = new_ids.source WHEN NOT MATCHED THEN INSERT VALUES (new_ids.new_id)" + ) self.validate_identity("ALTER TABLE table1 CLUSTER BY (name DESC)") self.validate_identity( "INSERT OVERWRITE TABLE t SELECT 1", "INSERT OVERWRITE INTO t SELECT 1" @@ -388,7 +391,7 @@ WHERE "SELECT DATE_PART('year', TIMESTAMP '2020-01-01')", write={ "hive": "SELECT EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMP))", - "snowflake": "SELECT DATE_PART('year', CAST('2020-01-01' AS TIMESTAMPNTZ))", + "snowflake": "SELECT DATE_PART('year', CAST('2020-01-01' AS TIMESTAMP))", "spark": "SELECT EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMP))", }, ) @@ -591,7 +594,7 @@ WHERE self.validate_all( "SELECT DAYOFWEEK('2016-01-02T23:39:20.123-07:00'::TIMESTAMP)", write={ - "snowflake": "SELECT DAYOFWEEK(CAST('2016-01-02T23:39:20.123-07:00' AS TIMESTAMPNTZ))", + "snowflake": "SELECT DAYOFWEEK(CAST('2016-01-02T23:39:20.123-07:00' AS TIMESTAMP))", }, ) self.validate_all( @@ -689,7 +692,7 @@ WHERE "SELECT TO_TIMESTAMP('2013-04-05 01:02:03')", write={ "bigquery": "SELECT CAST('2013-04-05 01:02:03' AS DATETIME)", - "snowflake": "SELECT CAST('2013-04-05 01:02:03' AS TIMESTAMPNTZ)", + "snowflake": "SELECT CAST('2013-04-05 01:02:03' AS TIMESTAMP)", "spark": "SELECT CAST('2013-04-05 01:02:03' AS TIMESTAMP)", }, ) @@ -878,10 +881,6 @@ WHERE self.validate_identity("SELECT * FROM @namespace.%table_name/path/to/file.json.gz") self.validate_identity("SELECT * FROM '@external/location' (FILE_FORMAT => 'path.to.csv')") self.validate_identity("PUT file:///dir/tmp.csv @%table", check_command_warning=True) - self.validate_identity( - 'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)', - check_command_warning=True, - ) self.validate_identity( "SELECT * FROM @foo/bar (FILE_FORMAT => ds_sandbox.test.my_csv_format, PATTERN => 'test') AS bla" ) @@ -955,12 +954,16 @@ WHERE self.validate_identity("SELECT CAST('12:00:00' AS TIME)") self.validate_identity("SELECT DATE_PART(month, a)") - self.validate_all( - "SELECT CAST(a AS TIMESTAMP)", - write={ - "snowflake": "SELECT CAST(a AS TIMESTAMPNTZ)", - }, - ) + for data_type in ( + "TIMESTAMP", + "TIMESTAMPLTZ", + "TIMESTAMPNTZ", + ): + self.validate_identity(f"CAST(a AS {data_type})") + + self.validate_identity("CAST(a AS TIMESTAMP_NTZ)", "CAST(a AS TIMESTAMPNTZ)") + self.validate_identity("CAST(a AS TIMESTAMP_LTZ)", "CAST(a AS TIMESTAMPLTZ)") + self.validate_all( "SELECT a::TIMESTAMP_LTZ(9)", write={ @@ -1000,14 +1003,14 @@ WHERE self.validate_all( "SELECT DATE_PART(epoch_second, foo) as ddate from table_name", write={ - "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) AS ddate FROM table_name", + "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMP)) AS ddate FROM table_name", "presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) AS ddate FROM table_name", }, ) self.validate_all( "SELECT DATE_PART(epoch_milliseconds, foo) as ddate from table_name", write={ - "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) * 1000 AS ddate FROM table_name", + "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMP)) * 1000 AS ddate FROM table_name", "presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) * 1000 AS ddate FROM table_name", }, ) @@ -1138,7 +1141,7 @@ WHERE ) self.validate_identity( "SELECT * FROM my_table AT (TIMESTAMP => 'Fri, 01 May 2015 16:20:00 -0700'::timestamp)", - "SELECT * FROM my_table AT (TIMESTAMP => CAST('Fri, 01 May 2015 16:20:00 -0700' AS TIMESTAMPNTZ))", + "SELECT * FROM my_table AT (TIMESTAMP => CAST('Fri, 01 May 2015 16:20:00 -0700' AS TIMESTAMP))", ) self.validate_identity( "SELECT * FROM my_table AT(TIMESTAMP => 'Fri, 01 May 2015 16:20:00 -0700'::timestamp_tz)", @@ -1581,7 +1584,7 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene "REGEXP_REPLACE(subject, pattern, replacement, position, occurrence, parameters)", write={ "bigquery": "REGEXP_REPLACE(subject, pattern, replacement)", - "duckdb": "REGEXP_REPLACE(subject, pattern, replacement)", + "duckdb": "REGEXP_REPLACE(subject, pattern, replacement, parameters)", "hive": "REGEXP_REPLACE(subject, pattern, replacement)", "snowflake": "REGEXP_REPLACE(subject, pattern, replacement, position, occurrence, parameters)", "spark": "REGEXP_REPLACE(subject, pattern, replacement, position)", @@ -1827,3 +1830,17 @@ STORAGE_ALLOWED_LOCATIONS=('s3://mybucket1/path1/', 's3://mybucket2/path2/')""", expression = annotate_types(expression) self.assertEqual(expression.sql(dialect="snowflake"), "SELECT TRY_CAST(FOO() AS TEXT)") + + def test_copy(self): + self.validate_identity( + """COPY INTO mytable (col1, col2) FROM 's3://mybucket/data/files' FILES = ('file1', 'file2') PATTERN = 'pattern' FILE_FORMAT = (FORMAT_NAME = my_csv_format NULL_IF = ('str1', 'str2')) PARSE_HEADER = TRUE""" + ) + self.validate_identity( + """COPY INTO temp FROM @random_stage/path/ FILE_FORMAT = (TYPE = CSV FIELD_DELIMITER = '|' NULL_IF = () FIELD_OPTIONALLY_ENCLOSED_BY = '"' TIMESTAMP_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' DATE_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' BINARY_FORMAT = BASE64) VALIDATION_MODE = 'RETURN_3_ROWS'""" + ) + self.validate_identity( + """COPY INTO load1 FROM @%load1/data1/ FILES = ('test1.csv', 'test2.csv') FORCE = TRUE""" + ) + self.validate_identity( + """COPY INTO mytable FROM 'azure://myaccount.blob.core.windows.net/mycontainer/data/files' CREDENTIALS = (AZURE_SAS_TOKEN = 'token') ENCRYPTION = (TYPE = 'AZURE_CSE' MASTER_KEY = 'kPx...') FILE_FORMAT = (FORMAT_NAME = my_csv_format)""" + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 7534573..069ae42 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -343,7 +343,7 @@ TBLPROPERTIES ( "postgres": "SELECT CAST('2016-08-31' AS TIMESTAMP) AT TIME ZONE 'Asia/Seoul' AT TIME ZONE 'UTC'", "presto": "SELECT WITH_TIMEZONE(CAST('2016-08-31' AS TIMESTAMP), 'Asia/Seoul') AT TIME ZONE 'UTC'", "redshift": "SELECT CAST('2016-08-31' AS TIMESTAMP) AT TIME ZONE 'Asia/Seoul' AT TIME ZONE 'UTC'", - "snowflake": "SELECT CONVERT_TIMEZONE('Asia/Seoul', 'UTC', CAST('2016-08-31' AS TIMESTAMPNTZ))", + "snowflake": "SELECT CONVERT_TIMEZONE('Asia/Seoul', 'UTC', CAST('2016-08-31' AS TIMESTAMP))", "spark": "SELECT TO_UTC_TIMESTAMP(CAST('2016-08-31' AS TIMESTAMP), 'Asia/Seoul')", }, ) @@ -523,7 +523,14 @@ TBLPROPERTIES ( }, ) - for data_type in ("BOOLEAN", "DATE", "DOUBLE", "FLOAT", "INT", "TIMESTAMP"): + for data_type in ( + "BOOLEAN", + "DATE", + "DOUBLE", + "FLOAT", + "INT", + "TIMESTAMP", + ): self.validate_all( f"{data_type}(x)", write={ @@ -531,6 +538,16 @@ TBLPROPERTIES ( "spark": f"CAST(x AS {data_type})", }, ) + + for ts_suffix in ("NTZ", "LTZ"): + self.validate_all( + f"TIMESTAMP_{ts_suffix}(x)", + write={ + "": f"CAST(x AS TIMESTAMP{ts_suffix})", + "spark": f"CAST(x AS TIMESTAMP_{ts_suffix})", + }, + ) + self.validate_all( "STRING(x)", write={ diff --git a/tests/dialects/test_trino.py b/tests/dialects/test_trino.py new file mode 100644 index 0000000..ccc1407 --- /dev/null +++ b/tests/dialects/test_trino.py @@ -0,0 +1,18 @@ +from tests.dialects.test_dialect import Validator + + +class TestTrino(Validator): + dialect = "trino" + + def test_trim(self): + self.validate_identity("SELECT TRIM('!' FROM '!foo!')") + self.validate_identity("SELECT TRIM(BOTH '$' FROM '$var$')") + self.validate_identity("SELECT TRIM(TRAILING 'ER' FROM UPPER('worker'))") + self.validate_identity( + "SELECT TRIM(LEADING FROM ' abcd')", + "SELECT LTRIM(' abcd')", + ) + self.validate_identity( + "SELECT TRIM('!foo!', '!')", + "SELECT TRIM('!' FROM '!foo!')", + ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 4a475f6..1538d47 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -29,6 +29,9 @@ class TestTSQL(Validator): self.validate_identity("1 AND true", "1 <> 0 AND (1 = 1)") self.validate_identity("CAST(x AS int) OR y", "CAST(x AS INTEGER) <> 0 OR y <> 0") self.validate_identity("TRUNCATE TABLE t1 WITH (PARTITIONS(1, 2 TO 5, 10 TO 20, 84))") + self.validate_identity( + "COPY INTO test_1 FROM 'path' WITH (FILE_TYPE = 'CSV', CREDENTIAL = (IDENTITY = 'Shared Access Signature', SECRET = 'token'), FIELDTERMINATOR = ';', ROWTERMINATOR = '0X0A', ENCODING = 'UTF8', DATEFORMAT = 'ymd', MAXERRORS = 10, ERRORFILE = 'errorsfolder', IDENTITY_INSERT = 'ON')" + ) self.validate_all( "SELECT IIF(cond <> 0, 'True', 'False')", @@ -777,6 +780,14 @@ class TestTSQL(Validator): "CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < CURRENT_TIMESTAMP - 7 END", "CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < GETDATE() - 7 END", ) + + self.validate_all( + "CREATE TABLE [#temptest] (name VARCHAR)", + read={ + "duckdb": "CREATE TEMPORARY TABLE 'temptest' (name VARCHAR)", + "tsql": "CREATE TABLE [#temptest] (name VARCHAR)", + }, + ) self.validate_all( "CREATE TABLE tbl (id INTEGER IDENTITY PRIMARY KEY)", read={ diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index d51a978..6b742c3 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -642,6 +642,7 @@ CREATE TABLE T3 AS (SELECT DISTINCT A FROM T1 EXCEPT (SELECT A FROM T2) LIMIT 1) DESCRIBE x DESCRIBE EXTENDED a.b DESCRIBE FORMATTED a.b +DESCRIBE SELECT 1 DROP INDEX a.b.c DROP FUNCTION a.b.c (INT) DROP MATERIALIZED VIEW x.y.z @@ -867,4 +868,5 @@ SELECT only TRUNCATE(a, b) SELECT enum SELECT unlogged -SELECT name \ No newline at end of file +SELECT name +SELECT copy \ No newline at end of file diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index b020a27..6342cfc 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -523,6 +523,10 @@ SELECT t.c1 AS c1, t.c3 AS c3 FROM FOO(bar) AS t(c1, c2, c3); SELECT c.f::VARCHAR(MAX) AS f, e AS e FROM a.b AS c, c.d AS e; SELECT CAST(c.f AS VARCHAR(MAX)) AS f, e AS e FROM a.b AS c, c.d AS e; +# dialect: bigquery +WITH cte AS (SELECT 1 AS col) SELECT * FROM cte LEFT JOIN UNNEST((SELECT ARRAY_AGG(DISTINCT x) AS agg FROM UNNEST([1]) AS x WHERE col = 1)); +WITH cte AS (SELECT 1 AS col) SELECT * FROM cte AS cte LEFT JOIN UNNEST((SELECT ARRAY_AGG(DISTINCT x) AS agg FROM UNNEST([1]) AS x WHERE cte.col = 1)); + -------------------------------------- -- Window functions -------------------------------------- diff --git a/tests/fixtures/optimizer/qualify_columns_ddl.sql b/tests/fixtures/optimizer/qualify_columns_ddl.sql index 9b4bb34..75d84ca 100644 --- a/tests/fixtures/optimizer/qualify_columns_ddl.sql +++ b/tests/fixtures/optimizer/qualify_columns_ddl.sql @@ -1,26 +1,26 @@ # title: Create with CTE WITH cte AS (SELECT b FROM y) CREATE TABLE s AS SELECT * FROM cte; -CREATE TABLE s AS WITH cte AS (SELECT y.b AS b FROM y AS y) SELECT cte.b AS b FROM cte AS cte; +WITH cte AS (SELECT y.b AS b FROM y AS y) CREATE TABLE s AS SELECT cte.b AS b FROM cte AS cte; # title: Create with CTE, query also has CTE WITH cte1 AS (SELECT b FROM y) CREATE TABLE s AS WITH cte2 AS (SELECT b FROM cte1) SELECT * FROM cte2; -CREATE TABLE s AS WITH cte1 AS (SELECT y.b AS b FROM y AS y), cte2 AS (SELECT cte1.b AS b FROM cte1 AS cte1) SELECT cte2.b AS b FROM cte2 AS cte2; +WITH cte1 AS (SELECT y.b AS b FROM y AS y) CREATE TABLE s AS WITH cte2 AS (SELECT cte1.b AS b FROM cte1 AS cte1) SELECT cte2.b AS b FROM cte2 AS cte2; # title: Create without CTE CREATE TABLE foo AS SELECT a FROM tbl; CREATE TABLE foo AS SELECT tbl.a AS a FROM tbl AS tbl; # title: Create with complex CTE with derived table -WITH cte AS (SELECT a FROM (SELECT a from x)) CREATE TABLE s AS SELECT * FROM cte; -CREATE TABLE s AS WITH cte AS (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) SELECT cte.a AS a FROM cte AS cte; +WITH cte AS (SELECT a FROM (SELECT a FROM x)) CREATE TABLE s AS SELECT * FROM cte; +WITH cte AS (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) CREATE TABLE s AS SELECT cte.a AS a FROM cte AS cte; # title: Create wtih multiple CTEs WITH cte1 AS (SELECT b FROM y), cte2 AS (SELECT b FROM cte1) CREATE TABLE s AS SELECT * FROM cte2; -CREATE TABLE s AS WITH cte1 AS (SELECT y.b AS b FROM y AS y), cte2 AS (SELECT cte1.b AS b FROM cte1 AS cte1) SELECT cte2.b AS b FROM cte2 AS cte2; +WITH cte1 AS (SELECT y.b AS b FROM y AS y), cte2 AS (SELECT cte1.b AS b FROM cte1 AS cte1) CREATE TABLE s AS SELECT cte2.b AS b FROM cte2 AS cte2; # title: Create with multiple CTEs, selecting only from the first CTE (unnecessary code) WITH cte1 AS (SELECT b FROM y), cte2 AS (SELECT b FROM cte1) CREATE TABLE s AS SELECT * FROM cte1; -CREATE TABLE s AS WITH cte1 AS (SELECT y.b AS b FROM y AS y), cte2 AS (SELECT cte1.b AS b FROM cte1 AS cte1) SELECT cte1.b AS b FROM cte1 AS cte1; +WITH cte1 AS (SELECT y.b AS b FROM y AS y), cte2 AS (SELECT cte1.b AS b FROM cte1 AS cte1) CREATE TABLE s AS SELECT cte1.b AS b FROM cte1 AS cte1; # title: Create with multiple derived tables CREATE TABLE s AS SELECT * FROM (SELECT b FROM (SELECT b FROM y)); @@ -28,7 +28,7 @@ CREATE TABLE s AS SELECT _q_1.b AS b FROM (SELECT _q_0.b AS b FROM (SELECT y.b A # title: Create with a CTE and a derived table WITH cte AS (SELECT b FROM y) CREATE TABLE s AS SELECT * FROM (SELECT b FROM (SELECT b FROM cte)); -CREATE TABLE s AS WITH cte AS (SELECT y.b AS b FROM y AS y) SELECT _q_1.b AS b FROM (SELECT _q_0.b AS b FROM (SELECT cte.b AS b FROM cte AS cte) AS _q_0) AS _q_1; +WITH cte AS (SELECT y.b AS b FROM y AS y) CREATE TABLE s AS SELECT _q_1.b AS b FROM (SELECT _q_0.b AS b FROM (SELECT cte.b AS b FROM cte AS cte) AS _q_0) AS _q_1; # title: Insert with CTE # dialect: spark diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql index 104400e..30bf834 100644 --- a/tests/fixtures/optimizer/qualify_tables.sql +++ b/tests/fixtures/optimizer/qualify_tables.sql @@ -158,6 +158,10 @@ ALTER TABLE c.db.t ADD PRIMARY KEY (id) NOT ENFORCED; CREATE TABLE t1 AS (WITH cte AS (SELECT x FROM t2) SELECT * FROM cte); CREATE TABLE c.db.t1 AS (WITH cte AS (SELECT x FROM c.db.t2 AS t2) SELECT * FROM cte AS cte); +# title: delete statement +DELETE FROM t1 WHERE NOT c IN (SELECT c FROM t2); +DELETE FROM c.db.t1 WHERE NOT c IN (SELECT c FROM c.db.t2 AS t2); + # title: insert statement with cte # dialect: spark WITH cte AS (SELECT b FROM y) INSERT INTO s SELECT * FROM cte; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 6af51bf..75abc38 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -109,6 +109,10 @@ a AND b; (x is not null) != (y is null); (NOT x IS NULL) <> (y IS NULL); +# dialect: mysql +A XOR A; +FALSE; + -------------------------------------- -- Absorption -------------------------------------- @@ -232,6 +236,13 @@ x - 1; A AND D AND B AND E AND F AND G AND E AND A; A AND B AND D AND E AND F AND G; +A OR D OR B OR E OR F OR G OR E OR A; +A OR B OR D OR E OR F OR G; + +# dialect: mysql +A XOR D XOR B XOR E XOR F XOR G XOR C; +A XOR B XOR C XOR D XOR E XOR F XOR G; + A AND NOT B AND C AND B; FALSE; diff --git a/tests/test_build.py b/tests/test_build.py index ad0bb9a..da1677f 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -545,6 +545,10 @@ class TestBuild(unittest.TestCase): lambda: exp.update("tbl", {"x": 1}, from_="tbl2"), "UPDATE tbl SET x = 1 FROM tbl2", ), + ( + lambda: exp.update("tbl", {"x": 1}, from_="tbl2 cross join tbl3"), + "UPDATE tbl SET x = 1 FROM tbl2 CROSS JOIN tbl3", + ), ( lambda: union("SELECT * FROM foo", "SELECT * FROM bla"), "SELECT * FROM foo UNION SELECT * FROM bla", diff --git a/tests/test_lineage.py b/tests/test_lineage.py index c782d9a..3e17f95 100644 --- a/tests/test_lineage.py +++ b/tests/test_lineage.py @@ -224,16 +224,50 @@ class TestLineage(unittest.TestCase): downstream.source.sql(dialect="snowflake"), "LATERAL FLATTEN(INPUT => TEST_TABLE.RESULT, OUTER => TRUE) AS FLATTENED(SEQ, KEY, PATH, INDEX, VALUE, THIS)", ) - self.assertEqual( - downstream.expression.sql(dialect="snowflake"), - "VALUE", - ) + self.assertEqual(downstream.expression.sql(dialect="snowflake"), "VALUE") self.assertEqual(len(downstream.downstream), 1) downstream = downstream.downstream[0] self.assertEqual(downstream.name, "TEST_TABLE.RESULT") self.assertEqual(downstream.source.sql(dialect="snowflake"), "TEST_TABLE AS TEST_TABLE") + node = lineage( + "FIELD", + "SELECT FLATTENED.VALUE:field::text AS FIELD FROM SNOWFLAKE.SCHEMA.MODEL AS MODEL_ALIAS, LATERAL FLATTEN(INPUT => MODEL_ALIAS.A) AS FLATTENED", + schema={"SNOWFLAKE": {"SCHEMA": {"TABLE": {"A": "integer"}}}}, + sources={"SNOWFLAKE.SCHEMA.MODEL": "SELECT A FROM SNOWFLAKE.SCHEMA.TABLE"}, + dialect="snowflake", + ) + self.assertEqual(node.name, "FIELD") + + downstream = node.downstream[0] + self.assertEqual(downstream.name, "FLATTENED.VALUE") + self.assertEqual( + downstream.source.sql(dialect="snowflake"), + "LATERAL FLATTEN(INPUT => MODEL_ALIAS.A) AS FLATTENED(SEQ, KEY, PATH, INDEX, VALUE, THIS)", + ) + self.assertEqual(downstream.expression.sql(dialect="snowflake"), "VALUE") + self.assertEqual(len(downstream.downstream), 1) + + downstream = downstream.downstream[0] + self.assertEqual(downstream.name, "MODEL_ALIAS.A") + self.assertEqual(downstream.source_name, "SNOWFLAKE.SCHEMA.MODEL") + self.assertEqual( + downstream.source.sql(dialect="snowflake"), + "SELECT TABLE.A AS A FROM SNOWFLAKE.SCHEMA.TABLE AS TABLE", + ) + self.assertEqual(downstream.expression.sql(dialect="snowflake"), "TABLE.A AS A") + self.assertEqual(len(downstream.downstream), 1) + + downstream = downstream.downstream[0] + self.assertEqual(downstream.name, "TABLE.A") + self.assertEqual( + downstream.source.sql(dialect="snowflake"), "SNOWFLAKE.SCHEMA.TABLE AS TABLE" + ) + self.assertEqual( + downstream.expression.sql(dialect="snowflake"), "SNOWFLAKE.SCHEMA.TABLE AS TABLE" + ) + def test_subquery(self) -> None: node = lineage( "output", @@ -266,6 +300,7 @@ class TestLineage(unittest.TestCase): self.assertEqual(node.name, "a") node = node.downstream[0] self.assertEqual(node.name, "cte.a") + self.assertEqual(node.reference_node_name, "cte") node = node.downstream[0] self.assertEqual(node.name, "z.a") @@ -304,6 +339,27 @@ class TestLineage(unittest.TestCase): node = a.downstream[0] self.assertEqual(node.name, "foo.a") + # Select from derived table + node = lineage( + "a", + "SELECT a FROM (SELECT a FROM x) subquery", + ) + self.assertEqual(node.name, "a") + self.assertEqual(len(node.downstream), 1) + node = node.downstream[0] + self.assertEqual(node.name, "subquery.a") + self.assertEqual(node.reference_node_name, "subquery") + + node = lineage( + "a", + "SELECT a FROM (SELECT a FROM x)", + ) + self.assertEqual(node.name, "a") + self.assertEqual(len(node.downstream), 1) + node = node.downstream[0] + self.assertEqual(node.name, "_q_0.a") + self.assertEqual(node.reference_node_name, "_q_0") + def test_lineage_cte_union(self) -> None: query = """ WITH dataset AS ( -- cgit v1.2.3