summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/dialects/test_bigquery.py23
-rw-r--r--tests/dialects/test_clickhouse.py49
-rw-r--r--tests/dialects/test_databricks.py4
-rw-r--r--tests/dialects/test_dialect.py4
-rw-r--r--tests/dialects/test_duckdb.py1
-rw-r--r--tests/dialects/test_mysql.py2
-rw-r--r--tests/dialects/test_postgres.py5
-rw-r--r--tests/dialects/test_presto.py62
-rw-r--r--tests/dialects/test_redshift.py26
-rw-r--r--tests/dialects/test_snowflake.py42
-rw-r--r--tests/dialects/test_spark.py7
-rw-r--r--tests/dialects/test_teradata.py1
-rw-r--r--tests/dialects/test_tsql.py14
-rw-r--r--tests/fixtures/identity.sql1
-rw-r--r--tests/fixtures/optimizer/canonicalize.sql1
-rw-r--r--tests/fixtures/optimizer/qualify_tables.sql8
-rw-r--r--tests/fixtures/optimizer/simplify.sql67
-rw-r--r--tests/fixtures/optimizer/tpc-ds/tpc-ds.sql4
-rw-r--r--tests/test_expressions.py9
-rw-r--r--tests/test_lineage.py74
-rw-r--r--tests/test_optimizer.py41
-rw-r--r--tests/test_parser.py3
-rw-r--r--tests/test_transpile.py57
23 files changed, 475 insertions, 30 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index 3cf95a7..3601e47 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -9,6 +9,10 @@ class TestBigQuery(Validator):
maxDiff = None
def test_bigquery(self):
+ self.validate_identity("CREATE SCHEMA x DEFAULT COLLATE 'en'")
+ self.validate_identity("CREATE TABLE x (y INT64) DEFAULT COLLATE 'en'")
+ self.validate_identity("PARSE_JSON('{}', wide_number_mode => 'exact')")
+
with self.assertRaises(TokenError):
transpile("'\\'", read="bigquery")
@@ -139,6 +143,20 @@ class TestBigQuery(Validator):
self.validate_all("x <> ''''''", write={"bigquery": "x <> ''"})
self.validate_all("CAST(x AS DATETIME)", read={"": "x::timestamp"})
self.validate_all(
+ "SELECT * FROM t WHERE EXISTS(SELECT * FROM unnest(nums) AS x WHERE x > 1)",
+ write={
+ "bigquery": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS x WHERE x > 1)",
+ "duckdb": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS _t(x) WHERE x > 1)",
+ },
+ )
+ self.validate_all(
+ "NULL",
+ read={
+ "duckdb": "NULL = a",
+ "postgres": "a = NULL",
+ },
+ )
+ self.validate_all(
"SELECT '\\n'",
read={
"bigquery": "SELECT '''\n'''",
@@ -465,9 +483,8 @@ class TestBigQuery(Validator):
},
write={
"bigquery": "SELECT * FROM UNNEST(['7', '14']) AS x",
- "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS (x)",
- "hive": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)",
- "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)",
+ "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS _t(x)",
+ "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS _t(x)",
},
)
self.validate_all(
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index 948c00e..93d1ced 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -6,6 +6,22 @@ class TestClickhouse(Validator):
dialect = "clickhouse"
def test_clickhouse(self):
+ self.validate_identity("x <> y")
+
+ self.validate_all(
+ "has([1], x)",
+ read={
+ "postgres": "x = any(array[1])",
+ },
+ )
+ self.validate_all(
+ "NOT has([1], x)",
+ read={
+ "postgres": "any(array[1]) <> x",
+ },
+ )
+ self.validate_identity("x = y")
+
string_types = [
"BLOB",
"LONGBLOB",
@@ -86,6 +102,39 @@ class TestClickhouse(Validator):
)
self.validate_all(
+ "SELECT CAST('2020-01-01' AS TIMESTAMP) + INTERVAL '500' microsecond",
+ read={
+ "duckdb": "SELECT TIMESTAMP '2020-01-01' + INTERVAL '500 us'",
+ "postgres": "SELECT TIMESTAMP '2020-01-01' + INTERVAL '500 us'",
+ },
+ )
+ self.validate_all(
+ "SELECT CURRENT_DATE()",
+ read={
+ "clickhouse": "SELECT CURRENT_DATE()",
+ "postgres": "SELECT CURRENT_DATE",
+ },
+ )
+ self.validate_all(
+ "SELECT CURRENT_TIMESTAMP()",
+ read={
+ "clickhouse": "SELECT CURRENT_TIMESTAMP()",
+ "postgres": "SELECT CURRENT_TIMESTAMP",
+ },
+ )
+ self.validate_all(
+ "SELECT match('ThOmAs', CONCAT('(?i)', 'thomas'))",
+ read={
+ "postgres": "SELECT 'ThOmAs' ~* 'thomas'",
+ },
+ )
+ self.validate_all(
+ "SELECT match('ThOmAs', CONCAT('(?i)', x)) FROM t",
+ read={
+ "postgres": "SELECT 'ThOmAs' ~* x FROM t",
+ },
+ )
+ self.validate_all(
"SELECT '\\0'",
read={
"mysql": "SELECT '\0'",
diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py
index 7c03c83..8bb88b3 100644
--- a/tests/dialects/test_databricks.py
+++ b/tests/dialects/test_databricks.py
@@ -6,8 +6,8 @@ class TestDatabricks(Validator):
def test_databricks(self):
self.validate_identity("CREATE TABLE t (c STRUCT<interval: DOUBLE COMMENT 'aaa'>)")
- self.validate_identity("CREATE TABLE my_table () TBLPROPERTIES (a.b=15)")
- self.validate_identity("CREATE TABLE my_table () TBLPROPERTIES ('a.b'=15)")
+ self.validate_identity("CREATE TABLE my_table TBLPROPERTIES (a.b=15)")
+ self.validate_identity("CREATE TABLE my_table TBLPROPERTIES ('a.b'=15)")
self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO HOUR)")
self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO MINUTE)")
self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO SECOND)")
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 91eba17..0d43b2a 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -99,6 +99,7 @@ class TestDialect(Validator):
"snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)",
+ "tsql": "CAST(a AS VARCHAR(MAX))",
"doris": "CAST(a AS STRING)",
},
)
@@ -179,6 +180,7 @@ class TestDialect(Validator):
"snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)",
+ "tsql": "CAST(a AS VARCHAR(MAX))",
"doris": "CAST(a AS STRING)",
},
)
@@ -197,6 +199,7 @@ class TestDialect(Validator):
"snowflake": "CAST(a AS VARCHAR)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS VARCHAR)",
+ "tsql": "CAST(a AS VARCHAR)",
"doris": "CAST(a AS VARCHAR)",
},
)
@@ -215,6 +218,7 @@ class TestDialect(Validator):
"snowflake": "CAST(a AS VARCHAR(3))",
"spark": "CAST(a AS VARCHAR(3))",
"starrocks": "CAST(a AS VARCHAR(3))",
+ "tsql": "CAST(a AS VARCHAR(3))",
"doris": "CAST(a AS VARCHAR(3))",
},
)
diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py
index 54553b3..f9de953 100644
--- a/tests/dialects/test_duckdb.py
+++ b/tests/dialects/test_duckdb.py
@@ -249,6 +249,7 @@ class TestDuckDB(Validator):
"SELECT ARRAY_LENGTH([0], 1) AS x",
write={"duckdb": "SELECT ARRAY_LENGTH([0], 1) AS x"},
)
+ self.validate_identity("REGEXP_REPLACE(this, pattern, replacement, modifiers)")
self.validate_all(
"REGEXP_MATCHES(x, y)",
write={
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index b9d1d26..dce2b9d 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -586,6 +586,8 @@ class TestMySQL(Validator):
write={
"mysql": "SELECT * FROM test LIMIT 1 OFFSET 1",
"postgres": "SELECT * FROM test LIMIT 0 + 1 OFFSET 0 + 1",
+ "presto": "SELECT * FROM test OFFSET 1 LIMIT 1",
+ "trino": "SELECT * FROM test OFFSET 1 LIMIT 1",
},
)
self.validate_all(
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 22bede4..3121cb0 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -732,3 +732,8 @@ class TestPostgres(Validator):
self.validate_all(
"VAR_POP(x)", read={"": "VARIANCE_POP(x)"}, write={"postgres": "VAR_POP(x)"}
)
+
+ def test_regexp_binary(self):
+ """See https://github.com/tobymao/sqlglot/pull/2404 for details."""
+ self.assertIsInstance(parse_one("'thomas' ~ '.*thomas.*'", read="postgres"), exp.Binary)
+ self.assertIsInstance(parse_one("'thomas' ~* '.*thomas.*'", read="postgres"), exp.Binary)
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index fd297d7..ed734b6 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -367,6 +367,21 @@ class TestPresto(Validator):
"CAST(x AS TIMESTAMP)",
read={"mysql": "TIMESTAMP(x)"},
)
+ self.validate_all(
+ "TIMESTAMP(x, 'America/Los_Angeles')",
+ write={
+ "duckdb": "CAST(x AS TIMESTAMP) AT TIME ZONE 'America/Los_Angeles'",
+ "presto": "CAST(x AS TIMESTAMP) AT TIME ZONE 'America/Los_Angeles'",
+ },
+ )
+ # this case isn't really correct, but it's a fall back for mysql's version
+ self.validate_all(
+ "TIMESTAMP(x, '12:00:00')",
+ write={
+ "duckdb": "TIMESTAMP(x, '12:00:00')",
+ "presto": "TIMESTAMP(x, '12:00:00')",
+ },
+ )
def test_ddl(self):
self.validate_all(
@@ -441,6 +456,22 @@ class TestPresto(Validator):
},
)
+ self.validate_all(
+ "CREATE OR REPLACE VIEW x (cola) SELECT 1 as cola",
+ write={
+ "spark": "CREATE OR REPLACE VIEW x (cola) AS SELECT 1 AS cola",
+ "presto": "CREATE OR REPLACE VIEW x AS SELECT 1 AS cola",
+ },
+ )
+
+ self.validate_all(
+ 'CREATE TABLE IF NOT EXISTS x ("cola" INTEGER, "ds" TEXT) WITH (PARTITIONED BY=("ds"))',
+ write={
+ "spark": "CREATE TABLE IF NOT EXISTS x (`cola` INT, `ds` STRING) PARTITIONED BY (`ds`)",
+ "presto": """CREATE TABLE IF NOT EXISTS x ("cola" INTEGER, "ds" VARCHAR) WITH (PARTITIONED_BY=ARRAY['ds'])""",
+ },
+ )
+
def test_quotes(self):
self.validate_all(
"''''",
@@ -528,6 +559,37 @@ class TestPresto(Validator):
)
self.validate_all(
+ "SELECT MAX_BY(a.id, a.timestamp) FROM a",
+ read={
+ "bigquery": "SELECT MAX_BY(a.id, a.timestamp) FROM a",
+ "clickhouse": "SELECT argMax(a.id, a.timestamp) FROM a",
+ "duckdb": "SELECT MAX_BY(a.id, a.timestamp) FROM a",
+ "snowflake": "SELECT MAX_BY(a.id, a.timestamp) FROM a",
+ "spark": "SELECT MAX_BY(a.id, a.timestamp) FROM a",
+ "teradata": "SELECT MAX_BY(a.id, a.timestamp) FROM a",
+ },
+ write={
+ "bigquery": "SELECT MAX_BY(a.id, a.timestamp) FROM a",
+ "clickhouse": "SELECT argMax(a.id, a.timestamp) FROM a",
+ "duckdb": "SELECT ARG_MAX(a.id, a.timestamp) FROM a",
+ "presto": "SELECT MAX_BY(a.id, a.timestamp) FROM a",
+ "snowflake": "SELECT MAX_BY(a.id, a.timestamp) FROM a",
+ "spark": "SELECT MAX_BY(a.id, a.timestamp) FROM a",
+ "teradata": "SELECT MAX_BY(a.id, a.timestamp) FROM a",
+ },
+ )
+ self.validate_all(
+ "SELECT MIN_BY(a.id, a.timestamp, 3) FROM a",
+ write={
+ "clickhouse": "SELECT argMin(a.id, a.timestamp) FROM a",
+ "duckdb": "SELECT ARG_MIN(a.id, a.timestamp) FROM a",
+ "presto": "SELECT MIN_BY(a.id, a.timestamp, 3) FROM a",
+ "snowflake": "SELECT MIN_BY(a.id, a.timestamp, 3) FROM a",
+ "spark": "SELECT MIN_BY(a.id, a.timestamp) FROM a",
+ "teradata": "SELECT MIN_BY(a.id, a.timestamp, 3) FROM a",
+ },
+ )
+ self.validate_all(
"""JSON '"foo"'""",
write={
"bigquery": """PARSE_JSON('"foo"')""",
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index f182feb..c848010 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -6,6 +6,10 @@ class TestRedshift(Validator):
dialect = "redshift"
def test_redshift(self):
+ self.validate_identity(
+ "SELECT * FROM x WHERE y = DATEADD('month', -1, DATE_TRUNC('month', (SELECT y FROM #temp_table)))",
+ "SELECT * FROM x WHERE y = DATEADD(month, -1, CAST(DATE_TRUNC('month', (SELECT y FROM #temp_table)) AS DATE))",
+ )
self.validate_all(
"SELECT APPROXIMATE COUNT(DISTINCT y)",
read={
@@ -16,13 +20,6 @@ class TestRedshift(Validator):
"spark": "SELECT APPROX_COUNT_DISTINCT(y)",
},
)
- self.validate_identity("SELECT APPROXIMATE AS y")
-
- self.validate_identity(
- "SELECT 'a''b'",
- "SELECT 'a\\'b'",
- )
-
self.validate_all(
"x ~* 'pat'",
write={
@@ -30,7 +27,6 @@ class TestRedshift(Validator):
"snowflake": "REGEXP_LIKE(x, 'pat', 'i')",
},
)
-
self.validate_all(
"SELECT CAST('01:03:05.124' AS TIME(2) WITH TIME ZONE)",
read={
@@ -248,6 +244,19 @@ class TestRedshift(Validator):
self.validate_identity("CAST('foo' AS HLLSKETCH)")
self.validate_identity("'abc' SIMILAR TO '(b|c)%'")
self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)")
+ self.validate_identity("SELECT APPROXIMATE AS y")
+ self.validate_identity("CREATE TABLE t (c BIGINT IDENTITY(0, 1))")
+ self.validate_identity(
+ "SELECT 'a''b'",
+ "SELECT 'a\\'b'",
+ )
+ self.validate_identity(
+ "CREATE TABLE t (c BIGINT GENERATED BY DEFAULT AS IDENTITY (0, 1))",
+ "CREATE TABLE t (c BIGINT IDENTITY(0, 1))",
+ )
+ self.validate_identity(
+ "CREATE OR REPLACE VIEW v1 AS SELECT id, AVG(average_metric1) AS m1, AVG(average_metric2) AS m2 FROM t GROUP BY id WITH NO SCHEMA BINDING"
+ )
self.validate_identity(
"SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'"
)
@@ -301,6 +310,7 @@ ORDER BY
self.validate_identity(
"SELECT attr AS attr, JSON_TYPEOF(val) AS value_type FROM customer_orders_lineitem AS c, UNPIVOT c.c_orders AS val AT attr WHERE c_custkey = 9451"
)
+ self.validate_identity("SELECT JSON_PARSE('[]')")
def test_values(self):
# Test crazy-sized VALUES clause to UNION ALL conversion to ensure we don't get RecursionError
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 7c36bea..65b77ea 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -9,6 +9,12 @@ class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
+ expr = parse_one("SELECT APPROX_TOP_K(C4, 3, 5) FROM t")
+ expr.selects[0].assert_is(exp.AggFunc)
+ self.assertEqual(expr.sql(dialect="snowflake"), "SELECT APPROX_TOP_K(C4, 3, 5) FROM t")
+
+ self.validate_identity("SELECT DAYOFMONTH(CURRENT_TIMESTAMP())")
+ self.validate_identity("SELECT DAYOFYEAR(CURRENT_TIMESTAMP())")
self.validate_identity("LISTAGG(data['some_field'], ',')")
self.validate_identity("WEEKOFYEAR(tstamp)")
self.validate_identity("SELECT SUM(amount) FROM mytable GROUP BY ALL")
@@ -36,6 +42,7 @@ class TestSnowflake(Validator):
self.validate_identity("COMMENT IF EXISTS ON TABLE foo IS 'bar'")
self.validate_identity("SELECT CONVERT_TIMEZONE('UTC', 'America/Los_Angeles', col)")
self.validate_identity("REGEXP_REPLACE('target', 'pattern', '\n')")
+ self.validate_identity("ALTER TABLE a SWAP WITH b")
self.validate_identity(
'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage'
)
@@ -58,6 +65,18 @@ class TestSnowflake(Validator):
"SELECT {'test': 'best'}::VARIANT",
"SELECT CAST(OBJECT_CONSTRUCT('test', 'best') AS VARIANT)",
)
+ self.validate_identity(
+ "SELECT {fn DAYNAME('2022-5-13')}",
+ "SELECT DAYNAME('2022-5-13')",
+ )
+ self.validate_identity(
+ "SELECT {fn LOG(5)}",
+ "SELECT LN(5)",
+ )
+ self.validate_identity(
+ "SELECT {fn CEILING(5.3)}",
+ "SELECT CEIL(5.3)",
+ )
self.validate_all("CAST(x AS BYTEINT)", write={"snowflake": "CAST(x AS INT)"})
self.validate_all("CAST(x AS CHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"})
@@ -911,7 +930,23 @@ FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS
f.value AS "Contact",
f1.value['type'] AS "Type",
f1.value['content'] AS "Details"
-FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERAL FLATTEN(input => f.value['business']) AS f1""",
+FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f(SEQ, KEY, PATH, INDEX, VALUE, THIS), LATERAL FLATTEN(input => f.value['business']) AS f1(SEQ, KEY, PATH, INDEX, VALUE, THIS)""",
+ },
+ pretty=True,
+ )
+
+ self.validate_all(
+ """
+ SELECT id as "ID",
+ value AS "Contact"
+ FROM persons p,
+ lateral flatten(input => p.c, path => 'contact')
+ """,
+ write={
+ "snowflake": """SELECT
+ id AS "ID",
+ value AS "Contact"
+FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattened(SEQ, KEY, PATH, INDEX, VALUE, THIS)""",
},
pretty=True,
)
@@ -1134,3 +1169,8 @@ MATCH_RECOGNIZE (
self.assertIsNotNone(table)
self.assertEqual(table.sql(dialect="snowflake"), '"TEST"."PUBLIC"."customers"')
+
+ def test_swap(self):
+ ast = parse_one("ALTER TABLE a SWAP WITH b", read="snowflake")
+ assert isinstance(ast, exp.AlterTable)
+ assert isinstance(ast.args["actions"][0], exp.SwapTable)
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index 9bb9d79..e08915b 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -230,6 +230,7 @@ TBLPROPERTIES (
self.assertIsInstance(expr.args.get("ignore_nulls"), exp.Boolean)
self.assertEqual(expr.sql(dialect="spark"), "ANY_VALUE(col, TRUE)")
+ self.validate_identity("SELECT CASE WHEN a = NULL THEN 1 ELSE 2 END")
self.validate_identity("SELECT * FROM t1 SEMI JOIN t2 ON t1.x = t2.x")
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)")
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)")
@@ -295,7 +296,7 @@ TBLPROPERTIES (
},
write={
"spark": "SELECT DATEDIFF(month, TO_DATE(CAST('1996-10-30' AS TIMESTAMP)), TO_DATE(CAST('1997-02-28 10:30:00' AS TIMESTAMP)))",
- "spark2": "SELECT MONTHS_BETWEEN(TO_DATE(CAST('1997-02-28 10:30:00' AS TIMESTAMP)), TO_DATE(CAST('1996-10-30' AS TIMESTAMP)))",
+ "spark2": "SELECT CAST(MONTHS_BETWEEN(TO_DATE(CAST('1997-02-28 10:30:00' AS TIMESTAMP)), TO_DATE(CAST('1996-10-30' AS TIMESTAMP))) AS INT)",
},
)
self.validate_all(
@@ -403,10 +404,10 @@ TBLPROPERTIES (
"SELECT DATEDIFF(MONTH, '2020-01-01', '2020-03-05')",
write={
"databricks": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))",
- "hive": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))",
+ "hive": "SELECT CAST(MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01')) AS INT)",
"presto": "SELECT DATE_DIFF('MONTH', CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE), CAST(CAST('2020-03-05' AS TIMESTAMP) AS DATE))",
"spark": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))",
- "spark2": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))",
+ "spark2": "SELECT CAST(MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01')) AS INT)",
"trino": "SELECT DATE_DIFF('MONTH', CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE), CAST(CAST('2020-03-05' AS TIMESTAMP) AS DATE))",
},
)
diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py
index 9dbac8c..b5c0fe8 100644
--- a/tests/dialects/test_teradata.py
+++ b/tests/dialects/test_teradata.py
@@ -5,6 +5,7 @@ class TestTeradata(Validator):
dialect = "teradata"
def test_teradata(self):
+ self.validate_identity("SELECT TOP 10 * FROM tbl")
self.validate_identity("SELECT * FROM tbl SAMPLE 5")
self.validate_identity(
"SELECT * FROM tbl SAMPLE 0.33, .25, .1",
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index f9a720a..4775020 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -1058,18 +1058,18 @@ WHERE
},
)
self.validate_all(
- "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')",
+ "SELECT DATEDIFF(year, '2020-01-01', '2021-01-01')",
write={
- "tsql": "SELECT DATEDIFF(year, CAST('2020/01/01' AS DATETIME2), CAST('2021/01/01' AS DATETIME2))",
- "spark": "SELECT DATEDIFF(year, CAST('2020/01/01' AS TIMESTAMP), CAST('2021/01/01' AS TIMESTAMP))",
- "spark2": "SELECT MONTHS_BETWEEN(CAST('2021/01/01' AS TIMESTAMP), CAST('2020/01/01' AS TIMESTAMP)) / 12",
+ "tsql": "SELECT DATEDIFF(year, CAST('2020-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))",
+ "spark": "SELECT DATEDIFF(year, CAST('2020-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))",
+ "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('2021-01-01' AS TIMESTAMP), CAST('2020-01-01' AS TIMESTAMP)) AS INT) / 12",
},
)
self.validate_all(
"SELECT DATEDIFF(mm, 'start', 'end')",
write={
"databricks": "SELECT DATEDIFF(month, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))",
- "spark2": "SELECT MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))",
+ "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) AS INT)",
"tsql": "SELECT DATEDIFF(month, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))",
},
)
@@ -1078,7 +1078,7 @@ WHERE
write={
"databricks": "SELECT DATEDIFF(quarter, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))",
"spark": "SELECT DATEDIFF(quarter, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))",
- "spark2": "SELECT MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) / 3",
+ "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) AS INT) / 3",
"tsql": "SELECT DATEDIFF(quarter, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))",
},
)
@@ -1374,7 +1374,7 @@ FROM OPENJSON(@json) WITH (
Date DATETIME2 '$.Order.Date',
Customer VARCHAR(200) '$.AccountNumber',
Quantity INTEGER '$.Item.Quantity',
- "Order" TEXT AS JSON
+ "Order" VARCHAR(MAX) AS JSON
)"""
},
pretty=True,
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 2738707..6e0a3e5 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -866,3 +866,4 @@ KILL CONNECTION 123
KILL QUERY '123'
CHR(97)
SELECT * FROM UNNEST(x) WITH ORDINALITY UNION ALL SELECT * FROM UNNEST(y) WITH ORDINALITY
+WITH use(use) AS (SELECT 1) SELECT use FROM use
diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql
index 2ba762d..954b1c1 100644
--- a/tests/fixtures/optimizer/canonicalize.sql
+++ b/tests/fixtures/optimizer/canonicalize.sql
@@ -16,7 +16,6 @@ SELECT CAST('2022-01-01' AS DATE) + INTERVAL '1' day AS "_col_0";
--------------------------------------
-- Ensure boolean predicates
--------------------------------------
-
SELECT a FROM x WHERE b;
SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE "x"."b" <> 0;
diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql
index f43ac01..3717cd4 100644
--- a/tests/fixtures/optimizer/qualify_tables.sql
+++ b/tests/fixtures/optimizer/qualify_tables.sql
@@ -109,3 +109,11 @@ SELECT * FROM ((SELECT * FROM c.db.t AS t) AS _q_0);
# title: wrapped subquery without alias joined with a table
SELECT * FROM ((SELECT * FROM t1) INNER JOIN t2 ON a = b);
SELECT * FROM ((SELECT * FROM c.db.t1 AS t1) AS _q_0 INNER JOIN c.db.t2 AS t2 ON a = b);
+
+# title: lateral unnest with alias
+SELECT x FROM t, LATERAL UNNEST(t.xs) AS x;
+SELECT x FROM c.db.t AS t, LATERAL UNNEST(t.xs) AS x;
+
+# title: lateral unnest without alias
+SELECT x FROM t, LATERAL UNNEST(t.xs);
+SELECT x FROM c.db.t AS t, LATERAL UNNEST(t.xs) AS _q_0;
diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql
index e54170c..c53a972 100644
--- a/tests/fixtures/optimizer/simplify.sql
+++ b/tests/fixtures/optimizer/simplify.sql
@@ -911,13 +911,76 @@ t1.a = 39 AND t2.b = t1.a AND t3.c = t2.b;
t1.a = 39 AND t2.b = 39 AND t3.c = 39;
x = 1 AND CASE WHEN x = 5 THEN FALSE ELSE TRUE END;
-x = 1 AND CASE WHEN FALSE THEN FALSE ELSE TRUE END;
+x = 1;
x = 1 AND IF(x = 5, FALSE, TRUE);
-x = 1 AND CASE WHEN FALSE THEN FALSE ELSE TRUE END;
+x = 1;
+
+x = 1 AND CASE x WHEN 5 THEN FALSE ELSE TRUE END;
+x = 1;
x = y AND CASE WHEN x = 5 THEN FALSE ELSE TRUE END;
x = y AND CASE WHEN x = 5 THEN FALSE ELSE TRUE END;
x = 1 AND CASE WHEN y = 5 THEN x = z END;
x = 1 AND CASE WHEN y = 5 THEN 1 = z END;
+
+--------------------------------------
+-- Simplify Conditionals
+--------------------------------------
+IF(TRUE, x, y);
+x;
+
+IF(FALSE, x, y);
+y;
+
+IF(FALSE, x);
+NULL;
+
+IF(NULL, x, y);
+y;
+
+IF(cond, x, y);
+CASE WHEN cond THEN x ELSE y END;
+
+CASE WHEN TRUE THEN x ELSE y END;
+x;
+
+CASE WHEN FALSE THEN x ELSE y END;
+y;
+
+CASE WHEN FALSE THEN x WHEN FALSE THEN y WHEN TRUE THEN z END;
+z;
+
+CASE NULL WHEN NULL THEN x ELSE y END;
+y;
+
+CASE 4 WHEN 1 THEN x WHEN 2 THEN y WHEN 3 THEN z ELSE w END;
+w;
+
+CASE 4 WHEN 1 THEN x WHEN 2 THEN y WHEN 3 THEN z WHEN 4 THEN w END;
+w;
+
+CASE WHEN value = 1 THEN x ELSE y END;
+CASE WHEN value = 1 THEN x ELSE y END;
+
+CASE WHEN FALSE THEN x END;
+NULL;
+
+CASE 1 WHEN 1 + 1 THEN x END;
+NULL;
+
+CASE WHEN cond THEN x ELSE y END;
+CASE WHEN cond THEN x ELSE y END;
+
+CASE WHEN cond THEN x END;
+CASE WHEN cond THEN x END;
+
+CASE x WHEN y THEN z ELSE w END;
+CASE WHEN x = y THEN z ELSE w END;
+
+CASE x WHEN y THEN z END;
+CASE WHEN x = y THEN z END;
+
+CASE x1 + x2 WHEN x3 THEN x4 WHEN x5 + x6 THEN x7 ELSE x8 END;
+CASE WHEN (x1 + x2) = x3 THEN x4 WHEN (x1 + x2) = (x5 + x6) THEN x7 ELSE x8 END;
diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
index 91b553e..52ee12c 100644
--- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
+++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
@@ -4808,10 +4808,10 @@ WITH "foo" AS (
"foo"."i_item_sk" AS "i_item_sk",
"foo"."d_moy" AS "d_moy",
"foo"."mean" AS "mean",
- CASE "foo"."mean" WHEN FALSE THEN NULL ELSE "foo"."stdev" / "foo"."mean" END AS "cov"
+ CASE WHEN "foo"."mean" = 0 THEN NULL ELSE "foo"."stdev" / "foo"."mean" END AS "cov"
FROM "foo" AS "foo"
WHERE
- CASE "foo"."mean" WHEN FALSE THEN 0 ELSE "foo"."stdev" / "foo"."mean" END > 1
+ CASE WHEN "foo"."mean" = 0 THEN 0 ELSE "foo"."stdev" / "foo"."mean" END > 1
)
SELECT
"inv1"."w_warehouse_sk" AS "w_warehouse_sk",
diff --git a/tests/test_expressions.py b/tests/test_expressions.py
index f8c8bcc..6c48943 100644
--- a/tests/test_expressions.py
+++ b/tests/test_expressions.py
@@ -632,6 +632,11 @@ class TestExpressions(unittest.TestCase):
week = unit.find(exp.Week)
self.assertEqual(week.this, exp.var("thursday"))
+ for abbreviated_unit, unnabreviated_unit in exp.TimeUnit.UNABBREVIATED_UNIT_NAME.items():
+ interval = parse_one(f"interval '500 {abbreviated_unit}'")
+ self.assertIsInstance(interval.unit, exp.Var)
+ self.assertEqual(interval.unit.name, unnabreviated_unit)
+
def test_identifier(self):
self.assertTrue(exp.to_identifier('"x"').quoted)
self.assertFalse(exp.to_identifier("x").quoted)
@@ -861,6 +866,10 @@ FROM foo""",
self.assertEqual(exp.DataType.build("ARRAY<UNKNOWN>").sql(), "ARRAY<UNKNOWN>")
self.assertEqual(exp.DataType.build("ARRAY<NULL>").sql(), "ARRAY<NULL>")
+ self.assertEqual(exp.DataType.build("varchar(100) collate 'en-ci'").sql(), "VARCHAR(100)")
+
+ with self.assertRaises(ParseError):
+ exp.DataType.build("varchar(")
def test_rename_table(self):
self.assertEqual(
diff --git a/tests/test_lineage.py b/tests/test_lineage.py
index 0fd9da8..25329e2 100644
--- a/tests/test_lineage.py
+++ b/tests/test_lineage.py
@@ -199,3 +199,77 @@ class TestLineage(unittest.TestCase):
"SELECT x FROM (SELECT ax AS x FROM a UNION SELECT bx FROM b UNION SELECT cx FROM c)",
)
assert len(node.downstream) == 3
+
+ def test_lineage_lateral_flatten(self) -> None:
+ node = lineage(
+ "VALUE",
+ "SELECT FLATTENED.VALUE FROM TEST_TABLE, LATERAL FLATTEN(INPUT => RESULT, OUTER => TRUE) FLATTENED",
+ dialect="snowflake",
+ )
+ self.assertEqual(node.name, "VALUE")
+
+ downstream = node.downstream[0]
+ self.assertEqual(downstream.name, "FLATTENED.VALUE")
+ self.assertEqual(
+ downstream.source.sql(dialect="snowflake"),
+ "LATERAL FLATTEN(INPUT => TEST_TABLE.RESULT, OUTER => TRUE) AS FLATTENED(SEQ, KEY, PATH, INDEX, VALUE, THIS)",
+ )
+ self.assertEqual(
+ downstream.expression.sql(dialect="snowflake"),
+ "VALUE",
+ )
+ self.assertEqual(len(downstream.downstream), 1)
+
+ downstream = downstream.downstream[0]
+ self.assertEqual(downstream.name, "TEST_TABLE.RESULT")
+ self.assertEqual(downstream.source.sql(dialect="snowflake"), "TEST_TABLE AS TEST_TABLE")
+
+ def test_subquery(self) -> None:
+ node = lineage(
+ "output",
+ "SELECT (SELECT max(t3.my_column) my_column FROM foo t3) AS output FROM table3",
+ )
+ self.assertEqual(node.name, "SUBQUERY")
+ node = node.downstream[0]
+ self.assertEqual(node.name, "my_column")
+ node = node.downstream[0]
+ self.assertEqual(node.name, "t3.my_column")
+ self.assertEqual(node.source.sql(), "foo AS t3")
+
+ def test_lineage_cte_union(self) -> None:
+ query = """
+ WITH dataset AS (
+ SELECT *
+ FROM catalog.db.table_a
+
+ UNION
+
+ SELECT *
+ FROM catalog.db.table_b
+ )
+
+ SELECT x, created_at FROM dataset;
+ """
+ node = lineage("x", query)
+
+ self.assertEqual(node.name, "x")
+
+ downstream_a = node.downstream[0]
+ self.assertEqual(downstream_a.name, "0")
+ self.assertEqual(downstream_a.source.sql(), "SELECT * FROM catalog.db.table_a AS table_a")
+ downstream_b = node.downstream[1]
+ self.assertEqual(downstream_b.name, "0")
+ self.assertEqual(downstream_b.source.sql(), "SELECT * FROM catalog.db.table_b AS table_b")
+
+ def test_select_star(self) -> None:
+ node = lineage("x", "SELECT x from (SELECT * from table_a)")
+
+ self.assertEqual(node.name, "x")
+
+ downstream = node.downstream[0]
+ self.assertEqual(downstream.name, "_q_0.x")
+ self.assertEqual(downstream.source.sql(), "SELECT * FROM table_a AS table_a")
+
+ downstream = downstream.downstream[0]
+ self.assertEqual(downstream.name, "*")
+ self.assertEqual(downstream.source.sql(), "table_a AS table_a")
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index c43a84e..8f5dd08 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -550,6 +550,47 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT)
+ def test_bracket_annotation(self):
+ expression = annotate_types(parse_one("SELECT A[:]")).expressions[0]
+
+ self.assertEqual(expression.type.this, exp.DataType.Type.UNKNOWN)
+ self.assertEqual(expression.expressions[0].type.this, exp.DataType.Type.UNKNOWN)
+
+ expression = annotate_types(parse_one("SELECT ARRAY[1, 2, 3][1]")).expressions[0]
+ self.assertEqual(expression.this.type.sql(), "ARRAY<INT>")
+ self.assertEqual(expression.type.this, exp.DataType.Type.INT)
+
+ expression = annotate_types(parse_one("SELECT ARRAY[1, 2, 3][1 : 2]")).expressions[0]
+ self.assertEqual(expression.this.type.sql(), "ARRAY<INT>")
+ self.assertEqual(expression.type.sql(), "ARRAY<INT>")
+
+ expression = annotate_types(
+ parse_one("SELECT ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]][1][2]")
+ ).expressions[0]
+ self.assertEqual(expression.this.this.type.sql(), "ARRAY<ARRAY<INT>>")
+ self.assertEqual(expression.this.type.sql(), "ARRAY<INT>")
+ self.assertEqual(expression.type.this, exp.DataType.Type.INT)
+
+ expression = annotate_types(
+ parse_one("SELECT ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]][1:2]")
+ ).expressions[0]
+ self.assertEqual(expression.type.sql(), "ARRAY<ARRAY<INT>>")
+
+ expression = annotate_types(parse_one("MAP(1.0, 2, '2', 3.0)['2']", read="spark"))
+ self.assertEqual(expression.type.this, exp.DataType.Type.DOUBLE)
+
+ expression = annotate_types(parse_one("MAP(1.0, 2, x, 3.0)[2]", read="spark"))
+ self.assertEqual(expression.type.this, exp.DataType.Type.UNKNOWN)
+
+ expression = annotate_types(parse_one("MAP(ARRAY(1.0, x), ARRAY(2, 3.0))[x]"))
+ self.assertEqual(expression.type.this, exp.DataType.Type.DOUBLE)
+
+ expression = annotate_types(
+ parse_one("SELECT MAP(1.0, 2, 2, t.y)[2] FROM t", read="spark"),
+ schema={"t": {"y": "int"}},
+ ).expressions[0]
+ self.assertEqual(expression.type.this, exp.DataType.Type.INT)
+
def test_interval_math_annotation(self):
schema = {
"x": {
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 53e1a85..f3e663e 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -234,6 +234,9 @@ class TestParser(unittest.TestCase):
"CREATE TABLE t (i UInt8) ENGINE=AggregatingMergeTree() ORDER BY tuple()",
)
+ with self.assertRaises(ParseError):
+ parse_one("SELECT A[:")
+
def test_space(self):
self.assertEqual(
parse_one("SELECT ROW() OVER(PARTITION BY x) FROM x GROUP BY y").sql(),
diff --git a/tests/test_transpile.py b/tests/test_transpile.py
index d588f07..c16b1f6 100644
--- a/tests/test_transpile.py
+++ b/tests/test_transpile.py
@@ -19,6 +19,9 @@ class TestTranspile(unittest.TestCase):
def validate(self, sql, target, **kwargs):
self.assertEqual(transpile(sql, **kwargs)[0], target)
+ def test_weird_chars(self):
+ self.assertEqual(transpile("0Êß")[0], "0 AS Êß")
+
def test_alias(self):
self.assertEqual(transpile("SELECT SUM(y) KEEP")[0], "SELECT SUM(y) AS KEEP")
self.assertEqual(transpile("SELECT 1 overwrite")[0], "SELECT 1 AS overwrite")
@@ -87,7 +90,18 @@ class TestTranspile(unittest.TestCase):
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
def test_comments(self):
- self.validate("SELECT\n foo\n/* comments */\n;", "SELECT foo /* comments */")
+ self.validate(
+ "SELECT * FROM t1\n/*x*/\nUNION ALL SELECT * FROM t2",
+ "SELECT * FROM t1 /* x */ UNION ALL SELECT * FROM t2",
+ )
+ self.validate(
+ "SELECT * FROM t1\n/*x*/\nINTERSECT ALL SELECT * FROM t2",
+ "SELECT * FROM t1 /* x */ INTERSECT ALL SELECT * FROM t2",
+ )
+ self.validate(
+ "SELECT\n foo\n/* comments */\n;",
+ "SELECT foo /* comments */",
+ )
self.validate(
"SELECT * FROM a INNER /* comments */ JOIN b",
"SELECT * FROM a /* comments */ INNER JOIN b",
@@ -379,6 +393,47 @@ LEFT OUTER JOIN b""",
FROM tbl""",
pretty=True,
)
+ self.validate(
+ """
+SELECT
+ 'hotel1' AS hotel,
+ *
+FROM dw_1_dw_1_1.exactonline_1.transactionlines
+/*
+ UNION ALL
+ SELECT
+ 'Thon Partner Hotel Jølster' AS hotel,
+ name,
+ date,
+ CAST(identifier AS VARCHAR) AS identifier,
+ value
+ FROM d2o_889_oupjr_1348.public.accountvalues_forecast
+*/
+UNION ALL
+SELECT
+ 'hotel2' AS hotel,
+ *
+FROM dw_1_dw_1_1.exactonline_2.transactionlines""",
+ """SELECT
+ 'hotel1' AS hotel,
+ *
+FROM dw_1_dw_1_1.exactonline_1.transactionlines /*
+ UNION ALL
+ SELECT
+ 'Thon Partner Hotel Jølster' AS hotel,
+ name,
+ date,
+ CAST(identifier AS VARCHAR) AS identifier,
+ value
+ FROM d2o_889_oupjr_1348.public.accountvalues_forecast
+*/
+UNION ALL
+SELECT
+ 'hotel2' AS hotel,
+ *
+FROM dw_1_dw_1_1.exactonline_2.transactionlines""",
+ pretty=True,
+ )
def test_types(self):
self.validate("INT 1", "CAST(1 AS INT)")