diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dialects/test_bigquery.py | 136 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 120 | ||||
-rw-r--r-- | tests/dialects/test_databricks.py | 12 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 128 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 45 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 19 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 24 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 18 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 7 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 16 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 20 | ||||
-rw-r--r-- | tests/dialects/test_sqlite.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_teradata.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_trino.py | 21 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 9 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 1 | ||||
-rw-r--r-- | tests/fixtures/optimizer/annotate_functions.sql | 8 | ||||
-rw-r--r-- | tests/fixtures/optimizer/qualify_tables.sql | 20 | ||||
-rw-r--r-- | tests/test_diff.py | 139 | ||||
-rw-r--r-- | tests/test_optimizer.py | 43 |
21 files changed, 689 insertions, 109 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 3b317bc..eeb49f3 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -9,7 +9,6 @@ from sqlglot import ( UnsupportedError, exp, parse, - parse_one, transpile, ) from sqlglot.helper import logger as helper_logger @@ -85,12 +84,21 @@ LANGUAGE js AS "PARSE_TIMESTAMP('%Y-%m-%dT%H:%M:%E*S%z', x)", ) - table = parse_one("x-0._y.z", dialect="bigquery", into=exp.Table) + for prefix in ("c.db.", "db.", ""): + with self.subTest(f"Parsing {prefix}INFORMATION_SCHEMA.X into a Table"): + table = self.parse_one(f"`{prefix}INFORMATION_SCHEMA.X`", into=exp.Table) + this = table.this + + self.assertIsInstance(this, exp.Identifier) + self.assertTrue(this.quoted) + self.assertEqual(this.name, "INFORMATION_SCHEMA.X") + + table = self.parse_one("x-0._y.z", into=exp.Table) self.assertEqual(table.catalog, "x-0") self.assertEqual(table.db, "_y") self.assertEqual(table.name, "z") - table = parse_one("x-0._y", dialect="bigquery", into=exp.Table) + table = self.parse_one("x-0._y", into=exp.Table) self.assertEqual(table.db, "x-0") self.assertEqual(table.name, "_y") @@ -165,6 +173,7 @@ LANGUAGE js AS self.validate_identity("SELECT * FROM foo.bar.25ab c", "SELECT * FROM foo.bar.`25ab` AS c") self.validate_identity("x <> ''") self.validate_identity("DATE_TRUNC(col, WEEK(MONDAY))") + self.validate_identity("DATE_TRUNC(col, MONTH, 'UTC+8')") self.validate_identity("SELECT b'abc'") self.validate_identity("SELECT AS STRUCT 1 AS a, 2 AS b") self.validate_identity("SELECT DISTINCT AS STRUCT 1 AS a, 2 AS b") @@ -182,7 +191,6 @@ LANGUAGE js AS self.validate_identity("SELECT y + 1 FROM x GROUP BY y + 1 ORDER BY 1") self.validate_identity("SELECT TIMESTAMP_SECONDS(2) AS t") self.validate_identity("SELECT TIMESTAMP_MILLIS(2) AS t") - self.validate_identity("""SELECT JSON_EXTRACT_SCALAR('{"a": 5}', '$.a')""") self.validate_identity("UPDATE x SET y = NULL") self.validate_identity("LOG(n, b)") self.validate_identity("SELECT COUNT(x RESPECT NULLS)") @@ -194,11 +202,11 @@ LANGUAGE js AS self.validate_identity("CAST(x AS NVARCHAR)", "CAST(x AS STRING)") self.validate_identity("CAST(x AS TIMESTAMPTZ)", "CAST(x AS TIMESTAMP)") self.validate_identity("CAST(x AS RECORD)", "CAST(x AS STRUCT)") - self.validate_identity( - "MERGE INTO dataset.NewArrivals USING (SELECT * FROM UNNEST([('microwave', 10, 'warehouse #1'), ('dryer', 30, 'warehouse #1'), ('oven', 20, 'warehouse #2')])) ON FALSE WHEN NOT MATCHED THEN INSERT ROW WHEN NOT MATCHED BY SOURCE THEN DELETE" + self.validate_identity("EDIT_DISTANCE('a', 'a', max_distance => 2)").assert_is( + exp.Levenshtein ) self.validate_identity( - "SELECT * FROM `SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW`" + "MERGE INTO dataset.NewArrivals USING (SELECT * FROM UNNEST([('microwave', 10, 'warehouse #1'), ('dryer', 30, 'warehouse #1'), ('oven', 20, 'warehouse #2')])) ON FALSE WHEN NOT MATCHED THEN INSERT ROW WHEN NOT MATCHED BY SOURCE THEN DELETE" ) self.validate_identity( "SELECT * FROM test QUALIFY a IS DISTINCT FROM b WINDOW c AS (PARTITION BY d)" @@ -228,10 +236,23 @@ LANGUAGE js AS "SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)", ) self.validate_identity( - """SELECT JSON_EXTRACT_SCALAR('5')""", """SELECT JSON_EXTRACT_SCALAR('5', '$')""" + "CREATE OR REPLACE VIEW test (tenant_id OPTIONS (description='Test description on table creation')) AS SELECT 1 AS tenant_id, 1 AS customer_id", ) self.validate_identity( - "CREATE OR REPLACE VIEW test (tenant_id OPTIONS (description='Test description on table creation')) AS SELECT 1 AS tenant_id, 1 AS customer_id", + "SELECT * FROM `proj.dataset.INFORMATION_SCHEMA.SOME_VIEW`", + "SELECT * FROM `proj.dataset.INFORMATION_SCHEMA.SOME_VIEW` AS `proj.dataset.INFORMATION_SCHEMA.SOME_VIEW`", + ) + self.validate_identity( + "SELECT * FROM region_or_dataset.INFORMATION_SCHEMA.TABLES", + "SELECT * FROM region_or_dataset.`INFORMATION_SCHEMA.TABLES` AS TABLES", + ) + self.validate_identity( + "SELECT * FROM region_or_dataset.INFORMATION_SCHEMA.TABLES AS some_name", + "SELECT * FROM region_or_dataset.`INFORMATION_SCHEMA.TABLES` AS some_name", + ) + self.validate_identity( + "SELECT * FROM proj.region_or_dataset.INFORMATION_SCHEMA.TABLES", + "SELECT * FROM proj.region_or_dataset.`INFORMATION_SCHEMA.TABLES` AS TABLES", ) self.validate_identity( "CREATE VIEW `d.v` OPTIONS (expiration_timestamp=TIMESTAMP '2020-01-02T04:05:06.007Z') AS SELECT 1 AS c", @@ -303,6 +324,13 @@ LANGUAGE js AS ) self.validate_all( + "EDIT_DISTANCE(a, b)", + write={ + "bigquery": "EDIT_DISTANCE(a, b)", + "duckdb": "LEVENSHTEIN(a, b)", + }, + ) + self.validate_all( "SAFE_CAST(some_date AS DATE FORMAT 'DD MONTH YYYY')", write={ "bigquery": "SAFE_CAST(some_date AS DATE FORMAT 'DD MONTH YYYY')", @@ -361,10 +389,19 @@ LANGUAGE js AS write={ "bigquery": "TIMESTAMP(x)", "duckdb": "CAST(x AS TIMESTAMPTZ)", + "snowflake": "CAST(x AS TIMESTAMPTZ)", "presto": "CAST(x AS TIMESTAMP WITH TIME ZONE)", }, ) self.validate_all( + "SELECT TIMESTAMP('2008-12-25 15:30:00', 'America/Los_Angeles')", + write={ + "bigquery": "SELECT TIMESTAMP('2008-12-25 15:30:00', 'America/Los_Angeles')", + "duckdb": "SELECT CAST('2008-12-25 15:30:00' AS TIMESTAMP) AT TIME ZONE 'America/Los_Angeles'", + "snowflake": "SELECT CONVERT_TIMEZONE('America/Los_Angeles', CAST('2008-12-25 15:30:00' AS TIMESTAMP))", + }, + ) + self.validate_all( "SELECT SUM(x IGNORE NULLS) AS x", read={ "bigquery": "SELECT SUM(x IGNORE NULLS) AS x", @@ -629,6 +666,7 @@ LANGUAGE js AS write={ "bigquery": "SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)", "databricks": "SELECT TIMESTAMPDIFF(MILLISECOND, '2023-01-01T05:00:00', '2023-01-01T00:00:00')", + "snowflake": "SELECT TIMESTAMPDIFF(MILLISECOND, '2023-01-01T05:00:00', '2023-01-01T00:00:00')", }, ), ) @@ -639,6 +677,7 @@ LANGUAGE js AS "bigquery": "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL '1' MILLISECOND)", "databricks": "SELECT TIMESTAMPADD(MILLISECOND, '1', '2023-01-01T00:00:00')", "duckdb": "SELECT CAST('2023-01-01T00:00:00' AS DATETIME) + INTERVAL '1' MILLISECOND", + "snowflake": "SELECT TIMESTAMPADD(MILLISECOND, '1', '2023-01-01T00:00:00')", }, ), ) @@ -670,6 +709,7 @@ LANGUAGE js AS "databricks": "SELECT DATE_ADD(MINUTE, '10', CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", "mysql": "SELECT DATE_ADD(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL '10' MINUTE)", "spark": "SELECT DATE_ADD(MINUTE, '10', CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", + "snowflake": "SELECT TIMESTAMPADD(MINUTE, '10', CAST('2008-12-25 15:30:00+00' AS TIMESTAMPTZ))", }, ) self.validate_all( @@ -677,6 +717,14 @@ LANGUAGE js AS write={ "bigquery": "SELECT TIMESTAMP_SUB(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL '10' MINUTE)", "mysql": "SELECT DATE_SUB(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL '10' MINUTE)", + "snowflake": "SELECT TIMESTAMPADD(MINUTE, '10' * -1, CAST('2008-12-25 15:30:00+00' AS TIMESTAMPTZ))", + }, + ) + self.validate_all( + 'SELECT TIMESTAMP_SUB(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL col MINUTE)', + write={ + "bigquery": "SELECT TIMESTAMP_SUB(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL col MINUTE)", + "snowflake": "SELECT TIMESTAMPADD(MINUTE, col * -1, CAST('2008-12-25 15:30:00+00' AS TIMESTAMPTZ))", }, ) self.validate_all( @@ -1113,7 +1161,8 @@ LANGUAGE js AS write={ "bigquery": "CURRENT_TIME()", "duckdb": "CURRENT_TIME", - "presto": "CURRENT_TIME()", + "presto": "CURRENT_TIME", + "trino": "CURRENT_TIME", "hive": "CURRENT_TIME()", "spark": "CURRENT_TIME()", }, @@ -1491,6 +1540,14 @@ WHERE }, ) self.validate_all( + "SELECT PARSE_DATE('%Y%m%d', '20081225')", + write={ + "bigquery": "SELECT PARSE_DATE('%Y%m%d', '20081225')", + "duckdb": "SELECT CAST(STRPTIME('20081225', '%Y%m%d') AS DATE)", + "snowflake": "SELECT DATE('20081225', 'yyyymmDD')", + }, + ) + self.validate_all( "SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--') AS text", write={ "bigquery": "SELECT ARRAY_TO_STRING(['cake', 'pie', NULL], '--') AS text", @@ -1504,9 +1561,48 @@ WHERE "duckdb": "SELECT ARRAY_TO_STRING(LIST_TRANSFORM(['cake', 'pie', NULL], x -> COALESCE(x, 'MISSING')), '--') AS text", }, ) + self.validate_all( + "STRING(a)", + write={ + "bigquery": "STRING(a)", + "snowflake": "CAST(a AS VARCHAR)", + "duckdb": "CAST(a AS TEXT)", + }, + ) + self.validate_all( + "STRING('2008-12-25 15:30:00', 'America/New_York')", + write={ + "bigquery": "STRING('2008-12-25 15:30:00', 'America/New_York')", + "snowflake": "CAST(CONVERT_TIMEZONE('UTC', 'America/New_York', '2008-12-25 15:30:00') AS VARCHAR)", + "duckdb": "CAST(CAST('2008-12-25 15:30:00' AS TIMESTAMP) AT TIME ZONE 'UTC' AT TIME ZONE 'America/New_York' AS TEXT)", + }, + ) self.validate_identity("SELECT * FROM a-b c", "SELECT * FROM a-b AS c") + self.validate_all( + "SAFE_DIVIDE(x, y)", + write={ + "bigquery": "SAFE_DIVIDE(x, y)", + "duckdb": "IF((y) <> 0, (x) / (y), NULL)", + "presto": "IF((y) <> 0, (x) / (y), NULL)", + "trino": "IF((y) <> 0, (x) / (y), NULL)", + "hive": "IF((y) <> 0, (x) / (y), NULL)", + "spark2": "IF((y) <> 0, (x) / (y), NULL)", + "spark": "IF((y) <> 0, (x) / (y), NULL)", + "databricks": "IF((y) <> 0, (x) / (y), NULL)", + "snowflake": "IFF((y) <> 0, (x) / (y), NULL)", + }, + ) + self.validate_all( + """SELECT JSON_QUERY('{"class": {"students": []}}', '$.class')""", + write={ + "bigquery": """SELECT JSON_QUERY('{"class": {"students": []}}', '$.class')""", + "duckdb": """SELECT '{"class": {"students": []}}' -> '$.class'""", + "snowflake": """SELECT GET_PATH(PARSE_JSON('{"class": {"students": []}}'), 'class')""", + }, + ) + def test_errors(self): with self.assertRaises(TokenError): transpile("'\\'", read="bigquery") @@ -2000,3 +2096,23 @@ OPTIONS ( "bigquery": f"SELECT color, ARRAY_AGG(id ORDER BY id {sort_order}) AS ids FROM colors GROUP BY 1", }, ) + + def test_json_extract_scalar(self): + for func in ("JSON_EXTRACT_SCALAR", "JSON_VALUE"): + with self.subTest(f"Testing BigQuery's {func}"): + self.validate_all( + f"SELECT {func}('5')", + write={ + "bigquery": f"SELECT {func}('5', '$')", + "duckdb": """SELECT '5' ->> '$'""", + }, + ) + + self.validate_all( + f"""SELECT {func}('{{"name": "Jakob", "age": "6"}}', '$.age')""", + write={ + "bigquery": f"""SELECT {func}('{{"name": "Jakob", "age": "6"}}', '$.age')""", + "duckdb": """SELECT '{"name": "Jakob", "age": "6"}' ->> '$.age'""", + "snowflake": """SELECT JSON_EXTRACT_PATH_TEXT('{"name": "Jakob", "age": "6"}', 'age')""", + }, + ) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 56ff06f..a0efb54 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -1,4 +1,4 @@ -from datetime import date +from datetime import date, datetime, timezone from sqlglot import exp, parse_one from sqlglot.dialects import ClickHouse from sqlglot.expressions import convert @@ -88,6 +88,7 @@ class TestClickhouse(Validator): self.validate_identity("CAST(x AS DATETIME)", "CAST(x AS DateTime)") self.validate_identity("CAST(x AS TIMESTAMPTZ)", "CAST(x AS DateTime)") self.validate_identity("CAST(x as MEDIUMINT)", "CAST(x AS Int32)") + self.validate_identity("CAST(x AS DECIMAL(38, 2))", "CAST(x AS Decimal(38, 2))") self.validate_identity("SELECT arrayJoin([1, 2, 3] AS src) AS dst, 'Hello', src") self.validate_identity("""SELECT JSONExtractString('{"x": {"y": 1}}', 'x', 'y')""") self.validate_identity("SELECT * FROM table LIMIT 1 BY a, b") @@ -96,6 +97,9 @@ class TestClickhouse(Validator): self.validate_identity("TRUNCATE DATABASE db") self.validate_identity("TRUNCATE DATABASE db ON CLUSTER test_cluster") self.validate_identity( + "SELECT CAST(1730098800 AS DateTime64) AS DATETIME, 'test' AS interp ORDER BY DATETIME WITH FILL FROM toDateTime64(1730098800, 3) - INTERVAL '7' HOUR TO toDateTime64(1730185140, 3) - INTERVAL '7' HOUR STEP toIntervalSecond(900) INTERPOLATE (interp)" + ) + self.validate_identity( "SELECT number, COUNT() OVER (PARTITION BY number % 3) AS partition_count FROM numbers(10) WINDOW window_name AS (PARTITION BY number) QUALIFY partition_count = 4 ORDER BY number" ) self.validate_identity( @@ -150,6 +154,10 @@ class TestClickhouse(Validator): "CREATE TABLE t (foo String CODEC(LZ4HC(9), ZSTD, DELTA), size String ALIAS formatReadableSize(size_bytes), INDEX idx1 a TYPE bloom_filter(0.001) GRANULARITY 1, INDEX idx2 a TYPE set(100) GRANULARITY 2, INDEX idx3 a TYPE minmax GRANULARITY 3)" ) self.validate_identity( + "SELECT (toUInt8('1') + toUInt8('2')) IS NOT NULL", + "SELECT NOT ((toUInt8('1') + toUInt8('2')) IS NULL)", + ) + self.validate_identity( "SELECT $1$foo$1$", "SELECT 'foo'", ) @@ -424,8 +432,13 @@ class TestClickhouse(Validator): ) self.validate_all( "SELECT quantile(0.5)(a)", - read={"duckdb": "SELECT quantile(a, 0.5)"}, - write={"clickhouse": "SELECT quantile(0.5)(a)"}, + read={ + "duckdb": "SELECT quantile(a, 0.5)", + "clickhouse": "SELECT median(a)", + }, + write={ + "clickhouse": "SELECT quantile(0.5)(a)", + }, ) self.validate_all( "SELECT quantiles(0.5, 0.4)(a)", @@ -526,6 +539,10 @@ class TestClickhouse(Validator): "SELECT * FROM ABC WHERE hasAny(COLUMNS('.*field') APPLY(toUInt64) APPLY(to), (SELECT groupUniqArray(toUInt64(field))))" ) self.validate_identity("SELECT col apply", "SELECT col AS apply") + self.validate_identity( + "SELECT name FROM data WHERE (SELECT DISTINCT name FROM data) IS NOT NULL", + "SELECT name FROM data WHERE NOT ((SELECT DISTINCT name FROM data) IS NULL)", + ) def test_clickhouse_values(self): values = exp.select("*").from_( @@ -645,6 +662,12 @@ class TestClickhouse(Validator): write={"clickhouse": f"CAST(pow(2, 32) AS {data_type})"}, ) + def test_geom_types(self): + data_types = ["Point", "Ring", "LineString", "MultiLineString", "Polygon", "MultiPolygon"] + for data_type in data_types: + with self.subTest(f"Casting to ClickHouse {data_type}"): + self.validate_identity(f"SELECT CAST(val AS {data_type})") + def test_ddl(self): db_table_expr = exp.Table(this=None, db=exp.to_identifier("foo"), catalog=None) create_with_cluster = exp.Create( @@ -678,6 +701,7 @@ class TestClickhouse(Validator): "CREATE TABLE foo ENGINE=Memory AS (SELECT * FROM db.other_table) COMMENT 'foo'", ) + self.validate_identity("CREATE FUNCTION linear_equation AS (x, k, b) -> k * x + b") self.validate_identity("CREATE MATERIALIZED VIEW a.b TO a.c (c Int32) AS SELECT * FROM a.d") self.validate_identity("""CREATE TABLE ip_data (ip4 IPv4, ip6 IPv6) ENGINE=TinyLog()""") self.validate_identity("""CREATE TABLE dates (dt1 Date32) ENGINE=TinyLog()""") @@ -702,6 +726,10 @@ class TestClickhouse(Validator): "CREATE TABLE foo (x UInt32) TTL time_column + INTERVAL '1' MONTH DELETE WHERE column = 'value'" ) self.validate_identity( + "CREATE FUNCTION parity_str AS (n) -> IF(n % 2, 'odd', 'even')", + "CREATE FUNCTION parity_str AS n -> CASE WHEN n % 2 THEN 'odd' ELSE 'even' END", + ) + self.validate_identity( "CREATE TABLE a ENGINE=Memory AS SELECT 1 AS c COMMENT 'foo'", "CREATE TABLE a ENGINE=Memory AS (SELECT 1 AS c) COMMENT 'foo'", ) @@ -1094,6 +1122,92 @@ LIFETIME(MIN 0 MAX 0)""", convert(date(2020, 1, 1)).sql(dialect=self.dialect), "toDate('2020-01-01')" ) + # no fractional seconds + self.assertEqual( + convert(datetime(2020, 1, 1, 0, 0, 1)).sql(dialect=self.dialect), + "CAST('2020-01-01 00:00:01' AS DateTime64(6))", + ) + self.assertEqual( + convert(datetime(2020, 1, 1, 0, 0, 1, tzinfo=timezone.utc)).sql(dialect=self.dialect), + "CAST('2020-01-01 00:00:01' AS DateTime64(6, 'UTC'))", + ) + + # with fractional seconds + self.assertEqual( + convert(datetime(2020, 1, 1, 0, 0, 1, 1)).sql(dialect=self.dialect), + "CAST('2020-01-01 00:00:01.000001' AS DateTime64(6))", + ) + self.assertEqual( + convert(datetime(2020, 1, 1, 0, 0, 1, 1, tzinfo=timezone.utc)).sql( + dialect=self.dialect + ), + "CAST('2020-01-01 00:00:01.000001' AS DateTime64(6, 'UTC'))", + ) + + def test_timestr_to_time(self): + # no fractional seconds + time_strings = [ + "2020-01-01 00:00:01", + "2020-01-01 00:00:01+01:00", + " 2020-01-01 00:00:01-01:00 ", + "2020-01-01T00:00:01+01:00", + ] + for time_string in time_strings: + with self.subTest(f"'{time_string}'"): + self.assertEqual( + exp.TimeStrToTime(this=exp.Literal.string(time_string)).sql( + dialect=self.dialect + ), + f"CAST('{time_string}' AS DateTime64(6))", + ) + + time_strings_no_utc = ["2020-01-01 00:00:01" for i in range(4)] + for utc, no_utc in zip(time_strings, time_strings_no_utc): + with self.subTest(f"'{time_string}' with UTC timezone"): + self.assertEqual( + exp.TimeStrToTime( + this=exp.Literal.string(utc), zone=exp.Literal.string("UTC") + ).sql(dialect=self.dialect), + f"CAST('{no_utc}' AS DateTime64(6, 'UTC'))", + ) + + # with fractional seconds + time_strings = [ + "2020-01-01 00:00:01.001", + "2020-01-01 00:00:01.000001", + "2020-01-01 00:00:01.001+00:00", + "2020-01-01 00:00:01.000001-00:00", + "2020-01-01 00:00:01.0001", + "2020-01-01 00:00:01.1+00:00", + ] + + for time_string in time_strings: + with self.subTest(f"'{time_string}'"): + self.assertEqual( + exp.TimeStrToTime(this=exp.Literal.string(time_string[0])).sql( + dialect=self.dialect + ), + f"CAST('{time_string[0]}' AS DateTime64(6))", + ) + + time_strings_no_utc = [ + "2020-01-01 00:00:01.001000", + "2020-01-01 00:00:01.000001", + "2020-01-01 00:00:01.001000", + "2020-01-01 00:00:01.000001", + "2020-01-01 00:00:01.000100", + "2020-01-01 00:00:01.100000", + ] + + for utc, no_utc in zip(time_strings, time_strings_no_utc): + with self.subTest(f"'{time_string}' with UTC timezone"): + self.assertEqual( + exp.TimeStrToTime( + this=exp.Literal.string(utc), zone=exp.Literal.string("UTC") + ).sql(dialect=self.dialect), + f"CAST('{no_utc}' AS DateTime64(6, 'UTC'))", + ) + def test_grant(self): self.validate_identity("GRANT SELECT(x, y) ON db.table TO john WITH GRANT OPTION") self.validate_identity("GRANT INSERT(x, y) ON db.table TO john") diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index f7ec756..d0090b9 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -7,6 +7,7 @@ class TestDatabricks(Validator): dialect = "databricks" def test_databricks(self): + self.validate_identity("SELECT * FROM stream") self.validate_identity("SELECT t.current_time FROM t") self.validate_identity("ALTER TABLE labels ADD COLUMN label_score FLOAT") self.validate_identity("DESCRIBE HISTORY a.b") @@ -116,6 +117,17 @@ class TestDatabricks(Validator): }, ) + self.validate_all( + "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)", + read={ + "databricks": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)", + "spark": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)", + }, + write={ + "spark": "SELECT ANY(col) FROM VALUES (TRUE), (FALSE) AS tab(col)", + }, + ) + # https://docs.databricks.com/sql/language-manual/functions/colonsign.html def test_json(self): self.validate_identity("SELECT c1:price, c1:price.foo, c1:price.bar[1]") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 84a4ff3..85402e2 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -526,7 +526,7 @@ class TestDialect(Validator): write={ "": "SELECT NVL2(a, b, c)", "bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", - "clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "clickhouse": "SELECT CASE WHEN NOT (a IS NULL) THEN b ELSE c END", "databricks": "SELECT NVL2(a, b, c)", "doris": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", "drill": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", @@ -552,7 +552,7 @@ class TestDialect(Validator): write={ "": "SELECT NVL2(a, b)", "bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b END", - "clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "clickhouse": "SELECT CASE WHEN NOT (a IS NULL) THEN b END", "databricks": "SELECT NVL2(a, b)", "doris": "SELECT CASE WHEN NOT a IS NULL THEN b END", "drill": "SELECT CASE WHEN NOT a IS NULL THEN b END", @@ -651,7 +651,7 @@ class TestDialect(Validator): "snowflake": "CAST('2020-01-01' AS TIMESTAMP)", "spark": "CAST('2020-01-01' AS TIMESTAMP)", "trino": "CAST('2020-01-01' AS TIMESTAMP)", - "clickhouse": "CAST('2020-01-01' AS Nullable(DateTime))", + "clickhouse": "CAST('2020-01-01' AS DateTime64(6))", "drill": "CAST('2020-01-01' AS TIMESTAMP)", "hive": "CAST('2020-01-01' AS TIMESTAMP)", "presto": "CAST('2020-01-01' AS TIMESTAMP)", @@ -688,7 +688,7 @@ class TestDialect(Validator): "snowflake": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMPTZ)", "spark": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMP)", "trino": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMP WITH TIME ZONE)", - "clickhouse": "CAST('2020-01-01 12:13:14' AS Nullable(DateTime('America/Los_Angeles')))", + "clickhouse": "CAST('2020-01-01 12:13:14' AS DateTime64(6, 'America/Los_Angeles'))", "drill": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMP)", "hive": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMP)", "presto": "CAST('2020-01-01 12:13:14-08:00' AS TIMESTAMP WITH TIME ZONE)", @@ -709,7 +709,7 @@ class TestDialect(Validator): "snowflake": "CAST(col AS TIMESTAMPTZ)", "spark": "CAST(col AS TIMESTAMP)", "trino": "CAST(col AS TIMESTAMP WITH TIME ZONE)", - "clickhouse": "CAST(col AS Nullable(DateTime('America/Los_Angeles')))", + "clickhouse": "CAST(col AS DateTime64(6, 'America/Los_Angeles'))", "drill": "CAST(col AS TIMESTAMP)", "hive": "CAST(col AS TIMESTAMP)", "presto": "CAST(col AS TIMESTAMP WITH TIME ZONE)", @@ -2893,3 +2893,121 @@ FROM subquery2""", "snowflake": "UUID_STRING()", }, ) + + def test_escaped_identifier_delimiter(self): + for dialect in ("databricks", "hive", "mysql", "spark2", "spark"): + with self.subTest(f"Testing escaped backtick in identifier name for {dialect}"): + self.validate_all( + 'SELECT 1 AS "x`"', + read={ + dialect: "SELECT 1 AS `x```", + }, + write={ + dialect: "SELECT 1 AS `x```", + }, + ) + + for dialect in ( + "", + "clickhouse", + "duckdb", + "postgres", + "presto", + "trino", + "redshift", + "snowflake", + "sqlite", + ): + with self.subTest(f"Testing escaped double-quote in identifier name for {dialect}"): + self.validate_all( + 'SELECT 1 AS "x"""', + read={ + dialect: 'SELECT 1 AS "x"""', + }, + write={ + dialect: 'SELECT 1 AS "x"""', + }, + ) + + for dialect in ("clickhouse", "sqlite"): + with self.subTest(f"Testing escaped backtick in identifier name for {dialect}"): + self.validate_all( + 'SELECT 1 AS "x`"', + read={ + dialect: "SELECT 1 AS `x```", + }, + write={ + dialect: 'SELECT 1 AS "x`"', + }, + ) + + self.validate_all( + 'SELECT 1 AS "x`"', + read={ + "clickhouse": "SELECT 1 AS `x\\``", + }, + write={ + "clickhouse": 'SELECT 1 AS "x`"', + }, + ) + for name in ('"x\\""', '`x"`'): + with self.subTest(f"Testing ClickHouse delimiter escaping: {name}"): + self.validate_all( + 'SELECT 1 AS "x"""', + read={ + "clickhouse": f"SELECT 1 AS {name}", + }, + write={ + "clickhouse": 'SELECT 1 AS "x"""', + }, + ) + + for name in ("[[x]]]", '"[x]"'): + with self.subTest(f"Testing T-SQL delimiter escaping: {name}"): + self.validate_all( + 'SELECT 1 AS "[x]"', + read={ + "tsql": f"SELECT 1 AS {name}", + }, + write={ + "tsql": "SELECT 1 AS [[x]]]", + }, + ) + for name in ('[x"]', '"x"""'): + with self.subTest(f"Testing T-SQL delimiter escaping: {name}"): + self.validate_all( + 'SELECT 1 AS "x"""', + read={ + "tsql": f"SELECT 1 AS {name}", + }, + write={ + "tsql": 'SELECT 1 AS [x"]', + }, + ) + + def test_median(self): + for suffix in ( + "", + " OVER ()", + ): + self.validate_all( + f"MEDIAN(x){suffix}", + read={ + "snowflake": f"MEDIAN(x){suffix}", + "duckdb": f"MEDIAN(x){suffix}", + "spark": f"MEDIAN(x){suffix}", + "databricks": f"MEDIAN(x){suffix}", + "redshift": f"MEDIAN(x){suffix}", + "oracle": f"MEDIAN(x){suffix}", + }, + write={ + "snowflake": f"MEDIAN(x){suffix}", + "duckdb": f"MEDIAN(x){suffix}", + "spark": f"MEDIAN(x){suffix}", + "databricks": f"MEDIAN(x){suffix}", + "redshift": f"MEDIAN(x){suffix}", + "oracle": f"MEDIAN(x){suffix}", + "clickhouse": f"MEDIAN(x){suffix}", + "postgres": f"PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", + }, + ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 1f8fb81..b59ac9f 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -620,12 +620,6 @@ class TestDuckDB(Validator): }, ) self.validate_all( - "IF((y) <> 0, (x) / (y), NULL)", - read={ - "bigquery": "SAFE_DIVIDE(x, y)", - }, - ) - self.validate_all( "STRUCT_PACK(x := 1, y := '2')", write={ "bigquery": "STRUCT(1 AS x, '2' AS y)", @@ -758,16 +752,9 @@ class TestDuckDB(Validator): "snowflake": "SELECT PERCENTILE_DISC(q) WITHIN GROUP (ORDER BY x) FROM t", }, ) - self.validate_all( - "SELECT MEDIAN(x) FROM t", - write={ - "duckdb": "SELECT QUANTILE_CONT(x, 0.5) FROM t", - "postgres": "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM t", - "snowflake": "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) FROM t", - }, - ) with self.assertRaises(UnsupportedError): + # bq has the position arg, but duckdb doesn't transpile( "SELECT REGEXP_EXTRACT(a, 'pattern', 1) from table", read="bigquery", @@ -775,6 +762,36 @@ class TestDuckDB(Validator): unsupported_level=ErrorLevel.IMMEDIATE, ) + self.validate_all( + "SELECT REGEXP_EXTRACT(a, 'pattern') FROM t", + read={ + "duckdb": "SELECT REGEXP_EXTRACT(a, 'pattern') FROM t", + "bigquery": "SELECT REGEXP_EXTRACT(a, 'pattern') FROM t", + "snowflake": "SELECT REGEXP_SUBSTR(a, 'pattern') FROM t", + }, + write={ + "duckdb": "SELECT REGEXP_EXTRACT(a, 'pattern') FROM t", + "bigquery": "SELECT REGEXP_EXTRACT(a, 'pattern') FROM t", + "snowflake": "SELECT REGEXP_SUBSTR(a, 'pattern') FROM t", + }, + ) + self.validate_all( + "SELECT REGEXP_EXTRACT(a, 'pattern', 2, 'i') FROM t", + read={ + "snowflake": "SELECT REGEXP_SUBSTR(a, 'pattern', 1, 1, 'i', 2) FROM t", + }, + write={ + "duckdb": "SELECT REGEXP_EXTRACT(a, 'pattern', 2, 'i') FROM t", + "snowflake": "SELECT REGEXP_SUBSTR(a, 'pattern', 1, 1, 'i', 2) FROM t", + }, + ) + self.validate_identity( + "SELECT REGEXP_EXTRACT(a, 'pattern', 0)", + "SELECT REGEXP_EXTRACT(a, 'pattern')", + ) + self.validate_identity("SELECT REGEXP_EXTRACT(a, 'pattern', 0, 'i')") + self.validate_identity("SELECT REGEXP_EXTRACT(a, 'pattern', 1, 'i')") + self.validate_identity("SELECT ISNAN(x)") self.validate_all( diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index e40a85a..f13d92c 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -1,5 +1,7 @@ from tests.dialects.test_dialect import Validator +from sqlglot import exp + class TestHive(Validator): dialect = "hive" @@ -787,6 +789,23 @@ class TestHive(Validator): }, ) + self.validate_identity("EXISTS(col, x -> x % 2 = 0)").assert_is(exp.Exists) + + self.validate_all( + "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + read={ + "hive": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + "spark2": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + "spark": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + "databricks": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + }, + write={ + "spark2": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + "spark": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + "databricks": "SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0)", + }, + ) + def test_escapes(self) -> None: self.validate_identity("'\n'", "'\\n'") self.validate_identity("'\\n'") diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index bd0d6c3..52b04ea 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -388,7 +388,7 @@ class TestMySQL(Validator): "sqlite": "SELECT x'CC'", "starrocks": "SELECT x'CC'", "tableau": "SELECT 204", - "teradata": "SELECT 204", + "teradata": "SELECT X'CC'", "trino": "SELECT X'CC'", "tsql": "SELECT 0xCC", } @@ -409,7 +409,7 @@ class TestMySQL(Validator): "sqlite": "SELECT x'0000CC'", "starrocks": "SELECT x'0000CC'", "tableau": "SELECT 204", - "teradata": "SELECT 204", + "teradata": "SELECT X'0000CC'", "trino": "SELECT X'0000CC'", "tsql": "SELECT 0x0000CC", } diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 36ce5d0..0784810 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -120,13 +120,6 @@ class TestOracle(Validator): }, ) self.validate_all( - "TRUNC(SYSDATE, 'YEAR')", - write={ - "clickhouse": "DATE_TRUNC('YEAR', CURRENT_TIMESTAMP())", - "oracle": "TRUNC(SYSDATE, 'YEAR')", - }, - ) - self.validate_all( "SELECT * FROM test WHERE MOD(col1, 4) = 3", read={ "duckdb": "SELECT * FROM test WHERE col1 % 4 = 3", @@ -632,3 +625,20 @@ WHERE self.validate_identity("GRANT UPDATE, TRIGGER ON TABLE t TO anita, zhi") self.validate_identity("GRANT EXECUTE ON PROCEDURE p TO george") self.validate_identity("GRANT USAGE ON SEQUENCE order_id TO sales_role") + + def test_datetrunc(self): + self.validate_all( + "TRUNC(SYSDATE, 'YEAR')", + write={ + "clickhouse": "DATE_TRUNC('YEAR', CURRENT_TIMESTAMP())", + "oracle": "TRUNC(SYSDATE, 'YEAR')", + }, + ) + + # Make sure units are not normalized e.g 'Q' -> 'QUARTER' and 'W' -> 'WEEK' + # https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/ROUND-and-TRUNC-Date-Functions.html + for unit in ( + "'Q'", + "'W'", + ): + self.validate_identity(f"TRUNC(x, {unit})") diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 62ae247..4b54cd0 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -797,6 +797,24 @@ class TestPostgres(Validator): self.validate_identity("SELECT OVERLAY(a PLACING b FROM 1 FOR 1)") self.validate_identity("ARRAY[1, 2, 3] && ARRAY[1, 2]").assert_is(exp.ArrayOverlaps) + self.validate_all( + """SELECT JSONB_EXISTS('{"a": [1,2,3]}', 'a')""", + write={ + "postgres": """SELECT JSONB_EXISTS('{"a": [1,2,3]}', 'a')""", + "duckdb": """SELECT JSON_EXISTS('{"a": [1,2,3]}', '$.a')""", + }, + ) + self.validate_all( + "WITH t AS (SELECT ARRAY[1, 2, 3] AS col) SELECT * FROM t WHERE 1 <= ANY(col) AND 2 = ANY(col)", + write={ + "postgres": "WITH t AS (SELECT ARRAY[1, 2, 3] AS col) SELECT * FROM t WHERE 1 <= ANY(col) AND 2 = ANY(col)", + "hive": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)", + "spark2": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)", + "spark": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)", + "databricks": "WITH t AS (SELECT ARRAY(1, 2, 3) AS col) SELECT * FROM t WHERE EXISTS(col, x -> 1 <= x) AND EXISTS(col, x -> 2 = x)", + }, + ) + def test_ddl(self): # Checks that user-defined types are parsed into DataType instead of Identifier self.parse_one("CREATE TABLE t (a udt)").this.expressions[0].args["kind"].assert_is( diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 4c10a45..31a078c 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -414,13 +414,6 @@ class TestPresto(Validator): "CAST(x AS TIMESTAMP)", read={"mysql": "TIMESTAMP(x)"}, ) - self.validate_all( - "TIMESTAMP(x, 'America/Los_Angeles')", - write={ - "duckdb": "CAST(x AS TIMESTAMP) AT TIME ZONE 'America/Los_Angeles'", - "presto": "AT_TIMEZONE(CAST(x AS TIMESTAMP), 'America/Los_Angeles')", - }, - ) # this case isn't really correct, but it's a fall back for mysql's version self.validate_all( "TIMESTAMP(x, '12:00:00')", diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 409a5a6..8357642 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -331,10 +331,15 @@ WHERE "snowflake": "SELECT TIME_FROM_PARTS(12, 34, 56, 987654321)", }, ) + self.validate_identity( + "SELECT TIMESTAMPNTZFROMPARTS(2013, 4, 5, 12, 00, 00)", + "SELECT TIMESTAMP_FROM_PARTS(2013, 4, 5, 12, 00, 00)", + ) self.validate_all( "SELECT TIMESTAMP_FROM_PARTS(2013, 4, 5, 12, 00, 00)", read={ "duckdb": "SELECT MAKE_TIMESTAMP(2013, 4, 5, 12, 00, 00)", + "snowflake": "SELECT TIMESTAMP_NTZ_FROM_PARTS(2013, 4, 5, 12, 00, 00)", }, write={ "duckdb": "SELECT MAKE_TIMESTAMP(2013, 4, 5, 12, 00, 00)", @@ -519,7 +524,6 @@ WHERE self.validate_all( f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", read={ - "snowflake": f"SELECT MEDIAN(x){suffix}", "postgres": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", }, write={ @@ -529,15 +533,6 @@ WHERE "snowflake": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", }, ) - self.validate_all( - f"SELECT MEDIAN(x){suffix}", - write={ - "": f"SELECT PERCENTILE_CONT(x, 0.5){suffix}", - "duckdb": f"SELECT QUANTILE_CONT(x, 0.5){suffix}", - "postgres": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", - "snowflake": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", - }, - ) for func in ( "CORR", "COVAR_POP", @@ -1768,7 +1763,6 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene "REGEXP_SUBSTR(subject, pattern)", read={ "bigquery": "REGEXP_EXTRACT(subject, pattern)", - "snowflake": "REGEXP_EXTRACT(subject, pattern)", }, write={ "bigquery": "REGEXP_EXTRACT(subject, pattern)", diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 01859c6..486bf79 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -315,6 +315,20 @@ TBLPROPERTIES ( }, ) self.validate_all( + "SELECT ARRAY_AGG(1)", + write={ + "duckdb": "SELECT ARRAY_AGG(1)", + "spark": "SELECT COLLECT_LIST(1)", + }, + ) + self.validate_all( + "SELECT ARRAY_AGG(DISTINCT STRUCT('a'))", + write={ + "duckdb": "SELECT ARRAY_AGG(DISTINCT {'col1': 'a'})", + "spark": "SELECT COLLECT_LIST(DISTINCT STRUCT('a' AS col1))", + }, + ) + self.validate_all( "SELECT DATE_FORMAT(DATE '2020-01-01', 'EEEE') AS weekday", write={ "presto": "SELECT DATE_FORMAT(CAST(CAST('2020-01-01' AS DATE) AS TIMESTAMP), '%W') AS weekday", @@ -875,3 +889,9 @@ TBLPROPERTIES ( "databricks": "SELECT * FROM db.table1 EXCEPT SELECT * FROM db.table2", }, ) + + def test_string(self): + for dialect in ("hive", "spark2", "spark", "databricks"): + with self.subTest(f"Testing STRING() for {dialect}"): + query = parse_one("STRING(a)", dialect=dialect) + self.assertEqual(query.sql(dialect), "CAST(a AS STRING)") diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index 230c0e8..e37cdc8 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -222,3 +222,7 @@ class TestSQLite(Validator): "mysql": "CREATE TABLE `x` (`Name` VARCHAR(200) NOT NULL)", }, ) + + self.validate_identity( + "CREATE TABLE store (store_id INTEGER PRIMARY KEY AUTOINCREMENT, mgr_id INTEGER NOT NULL UNIQUE REFERENCES staff ON UPDATE CASCADE)" + ) diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 466f5d5..8951ebe 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -32,6 +32,10 @@ class TestTeradata(Validator): }, ) + self.validate_identity("SELECT 0x1d", "SELECT X'1d'") + self.validate_identity("SELECT X'1D'", "SELECT X'1D'") + self.validate_identity("SELECT x'1d'", "SELECT X'1d'") + self.validate_identity( "RENAME TABLE emp TO employee", check_command_warning=True ).assert_is(exp.Command) diff --git a/tests/dialects/test_trino.py b/tests/dialects/test_trino.py index 33a0229..8e968e9 100644 --- a/tests/dialects/test_trino.py +++ b/tests/dialects/test_trino.py @@ -9,9 +9,30 @@ class TestTrino(Validator): self.validate_identity("JSON_QUERY(content, 'lax $.HY.*')") self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITH UNCONDITIONAL WRAPPER)") self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITHOUT CONDITIONAL WRAPPER)") + + def test_listagg(self): self.validate_identity( "SELECT LISTAGG(DISTINCT col, ',') WITHIN GROUP (ORDER BY col ASC) FROM tbl" ) + self.validate_identity( + "SELECT LISTAGG(col, '; ' ON OVERFLOW ERROR) WITHIN GROUP (ORDER BY col ASC) FROM tbl" + ) + self.validate_identity( + "SELECT LISTAGG(col, '; ' ON OVERFLOW TRUNCATE WITH COUNT) WITHIN GROUP (ORDER BY col ASC) FROM tbl" + ) + self.validate_identity( + "SELECT LISTAGG(col, '; ' ON OVERFLOW TRUNCATE WITHOUT COUNT) WITHIN GROUP (ORDER BY col ASC) FROM tbl" + ) + self.validate_identity( + "SELECT LISTAGG(col, '; ' ON OVERFLOW TRUNCATE '...' WITH COUNT) WITHIN GROUP (ORDER BY col ASC) FROM tbl" + ) + self.validate_identity( + "SELECT LISTAGG(col, '; ' ON OVERFLOW TRUNCATE '...' WITHOUT COUNT) WITHIN GROUP (ORDER BY col ASC) FROM tbl" + ) + self.validate_identity( + "SELECT LISTAGG(col) WITHIN GROUP (ORDER BY col DESC) FROM tbl", + "SELECT LISTAGG(col, ',') WITHIN GROUP (ORDER BY col DESC) FROM tbl", + ) def test_trim(self): self.validate_identity("SELECT TRIM('!' FROM '!foo!')") diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 042891a..e4bd9a7 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -411,6 +411,7 @@ class TestTSQL(Validator): }, ) self.validate_identity("HASHBYTES('MD2', 'x')") + self.validate_identity("LOG(n)") self.validate_identity("LOG(n, b)") self.validate_all( @@ -921,6 +922,12 @@ class TestTSQL(Validator): }, ) self.validate_all( + "IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'baz' AND table_schema = 'bar' AND table_catalog = 'foo') EXEC('WITH cte1 AS (SELECT 1 AS col_a), cte2 AS (SELECT 1 AS col_b) SELECT * INTO foo.bar.baz FROM (SELECT col_a FROM cte1 UNION ALL SELECT col_b FROM cte2) AS temp')", + read={ + "": "CREATE TABLE IF NOT EXISTS foo.bar.baz AS WITH cte1 AS (SELECT 1 AS col_a), cte2 AS (SELECT 1 AS col_b) SELECT col_a FROM cte1 UNION ALL SELECT col_b FROM cte2" + }, + ) + self.validate_all( "CREATE OR ALTER VIEW a.b AS SELECT 1", read={ "": "CREATE OR REPLACE VIEW a.b AS SELECT 1", @@ -1567,7 +1574,7 @@ WHERE "SELECT DATEDIFF(DAY, CAST(a AS DATETIME2), CAST(b AS DATETIME2)) AS x FROM foo", write={ "tsql": "SELECT DATEDIFF(DAY, CAST(a AS DATETIME2), CAST(b AS DATETIME2)) AS x FROM foo", - "clickhouse": "SELECT DATE_DIFF(DAY, CAST(a AS Nullable(DateTime)), CAST(b AS Nullable(DateTime))) AS x FROM foo", + "clickhouse": "SELECT DATE_DIFF(DAY, CAST(CAST(a AS Nullable(DateTime)) AS DateTime64(6)), CAST(CAST(b AS Nullable(DateTime)) AS DateTime64(6))) AS x FROM foo", }, ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index bed2502..33199de 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -250,7 +250,6 @@ SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x SELECT LEAD(a, 1, b) OVER (PARTITION BY a ORDER BY a) AS x SELECT X((a, b) -> a + b, z -> z) AS x SELECT X(a -> a + ("z" - 1)) -SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0) SELECT test.* FROM test SELECT a AS b FROM test SELECT "a"."b" FROM "a" diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql index 1f59a5a..1dd7c2d 100644 --- a/tests/fixtures/optimizer/annotate_functions.sql +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -307,3 +307,11 @@ ARRAY<STRING>; # dialect: bigquery SPLIT(tbl.bin_col, delim); ARRAY<BINARY>; + +# dialect: bigquery +STRING(json_expr); +STRING; + +# dialect: bigquery +STRING(timestamp_expr, timezone); +STRING;
\ No newline at end of file diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql index 49e07fa..03e8dbe 100644 --- a/tests/fixtures/optimizer/qualify_tables.sql +++ b/tests/fixtures/optimizer/qualify_tables.sql @@ -14,6 +14,26 @@ SELECT 1 FROM x.y.z AS z; SELECT 1 FROM x.y.z AS z; SELECT 1 FROM x.y.z AS z; +# title: only information schema +# dialect: bigquery +SELECT * FROM information_schema.tables; +SELECT * FROM c.db.`information_schema.tables` AS tables; + +# title: information schema with db +# dialect: bigquery +SELECT * FROM y.information_schema.tables; +SELECT * FROM c.y.`information_schema.tables` AS tables; + +# title: information schema with db, catalog +# dialect: bigquery +SELECT * FROM x.y.information_schema.tables; +SELECT * FROM x.y.`information_schema.tables` AS tables; + +# title: information schema with db, catalog, alias +# dialect: bigquery +SELECT * FROM x.y.information_schema.tables AS z; +SELECT * FROM x.y.`information_schema.tables` AS z; + # title: redshift unnest syntax, z.a should be a column, not a table # dialect: redshift SELECT 1 FROM y.z AS z, z.a; diff --git a/tests/test_diff.py b/tests/test_diff.py index f0e0747..440502e 100644 --- a/tests/test_diff.py +++ b/tests/test_diff.py @@ -2,7 +2,6 @@ import unittest from sqlglot import exp, parse_one from sqlglot.diff import Insert, Move, Remove, Update, diff -from sqlglot.expressions import Join, to_table def diff_delta_only(source, target, matchings=None, **kwargs): @@ -14,22 +13,24 @@ class TestDiff(unittest.TestCase): self._validate_delta_only( diff_delta_only(parse_one("SELECT a + b"), parse_one("SELECT a - b")), [ - Remove(parse_one("a + b")), # the Add node - Insert(parse_one("a - b")), # the Sub node + Remove(expression=parse_one("a + b")), # the Add node + Insert(expression=parse_one("a - b")), # the Sub node + Move(source=parse_one("a"), target=parse_one("a")), # the `a` Column node + Move(source=parse_one("b"), target=parse_one("b")), # the `b` Column node ], ) self._validate_delta_only( diff_delta_only(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")), [ - Remove(parse_one("b")), # the Column node + Remove(expression=parse_one("b")), # the Column node ], ) self._validate_delta_only( diff_delta_only(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")), [ - Insert(parse_one("c")), # the Column node + Insert(expression=parse_one("c")), # the Column node ], ) @@ -40,8 +41,8 @@ class TestDiff(unittest.TestCase): ), [ Update( - to_table("table_one", quoted=False), - to_table("table_two", quoted=False), + source=exp.to_table("table_one", quoted=False), + target=exp.to_table("table_two", quoted=False), ), # the Table node ], ) @@ -53,8 +54,12 @@ class TestDiff(unittest.TestCase): ), [ Update( - exp.Lambda(this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")]), - exp.Lambda(this=exp.to_identifier("b"), expressions=[exp.to_identifier("b")]), + source=exp.Lambda( + this=exp.to_identifier("a"), expressions=[exp.to_identifier("a")] + ), + target=exp.Lambda( + this=exp.to_identifier("b"), expressions=[exp.to_identifier("b")] + ), ), ], ) @@ -65,8 +70,8 @@ class TestDiff(unittest.TestCase): parse_one('SELECT a, b, "my.udf1"()'), parse_one('SELECT a, b, "my.udf2"()') ), [ - Insert(parse_one('"my.udf2"()')), - Remove(parse_one('"my.udf1"()')), + Insert(expression=parse_one('"my.udf2"()')), + Remove(expression=parse_one('"my.udf1"()')), ], ) self._validate_delta_only( @@ -75,41 +80,73 @@ class TestDiff(unittest.TestCase): parse_one('SELECT a, b, "my.udf"(x, y, w)'), ), [ - Insert(exp.column("w")), - Remove(exp.column("z")), + Insert(expression=exp.column("w")), + Remove(expression=exp.column("z")), ], ) def test_node_position_changed(self): + expr_src = parse_one("SELECT a, b, c") + expr_tgt = parse_one("SELECT c, a, b") + self._validate_delta_only( - diff_delta_only(parse_one("SELECT a, b, c"), parse_one("SELECT c, a, b")), + diff_delta_only(expr_src, expr_tgt), [ - Move(parse_one("c")), # the Column node + Move(source=expr_src.selects[2], target=expr_tgt.selects[0]), ], ) + expr_src = parse_one("SELECT a + b") + expr_tgt = parse_one("SELECT b + a") + self._validate_delta_only( - diff_delta_only(parse_one("SELECT a + b"), parse_one("SELECT b + a")), + diff_delta_only(expr_src, expr_tgt), [ - Move(parse_one("a")), # the Column node + Move(source=expr_src.selects[0].left, target=expr_tgt.selects[0].right), ], ) + expr_src = parse_one("SELECT aaaa AND bbbb") + expr_tgt = parse_one("SELECT bbbb AND aaaa") + self._validate_delta_only( - diff_delta_only(parse_one("SELECT aaaa AND bbbb"), parse_one("SELECT bbbb AND aaaa")), + diff_delta_only(expr_src, expr_tgt), [ - Move(parse_one("aaaa")), # the Column node + Move(source=expr_src.selects[0].left, target=expr_tgt.selects[0].right), ], ) + expr_src = parse_one("SELECT aaaa OR bbbb OR cccc") + expr_tgt = parse_one("SELECT cccc OR bbbb OR aaaa") + self._validate_delta_only( - diff_delta_only( - parse_one("SELECT aaaa OR bbbb OR cccc"), - parse_one("SELECT cccc OR bbbb OR aaaa"), - ), + diff_delta_only(expr_src, expr_tgt), + [ + Move(source=expr_src.selects[0].left.left, target=expr_tgt.selects[0].right), + Move(source=expr_src.selects[0].right, target=expr_tgt.selects[0].left.left), + ], + ) + + expr_src = parse_one("SELECT a, b FROM t WHERE CONCAT('a', 'b') = 'ab'") + expr_tgt = parse_one("SELECT a FROM t WHERE CONCAT('a', 'b', b) = 'ab'") + + self._validate_delta_only( + diff_delta_only(expr_src, expr_tgt), + [ + Move(source=expr_src.selects[1], target=expr_tgt.find(exp.Concat).expressions[-1]), + ], + ) + + expr_src = parse_one("SELECT a as a, b as b FROM t WHERE CONCAT('a', 'b') = 'ab'") + expr_tgt = parse_one("SELECT a as a FROM t WHERE CONCAT('a', 'b', b) = 'ab'") + + b_alias = expr_src.selects[1] + + self._validate_delta_only( + diff_delta_only(expr_src, expr_tgt), [ - Move(parse_one("aaaa")), # the Column node - Move(parse_one("cccc")), # the Column node + Remove(expression=b_alias), + Move(source=b_alias.this, target=expr_tgt.find(exp.Concat).expressions[-1]), ], ) @@ -130,23 +167,30 @@ class TestDiff(unittest.TestCase): self._validate_delta_only( diff_delta_only(parse_one(expr_src), parse_one(expr_tgt)), [ - Remove(parse_one("LOWER(c) AS c")), # the Alias node - Remove(parse_one("LOWER(c)")), # the Lower node - Remove(parse_one("'filter'")), # the Literal node - Insert(parse_one("'different_filter'")), # the Literal node + Remove(expression=parse_one("LOWER(c) AS c")), # the Alias node + Remove(expression=parse_one("LOWER(c)")), # the Lower node + Remove(expression=parse_one("'filter'")), # the Literal node + Insert(expression=parse_one("'different_filter'")), # the Literal node + Move(source=parse_one("c"), target=parse_one("c")), # the new Column c ], ) def test_join(self): - expr_src = "SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key" - expr_tgt = "SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key" + expr_src = parse_one("SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key") + expr_tgt = parse_one("SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key") - changes = diff_delta_only(parse_one(expr_src), parse_one(expr_tgt)) + src_join = expr_src.find(exp.Join) + tgt_join = expr_tgt.find(exp.Join) - self.assertEqual(len(changes), 2) - self.assertTrue(isinstance(changes[0], Remove)) - self.assertTrue(isinstance(changes[1], Insert)) - self.assertTrue(all(isinstance(c.expression, Join) for c in changes)) + self._validate_delta_only( + diff_delta_only(expr_src, expr_tgt), + [ + Remove(expression=src_join), + Insert(expression=tgt_join), + Move(source=exp.to_table("t2"), target=exp.to_table("t2")), + Move(source=src_join.args["on"], target=tgt_join.args["on"]), + ], + ) def test_window_functions(self): expr_src = parse_one("SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b)") @@ -157,8 +201,8 @@ class TestDiff(unittest.TestCase): self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ - Remove(parse_one("ROW_NUMBER()")), - Insert(parse_one("RANK()")), + Remove(expression=parse_one("ROW_NUMBER()")), + Insert(expression=parse_one("RANK()")), Update(source=expr_src.selects[0], target=expr_tgt.selects[0]), ], ) @@ -178,20 +222,21 @@ class TestDiff(unittest.TestCase): self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ - Remove(expr_src), - Insert(expr_tgt), - Insert(exp.Literal.number(2)), - Insert(exp.Literal.number(3)), - Insert(exp.Literal.number(4)), + Remove(expression=expr_src), + Insert(expression=expr_tgt), + Insert(expression=exp.Literal.number(2)), + Insert(expression=exp.Literal.number(3)), + Insert(expression=exp.Literal.number(4)), + Move(source=exp.Literal.number(1), target=exp.Literal.number(1)), ], ) self._validate_delta_only( diff_delta_only(expr_src, expr_tgt, matchings=[(expr_src, expr_tgt)]), [ - Insert(exp.Literal.number(2)), - Insert(exp.Literal.number(3)), - Insert(exp.Literal.number(4)), + Insert(expression=exp.Literal.number(2)), + Insert(expression=exp.Literal.number(3)), + Insert(expression=exp.Literal.number(4)), ], ) @@ -274,7 +319,7 @@ class TestDiff(unittest.TestCase): source=expr_src.find(exp.Order).expressions[0], target=expr_tgt.find(exp.Order).expressions[0], ), - Move(parse_one("a")), + Move(source=expr_src.selects[0], target=expr_tgt.selects[1]), ], ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 9313285..0fa4ff6 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -315,7 +315,7 @@ class TestOptimizer(unittest.TestCase): ), dialect="bigquery", ).sql(), - 'WITH "x" AS (SELECT "y"."a" AS "a" FROM "DB"."y" AS "y" CROSS JOIN "a"."b"."INFORMATION_SCHEMA"."COLUMNS" AS "COLUMNS") SELECT "x"."a" AS "a" FROM "x" AS "x"', + 'WITH "x" AS (SELECT "y"."a" AS "a" FROM "DB"."y" AS "y" CROSS JOIN "a"."b"."INFORMATION_SCHEMA.COLUMNS" AS "columns") SELECT "x"."a" AS "a" FROM "x" AS "x"', ) self.assertEqual( @@ -1337,6 +1337,47 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(union_by_name.selects[0].type.this, exp.DataType.Type.BIGINT) self.assertEqual(union_by_name.selects[1].type.this, exp.DataType.Type.DOUBLE) + # Test chained UNIONs + sql = """ + WITH t AS + ( + SELECT NULL AS col + UNION + SELECT NULL AS col + UNION + SELECT 'a' AS col + UNION + SELECT NULL AS col + UNION + SELECT NULL AS col + ) + SELECT col FROM t; + """ + self.assertEqual(optimizer.optimize(sql).selects[0].type.this, exp.DataType.Type.VARCHAR) + + # Test UNIONs with nested subqueries + sql = """ + WITH t AS + ( + SELECT NULL AS col + UNION + (SELECT NULL AS col UNION ALL SELECT 'a' AS col) + ) + SELECT col FROM t; + """ + self.assertEqual(optimizer.optimize(sql).selects[0].type.this, exp.DataType.Type.VARCHAR) + + sql = """ + WITH t AS + ( + (SELECT NULL AS col UNION ALL SELECT 'a' AS col) + UNION + SELECT NULL AS col + ) + SELECT col FROM t; + """ + self.assertEqual(optimizer.optimize(sql).selects[0].type.this, exp.DataType.Type.VARCHAR) + def test_recursive_cte(self): query = parse_one( """ |