summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/dialects/test_bigquery.py40
-rw-r--r--tests/dialects/test_clickhouse.py43
-rw-r--r--tests/dialects/test_databricks.py7
-rw-r--r--tests/dialects/test_dialect.py77
-rw-r--r--tests/dialects/test_duckdb.py9
-rw-r--r--tests/dialects/test_mysql.py64
-rw-r--r--tests/dialects/test_oracle.py16
-rw-r--r--tests/dialects/test_postgres.py22
-rw-r--r--tests/dialects/test_presto.py2
-rw-r--r--tests/dialects/test_redshift.py6
-rw-r--r--tests/dialects/test_snowflake.py111
-rw-r--r--tests/dialects/test_spark.py1
-rw-r--r--tests/dialects/test_tsql.py60
-rw-r--r--tests/fixtures/identity.sql6
-rw-r--r--tests/fixtures/optimizer/canonicalize.sql12
-rw-r--r--tests/fixtures/optimizer/normalize_identifiers.sql8
-rw-r--r--tests/fixtures/optimizer/optimizer.sql22
-rw-r--r--tests/fixtures/optimizer/simplify.sql23
-rw-r--r--tests/fixtures/optimizer/tpc-ds/tpc-ds.sql6
-rw-r--r--tests/test_executor.py2
-rw-r--r--tests/test_expressions.py14
-rw-r--r--tests/test_optimizer.py47
-rw-r--r--tests/test_parser.py25
-rw-r--r--tests/test_transpile.py33
24 files changed, 574 insertions, 82 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index 1f5f902..8d172ea 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -122,6 +122,14 @@ class TestBigQuery(Validator):
"""SELECT JSON '"foo"' AS json_data""",
"""SELECT PARSE_JSON('"foo"') AS json_data""",
)
+ self.validate_identity(
+ "CREATE OR REPLACE TABLE `a.b.c` CLONE `a.b.d`",
+ "CREATE OR REPLACE TABLE a.b.c CLONE a.b.d",
+ )
+ self.validate_identity(
+ "SELECT * FROM UNNEST(x) WITH OFFSET EXCEPT DISTINCT SELECT * FROM UNNEST(y) WITH OFFSET",
+ "SELECT * FROM UNNEST(x) WITH OFFSET AS offset EXCEPT DISTINCT SELECT * FROM UNNEST(y) WITH OFFSET AS offset",
+ )
self.validate_all("SELECT SPLIT(foo)", write={"bigquery": "SELECT SPLIT(foo, ',')"})
self.validate_all("SELECT 1 AS hash", write={"bigquery": "SELECT 1 AS `hash`"})
@@ -131,6 +139,35 @@ class TestBigQuery(Validator):
self.validate_all("x <> ''''''", write={"bigquery": "x <> ''"})
self.validate_all("CAST(x AS DATETIME)", read={"": "x::timestamp"})
self.validate_all(
+ "SELECT '\\n'",
+ read={
+ "bigquery": "SELECT '''\n'''",
+ },
+ write={
+ "bigquery": "SELECT '\\n'",
+ "postgres": "SELECT '\n'",
+ },
+ )
+ self.validate_all(
+ "TRIM(item, '*')",
+ read={
+ "snowflake": "TRIM(item, '*')",
+ "spark": "TRIM('*', item)",
+ },
+ write={
+ "bigquery": "TRIM(item, '*')",
+ "snowflake": "TRIM(item, '*')",
+ "spark": "TRIM('*' FROM item)",
+ },
+ )
+ self.validate_all(
+ "CREATE OR REPLACE TABLE `a.b.c` COPY `a.b.d`",
+ write={
+ "bigquery": "CREATE OR REPLACE TABLE a.b.c COPY a.b.d",
+ "snowflake": "CREATE OR REPLACE TABLE a.b.c CLONE a.b.d",
+ },
+ )
+ self.validate_all(
"SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)",
write={
"bigquery": "SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)",
@@ -608,6 +645,9 @@ class TestBigQuery(Validator):
"postgres": "CURRENT_DATE AT TIME ZONE 'UTC'",
},
)
+ self.validate_identity(
+ "SELECT * FROM test QUALIFY a IS DISTINCT FROM b WINDOW c AS (PARTITION BY d)"
+ )
self.validate_all(
"SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10",
write={
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index 40a270e..948c00e 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -24,6 +24,9 @@ class TestClickhouse(Validator):
self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)")
self.assertIsNone(expr._meta)
+ self.validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 0.01)")
+ self.validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 1 / 10 OFFSET 1 / 2)")
+ self.validate_identity("SELECT sum(foo * bar) FROM bla SAMPLE 10000000")
self.validate_identity("CAST(x AS Nested(ID UInt32, Serial UInt32, EventTime DATETIME))")
self.validate_identity("CAST(x AS Enum('hello' = 1, 'world' = 2))")
self.validate_identity("CAST(x AS Enum('hello', 'world'))")
@@ -83,6 +86,16 @@ class TestClickhouse(Validator):
)
self.validate_all(
+ "SELECT '\\0'",
+ read={
+ "mysql": "SELECT '\0'",
+ },
+ write={
+ "clickhouse": "SELECT '\\0'",
+ "mysql": "SELECT '\0'",
+ },
+ )
+ self.validate_all(
"DATE_ADD('day', 1, x)",
read={
"clickhouse": "dateAdd(day, 1, x)",
@@ -224,6 +237,33 @@ class TestClickhouse(Validator):
self.validate_identity(
"SELECT s, arr_external FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external"
)
+ self.validate_all(
+ "SELECT quantile(0.5)(a)",
+ read={"duckdb": "SELECT quantile(a, 0.5)"},
+ write={"clickhouse": "SELECT quantile(0.5)(a)"},
+ )
+ self.validate_all(
+ "SELECT quantiles(0.5, 0.4)(a)",
+ read={"duckdb": "SELECT quantile(a, [0.5, 0.4])"},
+ write={"clickhouse": "SELECT quantiles(0.5, 0.4)(a)"},
+ )
+ self.validate_all(
+ "SELECT quantiles(0.5)(a)",
+ read={"duckdb": "SELECT quantile(a, [0.5])"},
+ write={"clickhouse": "SELECT quantiles(0.5)(a)"},
+ )
+
+ self.validate_identity("SELECT isNaN(x)")
+ self.validate_all(
+ "SELECT IS_NAN(x), ISNAN(x)",
+ write={"clickhouse": "SELECT isNaN(x), isNaN(x)"},
+ )
+
+ self.validate_identity("SELECT startsWith('a', 'b')")
+ self.validate_all(
+ "SELECT STARTS_WITH('a', 'b'), STARTSWITH('a', 'b')",
+ write={"clickhouse": "SELECT startsWith('a', 'b'), startsWith('a', 'b')"},
+ )
def test_cte(self):
self.validate_identity("WITH 'x' AS foo SELECT foo")
@@ -305,6 +345,9 @@ class TestClickhouse(Validator):
def test_ddl(self):
self.validate_identity(
+ 'CREATE TABLE data5 ("x" UInt32, "y" UInt32) ENGINE=MergeTree ORDER BY (round(y / 1000000000), cityHash64(x)) SAMPLE BY cityHash64(x)'
+ )
+ self.validate_identity(
"CREATE TABLE foo (x UInt32) TTL time_column + INTERVAL '1' MONTH DELETE WHERE column = 'value'"
)
diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py
index 3df968b..7c03c83 100644
--- a/tests/dialects/test_databricks.py
+++ b/tests/dialects/test_databricks.py
@@ -32,6 +32,13 @@ class TestDatabricks(Validator):
"CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))",
write={
"databricks": "CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(TO_DATE(y))))",
+ "tsql": "CREATE TABLE foo (x AS YEAR(CAST(y AS DATE)))",
+ },
+ )
+ self.validate_all(
+ "CREATE TABLE t1 AS (SELECT c FROM t2)",
+ read={
+ "teradata": "CREATE TABLE t1 AS (SELECT c FROM t2) WITH DATA",
},
)
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 3e0ffd5..91eba17 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -5,6 +5,7 @@ from sqlglot import (
Dialects,
ErrorLevel,
ParseError,
+ TokenError,
UnsupportedError,
parse_one,
)
@@ -308,6 +309,44 @@ class TestDialect(Validator):
read={"postgres": "INET '127.0.0.1/32'"},
)
+ def test_heredoc_strings(self):
+ for dialect in ("clickhouse", "postgres", "redshift"):
+ # Invalid matching tag
+ with self.assertRaises(TokenError):
+ parse_one("SELECT $tag1$invalid heredoc string$tag2$", dialect=dialect)
+
+ # Unmatched tag
+ with self.assertRaises(TokenError):
+ parse_one("SELECT $tag1$invalid heredoc string", dialect=dialect)
+
+ # Without tag
+ self.validate_all(
+ "SELECT 'this is a heredoc string'",
+ read={
+ dialect: "SELECT $$this is a heredoc string$$",
+ },
+ )
+ self.validate_all(
+ "SELECT ''",
+ read={
+ dialect: "SELECT $$$$",
+ },
+ )
+
+ # With tag
+ self.validate_all(
+ "SELECT 'this is also a heredoc string'",
+ read={
+ dialect: "SELECT $foo$this is also a heredoc string$foo$",
+ },
+ )
+ self.validate_all(
+ "SELECT ''",
+ read={
+ dialect: "SELECT $foo$$foo$",
+ },
+ )
+
def test_decode(self):
self.validate_identity("DECODE(bin, charset)")
@@ -568,6 +607,7 @@ class TestDialect(Validator):
"presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)",
"snowflake": "CAST(x AS DATE)",
"doris": "TO_DATE(x)",
+ "mysql": "DATE(x)",
},
)
self.validate_all(
@@ -648,9 +688,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_ADD(x, 1, 'DAY')",
read={
- "mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
"snowflake": "DATEADD('DAY', 1, x)",
- "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
},
write={
"bigquery": "DATE_ADD(x, INTERVAL 1 DAY)",
@@ -842,6 +880,7 @@ class TestDialect(Validator):
"hive": "DATE_ADD('2021-02-01', 1)",
"presto": "DATE_ADD('DAY', 1, CAST(CAST('2021-02-01' AS TIMESTAMP) AS DATE))",
"spark": "DATE_ADD('2021-02-01', 1)",
+ "mysql": "DATE_ADD('2021-02-01', INTERVAL 1 DAY)",
},
)
self.validate_all(
@@ -897,10 +936,7 @@ class TestDialect(Validator):
"bigquery",
"drill",
"duckdb",
- "mysql",
"presto",
- "starrocks",
- "doris",
)
},
write={
@@ -913,8 +949,25 @@ class TestDialect(Validator):
"presto",
"hive",
"spark",
+ )
+ },
+ )
+ self.validate_all(
+ f"{unit}(TS_OR_DS_TO_DATE(x))",
+ read={
+ dialect: f"{unit}(x)"
+ for dialect in (
+ "mysql",
+ "doris",
"starrocks",
+ )
+ },
+ write={
+ dialect: f"{unit}(x)"
+ for dialect in (
+ "mysql",
"doris",
+ "starrocks",
)
},
)
@@ -1790,3 +1843,17 @@ SELECT
with self.assertRaises(ParseError):
parse_one("CAST(x AS some_udt)", read="bigquery")
+
+ def test_qualify(self):
+ self.validate_all(
+ "SELECT * FROM t QUALIFY COUNT(*) OVER () > 1",
+ write={
+ "duckdb": "SELECT * FROM t QUALIFY COUNT(*) OVER () > 1",
+ "snowflake": "SELECT * FROM t QUALIFY COUNT(*) OVER () > 1",
+ "clickhouse": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ "mysql": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ "oracle": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) _t WHERE _w > 1",
+ "postgres": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ "tsql": "SELECT * FROM (SELECT *, COUNT(*) OVER () AS _w FROM t) AS _t WHERE _w > 1",
+ },
+ )
diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py
index dbf0a87..240f6f9 100644
--- a/tests/dialects/test_duckdb.py
+++ b/tests/dialects/test_duckdb.py
@@ -10,6 +10,10 @@ class TestDuckDB(Validator):
parse_one("select * from t limit (select 5)").sql(dialect="duckdb"),
exp.select("*").from_("t").limit(exp.select("5").subquery()).sql(dialect="duckdb"),
)
+ self.assertEqual(
+ parse_one("select * from t offset (select 5)").sql(dialect="duckdb"),
+ exp.select("*").from_("t").offset(exp.select("5").subquery()).sql(dialect="duckdb"),
+ )
for struct_value in ("{'a': 1}", "struct_pack(a := 1)"):
self.validate_all(struct_value, write={"presto": UnsupportedError})
@@ -287,6 +291,8 @@ class TestDuckDB(Validator):
"duckdb": "STRUCT_EXTRACT(x, 'abc')",
"presto": "x.abc",
"hive": "x.abc",
+ "postgres": "x.abc",
+ "redshift": "x.abc",
"spark": "x.abc",
},
)
@@ -446,6 +452,7 @@ class TestDuckDB(Validator):
write={
"duckdb": "SELECT QUANTILE_CONT(x, q) FROM t",
"postgres": "SELECT PERCENTILE_CONT(q) WITHIN GROUP (ORDER BY x) FROM t",
+ "snowflake": "SELECT PERCENTILE_CONT(q) WITHIN GROUP (ORDER BY x) FROM t",
},
)
self.validate_all(
@@ -453,6 +460,7 @@ class TestDuckDB(Validator):
write={
"duckdb": "SELECT QUANTILE_DISC(x, q) FROM t",
"postgres": "SELECT PERCENTILE_DISC(q) WITHIN GROUP (ORDER BY x) FROM t",
+ "snowflake": "SELECT PERCENTILE_DISC(q) WITHIN GROUP (ORDER BY x) FROM t",
},
)
self.validate_all(
@@ -460,6 +468,7 @@ class TestDuckDB(Validator):
write={
"duckdb": "SELECT QUANTILE_CONT(x, 0.5) FROM t",
"postgres": "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM t",
+ "snowflake": "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM t",
},
)
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index 20f872c..11f921c 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -12,6 +12,8 @@ class TestMySQL(Validator):
self.validate_identity(f"CREATE TABLE t (id {t} UNSIGNED)")
self.validate_identity(f"CREATE TABLE t (id {t}(10) UNSIGNED)")
+ self.validate_identity("CREATE TABLE t (id DECIMAL(20, 4) UNSIGNED)")
+
self.validate_all(
"CREATE TABLE t (id INT UNSIGNED)",
write={
@@ -205,6 +207,9 @@ class TestMySQL(Validator):
)
self.validate_identity("INTERVAL '1' YEAR")
self.validate_identity("DATE_ADD(x, INTERVAL 1 YEAR)")
+ self.validate_identity("CHAR(0)")
+ self.validate_identity("CHAR(77, 121, 83, 81, '76')")
+ self.validate_identity("CHAR(77, 77.3, '77.3' USING utf8mb4)")
def test_types(self):
self.validate_identity("CAST(x AS MEDIUMINT) + CAST(y AS YEAR(4))")
@@ -244,6 +249,13 @@ class TestMySQL(Validator):
self.validate_identity(
"SELECT WEEK_OF_YEAR('2023-01-01')", "SELECT WEEKOFYEAR('2023-01-01')"
)
+ self.validate_all(
+ "CHAR(10)",
+ write={
+ "mysql": "CHAR(10)",
+ "presto": "CHR(10)",
+ },
+ )
def test_escape(self):
self.validate_identity("""'"abc"'""")
@@ -496,6 +508,56 @@ class TestMySQL(Validator):
self.validate_identity("FROM_UNIXTIME(a, b)")
self.validate_identity("FROM_UNIXTIME(a, b, c)")
self.validate_identity("TIME_STR_TO_UNIX(x)", "UNIX_TIMESTAMP(x)")
+ self.validate_all(
+ "SELECT TO_DAYS(x)",
+ write={
+ "mysql": "SELECT (DATEDIFF(x, '0000-01-01') + 1)",
+ "presto": "SELECT (DATE_DIFF('DAY', CAST(CAST('0000-01-01' AS TIMESTAMP) AS DATE), CAST(CAST(x AS TIMESTAMP) AS DATE)) + 1)",
+ },
+ )
+ self.validate_all(
+ "SELECT DATEDIFF(x, y)",
+ write={"mysql": "SELECT DATEDIFF(x, y)", "presto": "SELECT DATE_DIFF('day', y, x)"},
+ )
+ self.validate_all(
+ "DAYOFYEAR(x)",
+ write={
+ "mysql": "DAYOFYEAR(x)",
+ "": "DAY_OF_YEAR(TS_OR_DS_TO_DATE(x))",
+ },
+ )
+ self.validate_all(
+ "DAYOFMONTH(x)",
+ write={"mysql": "DAYOFMONTH(x)", "": "DAY_OF_MONTH(TS_OR_DS_TO_DATE(x))"},
+ )
+ self.validate_all(
+ "DAYOFWEEK(x)",
+ write={"mysql": "DAYOFWEEK(x)", "": "DAY_OF_WEEK(TS_OR_DS_TO_DATE(x))"},
+ )
+ self.validate_all(
+ "WEEKOFYEAR(x)",
+ write={"mysql": "WEEKOFYEAR(x)", "": "WEEK_OF_YEAR(TS_OR_DS_TO_DATE(x))"},
+ )
+ self.validate_all(
+ "DAY(x)",
+ write={"mysql": "DAY(x)", "": "DAY(TS_OR_DS_TO_DATE(x))"},
+ )
+ self.validate_all(
+ "WEEK(x)",
+ write={"mysql": "WEEK(x)", "": "WEEK(TS_OR_DS_TO_DATE(x))"},
+ )
+ self.validate_all(
+ "YEAR(x)",
+ write={"mysql": "YEAR(x)", "": "YEAR(TS_OR_DS_TO_DATE(x))"},
+ )
+ self.validate_all(
+ "DATE(x)",
+ read={"": "TS_OR_DS_TO_DATE(x)"},
+ )
+ self.validate_all(
+ "STR_TO_DATE(x, '%M')",
+ read={"": "TS_OR_DS_TO_DATE(x, '%B')"},
+ )
def test_mysql(self):
self.validate_all(
@@ -896,7 +958,7 @@ COMMENT='客户账户表'"""
self.validate_all(
"MONTHNAME(x)",
write={
- "": "TIME_TO_STR(x, '%B')",
+ "": "TIME_TO_STR(TS_OR_DS_TO_DATE(x), '%B')",
"mysql": "DATE_FORMAT(x, '%M')",
},
)
diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py
index 675ee8a..5572ec1 100644
--- a/tests/dialects/test_oracle.py
+++ b/tests/dialects/test_oracle.py
@@ -22,8 +22,6 @@ class TestOracle(Validator):
self.validate_identity("SELECT * FROM t FOR UPDATE OF s.t.c, s.t.v SKIP LOCKED")
self.validate_identity("SELECT STANDARD_HASH('hello')")
self.validate_identity("SELECT STANDARD_HASH('hello', 'MD5')")
- self.validate_identity("SELECT CAST(NULL AS VARCHAR2(2328 CHAR)) AS COL1")
- self.validate_identity("SELECT CAST(NULL AS VARCHAR2(2328 BYTE)) AS COL1")
self.validate_identity("SELECT * FROM table_name@dblink_name.database_link_domain")
self.validate_identity("SELECT * FROM table_name SAMPLE (25) s")
self.validate_identity("SELECT * FROM V$SESSION")
@@ -61,6 +59,20 @@ class TestOracle(Validator):
)
self.validate_all(
+ "SELECT CAST(NULL AS VARCHAR2(2328 CHAR)) AS COL1",
+ write={
+ "oracle": "SELECT CAST(NULL AS VARCHAR2(2328 CHAR)) AS COL1",
+ "spark": "SELECT CAST(NULL AS VARCHAR(2328)) AS COL1",
+ },
+ )
+ self.validate_all(
+ "SELECT CAST(NULL AS VARCHAR2(2328 BYTE)) AS COL1",
+ write={
+ "oracle": "SELECT CAST(NULL AS VARCHAR2(2328 BYTE)) AS COL1",
+ "spark": "SELECT CAST(NULL AS VARCHAR(2328)) AS COL1",
+ },
+ )
+ self.validate_all(
"NVL(NULL, 1)",
write={
"": "COALESCE(NULL, 1)",
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 6a3df47..0ddc106 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -10,6 +10,9 @@ class TestPostgres(Validator):
def test_ddl(self):
self.validate_identity(
+ "CREATE INDEX foo ON bar.baz USING btree(col1 varchar_pattern_ops ASC, col2)"
+ )
+ self.validate_identity(
"CREATE TABLE test (x TIMESTAMP WITHOUT TIME ZONE[][])",
"CREATE TABLE test (x TIMESTAMP[][])",
)
@@ -149,15 +152,27 @@ class TestPostgres(Validator):
)
def test_postgres(self):
- expr = parse_one("SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)")
+ expr = parse_one(
+ "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", read="postgres"
+ )
unnest = expr.args["joins"][0].this.this
unnest.assert_is(exp.Unnest)
alter_table_only = """ALTER TABLE ONLY "Album" ADD CONSTRAINT "FK_AlbumArtistId" FOREIGN KEY ("ArtistId") REFERENCES "Artist" ("ArtistId") ON DELETE NO ACTION ON UPDATE NO ACTION"""
- expr = parse_one(alter_table_only)
+ expr = parse_one(alter_table_only, read="postgres")
# Checks that user-defined types are parsed into DataType instead of Identifier
- parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is(exp.DataType)
+ parse_one("CREATE TABLE t (a udt)", read="postgres").this.expressions[0].args[
+ "kind"
+ ].assert_is(exp.DataType)
+
+ # Checks that OID is parsed into a DataType (ObjectIdentifier)
+ self.assertIsInstance(
+ parse_one("CREATE TABLE public.propertydata (propertyvalue oid)", read="postgres").find(
+ exp.DataType
+ ),
+ exp.ObjectIdentifier,
+ )
self.assertIsInstance(expr, exp.AlterTable)
self.assertEqual(expr.sql(dialect="postgres"), alter_table_only)
@@ -192,7 +207,6 @@ class TestPostgres(Validator):
self.validate_identity("SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]")
self.validate_identity("SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]")
self.validate_identity("SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]")
- self.validate_identity("$x")
self.validate_identity("x$")
self.validate_identity("SELECT ARRAY[1, 2, 3]")
self.validate_identity("SELECT ARRAY(SELECT 1)")
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index a80013e..8edd31c 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -300,7 +300,6 @@ class TestPresto(Validator):
write={
"presto": "DATE_ADD('DAY', 1 * -1, x)",
},
- read={"mysql": "DATE_SUB(x, INTERVAL 1 DAY)"},
)
self.validate_all(
"NOW()",
@@ -503,6 +502,7 @@ class TestPresto(Validator):
@mock.patch("sqlglot.helper.logger")
def test_presto(self, logger):
+ self.validate_identity("string_agg(x, ',')", "ARRAY_JOIN(ARRAY_AGG(x), ',')")
self.validate_identity(
"SELECT * FROM example.testdb.customer_orders FOR VERSION AS OF 8954597067493422955"
)
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index c75654c..ae1b987 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -6,6 +6,11 @@ class TestRedshift(Validator):
dialect = "redshift"
def test_redshift(self):
+ self.validate_identity(
+ "SELECT 'a''b'",
+ "SELECT 'a\\'b'",
+ )
+
self.validate_all(
"x ~* 'pat'",
write={
@@ -226,7 +231,6 @@ class TestRedshift(Validator):
self.validate_identity("SELECT * FROM #x")
self.validate_identity("SELECT INTERVAL '5 day'")
self.validate_identity("foo$")
- self.validate_identity("$foo")
self.validate_identity("CAST('bla' AS SUPER)")
self.validate_identity("CREATE TABLE real1 (realcol REAL)")
self.validate_identity("CAST('foo' AS HLLSKETCH)")
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index a217394..7c36bea 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -1,6 +1,7 @@
from unittest import mock
from sqlglot import UnsupportedError, exp, parse_one
+from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from tests.dialects.test_dialect import Validator
@@ -8,34 +9,6 @@ class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
- self.validate_identity(
- 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage'
- )
-
- self.validate_all(
- "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
- read={
- "oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
- },
- write={
- "oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
- "snowflake": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
- },
- )
- self.validate_all(
- "SELECT INSERT(a, 0, 0, 'b')",
- read={
- "mysql": "SELECT INSERT(a, 0, 0, 'b')",
- "snowflake": "SELECT INSERT(a, 0, 0, 'b')",
- "tsql": "SELECT STUFF(a, 0, 0, 'b')",
- },
- write={
- "mysql": "SELECT INSERT(a, 0, 0, 'b')",
- "snowflake": "SELECT INSERT(a, 0, 0, 'b')",
- "tsql": "SELECT STUFF(a, 0, 0, 'b')",
- },
- )
-
self.validate_identity("LISTAGG(data['some_field'], ',')")
self.validate_identity("WEEKOFYEAR(tstamp)")
self.validate_identity("SELECT SUM(amount) FROM mytable GROUP BY ALL")
@@ -54,7 +27,6 @@ class TestSnowflake(Validator):
self.validate_identity("$x") # parameter
self.validate_identity("a$b") # valid snowflake identifier
self.validate_identity("SELECT REGEXP_LIKE(a, b, c)")
- self.validate_identity("PUT file:///dir/tmp.csv @%table")
self.validate_identity("CREATE TABLE foo (bar FLOAT AUTOINCREMENT START 0 INCREMENT 1)")
self.validate_identity("ALTER TABLE IF EXISTS foo SET TAG a = 'a', b = 'b', c = 'c'")
self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c")
@@ -65,12 +37,16 @@ class TestSnowflake(Validator):
self.validate_identity("SELECT CONVERT_TIMEZONE('UTC', 'America/Los_Angeles', col)")
self.validate_identity("REGEXP_REPLACE('target', 'pattern', '\n')")
self.validate_identity(
- 'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)'
+ 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage'
)
self.validate_identity(
"SELECT state, city, SUM(retail_price * quantity) AS gross_revenue FROM sales GROUP BY ALL"
)
self.validate_identity(
+ "SELECT * FROM foo window",
+ "SELECT * FROM foo AS window",
+ )
+ self.validate_identity(
r"SELECT RLIKE(a, $$regular expression with \ characters: \d{2}-\d{3}-\d{4}$$, 'i') FROM log_source",
r"SELECT REGEXP_LIKE(a, 'regular expression with \\ characters: \\d{2}-\\d{3}-\\d{4}', 'i') FROM log_source",
)
@@ -88,6 +64,36 @@ class TestSnowflake(Validator):
self.validate_all("CAST(x AS CHARACTER VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"})
self.validate_all("CAST(x AS NCHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"})
self.validate_all(
+ "SELECT COLLATE('B', 'und:ci')",
+ write={
+ "bigquery": "SELECT COLLATE('B', 'und:ci')",
+ "snowflake": "SELECT COLLATE('B', 'und:ci')",
+ },
+ )
+ self.validate_all(
+ "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
+ read={
+ "oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
+ },
+ write={
+ "oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
+ "snowflake": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
+ },
+ )
+ self.validate_all(
+ "SELECT INSERT(a, 0, 0, 'b')",
+ read={
+ "mysql": "SELECT INSERT(a, 0, 0, 'b')",
+ "snowflake": "SELECT INSERT(a, 0, 0, 'b')",
+ "tsql": "SELECT STUFF(a, 0, 0, 'b')",
+ },
+ write={
+ "mysql": "SELECT INSERT(a, 0, 0, 'b')",
+ "snowflake": "SELECT INSERT(a, 0, 0, 'b')",
+ "tsql": "SELECT STUFF(a, 0, 0, 'b')",
+ },
+ )
+ self.validate_all(
"ARRAY_GENERATE_RANGE(0, 3)",
write={
"bigquery": "GENERATE_ARRAY(0, 3 - 1)",
@@ -513,6 +519,40 @@ class TestSnowflake(Validator):
},
)
+ def test_staged_files(self):
+ # Ensure we don't treat staged file paths as identifiers (i.e. they're not normalized)
+ staged_file = parse_one("SELECT * FROM @foo", read="snowflake")
+ self.assertEqual(
+ normalize_identifiers(staged_file, dialect="snowflake").sql(dialect="snowflake"),
+ staged_file.sql(dialect="snowflake"),
+ )
+
+ self.validate_identity("SELECT * FROM @~")
+ self.validate_identity("SELECT * FROM @~/some/path/to/file.csv")
+ self.validate_identity("SELECT * FROM @mystage")
+ self.validate_identity("SELECT * FROM '@mystage'")
+ self.validate_identity("SELECT * FROM @namespace.mystage/path/to/file.json.gz")
+ self.validate_identity("SELECT * FROM @namespace.%table_name/path/to/file.json.gz")
+ self.validate_identity("SELECT * FROM '@external/location' (FILE_FORMAT => 'path.to.csv')")
+ self.validate_identity("PUT file:///dir/tmp.csv @%table")
+ self.validate_identity(
+ 'COPY INTO NEW_TABLE ("foo", "bar") FROM (SELECT $1, $2, $3, $4 FROM @%old_table)'
+ )
+ self.validate_identity(
+ "SELECT * FROM @foo/bar (FILE_FORMAT => ds_sandbox.test.my_csv_format, PATTERN => 'test') AS bla"
+ )
+ self.validate_identity(
+ "SELECT t.$1, t.$2 FROM @mystage1 (FILE_FORMAT => 'myformat', PATTERN => '.*data.*[.]csv.gz') AS t"
+ )
+ self.validate_identity(
+ "SELECT parse_json($1):a.b FROM @mystage2/data1.json.gz",
+ "SELECT PARSE_JSON($1)['a'].b FROM @mystage2/data1.json.gz",
+ )
+ self.validate_identity(
+ "SELECT * FROM @mystage t (c1)",
+ "SELECT * FROM @mystage AS t(c1)",
+ )
+
def test_sample(self):
self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE (100)")
@@ -660,7 +700,6 @@ class TestSnowflake(Validator):
self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x")
self.validate_identity("CREATE DATABASE mytestdb_clone CLONE mytestdb")
self.validate_identity("CREATE SCHEMA mytestschema_clone CLONE testschema")
- self.validate_identity("CREATE TABLE orders_clone CLONE orders")
self.validate_identity("CREATE TABLE IDENTIFIER('foo') (COLUMN1 VARCHAR, COLUMN2 VARCHAR)")
self.validate_identity("CREATE TABLE IDENTIFIER($foo) (col1 VARCHAR, col2 VARCHAR)")
self.validate_identity(
@@ -680,6 +719,16 @@ class TestSnowflake(Validator):
)
self.validate_all(
+ "CREATE TABLE orders_clone CLONE orders",
+ read={
+ "bigquery": "CREATE TABLE orders_clone CLONE orders",
+ },
+ write={
+ "bigquery": "CREATE TABLE orders_clone CLONE orders",
+ "snowflake": "CREATE TABLE orders_clone CLONE orders",
+ },
+ )
+ self.validate_all(
"CREATE OR REPLACE TRANSIENT TABLE a (id INT)",
read={
"postgres": "CREATE OR REPLACE TRANSIENT TABLE a (id INT)",
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index 2e43ba5..0148e55 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -8,6 +8,7 @@ class TestSpark(Validator):
dialect = "spark"
def test_ddl(self):
+ self.validate_identity("CREATE TEMPORARY VIEW test AS SELECT 1")
self.validate_identity("CREATE TABLE foo (col VARCHAR(50))")
self.validate_identity("CREATE TABLE foo (col STRUCT<struct_col_a: VARCHAR((50))>)")
self.validate_identity("CREATE TABLE foo (col STRING) CLUSTERED BY (col) INTO 10 BUCKETS")
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index f76894d..7d89d06 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -7,6 +7,45 @@ class TestTSQL(Validator):
def test_tsql(self):
self.validate_all(
+ "CREATE TABLE #mytemptable (a INTEGER)",
+ read={
+ "duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)",
+ },
+ write={
+ "tsql": "CREATE TABLE #mytemptable (a INTEGER)",
+ "snowflake": "CREATE TEMPORARY TABLE mytemptable (a INT)",
+ "duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)",
+ "oracle": "CREATE TEMPORARY TABLE mytemptable (a NUMBER)",
+ "hive": "CREATE TEMPORARY TABLE mytemptable (a INT)",
+ "spark2": "CREATE TEMPORARY TABLE mytemptable (a INT) USING PARQUET",
+ "spark": "CREATE TEMPORARY TABLE mytemptable (a INT) USING PARQUET",
+ "databricks": "CREATE TEMPORARY TABLE mytemptable (a INT) USING PARQUET",
+ },
+ )
+ self.validate_all(
+ "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))",
+ write={
+ "spark": "CREATE TEMPORARY TABLE mytemp (a INT, b CHAR(2), c TIMESTAMP, d FLOAT) USING PARQUET",
+ "tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))",
+ },
+ )
+ self.validate_all(
+ """CREATE TABLE [dbo].[mytable](
+ [email] [varchar](255) NOT NULL,
+ CONSTRAINT [UN_t_mytable] UNIQUE NONCLUSTERED
+ (
+ [email] ASC
+ )
+ )""",
+ write={
+ "hive": "CREATE TABLE `dbo`.`mytable` (`email` VARCHAR(255) NOT NULL)",
+ "spark2": "CREATE TABLE `dbo`.`mytable` (`email` VARCHAR(255) NOT NULL)",
+ "spark": "CREATE TABLE `dbo`.`mytable` (`email` VARCHAR(255) NOT NULL)",
+ "databricks": "CREATE TABLE `dbo`.`mytable` (`email` VARCHAR(255) NOT NULL)",
+ },
+ )
+
+ self.validate_all(
"CREATE TABLE x ( A INTEGER NOT NULL, B INTEGER NULL )",
write={
"tsql": "CREATE TABLE x (A INTEGER NOT NULL, B INTEGER NULL)",
@@ -492,6 +531,10 @@ class TestTSQL(Validator):
)
def test_ddl(self):
+ self.validate_identity(
+ "CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < CURRENT_TIMESTAMP - 7 END",
+ "CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < GETDATE() - 7 END",
+ )
self.validate_all(
"CREATE TABLE tbl (id INTEGER IDENTITY PRIMARY KEY)",
read={
@@ -505,6 +548,9 @@ class TestTSQL(Validator):
"postgres": "CREATE TABLE tbl (id INT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10) PRIMARY KEY)",
"tsql": "CREATE TABLE tbl (id INTEGER NOT NULL IDENTITY(10, 1) PRIMARY KEY)",
},
+ write={
+ "databricks": "CREATE TABLE tbl (id BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1) PRIMARY KEY)",
+ },
)
self.validate_all(
"SELECT * INTO foo.bar.baz FROM (SELECT * FROM a.b.c) AS temp",
@@ -561,22 +607,10 @@ class TestTSQL(Validator):
self.validate_all(
"CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))",
write={
- "spark": "CREATE TEMPORARY TABLE mytemp (a INT, b CHAR(2), c TIMESTAMP, d FLOAT)",
+ "spark": "CREATE TEMPORARY TABLE mytemp (a INT, b CHAR(2), c TIMESTAMP, d FLOAT) USING PARQUET",
"tsql": "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))",
},
)
- self.validate_all(
- "CREATE TABLE #mytemptable (a INTEGER)",
- read={
- "duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)",
- },
- write={
- "tsql": "CREATE TABLE #mytemptable (a INTEGER)",
- "snowflake": "CREATE TEMPORARY TABLE mytemptable (a INT)",
- "duckdb": "CREATE TEMPORARY TABLE mytemptable (a INT)",
- "oracle": "CREATE TEMPORARY TABLE mytemptable (a NUMBER)",
- },
- )
def test_insert_cte(self):
self.validate_all(
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 17506e4..2738707 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -771,8 +771,8 @@ ALTER TABLE integers DROP COLUMN k
ALTER TABLE integers DROP PRIMARY KEY
ALTER TABLE integers DROP COLUMN IF EXISTS k
ALTER TABLE integers DROP COLUMN k CASCADE
-ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR
-ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR USING CONCAT(i, '_', j)
+ALTER TABLE integers ALTER COLUMN i SET DATA TYPE VARCHAR
+ALTER TABLE integers ALTER COLUMN i SET DATA TYPE VARCHAR USING CONCAT(i, '_', j)
ALTER TABLE integers ALTER COLUMN i SET DEFAULT 10
ALTER TABLE integers ALTER COLUMN i DROP DEFAULT
ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B
@@ -864,3 +864,5 @@ SELECT x FROM y ORDER BY x ASC
KILL '123'
KILL CONNECTION 123
KILL QUERY '123'
+CHR(97)
+SELECT * FROM UNNEST(x) WITH ORDINALITY UNION ALL SELECT * FROM UNNEST(y) WITH ORDINALITY
diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql
index e27b2d3..2ba762d 100644
--- a/tests/fixtures/optimizer/canonicalize.sql
+++ b/tests/fixtures/optimizer/canonicalize.sql
@@ -29,6 +29,12 @@ SELECT "x"."a" AS "a" FROM "x" AS "x" GROUP BY "x"."a" HAVING SUM("x"."b") <> 0
SELECT a FROM x WHERE 1;
SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE 1 <> 0;
+SELECT a FROM x WHERE COALESCE(0, 1);
+SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE COALESCE(0 <> 0, 1 <> 0);
+
+SELECT a FROM x WHERE CASE WHEN COALESCE(b, 1) THEN 1 ELSE 0 END;
+SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE CASE WHEN COALESCE("x"."b" <> 0, 1 <> 0) THEN 1 ELSE 0 END <> 0;
+
--------------------------------------
-- Replace date functions
--------------------------------------
@@ -40,3 +46,9 @@ CAST('2023-01-01' AS TIMESTAMP);
TIMESTAMP('2023-01-01', '12:00:00');
TIMESTAMP('2023-01-01', '12:00:00');
+
+DATE_ADD(CAST("x" AS DATE), 1, 'YEAR');
+DATE_ADD(CAST("x" AS DATE), 1, 'YEAR');
+
+DATE_ADD('2023-01-01', 1, 'YEAR');
+DATE_ADD(CAST('2023-01-01' AS DATE), 1, 'YEAR');
diff --git a/tests/fixtures/optimizer/normalize_identifiers.sql b/tests/fixtures/optimizer/normalize_identifiers.sql
index 2ab4778..4cb7dd1 100644
--- a/tests/fixtures/optimizer/normalize_identifiers.sql
+++ b/tests/fixtures/optimizer/normalize_identifiers.sql
@@ -62,3 +62,11 @@ SELECT a AS a FROM x UNION SELECT a AS a FROM x;
(SELECT A AS A FROM X);
(SELECT a AS a FROM x);
+
+# dialect: snowflake
+SELECT a /* sqlglot.meta case_sensitive */, b FROM table /* sqlglot.meta case_sensitive */;
+SELECT a /* sqlglot.meta case_sensitive */, B FROM table /* sqlglot.meta case_sensitive */;
+
+# dialect: redshift
+SELECT COALESCE(json_val.a /* sqlglot.meta case_sensitive */, json_val.A /* sqlglot.meta case_sensitive */) FROM table;
+SELECT COALESCE(json_val.a /* sqlglot.meta case_sensitive */, json_val.A /* sqlglot.meta case_sensitive */) FROM table;
diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql
index e59f14d..4cc62c9 100644
--- a/tests/fixtures/optimizer/optimizer.sql
+++ b/tests/fixtures/optimizer/optimizer.sql
@@ -1023,3 +1023,25 @@ SELECT
FROM "table1" AS "table1"
LEFT JOIN "alias3"
ON "table1"."cid" = "alias3"."cid";
+
+# title: CTE with EXPLODE cannot be merged
+# dialect: spark
+# execute: false
+SELECT Name,
+ FruitStruct.`$id`,
+ FruitStruct.value
+ FROM
+ (SELECT Name,
+ explode(Fruits) as FruitStruct
+ FROM fruits_table);
+WITH `_q_0` AS (
+ SELECT
+ `fruits_table`.`name` AS `name`,
+ EXPLODE(`fruits_table`.`fruits`) AS `fruitstruct`
+ FROM `fruits_table` AS `fruits_table`
+)
+SELECT
+ `_q_0`.`name` AS `name`,
+ `_q_0`.`fruitstruct`.`$id` AS `$id`,
+ `_q_0`.`fruitstruct`.`value` AS `value`
+FROM `_q_0` AS `_q_0`;
diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql
index 584e9d6..a9ae192 100644
--- a/tests/fixtures/optimizer/simplify.sql
+++ b/tests/fixtures/optimizer/simplify.sql
@@ -444,6 +444,9 @@ CAST('1998-09-02 00:00:00' AS DATETIME);
CAST(x AS DATETIME) + interval '1' week;
CAST(x AS DATETIME) + INTERVAL '1' week;
+TS_OR_DS_TO_DATE('1998-12-01 00:00:01') - interval '90' day;
+CAST('1998-09-02' AS DATE);
+
--------------------------------------
-- Comparisons
--------------------------------------
@@ -681,6 +684,9 @@ CONCAT('a', x, y, 'bc');
'a' || 'b' || x;
CONCAT('ab', x);
+CONCAT(a, b) IN (SELECT * FROM foo WHERE cond);
+CONCAT(a, b) IN (SELECT * FROM foo WHERE cond);
+
--------------------------------------
-- DATE_TRUNC
--------------------------------------
@@ -740,6 +746,9 @@ x >= CAST('2022-01-01' AS DATE);
DATE_TRUNC('year', x) > CAST('2021-01-02' AS DATE);
x >= CAST('2022-01-01' AS DATE);
+DATE_TRUNC('year', x) > TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE('2021-01-02'));
+x >= CAST('2022-01-01' AS DATE);
+
-- right is not a date
DATE_TRUNC('year', x) <> '2021-01-02';
DATE_TRUNC('year', x) <> '2021-01-02';
@@ -758,6 +767,17 @@ x < CAST('2022-01-01' AS DATE) AND x >= CAST('2021-01-01' AS DATE);
TIMESTAMP_TRUNC(x, YEAR) = CAST('2021-01-01' AS DATETIME);
x < CAST('2022-01-01 00:00:00' AS DATETIME) AND x >= CAST('2021-01-01 00:00:00' AS DATETIME);
+-- right side is not a date literal
+DATE_TRUNC('day', x) = CAST(y AS DATE);
+DATE_TRUNC('day', x) = CAST(y AS DATE);
+
+-- nested cast
+DATE_TRUNC('day', x) = CAST(CAST('2021-01-01 01:02:03' AS DATETIME) AS DATE);
+x < CAST('2021-01-02' AS DATE) AND x >= CAST('2021-01-01' AS DATE);
+
+TIMESTAMP_TRUNC(x, YEAR) = CAST(CAST('2021-01-01 01:02:03' AS DATE) AS DATETIME);
+x < CAST('2022-01-01 00:00:00' AS DATETIME) AND x >= CAST('2021-01-01 00:00:00' AS DATETIME);
+
--------------------------------------
-- EQUALITY
--------------------------------------
@@ -794,6 +814,9 @@ x = 2;
x - INTERVAL 1 DAY = CAST('2021-01-01' AS DATE);
x = CAST('2021-01-02' AS DATE);
+x - INTERVAL 1 DAY = TS_OR_DS_TO_DATE('2021-01-01 00:00:01');
+x = CAST('2021-01-02' AS DATE);
+
x - INTERVAL 1 HOUR > CAST('2021-01-01' AS DATETIME);
x > CAST('2021-01-01 01:00:00' AS DATETIME);
diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
index f50cf0b..2218182 100644
--- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
+++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
@@ -4793,10 +4793,10 @@ WITH "foo" AS (
"foo"."i_item_sk" AS "i_item_sk",
"foo"."d_moy" AS "d_moy",
"foo"."mean" AS "mean",
- CASE "foo"."mean" WHEN 0 THEN NULL ELSE "foo"."stdev" / "foo"."mean" END AS "cov"
+ CASE "foo"."mean" WHEN FALSE THEN NULL ELSE "foo"."stdev" / "foo"."mean" END AS "cov"
FROM "foo" AS "foo"
WHERE
- CASE "foo"."mean" WHEN 0 THEN 0 ELSE "foo"."stdev" / "foo"."mean" END > 1
+ CASE "foo"."mean" WHEN FALSE THEN 0 ELSE "foo"."stdev" / "foo"."mean" END > 1
)
SELECT
"inv1"."w_warehouse_sk" AS "w_warehouse_sk",
@@ -9775,7 +9775,7 @@ JOIN "date_dim" AS "d1"
ON "catalog_sales"."cs_sold_date_sk" = "d1"."d_date_sk"
AND "d1"."d_week_seq" = "d2"."d_week_seq"
AND "d1"."d_year" = 2002
- AND "d3"."d_date" > CONCAT("d1"."d_date", INTERVAL '5' day)
+ AND "d3"."d_date" > "d1"."d_date" + INTERVAL '5' day
GROUP BY
"item"."i_item_desc",
"warehouse"."w_warehouse_name",
diff --git a/tests/test_executor.py b/tests/test_executor.py
index ffe0229..c6b85c9 100644
--- a/tests/test_executor.py
+++ b/tests/test_executor.py
@@ -624,6 +624,8 @@ class TestExecutor(unittest.TestCase):
("LEFT('12345', 3)", "123"),
("RIGHT('12345', 3)", "345"),
("DATEDIFF('2022-01-03'::date, '2022-01-01'::TIMESTAMP::DATE)", 2),
+ ("TRIM(' foo ')", "foo"),
+ ("TRIM('afoob', 'ab')", "foo"),
]:
with self.subTest(sql):
result = execute(f"SELECT {sql}")
diff --git a/tests/test_expressions.py b/tests/test_expressions.py
index b1b5360..f8c8bcc 100644
--- a/tests/test_expressions.py
+++ b/tests/test_expressions.py
@@ -182,16 +182,21 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(parse_one("a.b.c").name, "c")
def test_table_name(self):
+ bq_dashed_table = exp.to_table("a-1.b.c", dialect="bigquery")
+ self.assertEqual(exp.table_name(bq_dashed_table), '"a-1".b.c')
+ self.assertEqual(exp.table_name(bq_dashed_table, dialect="bigquery"), "`a-1`.b.c")
+ self.assertEqual(exp.table_name("a-1.b.c", dialect="bigquery"), "`a-1`.b.c")
self.assertEqual(exp.table_name(parse_one("a", into=exp.Table)), "a")
self.assertEqual(exp.table_name(parse_one("a.b", into=exp.Table)), "a.b")
self.assertEqual(exp.table_name(parse_one("a.b.c", into=exp.Table)), "a.b.c")
self.assertEqual(exp.table_name("a.b.c"), "a.b.c")
+ self.assertEqual(exp.table_name(exp.to_table("a.b.c.d.e", dialect="bigquery")), "a.b.c.d.e")
+ self.assertEqual(exp.table_name(exp.to_table("'@foo'", dialect="snowflake")), "'@foo'")
+ self.assertEqual(exp.table_name(exp.to_table("@foo", dialect="snowflake")), "@foo")
self.assertEqual(
exp.table_name(parse_one("foo.`{bar,er}`", read="databricks"), dialect="databricks"),
"foo.`{bar,er}`",
)
- self.assertEqual(exp.table_name(exp.to_table("a-1.b.c", dialect="bigquery")), '"a-1".b.c')
- self.assertEqual(exp.table_name(exp.to_table("a.b.c.d.e", dialect="bigquery")), "a.b.c.d.e")
def test_table(self):
self.assertEqual(exp.table_("a", alias="b"), parse_one("select * from a b").find(exp.Table))
@@ -946,3 +951,8 @@ FROM foo""",
with self.assertRaises(ParseError):
exp.DataType.build("foo")
+
+ def test_set_meta(self):
+ query = parse_one("SELECT * FROM foo /* sqlglot.meta x = 1, y = a, z */")
+ self.assertEqual(query.find(exp.Table).meta, {"x": "1", "y": "a", "z": True})
+ self.assertEqual(query.sql(), "SELECT * FROM foo /* sqlglot.meta x = 1, y = a, z */")
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 8775852..8fc3273 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -546,6 +546,53 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT)
self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT)
+ def test_interval_math_annotation(self):
+ schema = {
+ "x": {
+ "a": "DATE",
+ "b": "DATETIME",
+ }
+ }
+ for sql, expected_type, *expected_sql in [
+ (
+ "SELECT '2023-01-01' + INTERVAL '1' DAY",
+ exp.DataType.Type.DATE,
+ "SELECT CAST('2023-01-01' AS DATE) + INTERVAL '1' DAY",
+ ),
+ (
+ "SELECT '2023-01-01' + INTERVAL '1' HOUR",
+ exp.DataType.Type.DATETIME,
+ "SELECT CAST('2023-01-01' AS DATETIME) + INTERVAL '1' HOUR",
+ ),
+ (
+ "SELECT '2023-01-01 00:00:01' + INTERVAL '1' HOUR",
+ exp.DataType.Type.DATETIME,
+ "SELECT CAST('2023-01-01 00:00:01' AS DATETIME) + INTERVAL '1' HOUR",
+ ),
+ ("SELECT 'nonsense' + INTERVAL '1' DAY", exp.DataType.Type.UNKNOWN),
+ ("SELECT x.a + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATE),
+ ("SELECT x.a + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME),
+ ("SELECT x.b + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATETIME),
+ ("SELECT x.b + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME),
+ (
+ "SELECT DATE_ADD('2023-01-01', 1, 'DAY')",
+ exp.DataType.Type.DATE,
+ "SELECT DATE_ADD(CAST('2023-01-01' AS DATE), 1, 'DAY')",
+ ),
+ (
+ "SELECT DATE_ADD('2023-01-01 00:00:00', 1, 'DAY')",
+ exp.DataType.Type.DATETIME,
+ "SELECT DATE_ADD(CAST('2023-01-01 00:00:00' AS DATETIME), 1, 'DAY')",
+ ),
+ ("SELECT DATE_ADD(x.a, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATE),
+ ("SELECT DATE_ADD(x.a, 1, 'HOUR') FROM x AS x", exp.DataType.Type.DATETIME),
+ ("SELECT DATE_ADD(x.b, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATETIME),
+ ]:
+ with self.subTest(sql):
+ expression = annotate_types(parse_one(sql), schema=schema)
+ self.assertEqual(expected_type, expression.expressions[0].type.this)
+ self.assertEqual(expected_sql[0] if expected_sql else sql, expression.sql())
+
def test_lateral_annotation(self):
expression = optimizer.optimize(
parse_one("SELECT c FROM (select 1 a) as x LATERAL VIEW EXPLODE (a) AS c")
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 74463fd..53e1a85 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -690,6 +690,31 @@ class TestParser(unittest.TestCase):
LEFT JOIN b ON a.id = b.id
"""
)
+
+ self.assertIsNotNone(query)
+
+ query = parse_one(
+ """
+ SELECT *
+ FROM a
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ LEFT JOIN UNNEST(ARRAY[])
+ """
+ )
+
self.assertIsNotNone(query)
self.assertLessEqual(time.time() - now, 0.2)
diff --git a/tests/test_transpile.py b/tests/test_transpile.py
index a5b1977..d588f07 100644
--- a/tests/test_transpile.py
+++ b/tests/test_transpile.py
@@ -156,9 +156,7 @@ SELECT * FROM foo
-- comment 2
-- comment 3
SELECT * FROM foo""",
- """/* comment 1 */
-/* comment 2 */
-/* comment 3 */
+ """/* comment 1 */ /* comment 2 */ /* comment 3 */
SELECT
*
FROM foo""",
@@ -182,8 +180,7 @@ line3*/ /*another comment*/ where 1=1 -- comment at the end""",
*
FROM tbl /* line1
line2
-line3 */
-/* another comment */
+line3 */ /* another comment */
WHERE
1 = 1 /* comment at the end */""",
pretty=True,
@@ -310,9 +307,7 @@ FROM v""",
-- comment3
DROP TABLE IF EXISTS db.tba
""",
- """/* comment1 */
-/* comment2 */
-/* comment3 */
+ """/* comment1 */ /* comment2 */ /* comment3 */
DROP TABLE IF EXISTS db.tba""",
pretty=True,
)
@@ -337,9 +332,7 @@ SELECT
c
FROM tb_01
WHERE
- a /* comment5 */ = 1 AND b = 2 /* comment6 */
- /* and c = 1 */
- /* comment7 */""",
+ a /* comment5 */ = 1 AND b = 2 /* comment6 */ /* and c = 1 */ /* comment7 */""",
pretty=True,
)
self.validate(
@@ -375,11 +368,17 @@ INNER JOIN b""",
"""SELECT
*
FROM a
-/* comment 1 */
-/* comment 2 */
+/* comment 1 */ /* comment 2 */
LEFT OUTER JOIN b""",
pretty=True,
)
+ self.validate(
+ "SELECT\n a /* sqlglot.meta case_sensitive */ -- noqa\nFROM tbl",
+ """SELECT
+ a /* sqlglot.meta case_sensitive */ /* noqa */
+FROM tbl""",
+ pretty=True,
+ )
def test_types(self):
self.validate("INT 1", "CAST(1 AS INT)")
@@ -468,12 +467,12 @@ LEFT OUTER JOIN b""",
"ALTER TABLE integers ADD COLUMN k INT",
)
self.validate(
- "ALTER TABLE integers ALTER i SET DATA TYPE VARCHAR",
- "ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR",
+ "ALTER TABLE integers ALTER i TYPE VARCHAR",
+ "ALTER TABLE integers ALTER COLUMN i SET DATA TYPE VARCHAR",
)
self.validate(
"ALTER TABLE integers ALTER i TYPE VARCHAR COLLATE foo USING bar",
- "ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR COLLATE foo USING bar",
+ "ALTER TABLE integers ALTER COLUMN i SET DATA TYPE VARCHAR COLLATE foo USING bar",
)
def test_time(self):
@@ -604,7 +603,7 @@ LEFT OUTER JOIN b""",
self.validate(
"CREATE TEMPORARY TABLE test AS SELECT 1",
"CREATE TEMPORARY VIEW test AS SELECT 1",
- write="spark",
+ write="spark2",
)
@mock.patch("sqlglot.helper.logger")