summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/dialects/test_bigquery.py136
-rw-r--r--tests/dialects/test_clickhouse.py120
-rw-r--r--tests/dialects/test_databricks.py12
-rw-r--r--tests/dialects/test_dialect.py128
-rw-r--r--tests/dialects/test_duckdb.py45
-rw-r--r--tests/dialects/test_hive.py19
-rw-r--r--tests/dialects/test_mysql.py4
-rw-r--r--tests/dialects/test_oracle.py24
-rw-r--r--tests/dialects/test_postgres.py18
-rw-r--r--tests/dialects/test_presto.py7
-rw-r--r--tests/dialects/test_snowflake.py16
-rw-r--r--tests/dialects/test_spark.py20
-rw-r--r--tests/dialects/test_sqlite.py4
-rw-r--r--tests/dialects/test_teradata.py4
-rw-r--r--tests/dialects/test_trino.py21
-rw-r--r--tests/dialects/test_tsql.py9
-rw-r--r--tests/fixtures/identity.sql1
-rw-r--r--tests/fixtures/optimizer/annotate_functions.sql8
-rw-r--r--tests/fixtures/optimizer/qualify_tables.sql20
-rw-r--r--tests/test_diff.py139
-rw-r--r--tests/test_optimizer.py43
21 files changed, 689 insertions, 109 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index 3b317bc..eeb49f3 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -9,7 +9,6 @@ from sqlglot import (
UnsupportedError,
exp,
parse,
- parse_one,
transpile,
)
from sqlglot.helper import logger as helper_logger
@@ -85,12 +84,21 @@ LANGUAGE js AS
"PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E*S%z', x)",
)
- table = parse_one("x-0._y.z", dialect="bigquery", into=exp.Table)
+ for prefix in ("c.db.", "db.", ""):
+ with self.subTest(f"Parsing {prefix}INFORMATION_SCHEMA.X into a Table"):
+ table = self.parse_one(f"`{prefix}INFORMATION_SCHEMA.X`", into=exp.Table)
+ this = table.this
+
+ self.assertIsInstance(this, exp.Identifier)
+ self.assertTrue(this.quoted)
+ self.assertEqual(this.name, "INFORMATION_SCHEMA.X")
+
+ table = self.parse_one("x-0._y.z", into=exp.Table)
self.assertEqual(table.catalog, "x-0")
self.assertEqual(table.db, "_y")
self.assertEqual(table.name, "z")
- table = parse_one("x-0._y", dialect="bigquery", into=exp.Table)
+ table = self.parse_one("x-0._y", into=exp.Table)
self.assertEqual(table.db, "x-0")
self.assertEqual(table.name, "_y")
@@ -165,6 +173,7 @@ LANGUAGE js AS
self.validate_identity("SELECT * FROM foo.bar.25ab c", "SELECT * FROM foo.bar.`25ab` AS c")
self.validate_identity("x <> ''")
self.validate_identity("DATE_TRUNC(col, WEEK(MONDAY))")
+ self.validate_identity("DATE_TRUNC(col, MONTH, 'UTC+8')")
self.validate_identity("SELECT b'abc'")
self.validate_identity("SELECT AS STRUCT 1 AS a, 2 AS b")
self.validate_identity("SELECT DISTINCT AS STRUCT 1 AS a, 2 AS b")
@@ -182,7 +191,6 @@ LANGUAGE js AS
self.validate_identity("SELECT y + 1 FROM x GROUP BY y + 1 ORDER BY 1")
self.validate_identity("SELECT TIMESTAMP_SECONDS(2) AS t")
self.validate_identity("SELECT TIMESTAMP_MILLIS(2) AS t")
- self.validate_identity("""SELECT JSON_EXTRACT_SCALAR('{"a": 5}', '$.a')""")
self.validate_identity("UPDATE x SET y = NULL")
self.validate_identity("LOG(n, b)")
self.validate_identity("SELECT COUNT(x RESPECT NULLS)")
@@ -194,11 +202,11 @@ LANGUAGE js AS
self.validate_identity("CAST(x AS NVARCHAR)", "CAST(x AS STRING)")
self.validate_identity("CAST(x AS TIMESTAMPTZ)", "CAST(x AS TIMESTAMP)")
self.validate_identity("CAST(x AS RECORD)", "CAST(x AS STRUCT)")
- self.validate_identity(
- "MERGE INTO dataset.NewArrivals USING (SELECT * FROM UNNEST([('microwave', 10, 'warehouse #1'), ('dryer', 30, 'warehouse #1'), ('oven', 20, 'warehouse #2')])) ON FALSE WHEN NOT MATCHED THEN INSERT ROW WHEN NOT MATCHED BY SOURCE THEN DELETE"
+ self.validate_identity("EDIT_DISTANCE('a', 'a', max_distance => 2)").assert_is(
+ exp.Levenshtein
)
self.validate_identity(
- "SELECT * FROM `SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW`"
+ "MERGE INTO dataset.NewArrivals USING (SELECT * FROM UNNEST([('microwave', 10, 'warehouse #1'), ('dryer', 30, 'warehouse #1'), ('oven', 20, 'warehouse #2')])) ON FALSE WHEN NOT MATCHED THEN INSERT ROW WHEN NOT MATCHED BY SOURCE THEN DELETE"
)
self.validate_identity(
"SELECT * FROM test QUALIFY a IS DISTINCT FROM b WINDOW c AS (PARTITION BY d)"
@@ -228,10 +236,23 @@ LANGUAGE js AS
"SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)",
)
self.validate_identity(
- """SELECT JSON_EXTRACT_SCALAR('5')""", """SELECT JSON_EXTRACT_SCALAR('5', '$')"""
+ "CREATE OR REPLACE VIEW test (tenant_id OPTIONS (description='Test description on table creation')) AS SELECT 1 AS tenant_id, 1 AS customer_id",
)
self.validate_identity(
- "CREATE OR REPLACE VIEW test (tenant_id OPTIONS (description='Test description on table creation')) AS SELECT 1 AS tenant_id, 1 AS customer_id",
+ "SELECT * FROM `proj.dataset.INFORMATION_SCHEMA.SOME_VIEW`",
+ "SELECT * FROM `proj.dataset.INFORMATION_SCHEMA.SOME_VIEW` AS `proj.dataset.INFORMATION_SCHEMA.SOME_VIEW`",
+ )
+ self.validate_identity(
+ "SELECT * FROM region_or_dataset.INFORMATION_SCHEMA.TABLES",
+ "SELECT * FROM region_or_dataset.`INFORMATION_SCHEMA.TABLES` AS TABLES",
+ )
+ self.validate_identity(
+ "SELECT * FROM region_or_dataset.INFORMATION_SCHEMA.TABLES AS some_name",
+ "SELECT * FROM region_or_dataset.`INFORMATION_SCHEMA.TABLES` AS some_name",
+ )
+ self.validate_identity(
+ "SELECT * FROM proj.region_or_dataset.INFORMATION_SCHEMA.TABLES",
+ "SELECT * FROM proj.region_or_dataset.`INFORMATION_SCHEMA.TABLES` AS TABLES",
)
self.validate_identity(
"CREATE VIEW `d.v` OPTIONS (expiration_timestamp=TIMESTAMP '2020-01-02T04:05:06.007Z') AS SELECT 1 AS c",
@@ -303,6 +324,13 @@ LANGUAGE js AS
)
self.validate_all(
+ "EDIT_DISTANCE(a, b)",
+ write={
+ "bigquery": "EDIT_DISTANCE(a, b)",
+ "duckdb": "LEVENSHTEIN(a, b)",
+ },
+ )
+ self.validate_all(
"SAFE_CAST(some_date AS DATE FORMAT 'DD MONTH YYYY')",
write={
"bigquery": "SAFE_CAST(some_date AS DATE FORMAT 'DD MONTH YYYY')",
@@ -361,10 +389,19 @@ LANGUAGE js AS
write={
"bigquery": "TIMESTAMP(x)",
"duckdb": "CAST(x AS TIMESTAMPTZ)",
+ "snowflake": "CAST(x AS TIMESTAMPTZ)",
"presto": "CAST(x AS TIMESTAMP WITH TIME ZONE)",
},
)
self.validate_all(
+ "SELECT TIMESTAMP('2008-12-25 15:30:00', 'America/Los_Angeles')",
+ write={
+ "bigquery": "SELECT TIMESTAMP('2008-12-25 15:30:00', 'America/Los_Angeles')",
+ "duckdb": "SELECT CAST('2008-12-25 15:30:00' AS TIMESTAMP) AT TIME ZONE 'America/Los_Angeles'",
+ "snowflake": "SELECT CONVERT_TIMEZONE('America/Los_Angeles', CAST('2008-12-25 15:30:00' AS TIMESTAMP))",
+ },
+ )
+ self.validate_all(
"SELECT SUM(x IGNORE NULLS) AS x",
read={
"bigquery": "SELECT SUM(x IGNORE NULLS) AS x",
@@ -629,6 +666,7 @@ LANGUAGE js AS
write={
"bigquery": "SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)",
"databricks": "SELECT TIMESTAMPDIFF(MILLISECOND, '2023-01-01T05:00:00', '2023-01-01T00:00:00')",
+ "snowflake": "SELECT TIMESTAMPDIFF(MILLISECOND, '2023-01-01T05:00:00', '2023-01-01T00:00:00')",
},
),
)
@@ -639,6 +677,7 @@ LANGUAGE js AS
"bigquery": "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL '1' MILLISECOND)",
"databricks": "SELECT TIMESTAMPADD(MILLISECOND, '1', '2023-01-01T00:00:00')",
"duckdb": "SELECT CAST('2023-01-01T00:00:00' AS DATETIME) + INTERVAL '1' MILLISECOND",
+ "snowflake": "SELECT TIMESTAMPADD(MILLISECOND, '1', '2023-01-01T00:00:00')",
},
),
)
@@ -670,6 +709,7 @@ LANGUAGE js AS
"databricks": "SELECT DATE_ADD(MINUTE, '10', CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
"mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL '10' MINUTE)",
"spark": "SELECT DATE_ADD(MINUTE, '10', CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))",
+ "snowflake": "SELECT TIMESTAMPADD(MINUTE, '10', CAST('2008-12-25 15:30:00+00' AS TIMESTAMPTZ))",
},
)
self.validate_all(
@@ -677,6 +717,14 @@ LANGUAGE js AS
write={
"bigquery": "SELECT TIMESTAMP_SUB(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL '10' MINUTE)",
"mysql": "SELECT DATE_SUB(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL '10' MINUTE)",
+ "snowflake": "SELECT TIMESTAMPADD(MINUTE, '10' * -1, CAST('2008-12-25 15:30:00+00' AS TIMESTAMPTZ))",
+ },
+ )
+ self.validate_all(
+ 'SELECT TIMESTAMP_SUB(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL col MINUTE)',
+ write={
+ "bigquery": "SELECT TIMESTAMP_SUB(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL col MINUTE)",
+ "snowflake": "SELECT TIMESTAMPADD(MINUTE, col * -1, CAST('2008-12-25 15:30:00+00' AS TIMESTAMPTZ))",
},
)
self.validate_all(
@@ -1113,7 +1161,8 @@ LANGUAGE js AS
write={
"bigquery": "CURRENT_TIME()",
"duckdb": "CURRENT_TIME",
- "presto": "CURRENT_TIME()",
+ "presto": "CURRENT_TIME",
+ "trino": "CURRENT_TIME",
"hive": "CURRENT_TIME()",
"spark": "CURRENT_TIME()",
},
@@ -1491,6 +1540,14 @@ WHERE
},
)
self.validate_all(
+ "SELECT PARSE_DATE('%Y%m%d', '20081225')",
+ write={
+ "bigquery": "SELECT PARSE_DATE('%Y%m%d', '20081225')",
+ "duckdb": "SELECT CAST(STRPTIME('20081225', '%Y%m%d') AS DATE)",
+ "snowflake": "SELECT DATE('20081225', 'yyyymmDD')",
+ },
+ )
+ self.validate_all(
"SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--') AS text",
write={
"bigquery": "SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--') AS text",
@@ -1504,9 +1561,48 @@ WHERE
"duckdb": "SELECT ARRAY_TO_STRING(LIST_TRANSFORM(['cake', 'pie', NULL], x -> COALESCE(x, 'MISSING')), '--') AS text",
},
)
+ self.validate_all(
+ "STRING(a)",
+ write={
+ "bigquery": "STRING(a)",
+ "snowflake": "CAST(a AS VARCHAR)",
+ "duckdb": "CAST(a AS TEXT)",
+ },
+ )
+ self.validate_all(
+ "STRING('2008-12-25 15:30:00', 'America/New_York')",
+ write={
+ "bigquery": "STRING('2008-12-25 15:30:00', 'America/New_York')",
+ "snowflake": "CAST(CONVERT_TIMEZONE('UTC', 'America/New_York', '2008-12-25 15:30:00') AS VARCHAR)",
+ "duckdb": "CAST(CAST('2008-12-25 15:30:00' AS TIMESTAMP) AT TIME ZONE 'UTC' AT TIME ZONE 'America/New_York' AS TEXT)",
+ },
+ )
self.validate_identity("SELECT * FROM a-b c", "SELECT * FROM a-b AS c")
+ self.validate_all(
+ "SAFE_DIVIDE(x, y)",
+ write={
+ "bigquery": "SAFE_DIVIDE(x, y)",
+ "duckdb": "IF((y) <> 0, (x) / (y), NULL)",
+ "presto": "IF((y) <> 0, (x) / (y), NULL)",
+ "trino": "IF((y) <> 0, (x) / (y), NULL)",
+ "hive": "IF((y) <> 0, (x) / (y), NULL)",
+ "spark2": "IF((y) <> 0, (x) / (y), NULL)",
+ "spark": "IF((y) <> 0, (x) / (y), NULL)",
+ "databricks": "IF((y) <> 0, (x) / (y), NULL)",
+ "snowflake": "IFF((y) <> 0, (x) / (y), NULL)",
+ },
+ )
+ self.validate_all(
+ """SELECT JSON_QUERY('{"class": {"students": []}}', '$.class')""",
+ write={
+ "bigquery": """SELECT JSON_QUERY('{"class": {"students": []}}', '$.class')""",
+ "duckdb": """SELECT '{"class": {"students": []}}' -> '$.class'""",
+ "snowflake": """SELECT GET_PATH(PARSE_JSON('{"class": {"students": []}}'), 'class')""",
+ },
+ )
+
def test_errors(self):
with self.assertRaises(TokenError):
transpile("'\\'", read="bigquery")
@@ -2000,3 +2096,23 @@ OPTIONS (
"bigquery": f"SELECT color, ARRAY_AGG(id ORDER BY id {sort_order}) AS ids FROM colors GROUP BY 1",
},
)
+
+ def test_json_extract_scalar(self):
+ for func in ("JSON_EXTRACT_SCALAR", "JSON_VALUE"):
+ with self.subTest(f"Testing BigQuery's {func}"):
+ self.validate_all(
+ f"SELECT {func}('5')",
+ write={
+ "bigquery": f"SELECT {func}('5', '$')",
+ "duckdb": """SELECT '5' ->> '$'""",
+ },
+ )
+
+ self.validate_all(
+ f"""SELECT {func}('{{"name": "Jakob", "age": "6"}}', '$.age')""",
+ write={
+ "bigquery": f"""SELECT {func}('{{"name": "Jakob", "age": "6"}}', '$.age')""",
+ "duckdb": """SELECT '{"name": "Jakob", "age": "6"}' ->> '$.age'""",
+ "snowflake": """SELECT JSON_EXTRACT_PATH_TEXT('{"name": "Jakob", "age": "6"}', 'age')""",
+ },
+ )
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index 56ff06f..a0efb54 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -1,4 +1,4 @@
-from datetime import date
+from datetime import date, datetime, timezone
from sqlglot import exp, parse_one
from sqlglot.dialects import ClickHouse
from sqlglot.expressions import convert
@@ -88,6 +88,7 @@ class TestClickhouse(Validator):
self.validate_identity("CAST(x AS DATETIME)", "CAST(x AS DateTime)")
self.validate_identity("CAST(x AS TIMESTAMPTZ)", "CAST(x AS DateTime)")
self.validate_identity("CAST(x as MEDIUMINT)", "CAST(x AS Int32)")
+ self.validate_identity("CAST(x AS DECIMAL(38, 2))", "CAST(x AS Decimal(38, 2))")
self.validate_identity("SELECT arrayJoin([1, 2, 3] AS src) AS dst, 'Hello', src")
self.validate_identity("""SELECT JSONExtractString('{"x": {"y": 1}}', 'x', 'y')""")
self.validate_identity("SELECT * FROM table LIMIT 1 BY a, b")
@@ -96,6 +97,9 @@ class TestClickhouse(Validator):
self.validate_identity("TRUNCATE DATABASE db")
self.validate_identity("TRUNCATE DATABASE db ON CLUSTER test_cluster")
self.validate_identity(
+ "SELECT CAST(1730098800 AS DateTime64) AS DATETIME, 'test' AS interp ORDER BY DATETIME WITH FILL FROM toDateTime64(1730098800, 3) - INTERVAL '7' HOUR TO toDateTime64(1730185140, 3) - INTERVAL '7' HOUR STEP toIntervalSecond(900) INTERPOLATE (interp)"
+ )
+ self.validate_identity(
"SELECT number, COUNT() OVER (PARTITION BY number % 3) AS partition_count FROM numbers(10) WINDOW window_name AS (PARTITION BY number) QUALIFY partition_count = 4 ORDER BY number"
)
self.validate_identity(
@@ -150,6 +154,10 @@ class TestClickhouse(Validator):
"CREATE TABLE t (foo String CODEC(LZ4HC(9), ZSTD, DELTA), size String ALIAS formatReadableSize(size_bytes), INDEX idx1 a TYPE bloom_filter(0.001) GRANULARITY 1, INDEX idx2 a TYPE set(100) GRANULARITY 2, INDEX idx3 a TYPE minmax GRANULARITY 3)"
)
self.validate_identity(
+ "SELECT (toUInt8('1') + toUInt8('2')) IS NOT NULL",
+ "SELECT NOT ((toUInt8('1') + toUInt8('2')) IS NULL)",
+ )
+ self.validate_identity(
"SELECT $1$foo$1$",
"SELECT 'foo'",
)
@@ -424,8 +432,13 @@ class TestClickhouse(Validator):
)
self.validate_all(
"SELECT quantile(0.5)(a)",
- read={"duckdb": "SELECT quantile(a, 0.5)"},
- write={"clickhouse": "SELECT quantile(0.5)(a)"},
+ read={
+ "duckdb": "SELECT quantile(a, 0.5)",
+ "clickhouse": "SELECT median(a)",
+ },
+ write={
+ "clickhouse": "SELECT quantile(0.5)(a)",
+ },
)
self.validate_all(
"SELECT quantiles(0.5, 0.4)(a)",
@@ -526,6 +539,10 @@ class TestClickhouse(Validator):
"SELECT * FROM ABC WHERE hasAny(COLUMNS('.*field') APPLY(toUInt64) APPLY(to), (SELECT groupUniqArray(toUInt64(field))))"
)
self.validate_identity("SELECT col apply", "SELECT col AS apply")
+ self.validate_identity(
+ "SELECT name FROM data WHERE (SELECT DISTINCT name FROM data) IS NOT NULL",
+ "SELECT name FROM data WHERE NOT ((SELECT DISTINCT name FROM data) IS NULL)",
+ )
def test_clickhouse_values(self):
values = exp.select("*").from_(
@@ -645,6 +662,12 @@ class TestClickhouse(Validator):
write={"clickhouse": f"CAST(pow(2, 32) AS {data_type})"},
)
+ def test_geom_types(self):
+ data_types = ["Point", "Ring", "LineString", "MultiLineString", "Polygon", "MultiPolygon"]
+ for data_type in data_types:
+ with self.subTest(f"Casting to ClickHouse {data_type}"):
+ self.validate_identity(f"SELECT CAST(val AS {data_type})")
+
def test_ddl(self):
db_table_expr = exp.Table(this=None, db=exp.to_identifier("foo"), catalog=None)
create_with_cluster = exp.Create(
@@ -678,6 +701,7 @@ class TestClickhouse(Validator):
"CREATE TABLE foo ENGINE=Memory AS (SELECT * FROM db.other_table) COMMENT 'foo'",
)
+ self.validate_identity("CREATE FUNCTION linear_equation AS (x, k, b) -> k * x + b")
self.validate_identity("CREATE MATERIALIZED VIEW a.b TO a.c (c Int32) AS SELECT * FROM a.d")
self.validate_identity("""CREATE TABLE ip_data (ip4 IPv4, ip6 IPv6) ENGINE=TinyLog()""")
self.validate_identity("""CREATE TABLE dates (dt1 Date32) ENGINE=TinyLog()""")
@@ -702,6 +726,10 @@ class TestClickhouse(Validator):
"CREATE TABLE foo (x UInt32) TTL time_column + INTERVAL '1' MONTH DELETE WHERE column = 'value'"
)
self.validate_identity(
+ "CREATE FUNCTION parity_str AS (n) -> IF(n % 2, 'odd', 'even')",
+ "CREATE FUNCTION parity_str AS n -> CASE WHEN n % 2 THEN 'odd' ELSE 'even' END",
+ )
+ self.validate_identity(
"CREATE TABLE a ENGINE=Memory AS SELECT 1 AS c COMMENT 'foo'",
"CREATE TABLE a ENGINE=Memory AS (SELECT 1 AS c) COMMENT 'foo'",
)
@@ -1094,6 +1122,92 @@ LIFETIME(MIN 0 MAX 0)""",
convert(date(2020, 1, 1)).sql(dialect=self.dialect), "toDate('2020-01-01')"
)
+ # no fractional seconds
+ self.assertEqual(
+ convert(datetime(2020, 1, 1, 0, 0, 1)).sql(dialect=self.dialect),
+ "CAST('2020-01-01 00:00:01' AS DateTime64(6))",
+ )
+ self.assertEqual(
+ convert(datetime(2020, 1, 1, 0, 0, 1, tzinfo=timezone.utc)).sql(dialect=self.dialect),
+ "CAST('2020-01-01 00:00:01' AS DateTime64(6, 'UTC'))",
+ )
+
+ # with fractional seconds
+ self.assertEqual(
+ convert(datetime(2020, 1, 1, 0, 0, 1, 1)).sql(dialect=self.dialect),
+ "CAST('2020-01-01 00:00:01.000001' AS DateTime64(6))",
+ )
+ self.assertEqual(
+ convert(datetime(2020, 1, 1, 0, 0, 1, 1, tzinfo=timezone.utc)).sql(
+ dialect=self.dialect
+ ),
+ "CAST('2020-01-01 00:00:01.000001' AS DateTime64(6, 'UTC'))",
+ )
+
+ def test_timestr_to_time(self):
+ # no fractional seconds
+ time_strings = [
+ "2020-01-01 00:00:01",
+ "2020-01-01 00:00:01+01:00",
+ " 2020-01-01 00:00:01-01:00 ",
+ "2020-01-01T00:00:01+01:00",
+ ]
+ for time_string in time_strings:
+ with self.subTest(f"'{time_string}'"):
+ self.assertEqual(
+ exp.TimeStrToTime(this=exp.Literal.string(time_string)).sql(
+ dialect=self.dialect
+ ),
+ f"CAST('{time_string}' AS DateTime64(6))",
+ )
+
+ time_strings_no_utc = ["2020-01-01 00:00:01" for i in range(4)]
+ for utc, no_utc in zip(time_strings, time_strings_no_utc):
+ with self.subTest(f"'{time_string}' with UTC timezone"):
+ self.assertEqual(
+ exp.TimeStrToTime(
+ this=exp.Literal.string(utc), zone=exp.Literal.string("UTC")
+ ).sql(dialect=self.dialect),
+ f"CAST('{no_utc}' AS DateTime64(6, 'UTC'))",
+ )
+
+ # with fractional seconds
+ time_strings = [
+ "2020-01-01 00:00:01.001",
+ "2020-01-01 00:00:01.000001",
+ "2020-01-01 00:00:01.001+00:00",
+ "2020-01-01 00:00:01.000001-00:00",
+ "2020-01-01 00:00:01.0001",
+ "2020-01-01 00:00:01.1+00:00",
+ ]
+
+ for time_string in time_strings:
+ with self.subTest(f"'{time_string}'"):
+ self.assertEqual(
+ exp.TimeStrToTime(this=exp.Literal.string(time_string[0])).sql(
+ dialect=self.dialect
+ ),
+ f"CAST('{time_string[0]}' AS DateTime64(6))",
+ )
+
+ time_strings_no_utc = [
+ "2020-01-01 00:00:01.001000",
+ "2020-01-01 00:00:01.000001",
+ "2020-01-01 00:00:01.001000",
+ "2020-01-01 00:00:01.000001",
+ "2020-01-01 00:00:01.000100",
+ "2020-01-01 00:00:01.100000",
+ ]
+
+ for utc, no_utc in zip(time_strings, time_strings_no_utc):
+ with self.subTest(f"'{time_string}' with UTC timezone"):
+ self.assertEqual(
+ exp.TimeStrToTime(
+ this=exp.Literal.string(utc), zone=exp.Literal.string("UTC")
+ ).sql(dialect=self.dialect),
+ f"CAST('{no_utc}' AS DateTime64(6, 'UTC'))",
+ )
+
def test_grant(self):
self.validate_identity("GRANT SELECT(x, y) ON db.table TO john WITH GRANT OPTION")
self.validate_identity("GRANT INSERT(x, y) ON db.table TO john")
diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py
index f7ec756..d0090b9 100644
--- a/tests/dialects/test_databricks.py
+++ b/tests/dialects/test_databricks.py
@@ -7,6 +7,7 @@ class TestDatabricks(Validator):
dialect = "databricks"
def test_databricks(self):
+ self.validate_identity("SELECT * FROM stream")
self.validate_identity("SELECT t.current_time FROM t")
self.validate_identity("ALTER TABLE labels ADD COLUMN label_score FLOAT")
self.validate_identity("DESCRIBE HISTORY a.b")
@@ -116,6 +117,17 @@ class TestDatabricks(Validator):
},
)
+ self.validate_all(
+ "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)",
+ read={
+ "databricks": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)",
+ "spark": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)",
+ },
+ write={
+ "spark": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)",
+ },
+ )
+
# https://docs.databricks.com/sql/language-manual/functions/colonsign.html
def test_json(self):
self.validate_identity("SELECT c1:price, c1:price.foo, c1:price.bar[1]")
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 84a4ff3..85402e2 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -526,7 +526,7 @@ class TestDialect(Validator):
write={
"": "SELECT NVL2(a, b, c)",
"bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
- "clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "clickhouse": "SELECT CASE WHEN NOT (a IS NULL) THEN b ELSE c END",
"databricks": "SELECT NVL2(a, b, c)",
"doris": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
"drill": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
@@ -552,7 +552,7 @@ class TestDialect(Validator):
write={
"": "SELECT NVL2(a, b)",
"bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b END",
- "clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "clickhouse": "SELECT CASE WHEN NOT (a IS NULL) THEN b END",
"databricks": "SELECT NVL2(a, b)",
"doris": "SELECT CASE WHEN NOT a IS NULL THEN b END",
"drill": "SELECT CASE WHEN NOT a IS NULL THEN b END",
@@ -651,7 +651,7 @@ class TestDialect(Validator):
"snowflake": "CAST('2020-01-01' AS TIMESTAMP)",
"spark": "CAST('2020-01-01' AS TIMESTAMP)",
"trino": "CAST('2020-01-01' AS TIMESTAMP)",
- "clickhouse": "CAST('2020-01-01' AS Nullable(DateTime))",
+ "clickhouse": "CAST('2020-01-01' AS DateTime64(6))",
"drill": "CAST('2020-01-01' AS TIMESTAMP)",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "CAST('2020-01-01' AS TIMESTAMP)",
@@ -688,7 +688,7 @@ class TestDialect(Validator):
"snowflake": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMPTZ)",
"spark": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMP)",
"trino": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMP WITH TIME ZONE)",
- "clickhouse": "CAST('2020-01-01 12:13:14' AS Nullable(DateTime('America/Los_Angeles')))",
+ "clickhouse": "CAST('2020-01-01 12:13:14' AS DateTime64(6, 'America/Los_Angeles'))",
"drill": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMP)",
"hive": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMP)",
"presto": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMP WITH TIME ZONE)",
@@ -709,7 +709,7 @@ class TestDialect(Validator):
"snowflake": "CAST(col AS TIMESTAMPTZ)",
"spark": "CAST(col AS TIMESTAMP)",
"trino": "CAST(col AS TIMESTAMP WITH TIME ZONE)",
- "clickhouse": "CAST(col AS Nullable(DateTime('America/Los_Angeles')))",
+ "clickhouse": "CAST(col AS DateTime64(6, 'America/Los_Angeles'))",
"drill": "CAST(col AS TIMESTAMP)",
"hive": "CAST(col AS TIMESTAMP)",
"presto": "CAST(col AS TIMESTAMP WITH TIME ZONE)",
@@ -2893,3 +2893,121 @@ FROM subquery2""",
"snowflake": "UUID_STRING()",
},
)
+
+ def test_escaped_identifier_delimiter(self):
+ for dialect in ("databricks", "hive", "mysql", "spark2", "spark"):
+ with self.subTest(f"Testing escaped backtick in identifier name for {dialect}"):
+ self.validate_all(
+ 'SELECT 1 AS "x`"',
+ read={
+ dialect: "SELECT 1 AS `x```",
+ },
+ write={
+ dialect: "SELECT 1 AS `x```",
+ },
+ )
+
+ for dialect in (
+ "",
+ "clickhouse",
+ "duckdb",
+ "postgres",
+ "presto",
+ "trino",
+ "redshift",
+ "snowflake",
+ "sqlite",
+ ):
+ with self.subTest(f"Testing escaped double-quote in identifier name for {dialect}"):
+ self.validate_all(
+ 'SELECT 1 AS "x"""',
+ read={
+ dialect: 'SELECT 1 AS "x"""',
+ },
+ write={
+ dialect: 'SELECT 1 AS "x"""',
+ },
+ )
+
+ for dialect in ("clickhouse", "sqlite"):
+ with self.subTest(f"Testing escaped backtick in identifier name for {dialect}"):
+ self.validate_all(
+ 'SELECT 1 AS "x`"',
+ read={
+ dialect: "SELECT 1 AS `x```",
+ },
+ write={
+ dialect: 'SELECT 1 AS "x`"',
+ },
+ )
+
+ self.validate_all(
+ 'SELECT 1 AS "x`"',
+ read={
+ "clickhouse": "SELECT 1 AS `x\\``",
+ },
+ write={
+ "clickhouse": 'SELECT 1 AS "x`"',
+ },
+ )
+ for name in ('"x\\""', '`x"`'):
+ with self.subTest(f"Testing ClickHouse delimiter escaping: {name}"):
+ self.validate_all(
+ 'SELECT 1 AS "x"""',
+ read={
+ "clickhouse": f"SELECT 1 AS {name}",
+ },
+ write={
+ "clickhouse": 'SELECT 1 AS "x"""',
+ },
+ )
+
+ for name in ("[[x]]]", '"[x]"'):
+ with self.subTest(f"Testing T-SQL delimiter escaping: {name}"):
+ self.validate_all(
+ 'SELECT 1 AS "[x]"',
+ read={
+ "tsql": f"SELECT 1 AS {name}",
+ },
+ write={
+ "tsql": "SELECT 1 AS [[x]]]",
+ },
+ )
+ for name in ('[x"]', '"x"""'):
+ with self.subTest(f"Testing T-SQL delimiter escaping: {name}"):
+ self.validate_all(
+ 'SELECT 1 AS "x"""',
+ read={
+ "tsql": f"SELECT 1 AS {name}",
+ },
+ write={
+ "tsql": 'SELECT 1 AS [x"]',
+ },
+ )
+
+ def test_median(self):
+ for suffix in (
+ "",
+ " OVER ()",
+ ):
+ self.validate_all(
+ f"MEDIAN(x){suffix}",
+ read={
+ "snowflake": f"MEDIAN(x){suffix}",
+ "duckdb": f"MEDIAN(x){suffix}",
+ "spark": f"MEDIAN(x){suffix}",
+ "databricks": f"MEDIAN(x){suffix}",
+ "redshift": f"MEDIAN(x){suffix}",
+ "oracle": f"MEDIAN(x){suffix}",
+ },
+ write={
+ "snowflake": f"MEDIAN(x){suffix}",
+ "duckdb": f"MEDIAN(x){suffix}",
+ "spark": f"MEDIAN(x){suffix}",
+ "databricks": f"MEDIAN(x){suffix}",
+ "redshift": f"MEDIAN(x){suffix}",
+ "oracle": f"MEDIAN(x){suffix}",
+ "clickhouse": f"MEDIAN(x){suffix}",
+ "postgres": f"PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}",
+ },
+ )
diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py
index 1f8fb81..b59ac9f 100644
--- a/tests/dialects/test_duckdb.py
+++ b/tests/dialects/test_duckdb.py
@@ -620,12 +620,6 @@ class TestDuckDB(Validator):
},
)
self.validate_all(
- "IF((y) <> 0, (x) / (y), NULL)",
- read={
- "bigquery": "SAFE_DIVIDE(x, y)",
- },
- )
- self.validate_all(
"STRUCT_PACK(x := 1, y := '2')",
write={
"bigquery": "STRUCT(1 AS x, '2' AS y)",
@@ -758,16 +752,9 @@ class TestDuckDB(Validator):
"snowflake": "SELECT PERCENTILE_DISC(q) WITHIN GROUP (ORDER BY x) FROM t",
},
)
- self.validate_all(
- "SELECT MEDIAN(x) FROM t",
- write={
- "duckdb": "SELECT QUANTILE_CONT(x, 0.5) FROM t",
- "postgres": "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM t",
- "snowflake": "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM t",
- },
- )
with self.assertRaises(UnsupportedError):
+ # bq has the position arg, but duckdb doesn't
transpile(
"SELECT REGEXP_EXTRACT(a, 'pattern', 1) from table",
read="bigquery",
@@ -775,6 +762,36 @@ class TestDuckDB(Validator):
unsupported_level=ErrorLevel.IMMEDIATE,
)
+ self.validate_all(
+ "SELECT REGEXP_EXTRACT(a, 'pattern') FROM t",
+ read={
+ "duckdb": "SELECT REGEXP_EXTRACT(a, 'pattern') FROM t",
+ "bigquery": "SELECT REGEXP_EXTRACT(a, 'pattern') FROM t",
+ "snowflake": "SELECT REGEXP_SUBSTR(a, 'pattern') FROM t",
+ },
+ write={
+ "duckdb": "SELECT REGEXP_EXTRACT(a, 'pattern') FROM t",
+ "bigquery": "SELECT REGEXP_EXTRACT(a, 'pattern') FROM t",
+ "snowflake": "SELECT REGEXP_SUBSTR(a, 'pattern') FROM t",
+ },
+ )
+ self.validate_all(
+ "SELECT REGEXP_EXTRACT(a, 'pattern', 2, 'i') FROM t",
+ read={
+ "snowflake": "SELECT REGEXP_SUBSTR(a, 'pattern', 1, 1, 'i', 2) FROM t",
+ },
+ write={
+ "duckdb": "SELECT REGEXP_EXTRACT(a, 'pattern', 2, 'i') FROM t",
+ "snowflake": "SELECT REGEXP_SUBSTR(a, 'pattern', 1, 1, 'i', 2) FROM t",
+ },
+ )
+ self.validate_identity(
+ "SELECT REGEXP_EXTRACT(a, 'pattern', 0)",
+ "SELECT REGEXP_EXTRACT(a, 'pattern')",
+ )
+ self.validate_identity("SELECT REGEXP_EXTRACT(a, 'pattern', 0, 'i')")
+ self.validate_identity("SELECT REGEXP_EXTRACT(a, 'pattern', 1, 'i')")
+
self.validate_identity("SELECT ISNAN(x)")
self.validate_all(
diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py
index e40a85a..f13d92c 100644
--- a/tests/dialects/test_hive.py
+++ b/tests/dialects/test_hive.py
@@ -1,5 +1,7 @@
from tests.dialects.test_dialect import Validator
+from sqlglot import exp
+
class TestHive(Validator):
dialect = "hive"
@@ -787,6 +789,23 @@ class TestHive(Validator):
},
)
+ self.validate_identity("EXISTS(col, x -> x % 2 = 0)").assert_is(exp.Exists)
+
+ self.validate_all(
+ "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
+ read={
+ "hive": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
+ "spark2": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
+ "spark": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
+ "databricks": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
+ },
+ write={
+ "spark2": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
+ "spark": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
+ "databricks": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)",
+ },
+ )
+
def test_escapes(self) -> None:
self.validate_identity("'\n'", "'\\n'")
self.validate_identity("'\\n'")
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index bd0d6c3..52b04ea 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -388,7 +388,7 @@ class TestMySQL(Validator):
"sqlite": "SELECT x'CC'",
"starrocks": "SELECT x'CC'",
"tableau": "SELECT 204",
- "teradata": "SELECT 204",
+ "teradata": "SELECT X'CC'",
"trino": "SELECT X'CC'",
"tsql": "SELECT 0xCC",
}
@@ -409,7 +409,7 @@ class TestMySQL(Validator):
"sqlite": "SELECT x'0000CC'",
"starrocks": "SELECT x'0000CC'",
"tableau": "SELECT 204",
- "teradata": "SELECT 204",
+ "teradata": "SELECT X'0000CC'",
"trino": "SELECT X'0000CC'",
"tsql": "SELECT 0x0000CC",
}
diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py
index 36ce5d0..0784810 100644
--- a/tests/dialects/test_oracle.py
+++ b/tests/dialects/test_oracle.py
@@ -120,13 +120,6 @@ class TestOracle(Validator):
},
)
self.validate_all(
- "TRUNC(SYSDATE, 'YEAR')",
- write={
- "clickhouse": "DATE_TRUNC('YEAR', CURRENT_TIMESTAMP())",
- "oracle": "TRUNC(SYSDATE, 'YEAR')",
- },
- )
- self.validate_all(
"SELECT * FROM test WHERE MOD(col1, 4) = 3",
read={
"duckdb": "SELECT * FROM test WHERE col1 % 4 = 3",
@@ -632,3 +625,20 @@ WHERE
self.validate_identity("GRANT UPDATE, TRIGGER ON TABLE t TO anita, zhi")
self.validate_identity("GRANT EXECUTE ON PROCEDURE p TO george")
self.validate_identity("GRANT USAGE ON SEQUENCE order_id TO sales_role")
+
+ def test_datetrunc(self):
+ self.validate_all(
+ "TRUNC(SYSDATE, 'YEAR')",
+ write={
+ "clickhouse": "DATE_TRUNC('YEAR', CURRENT_TIMESTAMP())",
+ "oracle": "TRUNC(SYSDATE, 'YEAR')",
+ },
+ )
+
+ # Make sure units are not normalized e.g 'Q' -> 'QUARTER' and 'W' -> 'WEEK'
+ # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html
+ for unit in (
+ "'Q'",
+ "'W'",
+ ):
+ self.validate_identity(f"TRUNC(x, {unit})")
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 62ae247..4b54cd0 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -797,6 +797,24 @@ class TestPostgres(Validator):
self.validate_identity("SELECT OVERLAY(a PLACING b FROM 1 FOR 1)")
self.validate_identity("ARRAY[1, 2, 3] && ARRAY[1, 2]").assert_is(exp.ArrayOverlaps)
+ self.validate_all(
+ """SELECT JSONB_EXISTS('{"a": [1,2,3]}', 'a')""",
+ write={
+ "postgres": """SELECT JSONB_EXISTS('{"a": [1,2,3]}', 'a')""",
+ "duckdb": """SELECT JSON_EXISTS('{"a": [1,2,3]}', '$.a')""",
+ },
+ )
+ self.validate_all(
+ "WITH t AS (SELECT ARRAY[1, 2, 3] AS col) SELECT * FROM t WHERE 1 <= ANY(col) AND 2 = ANY(col)",
+ write={
+ "postgres": "WITH t AS (SELECT ARRAY[1, 2, 3] AS col) SELECT * FROM t WHERE 1 <= ANY(col) AND 2 = ANY(col)",
+ "hive": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)",
+ "spark2": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)",
+ "spark": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)",
+ "databricks": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)",
+ },
+ )
+
def test_ddl(self):
# Checks that user-defined types are parsed into DataType instead of Identifier
self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is(
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index 4c10a45..31a078c 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -414,13 +414,6 @@ class TestPresto(Validator):
"CAST(x AS TIMESTAMP)",
read={"mysql": "TIMESTAMP(x)"},
)
- self.validate_all(
- "TIMESTAMP(x, 'America/Los_Angeles')",
- write={
- "duckdb": "CAST(x AS TIMESTAMP) AT TIME ZONE 'America/Los_Angeles'",
- "presto": "AT_TIMEZONE(CAST(x AS TIMESTAMP), 'America/Los_Angeles')",
- },
- )
# this case isn't really correct, but it's a fall back for mysql's version
self.validate_all(
"TIMESTAMP(x, '12:00:00')",
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 409a5a6..8357642 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -331,10 +331,15 @@ WHERE
"snowflake": "SELECT TIME_FROM_PARTS(12, 34, 56, 987654321)",
},
)
+ self.validate_identity(
+ "SELECT TIMESTAMPNTZFROMPARTS(2013, 4, 5, 12, 00, 00)",
+ "SELECT TIMESTAMP_FROM_PARTS(2013, 4, 5, 12, 00, 00)",
+ )
self.validate_all(
"SELECT TIMESTAMP_FROM_PARTS(2013, 4, 5, 12, 00, 00)",
read={
"duckdb": "SELECT MAKE_TIMESTAMP(2013, 4, 5, 12, 00, 00)",
+ "snowflake": "SELECT TIMESTAMP_NTZ_FROM_PARTS(2013, 4, 5, 12, 00, 00)",
},
write={
"duckdb": "SELECT MAKE_TIMESTAMP(2013, 4, 5, 12, 00, 00)",
@@ -519,7 +524,6 @@ WHERE
self.validate_all(
f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}",
read={
- "snowflake": f"SELECT MEDIAN(x){suffix}",
"postgres": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}",
},
write={
@@ -529,15 +533,6 @@ WHERE
"snowflake": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}",
},
)
- self.validate_all(
- f"SELECT MEDIAN(x){suffix}",
- write={
- "": f"SELECT PERCENTILE_CONT(x, 0.5){suffix}",
- "duckdb": f"SELECT QUANTILE_CONT(x, 0.5){suffix}",
- "postgres": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}",
- "snowflake": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}",
- },
- )
for func in (
"CORR",
"COVAR_POP",
@@ -1768,7 +1763,6 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene
"REGEXP_SUBSTR(subject, pattern)",
read={
"bigquery": "REGEXP_EXTRACT(subject, pattern)",
- "snowflake": "REGEXP_EXTRACT(subject, pattern)",
},
write={
"bigquery": "REGEXP_EXTRACT(subject, pattern)",
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index 01859c6..486bf79 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -315,6 +315,20 @@ TBLPROPERTIES (
},
)
self.validate_all(
+ "SELECT ARRAY_AGG(1)",
+ write={
+ "duckdb": "SELECT ARRAY_AGG(1)",
+ "spark": "SELECT COLLECT_LIST(1)",
+ },
+ )
+ self.validate_all(
+ "SELECT ARRAY_AGG(DISTINCT STRUCT('a'))",
+ write={
+ "duckdb": "SELECT ARRAY_AGG(DISTINCT {'col1': 'a'})",
+ "spark": "SELECT COLLECT_LIST(DISTINCT STRUCT('a' AS col1))",
+ },
+ )
+ self.validate_all(
"SELECT DATE_FORMAT(DATE '2020-01-01', 'EEEE') AS weekday",
write={
"presto": "SELECT DATE_FORMAT(CAST(CAST('2020-01-01' AS DATE) AS TIMESTAMP), '%W') AS weekday",
@@ -875,3 +889,9 @@ TBLPROPERTIES (
"databricks": "SELECT * FROM db.table1 EXCEPT SELECT * FROM db.table2",
},
)
+
+ def test_string(self):
+ for dialect in ("hive", "spark2", "spark", "databricks"):
+ with self.subTest(f"Testing STRING() for {dialect}"):
+ query = parse_one("STRING(a)", dialect=dialect)
+ self.assertEqual(query.sql(dialect), "CAST(a AS STRING)")
diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py
index 230c0e8..e37cdc8 100644
--- a/tests/dialects/test_sqlite.py
+++ b/tests/dialects/test_sqlite.py
@@ -222,3 +222,7 @@ class TestSQLite(Validator):
"mysql": "CREATE TABLE `x` (`Name` VARCHAR(200) NOT NULL)",
},
)
+
+ self.validate_identity(
+ "CREATE TABLE store (store_id INTEGER PRIMARY KEY AUTOINCREMENT, mgr_id INTEGER NOT NULL UNIQUE REFERENCES staff ON UPDATE CASCADE)"
+ )
diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py
index 466f5d5..8951ebe 100644
--- a/tests/dialects/test_teradata.py
+++ b/tests/dialects/test_teradata.py
@@ -32,6 +32,10 @@ class TestTeradata(Validator):
},
)
+ self.validate_identity("SELECT 0x1d", "SELECT X'1d'")
+ self.validate_identity("SELECT X'1D'", "SELECT X'1D'")
+ self.validate_identity("SELECT x'1d'", "SELECT X'1d'")
+
self.validate_identity(
"RENAME TABLE emp TO employee", check_command_warning=True
).assert_is(exp.Command)
diff --git a/tests/dialects/test_trino.py b/tests/dialects/test_trino.py
index 33a0229..8e968e9 100644
--- a/tests/dialects/test_trino.py
+++ b/tests/dialects/test_trino.py
@@ -9,9 +9,30 @@ class TestTrino(Validator):
self.validate_identity("JSON_QUERY(content, 'lax $.HY.*')")
self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITH UNCONDITIONAL WRAPPER)")
self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITHOUT CONDITIONAL WRAPPER)")
+
+ def test_listagg(self):
self.validate_identity(
"SELECT LISTAGG(DISTINCT col, ',') WITHIN GROUP (ORDER BY col ASC) FROM tbl"
)
+ self.validate_identity(
+ "SELECT LISTAGG(col, '; ' ON OVERFLOW ERROR) WITHIN GROUP (ORDER BY col ASC) FROM tbl"
+ )
+ self.validate_identity(
+ "SELECT LISTAGG(col, '; ' ON OVERFLOW TRUNCATE WITH COUNT) WITHIN GROUP (ORDER BY col ASC) FROM tbl"
+ )
+ self.validate_identity(
+ "SELECT LISTAGG(col, '; ' ON OVERFLOW TRUNCATE WITHOUT COUNT) WITHIN GROUP (ORDER BY col ASC) FROM tbl"
+ )
+ self.validate_identity(
+ "SELECT LISTAGG(col, '; ' ON OVERFLOW TRUNCATE '...' WITH COUNT) WITHIN GROUP (ORDER BY col ASC) FROM tbl"
+ )
+ self.validate_identity(
+ "SELECT LISTAGG(col, '; ' ON OVERFLOW TRUNCATE '...' WITHOUT COUNT) WITHIN GROUP (ORDER BY col ASC) FROM tbl"
+ )
+ self.validate_identity(
+ "SELECT LISTAGG(col) WITHIN GROUP (ORDER BY col DESC) FROM tbl",
+ "SELECT LISTAGG(col, ',') WITHIN GROUP (ORDER BY col DESC) FROM tbl",
+ )
def test_trim(self):
self.validate_identity("SELECT TRIM('!' FROM '!foo!')")
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index 042891a..e4bd9a7 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -411,6 +411,7 @@ class TestTSQL(Validator):
},
)
self.validate_identity("HASHBYTES('MD2', 'x')")
+ self.validate_identity("LOG(n)")
self.validate_identity("LOG(n, b)")
self.validate_all(
@@ -921,6 +922,12 @@ class TestTSQL(Validator):
},
)
self.validate_all(
+ "IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'baz' AND table_schema = 'bar' AND table_catalog = 'foo') EXEC('WITH cte1 AS (SELECT 1 AS col_a), cte2 AS (SELECT 1 AS col_b) SELECT * INTO foo.bar.baz FROM (SELECT col_a FROM cte1 UNION ALL SELECT col_b FROM cte2) AS temp')",
+ read={
+ "": "CREATE TABLE IF NOT EXISTS foo.bar.baz AS WITH cte1 AS (SELECT 1 AS col_a), cte2 AS (SELECT 1 AS col_b) SELECT col_a FROM cte1 UNION ALL SELECT col_b FROM cte2"
+ },
+ )
+ self.validate_all(
"CREATE OR ALTER VIEW a.b AS SELECT 1",
read={
"": "CREATE OR REPLACE VIEW a.b AS SELECT 1",
@@ -1567,7 +1574,7 @@ WHERE
"SELECT DATEDIFF(DAY, CAST(a AS DATETIME2), CAST(b AS DATETIME2)) AS x FROM foo",
write={
"tsql": "SELECT DATEDIFF(DAY, CAST(a AS DATETIME2), CAST(b AS DATETIME2)) AS x FROM foo",
- "clickhouse": "SELECT DATE_DIFF(DAY, CAST(a AS Nullable(DateTime)), CAST(b AS Nullable(DateTime))) AS x FROM foo",
+ "clickhouse": "SELECT DATE_DIFF(DAY, CAST(CAST(a AS Nullable(DateTime)) AS DateTime64(6)), CAST(CAST(b AS Nullable(DateTime)) AS DateTime64(6))) AS x FROM foo",
},
)
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index bed2502..33199de 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -250,7 +250,6 @@ SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x
SELECT LEAD(a, 1, b) OVER (PARTITION BY a ORDER BY a) AS x
SELECT X((a, b) -> a + b, z -> z) AS x
SELECT X(a -> a + ("z" - 1))
-SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)
SELECT test.* FROM test
SELECT a AS b FROM test
SELECT "a"."b" FROM "a"
diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql
index 1f59a5a..1dd7c2d 100644
--- a/tests/fixtures/optimizer/annotate_functions.sql
+++ b/tests/fixtures/optimizer/annotate_functions.sql
@@ -307,3 +307,11 @@ ARRAY<STRING>;
# dialect: bigquery
SPLIT(tbl.bin_col, delim);
ARRAY<BINARY>;
+
+# dialect: bigquery
+STRING(json_expr);
+STRING;
+
+# dialect: bigquery
+STRING(timestamp_expr, timezone);
+STRING; \ No newline at end of file
diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql
index 49e07fa..03e8dbe 100644
--- a/tests/fixtures/optimizer/qualify_tables.sql
+++ b/tests/fixtures/optimizer/qualify_tables.sql
@@ -14,6 +14,26 @@ SELECT 1 FROM x.y.z AS z;
SELECT 1 FROM x.y.z AS z;
SELECT 1 FROM x.y.z AS z;
+# title: only information schema
+# dialect: bigquery
+SELECT * FROM information_schema.tables;
+SELECT * FROM c.db.`information_schema.tables` AS tables;
+
+# title: information schema with db
+# dialect: bigquery
+SELECT * FROM y.information_schema.tables;
+SELECT * FROM c.y.`information_schema.tables` AS tables;
+
+# title: information schema with db, catalog
+# dialect: bigquery
+SELECT * FROM x.y.information_schema.tables;
+SELECT * FROM x.y.`information_schema.tables` AS tables;
+
+# title: information schema with db, catalog, alias
+# dialect: bigquery
+SELECT * FROM x.y.information_schema.tables AS z;
+SELECT * FROM x.y.`information_schema.tables` AS z;
+
# title: redshift unnest syntax, z.a should be a column, not a table
# dialect: redshift
SELECT 1 FROM y.z AS z, z.a;
diff --git a/tests/test_diff.py b/tests/test_diff.py
index f0e0747..440502e 100644
--- a/tests/test_diff.py
+++ b/tests/test_diff.py
@@ -2,7 +2,6 @@ import unittest
from sqlglot import exp, parse_one
from sqlglot.diff import Insert, Move, Remove, Update, diff
-from sqlglot.expressions import Join, to_table
def diff_delta_only(source, target, matchings=None, **kwargs):
@@ -14,22 +13,24 @@ class TestDiff(unittest.TestCase):
self._validate_delta_only(
diff_delta_only(parse_one("SELECT a + b"), parse_one("SELECT a - b")),
[
- Remove(parse_one("a + b")), # the Add node
- Insert(parse_one("a - b")), # the Sub node
+ Remove(expression=parse_one("a + b")), # the Add node
+ Insert(expression=parse_one("a - b")), # the Sub node
+ Move(source=parse_one("a"), target=parse_one("a")), # the `a` Column node
+ Move(source=parse_one("b"), target=parse_one("b")), # the `b` Column node
],
)
self._validate_delta_only(
diff_delta_only(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")),
[
- Remove(parse_one("b")), # the Column node
+ Remove(expression=parse_one("b")), # the Column node
],
)
self._validate_delta_only(
diff_delta_only(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")),
[
- Insert(parse_one("c")), # the Column node
+ Insert(expression=parse_one("c")), # the Column node
],
)
@@ -40,8 +41,8 @@ class TestDiff(unittest.TestCase):
),
[
Update(
- to_table("table_one", quoted=False),
- to_table("table_two", quoted=False),
+ source=exp.to_table("table_one", quoted=False),
+ target=exp.to_table("table_two", quoted=False),
), # the Table node
],
)
@@ -53,8 +54,12 @@ class TestDiff(unittest.TestCase):
),
[
Update(
- exp.Lambda(this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")]),
- exp.Lambda(this=exp.to_identifier("b"), expressions=[exp.to_identifier("b")]),
+ source=exp.Lambda(
+ this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")]
+ ),
+ target=exp.Lambda(
+ this=exp.to_identifier("b"), expressions=[exp.to_identifier("b")]
+ ),
),
],
)
@@ -65,8 +70,8 @@ class TestDiff(unittest.TestCase):
parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()')
),
[
- Insert(parse_one('"my.udf2"()')),
- Remove(parse_one('"my.udf1"()')),
+ Insert(expression=parse_one('"my.udf2"()')),
+ Remove(expression=parse_one('"my.udf1"()')),
],
)
self._validate_delta_only(
@@ -75,41 +80,73 @@ class TestDiff(unittest.TestCase):
parse_one('SELECT a, b, "my.udf"(x, y, w)'),
),
[
- Insert(exp.column("w")),
- Remove(exp.column("z")),
+ Insert(expression=exp.column("w")),
+ Remove(expression=exp.column("z")),
],
)
def test_node_position_changed(self):
+ expr_src = parse_one("SELECT a, b, c")
+ expr_tgt = parse_one("SELECT c, a, b")
+
self._validate_delta_only(
- diff_delta_only(parse_one("SELECT a, b, c"), parse_one("SELECT c, a, b")),
+ diff_delta_only(expr_src, expr_tgt),
[
- Move(parse_one("c")), # the Column node
+ Move(source=expr_src.selects[2], target=expr_tgt.selects[0]),
],
)
+ expr_src = parse_one("SELECT a + b")
+ expr_tgt = parse_one("SELECT b + a")
+
self._validate_delta_only(
- diff_delta_only(parse_one("SELECT a + b"), parse_one("SELECT b + a")),
+ diff_delta_only(expr_src, expr_tgt),
[
- Move(parse_one("a")), # the Column node
+ Move(source=expr_src.selects[0].left, target=expr_tgt.selects[0].right),
],
)
+ expr_src = parse_one("SELECT aaaa AND bbbb")
+ expr_tgt = parse_one("SELECT bbbb AND aaaa")
+
self._validate_delta_only(
- diff_delta_only(parse_one("SELECT aaaa AND bbbb"), parse_one("SELECT bbbb AND aaaa")),
+ diff_delta_only(expr_src, expr_tgt),
[
- Move(parse_one("aaaa")), # the Column node
+ Move(source=expr_src.selects[0].left, target=expr_tgt.selects[0].right),
],
)
+ expr_src = parse_one("SELECT aaaa OR bbbb OR cccc")
+ expr_tgt = parse_one("SELECT cccc OR bbbb OR aaaa")
+
self._validate_delta_only(
- diff_delta_only(
- parse_one("SELECT aaaa OR bbbb OR cccc"),
- parse_one("SELECT cccc OR bbbb OR aaaa"),
- ),
+ diff_delta_only(expr_src, expr_tgt),
+ [
+ Move(source=expr_src.selects[0].left.left, target=expr_tgt.selects[0].right),
+ Move(source=expr_src.selects[0].right, target=expr_tgt.selects[0].left.left),
+ ],
+ )
+
+ expr_src = parse_one("SELECT a, b FROM t WHERE CONCAT('a', 'b') = 'ab'")
+ expr_tgt = parse_one("SELECT a FROM t WHERE CONCAT('a', 'b', b) = 'ab'")
+
+ self._validate_delta_only(
+ diff_delta_only(expr_src, expr_tgt),
+ [
+ Move(source=expr_src.selects[1], target=expr_tgt.find(exp.Concat).expressions[-1]),
+ ],
+ )
+
+ expr_src = parse_one("SELECT a as a, b as b FROM t WHERE CONCAT('a', 'b') = 'ab'")
+ expr_tgt = parse_one("SELECT a as a FROM t WHERE CONCAT('a', 'b', b) = 'ab'")
+
+ b_alias = expr_src.selects[1]
+
+ self._validate_delta_only(
+ diff_delta_only(expr_src, expr_tgt),
[
- Move(parse_one("aaaa")), # the Column node
- Move(parse_one("cccc")), # the Column node
+ Remove(expression=b_alias),
+ Move(source=b_alias.this, target=expr_tgt.find(exp.Concat).expressions[-1]),
],
)
@@ -130,23 +167,30 @@ class TestDiff(unittest.TestCase):
self._validate_delta_only(
diff_delta_only(parse_one(expr_src), parse_one(expr_tgt)),
[
- Remove(parse_one("LOWER(c) AS c")), # the Alias node
- Remove(parse_one("LOWER(c)")), # the Lower node
- Remove(parse_one("'filter'")), # the Literal node
- Insert(parse_one("'different_filter'")), # the Literal node
+ Remove(expression=parse_one("LOWER(c) AS c")), # the Alias node
+ Remove(expression=parse_one("LOWER(c)")), # the Lower node
+ Remove(expression=parse_one("'filter'")), # the Literal node
+ Insert(expression=parse_one("'different_filter'")), # the Literal node
+ Move(source=parse_one("c"), target=parse_one("c")), # the new Column c
],
)
def test_join(self):
- expr_src = "SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key"
- expr_tgt = "SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key"
+ expr_src = parse_one("SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key")
+ expr_tgt = parse_one("SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key")
- changes = diff_delta_only(parse_one(expr_src), parse_one(expr_tgt))
+ src_join = expr_src.find(exp.Join)
+ tgt_join = expr_tgt.find(exp.Join)
- self.assertEqual(len(changes), 2)
- self.assertTrue(isinstance(changes[0], Remove))
- self.assertTrue(isinstance(changes[1], Insert))
- self.assertTrue(all(isinstance(c.expression, Join) for c in changes))
+ self._validate_delta_only(
+ diff_delta_only(expr_src, expr_tgt),
+ [
+ Remove(expression=src_join),
+ Insert(expression=tgt_join),
+ Move(source=exp.to_table("t2"), target=exp.to_table("t2")),
+ Move(source=src_join.args["on"], target=tgt_join.args["on"]),
+ ],
+ )
def test_window_functions(self):
expr_src = parse_one("SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b)")
@@ -157,8 +201,8 @@ class TestDiff(unittest.TestCase):
self._validate_delta_only(
diff_delta_only(expr_src, expr_tgt),
[
- Remove(parse_one("ROW_NUMBER()")),
- Insert(parse_one("RANK()")),
+ Remove(expression=parse_one("ROW_NUMBER()")),
+ Insert(expression=parse_one("RANK()")),
Update(source=expr_src.selects[0], target=expr_tgt.selects[0]),
],
)
@@ -178,20 +222,21 @@ class TestDiff(unittest.TestCase):
self._validate_delta_only(
diff_delta_only(expr_src, expr_tgt),
[
- Remove(expr_src),
- Insert(expr_tgt),
- Insert(exp.Literal.number(2)),
- Insert(exp.Literal.number(3)),
- Insert(exp.Literal.number(4)),
+ Remove(expression=expr_src),
+ Insert(expression=expr_tgt),
+ Insert(expression=exp.Literal.number(2)),
+ Insert(expression=exp.Literal.number(3)),
+ Insert(expression=exp.Literal.number(4)),
+ Move(source=exp.Literal.number(1), target=exp.Literal.number(1)),
],
)
self._validate_delta_only(
diff_delta_only(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]),
[
- Insert(exp.Literal.number(2)),
- Insert(exp.Literal.number(3)),
- Insert(exp.Literal.number(4)),
+ Insert(expression=exp.Literal.number(2)),
+ Insert(expression=exp.Literal.number(3)),
+ Insert(expression=exp.Literal.number(4)),
],
)
@@ -274,7 +319,7 @@ class TestDiff(unittest.TestCase):
source=expr_src.find(exp.Order).expressions[0],
target=expr_tgt.find(exp.Order).expressions[0],
),
- Move(parse_one("a")),
+ Move(source=expr_src.selects[0], target=expr_tgt.selects[1]),
],
)
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 9313285..0fa4ff6 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -315,7 +315,7 @@ class TestOptimizer(unittest.TestCase):
),
dialect="bigquery",
).sql(),
- 'WITH "x" AS (SELECT "y"."a" AS "a" FROM "DB"."y" AS "y" CROSS JOIN "a"."b"."INFORMATION_SCHEMA"."COLUMNS" AS "COLUMNS") SELECT "x"."a" AS "a" FROM "x" AS "x"',
+ 'WITH "x" AS (SELECT "y"."a" AS "a" FROM "DB"."y" AS "y" CROSS JOIN "a"."b"."INFORMATION_SCHEMA.COLUMNS" AS "columns") SELECT "x"."a" AS "a" FROM "x" AS "x"',
)
self.assertEqual(
@@ -1337,6 +1337,47 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(union_by_name.selects[0].type.this, exp.DataType.Type.BIGINT)
self.assertEqual(union_by_name.selects[1].type.this, exp.DataType.Type.DOUBLE)
+ # Test chained UNIONs
+ sql = """
+ WITH t AS
+ (
+ SELECT NULL AS col
+ UNION
+ SELECT NULL AS col
+ UNION
+ SELECT 'a' AS col
+ UNION
+ SELECT NULL AS col
+ UNION
+ SELECT NULL AS col
+ )
+ SELECT col FROM t;
+ """
+ self.assertEqual(optimizer.optimize(sql).selects[0].type.this, exp.DataType.Type.VARCHAR)
+
+ # Test UNIONs with nested subqueries
+ sql = """
+ WITH t AS
+ (
+ SELECT NULL AS col
+ UNION
+ (SELECT NULL AS col UNION ALL SELECT 'a' AS col)
+ )
+ SELECT col FROM t;
+ """
+ self.assertEqual(optimizer.optimize(sql).selects[0].type.this, exp.DataType.Type.VARCHAR)
+
+ sql = """
+ WITH t AS
+ (
+ (SELECT NULL AS col UNION ALL SELECT 'a' AS col)
+ UNION
+ SELECT NULL AS col
+ )
+ SELECT col FROM t;
+ """
+ self.assertEqual(optimizer.optimize(sql).selects[0].type.this, exp.DataType.Type.VARCHAR)
+
def test_recursive_cte(self):
query = parse_one(
"""