summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/dialects/test_databricks.py6
-rw-r--r--tests/dialects/test_dialect.py2
-rw-r--r--tests/dialects/test_duckdb.py10
-rw-r--r--tests/dialects/test_hive.py4
-rw-r--r--tests/dialects/test_mysql.py118
-rw-r--r--tests/dialects/test_postgres.py39
-rw-r--r--tests/dialects/test_presto.py12
-rw-r--r--tests/dialects/test_prql.py13
-rw-r--r--tests/dialects/test_redshift.py10
-rw-r--r--tests/dialects/test_snowflake.py52
-rw-r--r--tests/dialects/test_spark.py21
-rw-r--r--tests/dialects/test_trino.py18
-rw-r--r--tests/dialects/test_tsql.py11
-rw-r--r--tests/fixtures/identity.sql4
-rw-r--r--tests/fixtures/optimizer/qualify_columns.sql4
-rw-r--r--tests/fixtures/optimizer/qualify_columns_ddl.sql14
-rw-r--r--tests/fixtures/optimizer/qualify_tables.sql4
-rw-r--r--tests/fixtures/optimizer/simplify.sql11
-rw-r--r--tests/test_build.py4
-rw-r--r--tests/test_lineage.py64
20 files changed, 342 insertions, 79 deletions
diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py
index c15cf09..14a6bf3 100644
--- a/tests/dialects/test_databricks.py
+++ b/tests/dialects/test_databricks.py
@@ -26,6 +26,9 @@ class TestDatabricks(Validator):
self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1")
self.validate_identity("CREATE TABLE foo (x DATE GENERATED ALWAYS AS (CAST(y AS DATE)))")
self.validate_identity(
+ "CREATE TABLE IF NOT EXISTS db.table (a TIMESTAMP, b BOOLEAN GENERATED ALWAYS AS (NOT a IS NULL)) USING DELTA"
+ )
+ self.validate_identity(
"SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t"
)
self.validate_identity(
@@ -47,6 +50,9 @@ class TestDatabricks(Validator):
self.validate_identity(
"TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', city LIKE 'LA')"
)
+ self.validate_identity(
+ "COPY INTO target FROM `s3://link` FILEFORMAT = AVRO VALIDATE = ALL FILES = ('file1', 'file2') FORMAT_OPTIONS(opt1 = TRUE, opt2 = 'test') COPY_OPTIONS(opt3 = 5)"
+ )
self.validate_all(
"CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))",
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index ea38521..dda0eb2 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -293,7 +293,7 @@ class TestDialect(Validator):
"bigquery": "CAST(a AS INT64)",
"drill": "CAST(a AS INTEGER)",
"duckdb": "CAST(a AS SMALLINT)",
- "mysql": "CAST(a AS SMALLINT)",
+ "mysql": "CAST(a AS SIGNED)",
"hive": "CAST(a AS SMALLINT)",
"oracle": "CAST(a AS NUMBER)",
"postgres": "CAST(a AS SMALLINT)",
diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py
index 9105a49..2d0af13 100644
--- a/tests/dialects/test_duckdb.py
+++ b/tests/dialects/test_duckdb.py
@@ -324,6 +324,8 @@ class TestDuckDB(Validator):
self.validate_identity(
"SELECT * FROM (PIVOT Cities ON Year USING SUM(Population) GROUP BY Country) AS pivot_alias"
)
+ self.validate_identity("DATE_SUB('YEAR', col, '2020-01-01')").assert_is(exp.Anonymous)
+ self.validate_identity("DATESUB('YEAR', col, '2020-01-01')").assert_is(exp.Anonymous)
self.validate_all("0b1010", write={"": "0 AS b1010"})
self.validate_all("0x1010", write={"": "0 AS x1010"})
@@ -724,6 +726,14 @@ class TestDuckDB(Validator):
"""SELECT i FROM GENERATE_SERIES(0, 12) AS _(i) ORDER BY i ASC""",
)
+ self.validate_identity(
+ "COPY lineitem FROM 'lineitem.ndjson' WITH (FORMAT JSON, DELIMITER ',', AUTO_DETECT TRUE, COMPRESSION SNAPPY, CODEC ZSTD, FORCE_NOT_NULL(col1, col2))"
+ )
+ self.validate_identity(
+ "COPY (SELECT 42 AS a, 'hello' AS b) TO 'query.json' WITH (FORMAT JSON, ARRAY TRUE)"
+ )
+ self.validate_identity("COPY lineitem (l_orderkey) TO 'orderkey.tbl' WITH (DELIMITER '|')")
+
def test_array_index(self):
with self.assertLogs(helper_logger) as cm:
self.validate_all(
diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py
index 9215f05..dfce446 100644
--- a/tests/dialects/test_hive.py
+++ b/tests/dialects/test_hive.py
@@ -334,7 +334,7 @@ class TestHive(Validator):
"hive": "DATE_ADD('2020-01-01', 1)",
"presto": "DATE_ADD('DAY', 1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))",
"redshift": "DATEADD(DAY, 1, '2020-01-01')",
- "snowflake": "DATEADD(DAY, 1, CAST(CAST('2020-01-01' AS TIMESTAMPNTZ) AS DATE))",
+ "snowflake": "DATEADD(DAY, 1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))",
"spark": "DATE_ADD('2020-01-01', 1)",
"tsql": "DATEADD(DAY, 1, CAST(CAST('2020-01-01' AS DATETIME2) AS DATE))",
},
@@ -348,7 +348,7 @@ class TestHive(Validator):
"hive": "DATE_ADD('2020-01-01', 1 * -1)",
"presto": "DATE_ADD('DAY', 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))",
"redshift": "DATEADD(DAY, 1 * -1, '2020-01-01')",
- "snowflake": "DATEADD(DAY, 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMPNTZ) AS DATE))",
+ "snowflake": "DATEADD(DAY, 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))",
"spark": "DATE_ADD('2020-01-01', 1 * -1)",
"tsql": "DATEADD(DAY, 1 * -1, CAST(CAST('2020-01-01' AS DATETIME2) AS DATE))",
},
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index e8af5c6..53e2dab 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -1,4 +1,5 @@
from sqlglot import expressions as exp
+from sqlglot.dialects.mysql import MySQL
from tests.dialects.test_dialect import Validator
@@ -6,21 +7,11 @@ class TestMySQL(Validator):
dialect = "mysql"
def test_ddl(self):
- int_types = {"BIGINT", "INT", "MEDIUMINT", "SMALLINT", "TINYINT"}
-
- for t in int_types:
+ for t in ("BIGINT", "INT", "MEDIUMINT", "SMALLINT", "TINYINT"):
self.validate_identity(f"CREATE TABLE t (id {t} UNSIGNED)")
self.validate_identity(f"CREATE TABLE t (id {t}(10) UNSIGNED)")
self.validate_identity("CREATE TABLE t (id DECIMAL(20, 4) UNSIGNED)")
-
- self.validate_all(
- "CREATE TABLE t (id INT UNSIGNED)",
- write={
- "duckdb": "CREATE TABLE t (id UINTEGER)",
- },
- )
-
self.validate_identity("CREATE TABLE foo (a BIGINT, UNIQUE (b) USING BTREE)")
self.validate_identity("CREATE TABLE foo (id BIGINT)")
self.validate_identity("CREATE TABLE 00f (1d BIGINT)")
@@ -98,6 +89,13 @@ class TestMySQL(Validator):
)
self.validate_all(
+ "CREATE TABLE t (id INT UNSIGNED)",
+ write={
+ "duckdb": "CREATE TABLE t (id UINTEGER)",
+ "mysql": "CREATE TABLE t (id INT UNSIGNED)",
+ },
+ )
+ self.validate_all(
"CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'",
write={
"duckdb": "CREATE TABLE z (a INT)",
@@ -109,15 +107,10 @@ class TestMySQL(Validator):
self.validate_all(
"CREATE TABLE x (id int not null auto_increment, primary key (id))",
write={
+ "mysql": "CREATE TABLE x (id INT NOT NULL AUTO_INCREMENT, PRIMARY KEY (id))",
"sqlite": "CREATE TABLE x (id INTEGER NOT NULL AUTOINCREMENT PRIMARY KEY)",
},
)
- self.validate_all(
- "CREATE TABLE x (id int not null auto_increment)",
- write={
- "sqlite": "CREATE TABLE x (id INTEGER NOT NULL)",
- },
- )
def test_identity(self):
self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1")
@@ -135,8 +128,6 @@ class TestMySQL(Validator):
self.validate_identity("SELECT CAST('[4,5]' AS JSON) MEMBER OF('[[3,4],[4,5]]')")
self.validate_identity("""SELECT 'ab' MEMBER OF('[23, "abc", 17, "ab", 10]')""")
self.validate_identity("""SELECT * FROM foo WHERE 'ab' MEMBER OF(content)""")
- self.validate_identity("CAST(x AS ENUM('a', 'b'))")
- self.validate_identity("CAST(x AS SET('a', 'b'))")
self.validate_identity("SELECT CURRENT_TIMESTAMP(6)")
self.validate_identity("x ->> '$.name'")
self.validate_identity("SELECT CAST(`a`.`b` AS CHAR) FROM foo")
@@ -226,29 +217,47 @@ class TestMySQL(Validator):
self.validate_identity("SELECT * FROM t1 PARTITION(p0)")
def test_types(self):
- self.validate_identity("CAST(x AS MEDIUMINT) + CAST(y AS YEAR(4))")
+ for char_type in MySQL.Generator.CHAR_CAST_MAPPING:
+ with self.subTest(f"MySQL cast into {char_type}"):
+ self.validate_identity(f"CAST(x AS {char_type.value})", "CAST(x AS CHAR)")
+
+ for signed_type in MySQL.Generator.SIGNED_CAST_MAPPING:
+ with self.subTest(f"MySQL cast into {signed_type}"):
+ self.validate_identity(f"CAST(x AS {signed_type.value})", "CAST(x AS SIGNED)")
+
+ self.validate_identity("CAST(x AS ENUM('a', 'b'))")
+ self.validate_identity("CAST(x AS SET('a', 'b'))")
+ self.validate_identity(
+ "CAST(x AS MEDIUMINT) + CAST(y AS YEAR(4))",
+ "CAST(x AS SIGNED) + CAST(y AS YEAR(4))",
+ )
+ self.validate_identity(
+ "CAST(x AS TIMESTAMP)",
+ "CAST(x AS DATETIME)",
+ )
+ self.validate_identity(
+ "CAST(x AS TIMESTAMPTZ)",
+ "TIMESTAMP(x)",
+ )
+ self.validate_identity(
+ "CAST(x AS TIMESTAMPLTZ)",
+ "TIMESTAMP(x)",
+ )
self.validate_all(
"CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT) + CAST(z AS TINYTEXT)",
- read={
- "mysql": "CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT) + CAST(z AS TINYTEXT)",
- },
write={
+ "mysql": "CAST(x AS CHAR) + CAST(y AS CHAR) + CAST(z AS CHAR)",
"spark": "CAST(x AS TEXT) + CAST(y AS TEXT) + CAST(z AS TEXT)",
},
)
self.validate_all(
"CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB) + CAST(z AS TINYBLOB)",
- read={
- "mysql": "CAST(x AS MEDIUMBLOB) + CAST(y AS LONGBLOB) + CAST(z AS TINYBLOB)",
- },
write={
+ "mysql": "CAST(x AS CHAR) + CAST(y AS CHAR) + CAST(z AS CHAR)",
"spark": "CAST(x AS BLOB) + CAST(y AS BLOB) + CAST(z AS BLOB)",
},
)
- self.validate_all("CAST(x AS TIMESTAMP)", write={"mysql": "CAST(x AS DATETIME)"})
- self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"mysql": "TIMESTAMP(x)"})
- self.validate_all("CAST(x AS TIMESTAMPLTZ)", write={"mysql": "TIMESTAMP(x)"})
def test_canonical_functions(self):
self.validate_identity("SELECT LEFT('str', 2)", "SELECT LEFT('str', 2)")
@@ -457,63 +466,63 @@ class TestMySQL(Validator):
"SELECT DATE_FORMAT('2017-06-15', '%Y')",
write={
"mysql": "SELECT DATE_FORMAT('2017-06-15', '%Y')",
- "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'yyyy')",
+ "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'yyyy')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2017-06-15', '%m')",
write={
"mysql": "SELECT DATE_FORMAT('2017-06-15', '%m')",
- "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'mm')",
+ "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'mm')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2017-06-15', '%d')",
write={
"mysql": "SELECT DATE_FORMAT('2017-06-15', '%d')",
- "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'DD')",
+ "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'DD')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2017-06-15', '%Y-%m-%d')",
write={
"mysql": "SELECT DATE_FORMAT('2017-06-15', '%Y-%m-%d')",
- "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'yyyy-mm-DD')",
+ "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'yyyy-mm-DD')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2017-06-15 22:23:34', '%H')",
write={
"mysql": "SELECT DATE_FORMAT('2017-06-15 22:23:34', '%H')",
- "snowflake": "SELECT TO_CHAR(CAST('2017-06-15 22:23:34' AS TIMESTAMPNTZ), 'hh24')",
+ "snowflake": "SELECT TO_CHAR(CAST('2017-06-15 22:23:34' AS TIMESTAMP), 'hh24')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2017-06-15', '%w')",
write={
"mysql": "SELECT DATE_FORMAT('2017-06-15', '%w')",
- "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'dy')",
+ "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMP), 'dy')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')",
write={
"mysql": "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')",
- "snowflake": "SELECT TO_CHAR(CAST('2009-10-04 22:23:00' AS TIMESTAMPNTZ), 'DY mmmm yyyy')",
+ "snowflake": "SELECT TO_CHAR(CAST('2009-10-04 22:23:00' AS TIMESTAMP), 'DY mmmm yyyy')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('2007-10-04 22:23:00', '%H:%i:%s')",
write={
"mysql": "SELECT DATE_FORMAT('2007-10-04 22:23:00', '%T')",
- "snowflake": "SELECT TO_CHAR(CAST('2007-10-04 22:23:00' AS TIMESTAMPNTZ), 'hh24:mi:ss')",
+ "snowflake": "SELECT TO_CHAR(CAST('2007-10-04 22:23:00' AS TIMESTAMP), 'hh24:mi:ss')",
},
)
self.validate_all(
"SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %a %d %m %b')",
write={
"mysql": "SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %W %d %m %b')",
- "snowflake": "SELECT TO_CHAR(CAST('1900-10-04 22:23:00' AS TIMESTAMPNTZ), 'DD yy DY DD mm mon')",
+ "snowflake": "SELECT TO_CHAR(CAST('1900-10-04 22:23:00' AS TIMESTAMP), 'DD yy DY DD mm mon')",
},
)
@@ -599,6 +608,19 @@ class TestMySQL(Validator):
def test_mysql(self):
self.validate_all(
+ "SELECT department, GROUP_CONCAT(name) AS employee_names FROM data GROUP BY department",
+ read={
+ "postgres": "SELECT department, array_agg(name) AS employee_names FROM data GROUP BY department",
+ },
+ )
+ self.validate_all(
+ "SELECT UNIX_TIMESTAMP(CAST('2024-04-29 12:00:00' AS DATETIME))",
+ read={
+ "mysql": "SELECT UNIX_TIMESTAMP(CAST('2024-04-29 12:00:00' AS DATETIME))",
+ "postgres": "SELECT EXTRACT(epoch FROM TIMESTAMP '2024-04-29 12:00:00')",
+ },
+ )
+ self.validate_all(
"SELECT JSON_EXTRACT('[10, 20, [30, 40]]', '$[1]')",
read={
"sqlite": "SELECT JSON_EXTRACT('[10, 20, [30, 40]]', '$[1]')",
@@ -1109,3 +1131,23 @@ COMMENT='客户账户表'"""
"tsql": "CAST(a AS FLOAT) / NULLIF(b, 0)",
},
)
+
+ def test_timestamp_trunc(self):
+ for dialect in ("postgres", "snowflake", "duckdb", "spark", "databricks"):
+ for unit in (
+ "MILLISECOND",
+ "SECOND",
+ "DAY",
+ "MONTH",
+ "YEAR",
+ ):
+ with self.subTest(f"MySQL -> {dialect} Timestamp Trunc with unit {unit}: "):
+ self.validate_all(
+ f"DATE_ADD('0000-01-01 00:00:00', INTERVAL (TIMESTAMPDIFF({unit}, '0000-01-01 00:00:00', CAST('2001-02-16 20:38:40' AS DATETIME))) {unit})",
+ read={
+ dialect: f"DATE_TRUNC({unit}, TIMESTAMP '2001-02-16 20:38:40')",
+ },
+ write={
+ "mysql": f"DATE_ADD('0000-01-01 00:00:00', INTERVAL (TIMESTAMPDIFF({unit}, '0000-01-01 00:00:00', CAST('2001-02-16 20:38:40' AS DATETIME))) {unit})",
+ },
+ )
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 5a55a7d..6b6117e 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -312,9 +312,33 @@ class TestPostgres(Validator):
"MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)",
"MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)",
)
- self.validate_identity("SELECT * FROM t1*", "SELECT * FROM t1")
+ self.validate_identity(
+ "SELECT * FROM t1*",
+ "SELECT * FROM t1",
+ )
+ self.validate_identity(
+ "SELECT SUBSTRING('afafa' for 1)",
+ "SELECT SUBSTRING('afafa' FROM 1 FOR 1)",
+ )
+ self.validate_identity(
+ "CAST(x AS INT8)",
+ "CAST(x AS BIGINT)",
+ )
self.validate_all(
+ "SELECT REGEXP_REPLACE('mr .', '[^a-zA-Z]', '', 'g')",
+ write={
+ "duckdb": "SELECT REGEXP_REPLACE('mr .', '[^a-zA-Z]', '', 'g')",
+ "postgres": "SELECT REGEXP_REPLACE('mr .', '[^a-zA-Z]', '', 'g')",
+ },
+ )
+ self.validate_all(
+ "CREATE TABLE t (c INT)",
+ read={
+ "mysql": "CREATE TABLE t (c INT COMMENT 'comment 1') COMMENT = 'comment 2'",
+ },
+ )
+ self.validate_all(
'SELECT * FROM "test_table" ORDER BY RANDOM() LIMIT 5',
write={
"bigquery": "SELECT * FROM `test_table` ORDER BY RAND() NULLS LAST LIMIT 5",
@@ -449,7 +473,7 @@ class TestPostgres(Validator):
write={
"postgres": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))",
"redshift": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))",
- "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))",
+ "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))",
},
)
self.validate_all(
@@ -660,6 +684,16 @@ class TestPostgres(Validator):
)
self.assertIsInstance(self.parse_one("id::UUID"), exp.Cast)
+ self.validate_identity(
+ "COPY tbl (col1, col2) FROM 'file' WITH (FORMAT format, HEADER MATCH, FREEZE TRUE)"
+ )
+ self.validate_identity(
+ "COPY tbl (col1, col2) TO 'file' WITH (FORMAT format, HEADER MATCH, FREEZE TRUE)"
+ )
+ self.validate_identity(
+ "COPY (SELECT * FROM t) TO 'file' WITH (FORMAT format, HEADER MATCH, FREEZE TRUE)"
+ )
+
def test_ddl(self):
# Checks that user-defined types are parsed into DataType instead of Identifier
self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is(
@@ -676,6 +710,7 @@ class TestPostgres(Validator):
cdef.args["kind"].assert_is(exp.DataType)
self.assertEqual(expr.sql(dialect="postgres"), "CREATE TABLE t (x INTERVAL DAY)")
+ self.validate_identity("CREATE INDEX IF NOT EXISTS ON t(c)")
self.validate_identity("CREATE INDEX et_vid_idx ON et(vid) INCLUDE (fid)")
self.validate_identity("CREATE INDEX idx_x ON x USING BTREE(x, y) WHERE (NOT y IS NULL)")
self.validate_identity("CREATE TABLE test (elems JSONB[])")
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index 4bafc08..108e916 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -10,6 +10,8 @@ class TestPresto(Validator):
self.validate_identity("SELECT * FROM x qualify", "SELECT * FROM x AS qualify")
self.validate_identity("CAST(x AS IPADDRESS)")
self.validate_identity("CAST(x AS IPPREFIX)")
+ self.validate_identity("CAST(TDIGEST_AGG(1) AS TDIGEST)")
+ self.validate_identity("CAST(x AS HYPERLOGLOG)")
self.validate_all(
"CAST(x AS INTERVAL YEAR TO MONTH)",
@@ -1059,6 +1061,15 @@ class TestPresto(Validator):
)
def test_json(self):
+ with self.assertLogs(helper_logger):
+ self.validate_all(
+ """SELECT JSON_EXTRACT_SCALAR(TRY(FILTER(CAST(JSON_EXTRACT('{"k1": [{"k2": "{\\"k3\\": 1}", "k4": "v"}]}', '$.k1') AS ARRAY(MAP(VARCHAR, VARCHAR))), x -> x['k4'] = 'v')[1]['k2']), '$.k3')""",
+ write={
+ "presto": """SELECT JSON_EXTRACT_SCALAR(TRY(FILTER(CAST(JSON_EXTRACT('{"k1": [{"k2": "{\\"k3\\": 1}", "k4": "v"}]}', '$.k1') AS ARRAY(MAP(VARCHAR, VARCHAR))), x -> x['k4'] = 'v')[1]['k2']), '$.k3')""",
+ "spark": """SELECT GET_JSON_OBJECT(FILTER(FROM_JSON(GET_JSON_OBJECT('{"k1": [{"k2": "{\\\\"k3\\\\": 1}", "k4": "v"}]}', '$.k1'), 'ARRAY<MAP<STRING, STRING>>'), x -> x['k4'] = 'v')[0]['k2'], '$.k3')""",
+ },
+ )
+
self.validate_all(
"SELECT CAST(JSON '[1,23,456]' AS ARRAY(INTEGER))",
write={
@@ -1073,7 +1084,6 @@ class TestPresto(Validator):
"presto": 'SELECT CAST(JSON_PARSE(\'{"k1":1,"k2":23,"k3":456}\') AS MAP(VARCHAR, INTEGER))',
},
)
-
self.validate_all(
"SELECT CAST(ARRAY [1, 23, 456] AS JSON)",
write={
diff --git a/tests/dialects/test_prql.py b/tests/dialects/test_prql.py
index 1a0eec2..5b438f1 100644
--- a/tests/dialects/test_prql.py
+++ b/tests/dialects/test_prql.py
@@ -66,3 +66,16 @@ class TestPRQL(Validator):
"from x filter (a > 1 || null != b || c != null)",
"SELECT * FROM x WHERE (a > 1 OR NOT b IS NULL OR NOT c IS NULL)",
)
+ self.validate_identity("from a aggregate { average x }", "SELECT AVG(x) FROM a")
+ self.validate_identity(
+ "from a aggregate { average x, min y, ct = sum z }",
+ "SELECT AVG(x), MIN(y), COALESCE(SUM(z), 0) AS ct FROM a",
+ )
+ self.validate_identity(
+ "from a aggregate { average x, min y, sum z }",
+ "SELECT AVG(x), MIN(y), COALESCE(SUM(z), 0) FROM a",
+ )
+ self.validate_identity(
+ "from a aggregate { min y, b = stddev x, max z }",
+ "SELECT MIN(y), STDDEV(x) AS b, MAX(z) FROM a",
+ )
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index a91f4f9..e227ea9 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -162,7 +162,7 @@ class TestRedshift(Validator):
write={
"postgres": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))",
"redshift": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))",
- "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))",
+ "snowflake": "SELECT DATE_PART(minute, CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))",
},
)
self.validate_all(
@@ -271,7 +271,7 @@ class TestRedshift(Validator):
"postgres": "SELECT CAST('2008-02-28' AS TIMESTAMP) + INTERVAL '18 MONTH'",
"presto": "SELECT DATE_ADD('MONTH', 18, CAST('2008-02-28' AS TIMESTAMP))",
"redshift": "SELECT DATEADD(MONTH, 18, '2008-02-28')",
- "snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMPNTZ))",
+ "snowflake": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS TIMESTAMP))",
"tsql": "SELECT DATEADD(MONTH, 18, CAST('2008-02-28' AS DATETIME2))",
},
)
@@ -362,8 +362,10 @@ class TestRedshift(Validator):
"CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO"
)
self.validate_identity(
- "COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'",
- check_command_warning=True,
+ "COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole' REGION 'us-east-1' FORMAT orc",
+ )
+ self.validate_identity(
+ "COPY customer FROM 's3://mybucket/mydata' CREDENTIALS 'aws_iam_role=arn:aws:iam::<aws-account-id>:role/<role-name>;master_symmetric_key=<root-key>' emptyasnull blanksasnull timeformat 'YYYY-MM-DD HH:MI:SS'"
)
self.validate_identity(
"UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'",
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 1cbf68c..dae8355 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -10,6 +10,9 @@ class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
+ self.validate_identity(
+ "MERGE INTO my_db AS ids USING (SELECT new_id FROM my_model WHERE NOT col IS NULL) AS new_ids ON ids.type = new_ids.type AND ids.source = new_ids.source WHEN NOT MATCHED THEN INSERT VALUES (new_ids.new_id)"
+ )
self.validate_identity("ALTER TABLE table1 CLUSTER BY (name DESC)")
self.validate_identity(
"INSERT OVERWRITE TABLE t SELECT 1", "INSERT OVERWRITE INTO t SELECT 1"
@@ -388,7 +391,7 @@ WHERE
"SELECT DATE_PART('year', TIMESTAMP '2020-01-01')",
write={
"hive": "SELECT EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMP))",
- "snowflake": "SELECT DATE_PART('year', CAST('2020-01-01' AS TIMESTAMPNTZ))",
+ "snowflake": "SELECT DATE_PART('year', CAST('2020-01-01' AS TIMESTAMP))",
"spark": "SELECT EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMP))",
},
)
@@ -591,7 +594,7 @@ WHERE
self.validate_all(
"SELECT DAYOFWEEK('2016-01-02T23:39:20.123-07:00'::TIMESTAMP)",
write={
- "snowflake": "SELECT DAYOFWEEK(CAST('2016-01-02T23:39:20.123-07:00' AS TIMESTAMPNTZ))",
+ "snowflake": "SELECT DAYOFWEEK(CAST('2016-01-02T23:39:20.123-07:00' AS TIMESTAMP))",
},
)
self.validate_all(
@@ -689,7 +692,7 @@ WHERE
"SELECT TO_TIMESTAMP('2013-04-05 01:02:03')",
write={
"bigquery": "SELECT CAST('2013-04-05 01:02:03' AS DATETIME)",
- "snowflake": "SELECT CAST('2013-04-05 01:02:03' AS TIMESTAMPNTZ)",
+ "snowflake": "SELECT CAST('2013-04-05 01:02:03' AS TIMESTAMP)",
"spark": "SELECT CAST('2013-04-05 01:02:03' AS TIMESTAMP)",
},
)
@@ -879,10 +882,6 @@ WHERE
self.validate_identity("SELECT * FROM '@external/location' (FILE_FORMAT => 'path.to.csv')")
self.validate_identity("PUT file:///dir/tmp.csv @%table", check_command_warning=True)
self.validate_identity(
- 'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)',
- check_command_warning=True,
- )
- self.validate_identity(
"SELECT * FROM @foo/bar (FILE_FORMAT => ds_sandbox.test.my_csv_format, PATTERN => 'test') AS bla"
)
self.validate_identity(
@@ -955,12 +954,16 @@ WHERE
self.validate_identity("SELECT CAST('12:00:00' AS TIME)")
self.validate_identity("SELECT DATE_PART(month, a)")
- self.validate_all(
- "SELECT CAST(a AS TIMESTAMP)",
- write={
- "snowflake": "SELECT CAST(a AS TIMESTAMPNTZ)",
- },
- )
+ for data_type in (
+ "TIMESTAMP",
+ "TIMESTAMPLTZ",
+ "TIMESTAMPNTZ",
+ ):
+ self.validate_identity(f"CAST(a AS {data_type})")
+
+ self.validate_identity("CAST(a AS TIMESTAMP_NTZ)", "CAST(a AS TIMESTAMPNTZ)")
+ self.validate_identity("CAST(a AS TIMESTAMP_LTZ)", "CAST(a AS TIMESTAMPLTZ)")
+
self.validate_all(
"SELECT a::TIMESTAMP_LTZ(9)",
write={
@@ -1000,14 +1003,14 @@ WHERE
self.validate_all(
"SELECT DATE_PART(epoch_second, foo) as ddate from table_name",
write={
- "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) AS ddate FROM table_name",
+ "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMP)) AS ddate FROM table_name",
"presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) AS ddate FROM table_name",
},
)
self.validate_all(
"SELECT DATE_PART(epoch_milliseconds, foo) as ddate from table_name",
write={
- "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMPNTZ)) * 1000 AS ddate FROM table_name",
+ "snowflake": "SELECT EXTRACT(epoch_second FROM CAST(foo AS TIMESTAMP)) * 1000 AS ddate FROM table_name",
"presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) * 1000 AS ddate FROM table_name",
},
)
@@ -1138,7 +1141,7 @@ WHERE
)
self.validate_identity(
"SELECT * FROM my_table AT (TIMESTAMP => 'Fri, 01 May 2015 16:20:00 -0700'::timestamp)",
- "SELECT * FROM my_table AT (TIMESTAMP => CAST('Fri, 01 May 2015 16:20:00 -0700' AS TIMESTAMPNTZ))",
+ "SELECT * FROM my_table AT (TIMESTAMP => CAST('Fri, 01 May 2015 16:20:00 -0700' AS TIMESTAMP))",
)
self.validate_identity(
"SELECT * FROM my_table AT(TIMESTAMP => 'Fri, 01 May 2015 16:20:00 -0700'::timestamp_tz)",
@@ -1581,7 +1584,7 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene
"REGEXP_REPLACE(subject, pattern, replacement, position, occurrence, parameters)",
write={
"bigquery": "REGEXP_REPLACE(subject, pattern, replacement)",
- "duckdb": "REGEXP_REPLACE(subject, pattern, replacement)",
+ "duckdb": "REGEXP_REPLACE(subject, pattern, replacement, parameters)",
"hive": "REGEXP_REPLACE(subject, pattern, replacement)",
"snowflake": "REGEXP_REPLACE(subject, pattern, replacement, position, occurrence, parameters)",
"spark": "REGEXP_REPLACE(subject, pattern, replacement, position)",
@@ -1827,3 +1830,18 @@ STORAGE_ALLOWED_LOCATIONS=('s3://mybucket1/path1/', 's3://mybucket2/path2/')""",
expression = annotate_types(expression)
self.assertEqual(expression.sql(dialect="snowflake"), "SELECT TRY_CAST(FOO() AS TEXT)")
+
+ def test_copy(self):
+ self.validate_identity(
+ """COPY INTO mytable (col1, col2) FROM 's3://mybucket/data/files' FILES = ('file1', 'file2') PATTERN = 'pattern' file_format = (FORMAT_NAME = my_csv_format NULL_IF = ('str1', 'str2')) PARSE_HEADER = TRUE""",
+ """COPY INTO mytable (col1, col2) FROM 's3://mybucket/data/files' FILES = ('file1', 'file2') PATTERN = 'pattern' FILE_FORMAT = (FORMAT_NAME = my_csv_format NULL_IF = ('str1', 'str2')) PARSE_HEADER = TRUE""",
+ )
+ self.validate_identity(
+ """COPY INTO temp FROM @random_stage/path/ FILE_FORMAT = (TYPE = CSV FIELD_DELIMITER = '|' NULL_IF = () FIELD_OPTIONALLY_ENCLOSED_BY = '"' TIMESTAMP_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' DATE_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' BINARY_FORMAT = BASE64) VALIDATION_MODE = 'RETURN_3_ROWS'"""
+ )
+ self.validate_identity(
+ """COPY INTO load1 FROM @%load1/data1/ FILES = ('test1.csv', 'test2.csv') FORCE = TRUE"""
+ )
+ self.validate_identity(
+ """COPY INTO mytable FROM 'azure://myaccount.blob.core.windows.net/mycontainer/data/files' CREDENTIALS = (AZURE_SAS_TOKEN = 'token') ENCRYPTION = (TYPE = 'AZURE_CSE' MASTER_KEY = 'kPx...') FILE_FORMAT = (FORMAT_NAME = my_csv_format)"""
+ )
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index 7534573..069ae42 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -343,7 +343,7 @@ TBLPROPERTIES (
"postgres": "SELECT CAST('2016-08-31' AS TIMESTAMP) AT TIME ZONE 'Asia/Seoul' AT TIME ZONE 'UTC'",
"presto": "SELECT WITH_TIMEZONE(CAST('2016-08-31' AS TIMESTAMP), 'Asia/Seoul') AT TIME ZONE 'UTC'",
"redshift": "SELECT CAST('2016-08-31' AS TIMESTAMP) AT TIME ZONE 'Asia/Seoul' AT TIME ZONE 'UTC'",
- "snowflake": "SELECT CONVERT_TIMEZONE('Asia/Seoul', 'UTC', CAST('2016-08-31' AS TIMESTAMPNTZ))",
+ "snowflake": "SELECT CONVERT_TIMEZONE('Asia/Seoul', 'UTC', CAST('2016-08-31' AS TIMESTAMP))",
"spark": "SELECT TO_UTC_TIMESTAMP(CAST('2016-08-31' AS TIMESTAMP), 'Asia/Seoul')",
},
)
@@ -523,7 +523,14 @@ TBLPROPERTIES (
},
)
- for data_type in ("BOOLEAN", "DATE", "DOUBLE", "FLOAT", "INT", "TIMESTAMP"):
+ for data_type in (
+ "BOOLEAN",
+ "DATE",
+ "DOUBLE",
+ "FLOAT",
+ "INT",
+ "TIMESTAMP",
+ ):
self.validate_all(
f"{data_type}(x)",
write={
@@ -531,6 +538,16 @@ TBLPROPERTIES (
"spark": f"CAST(x AS {data_type})",
},
)
+
+ for ts_suffix in ("NTZ", "LTZ"):
+ self.validate_all(
+ f"TIMESTAMP_{ts_suffix}(x)",
+ write={
+ "": f"CAST(x AS TIMESTAMP{ts_suffix})",
+ "spark": f"CAST(x AS TIMESTAMP_{ts_suffix})",
+ },
+ )
+
self.validate_all(
"STRING(x)",
write={
diff --git a/tests/dialects/test_trino.py b/tests/dialects/test_trino.py
new file mode 100644
index 0000000..ccc1407
--- /dev/null
+++ b/tests/dialects/test_trino.py
@@ -0,0 +1,18 @@
+from tests.dialects.test_dialect import Validator
+
+
+class TestTrino(Validator):
+ dialect = "trino"
+
+ def test_trim(self):
+ self.validate_identity("SELECT TRIM('!' FROM '!foo!')")
+ self.validate_identity("SELECT TRIM(BOTH '$' FROM '$var$')")
+ self.validate_identity("SELECT TRIM(TRAILING 'ER' FROM UPPER('worker'))")
+ self.validate_identity(
+ "SELECT TRIM(LEADING FROM ' abcd')",
+ "SELECT LTRIM(' abcd')",
+ )
+ self.validate_identity(
+ "SELECT TRIM('!foo!', '!')",
+ "SELECT TRIM('!' FROM '!foo!')",
+ )
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index 4a475f6..1538d47 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -29,6 +29,9 @@ class TestTSQL(Validator):
self.validate_identity("1 AND true", "1 <> 0 AND (1 = 1)")
self.validate_identity("CAST(x AS int) OR y", "CAST(x AS INTEGER) <> 0 OR y <> 0")
self.validate_identity("TRUNCATE TABLE t1 WITH (PARTITIONS(1, 2 TO 5, 10 TO 20, 84))")
+ self.validate_identity(
+ "COPY INTO test_1 FROM 'path' WITH (FILE_TYPE = 'CSV', CREDENTIAL = (IDENTITY = 'Shared Access Signature', SECRET = 'token'), FIELDTERMINATOR = ';', ROWTERMINATOR = '0X0A', ENCODING = 'UTF8', DATEFORMAT = 'ymd', MAXERRORS = 10, ERRORFILE = 'errorsfolder', IDENTITY_INSERT = 'ON')"
+ )
self.validate_all(
"SELECT IIF(cond <> 0, 'True', 'False')",
@@ -777,6 +780,14 @@ class TestTSQL(Validator):
"CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < CURRENT_TIMESTAMP - 7 END",
"CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < GETDATE() - 7 END",
)
+
+ self.validate_all(
+ "CREATE TABLE [#temptest] (name VARCHAR)",
+ read={
+ "duckdb": "CREATE TEMPORARY TABLE 'temptest' (name VARCHAR)",
+ "tsql": "CREATE TABLE [#temptest] (name VARCHAR)",
+ },
+ )
self.validate_all(
"CREATE TABLE tbl (id INTEGER IDENTITY PRIMARY KEY)",
read={
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index d51a978..6b742c3 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -642,6 +642,7 @@ CREATE TABLE T3 AS (SELECT DISTINCT A FROM T1 EXCEPT (SELECT A FROM T2) LIMIT 1)
DESCRIBE x
DESCRIBE EXTENDED a.b
DESCRIBE FORMATTED a.b
+DESCRIBE SELECT 1
DROP INDEX a.b.c
DROP FUNCTION a.b.c (INT)
DROP MATERIALIZED VIEW x.y.z
@@ -867,4 +868,5 @@ SELECT only
TRUNCATE(a, b)
SELECT enum
SELECT unlogged
-SELECT name \ No newline at end of file
+SELECT name
+SELECT copy \ No newline at end of file
diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql
index b020a27..6342cfc 100644
--- a/tests/fixtures/optimizer/qualify_columns.sql
+++ b/tests/fixtures/optimizer/qualify_columns.sql
@@ -523,6 +523,10 @@ SELECT t.c1 AS c1, t.c3 AS c3 FROM FOO(bar) AS t(c1, c2, c3);
SELECT c.f::VARCHAR(MAX) AS f, e AS e FROM a.b AS c, c.d AS e;
SELECT CAST(c.f AS VARCHAR(MAX)) AS f, e AS e FROM a.b AS c, c.d AS e;
+# dialect: bigquery
+WITH cte AS (SELECT 1 AS col) SELECT * FROM cte LEFT JOIN UNNEST((SELECT ARRAY_AGG(DISTINCT x) AS agg FROM UNNEST([1]) AS x WHERE col = 1));
+WITH cte AS (SELECT 1 AS col) SELECT * FROM cte AS cte LEFT JOIN UNNEST((SELECT ARRAY_AGG(DISTINCT x) AS agg FROM UNNEST([1]) AS x WHERE cte.col = 1));
+
--------------------------------------
-- Window functions
--------------------------------------
diff --git a/tests/fixtures/optimizer/qualify_columns_ddl.sql b/tests/fixtures/optimizer/qualify_columns_ddl.sql
index 9b4bb34..75d84ca 100644
--- a/tests/fixtures/optimizer/qualify_columns_ddl.sql
+++ b/tests/fixtures/optimizer/qualify_columns_ddl.sql
@@ -1,26 +1,26 @@
# title: Create with CTE
WITH cte AS (SELECT b FROM y) CREATE TABLE s AS SELECT * FROM cte;
-CREATE TABLE s AS WITH cte AS (SELECT y.b AS b FROM y AS y) SELECT cte.b AS b FROM cte AS cte;
+WITH cte AS (SELECT y.b AS b FROM y AS y) CREATE TABLE s AS SELECT cte.b AS b FROM cte AS cte;
# title: Create with CTE, query also has CTE
WITH cte1 AS (SELECT b FROM y) CREATE TABLE s AS WITH cte2 AS (SELECT b FROM cte1) SELECT * FROM cte2;
-CREATE TABLE s AS WITH cte1 AS (SELECT y.b AS b FROM y AS y), cte2 AS (SELECT cte1.b AS b FROM cte1 AS cte1) SELECT cte2.b AS b FROM cte2 AS cte2;
+WITH cte1 AS (SELECT y.b AS b FROM y AS y) CREATE TABLE s AS WITH cte2 AS (SELECT cte1.b AS b FROM cte1 AS cte1) SELECT cte2.b AS b FROM cte2 AS cte2;
# title: Create without CTE
CREATE TABLE foo AS SELECT a FROM tbl;
CREATE TABLE foo AS SELECT tbl.a AS a FROM tbl AS tbl;
# title: Create with complex CTE with derived table
-WITH cte AS (SELECT a FROM (SELECT a from x)) CREATE TABLE s AS SELECT * FROM cte;
-CREATE TABLE s AS WITH cte AS (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) SELECT cte.a AS a FROM cte AS cte;
+WITH cte AS (SELECT a FROM (SELECT a FROM x)) CREATE TABLE s AS SELECT * FROM cte;
+WITH cte AS (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) CREATE TABLE s AS SELECT cte.a AS a FROM cte AS cte;
# title: Create wtih multiple CTEs
WITH cte1 AS (SELECT b FROM y), cte2 AS (SELECT b FROM cte1) CREATE TABLE s AS SELECT * FROM cte2;
-CREATE TABLE s AS WITH cte1 AS (SELECT y.b AS b FROM y AS y), cte2 AS (SELECT cte1.b AS b FROM cte1 AS cte1) SELECT cte2.b AS b FROM cte2 AS cte2;
+WITH cte1 AS (SELECT y.b AS b FROM y AS y), cte2 AS (SELECT cte1.b AS b FROM cte1 AS cte1) CREATE TABLE s AS SELECT cte2.b AS b FROM cte2 AS cte2;
# title: Create with multiple CTEs, selecting only from the first CTE (unnecessary code)
WITH cte1 AS (SELECT b FROM y), cte2 AS (SELECT b FROM cte1) CREATE TABLE s AS SELECT * FROM cte1;
-CREATE TABLE s AS WITH cte1 AS (SELECT y.b AS b FROM y AS y), cte2 AS (SELECT cte1.b AS b FROM cte1 AS cte1) SELECT cte1.b AS b FROM cte1 AS cte1;
+WITH cte1 AS (SELECT y.b AS b FROM y AS y), cte2 AS (SELECT cte1.b AS b FROM cte1 AS cte1) CREATE TABLE s AS SELECT cte1.b AS b FROM cte1 AS cte1;
# title: Create with multiple derived tables
CREATE TABLE s AS SELECT * FROM (SELECT b FROM (SELECT b FROM y));
@@ -28,7 +28,7 @@ CREATE TABLE s AS SELECT _q_1.b AS b FROM (SELECT _q_0.b AS b FROM (SELECT y.b A
# title: Create with a CTE and a derived table
WITH cte AS (SELECT b FROM y) CREATE TABLE s AS SELECT * FROM (SELECT b FROM (SELECT b FROM cte));
-CREATE TABLE s AS WITH cte AS (SELECT y.b AS b FROM y AS y) SELECT _q_1.b AS b FROM (SELECT _q_0.b AS b FROM (SELECT cte.b AS b FROM cte AS cte) AS _q_0) AS _q_1;
+WITH cte AS (SELECT y.b AS b FROM y AS y) CREATE TABLE s AS SELECT _q_1.b AS b FROM (SELECT _q_0.b AS b FROM (SELECT cte.b AS b FROM cte AS cte) AS _q_0) AS _q_1;
# title: Insert with CTE
# dialect: spark
diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql
index 104400e..30bf834 100644
--- a/tests/fixtures/optimizer/qualify_tables.sql
+++ b/tests/fixtures/optimizer/qualify_tables.sql
@@ -158,6 +158,10 @@ ALTER TABLE c.db.t ADD PRIMARY KEY (id) NOT ENFORCED;
CREATE TABLE t1 AS (WITH cte AS (SELECT x FROM t2) SELECT * FROM cte);
CREATE TABLE c.db.t1 AS (WITH cte AS (SELECT x FROM c.db.t2 AS t2) SELECT * FROM cte AS cte);
+# title: delete statement
+DELETE FROM t1 WHERE NOT c IN (SELECT c FROM t2);
+DELETE FROM c.db.t1 WHERE NOT c IN (SELECT c FROM c.db.t2 AS t2);
+
# title: insert statement with cte
# dialect: spark
WITH cte AS (SELECT b FROM y) INSERT INTO s SELECT * FROM cte;
diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql
index 6af51bf..75abc38 100644
--- a/tests/fixtures/optimizer/simplify.sql
+++ b/tests/fixtures/optimizer/simplify.sql
@@ -109,6 +109,10 @@ a AND b;
(x is not null) != (y is null);
(NOT x IS NULL) <> (y IS NULL);
+# dialect: mysql
+A XOR A;
+FALSE;
+
--------------------------------------
-- Absorption
--------------------------------------
@@ -232,6 +236,13 @@ x - 1;
A AND D AND B AND E AND F AND G AND E AND A;
A AND B AND D AND E AND F AND G;
+A OR D OR B OR E OR F OR G OR E OR A;
+A OR B OR D OR E OR F OR G;
+
+# dialect: mysql
+A XOR D XOR B XOR E XOR F XOR G XOR C;
+A XOR B XOR C XOR D XOR E XOR F XOR G;
+
A AND NOT B AND C AND B;
FALSE;
diff --git a/tests/test_build.py b/tests/test_build.py
index ad0bb9a..da1677f 100644
--- a/tests/test_build.py
+++ b/tests/test_build.py
@@ -546,6 +546,10 @@ class TestBuild(unittest.TestCase):
"UPDATE tbl SET x = 1 FROM tbl2",
),
(
+ lambda: exp.update("tbl", {"x": 1}, from_="tbl2 cross join tbl3"),
+ "UPDATE tbl SET x = 1 FROM tbl2 CROSS JOIN tbl3",
+ ),
+ (
lambda: union("SELECT * FROM foo", "SELECT * FROM bla"),
"SELECT * FROM foo UNION SELECT * FROM bla",
),
diff --git a/tests/test_lineage.py b/tests/test_lineage.py
index c782d9a..3e17f95 100644
--- a/tests/test_lineage.py
+++ b/tests/test_lineage.py
@@ -224,16 +224,50 @@ class TestLineage(unittest.TestCase):
downstream.source.sql(dialect="snowflake"),
"LATERAL FLATTEN(INPUT => TEST_TABLE.RESULT, OUTER => TRUE) AS FLATTENED(SEQ, KEY, PATH, INDEX, VALUE, THIS)",
)
- self.assertEqual(
- downstream.expression.sql(dialect="snowflake"),
- "VALUE",
- )
+ self.assertEqual(downstream.expression.sql(dialect="snowflake"), "VALUE")
self.assertEqual(len(downstream.downstream), 1)
downstream = downstream.downstream[0]
self.assertEqual(downstream.name, "TEST_TABLE.RESULT")
self.assertEqual(downstream.source.sql(dialect="snowflake"), "TEST_TABLE AS TEST_TABLE")
+ node = lineage(
+ "FIELD",
+ "SELECT FLATTENED.VALUE:field::text AS FIELD FROM SNOWFLAKE.SCHEMA.MODEL AS MODEL_ALIAS, LATERAL FLATTEN(INPUT => MODEL_ALIAS.A) AS FLATTENED",
+ schema={"SNOWFLAKE": {"SCHEMA": {"TABLE": {"A": "integer"}}}},
+ sources={"SNOWFLAKE.SCHEMA.MODEL": "SELECT A FROM SNOWFLAKE.SCHEMA.TABLE"},
+ dialect="snowflake",
+ )
+ self.assertEqual(node.name, "FIELD")
+
+ downstream = node.downstream[0]
+ self.assertEqual(downstream.name, "FLATTENED.VALUE")
+ self.assertEqual(
+ downstream.source.sql(dialect="snowflake"),
+ "LATERAL FLATTEN(INPUT => MODEL_ALIAS.A) AS FLATTENED(SEQ, KEY, PATH, INDEX, VALUE, THIS)",
+ )
+ self.assertEqual(downstream.expression.sql(dialect="snowflake"), "VALUE")
+ self.assertEqual(len(downstream.downstream), 1)
+
+ downstream = downstream.downstream[0]
+ self.assertEqual(downstream.name, "MODEL_ALIAS.A")
+ self.assertEqual(downstream.source_name, "SNOWFLAKE.SCHEMA.MODEL")
+ self.assertEqual(
+ downstream.source.sql(dialect="snowflake"),
+ "SELECT TABLE.A AS A FROM SNOWFLAKE.SCHEMA.TABLE AS TABLE",
+ )
+ self.assertEqual(downstream.expression.sql(dialect="snowflake"), "TABLE.A AS A")
+ self.assertEqual(len(downstream.downstream), 1)
+
+ downstream = downstream.downstream[0]
+ self.assertEqual(downstream.name, "TABLE.A")
+ self.assertEqual(
+ downstream.source.sql(dialect="snowflake"), "SNOWFLAKE.SCHEMA.TABLE AS TABLE"
+ )
+ self.assertEqual(
+ downstream.expression.sql(dialect="snowflake"), "SNOWFLAKE.SCHEMA.TABLE AS TABLE"
+ )
+
def test_subquery(self) -> None:
node = lineage(
"output",
@@ -266,6 +300,7 @@ class TestLineage(unittest.TestCase):
self.assertEqual(node.name, "a")
node = node.downstream[0]
self.assertEqual(node.name, "cte.a")
+ self.assertEqual(node.reference_node_name, "cte")
node = node.downstream[0]
self.assertEqual(node.name, "z.a")
@@ -304,6 +339,27 @@ class TestLineage(unittest.TestCase):
node = a.downstream[0]
self.assertEqual(node.name, "foo.a")
+ # Select from derived table
+ node = lineage(
+ "a",
+ "SELECT a FROM (SELECT a FROM x) subquery",
+ )
+ self.assertEqual(node.name, "a")
+ self.assertEqual(len(node.downstream), 1)
+ node = node.downstream[0]
+ self.assertEqual(node.name, "subquery.a")
+ self.assertEqual(node.reference_node_name, "subquery")
+
+ node = lineage(
+ "a",
+ "SELECT a FROM (SELECT a FROM x)",
+ )
+ self.assertEqual(node.name, "a")
+ self.assertEqual(len(node.downstream), 1)
+ node = node.downstream[0]
+ self.assertEqual(node.name, "_q_0.a")
+ self.assertEqual(node.reference_node_name, "_q_0")
+
def test_lineage_cte_union(self) -> None:
query = """
WITH dataset AS (