From 6d546bfddf465f629d17ee52f78b477eb632fd91 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 13 Jul 2024 13:12:05 +0200 Subject: Merging upstream version 25.5.1. Signed-off-by: Daniel Baumann --- tests/dialects/test_bigquery.py | 55 ++++++++++++++++ tests/dialects/test_clickhouse.py | 96 ++++++++++++++++++---------- tests/dialects/test_databricks.py | 41 ++++++------ tests/dialects/test_dialect.py | 21 ++++-- tests/dialects/test_doris.py | 28 ++++++++ tests/dialects/test_duckdb.py | 11 +++- tests/dialects/test_mysql.py | 1 + tests/dialects/test_oracle.py | 71 ++------------------ tests/dialects/test_postgres.py | 20 ++++-- tests/dialects/test_presto.py | 22 +++++++ tests/dialects/test_redshift.py | 18 +++++- tests/dialects/test_snowflake.py | 49 ++++++++++---- tests/dialects/test_spark.py | 9 ++- tests/dialects/test_teradata.py | 9 +++ tests/dialects/test_tsql.py | 48 +++++++++++--- tests/fixtures/identity.sql | 6 +- tests/fixtures/optimizer/qualify_columns.sql | 12 ++++ tests/fixtures/optimizer/tpc-h/tpc-h.sql | 14 ++-- tests/test_executor.py | 17 +++++ tests/test_expressions.py | 18 ++++-- tests/test_jsonpath.py | 11 ++-- tests/test_optimizer.py | 41 +++++++++++- tests/test_parser.py | 9 +++ tests/test_transforms.py | 82 +++++++++++++++++++++++- tests/test_transpile.py | 18 ++++-- 25 files changed, 543 insertions(+), 184 deletions(-) (limited to 'tests') diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index ae8ed16..803ac11 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -103,6 +103,7 @@ LANGUAGE js AS select_with_quoted_udf = self.validate_identity("SELECT `p.d.UdF`(data) FROM `p.d.t`") self.assertEqual(select_with_quoted_udf.selects[0].name, "p.d.UdF") + self.validate_identity("CAST(x AS STRUCT>)") self.validate_identity("assert.true(1 = 1)") self.validate_identity("SELECT ARRAY_TO_STRING(list, '--') AS text") self.validate_identity("SELECT jsondoc['some_key']") @@ -293,6 +294,20 @@ LANGUAGE js AS r"REGEXP_EXTRACT(svc_plugin_output, '\\\\\\((.*)')", ) + self.validate_all( + "SAFE_CAST(some_date AS DATE FORMAT 'DD MONTH YYYY')", + write={ + "bigquery": "SAFE_CAST(some_date AS DATE FORMAT 'DD MONTH YYYY')", + "duckdb": "CAST(TRY_STRPTIME(some_date, '%d %B %Y') AS DATE)", + }, + ) + self.validate_all( + "SAFE_CAST(some_date AS DATE FORMAT 'YYYY-MM-DD') AS some_date", + write={ + "bigquery": "SAFE_CAST(some_date AS DATE FORMAT 'YYYY-MM-DD') AS some_date", + "duckdb": "CAST(TRY_STRPTIME(some_date, '%Y-%m-%d') AS DATE) AS some_date", + }, + ) self.validate_all( "SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS h, UNNEST(h.t3) AS s", write={ @@ -1345,6 +1360,46 @@ WHERE "bigquery": "SELECT CAST(x AS DATETIME)", }, ) + self.validate_all( + "SELECT TIME(foo, 'America/Los_Angeles')", + write={ + "duckdb": "SELECT CAST(CAST(foo AS TIMESTAMPTZ) AT TIME ZONE 'America/Los_Angeles' AS TIME)", + "bigquery": "SELECT TIME(foo, 'America/Los_Angeles')", + }, + ) + self.validate_all( + "SELECT DATETIME('2020-01-01')", + write={ + "duckdb": "SELECT CAST('2020-01-01' AS TIMESTAMP)", + "bigquery": "SELECT DATETIME('2020-01-01')", + }, + ) + self.validate_all( + "SELECT DATETIME('2020-01-01', TIME '23:59:59')", + write={ + "duckdb": "SELECT CAST(CAST('2020-01-01' AS DATE) + CAST('23:59:59' AS TIME) AS TIMESTAMP)", + "bigquery": "SELECT DATETIME('2020-01-01', CAST('23:59:59' AS TIME))", + }, + ) + self.validate_all( + "SELECT DATETIME('2020-01-01', 'America/Los_Angeles')", + write={ + "duckdb": "SELECT CAST(CAST('2020-01-01' AS TIMESTAMPTZ) AT TIME ZONE 'America/Los_Angeles' AS TIMESTAMP)", + "bigquery": "SELECT DATETIME('2020-01-01', 'America/Los_Angeles')", + }, + ) + self.validate_all( + "SELECT LENGTH(foo)", + read={ + "bigquery": "SELECT LENGTH(foo)", + "snowflake": "SELECT LENGTH(foo)", + }, + write={ + "duckdb": "SELECT CASE TYPEOF(foo) WHEN 'VARCHAR' THEN LENGTH(CAST(foo AS TEXT)) WHEN 'BLOB' THEN OCTET_LENGTH(CAST(foo AS BLOB)) END", + "snowflake": "SELECT LENGTH(foo)", + "": "SELECT LENGTH(foo)", + }, + ) def test_errors(self): with self.assertRaises(TokenError): diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 72634a8..ef84d48 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -7,23 +7,6 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): - self.validate_all( - "SELECT * FROM x PREWHERE y = 1 WHERE z = 2", - write={ - "": "SELECT * FROM x WHERE z = 2", - "clickhouse": "SELECT * FROM x PREWHERE y = 1 WHERE z = 2", - }, - ) - self.validate_all( - "SELECT * FROM x AS prewhere", - read={ - "clickhouse": "SELECT * FROM x AS prewhere", - "duckdb": "SELECT * FROM x prewhere", - }, - ) - - self.validate_identity("SELECT * FROM x LIMIT 1 UNION ALL SELECT * FROM y") - string_types = [ "BLOB", "LONGBLOB", @@ -42,6 +25,9 @@ class TestClickhouse(Validator): self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)") self.assertIsNone(expr._meta) + self.validate_identity("SELECT EXTRACT(YEAR FROM toDateTime('2023-02-01'))") + self.validate_identity("extract(haystack, pattern)") + self.validate_identity("SELECT * FROM x LIMIT 1 UNION ALL SELECT * FROM y") self.validate_identity("SELECT CAST(x AS Tuple(String, Array(Nullable(Float64))))") self.validate_identity("countIf(x, y)") self.validate_identity("x = y") @@ -94,18 +80,12 @@ class TestClickhouse(Validator): self.validate_identity("""SELECT JSONExtractString('{"x": {"y": 1}}', 'x', 'y')""") self.validate_identity("SELECT * FROM table LIMIT 1 BY a, b") self.validate_identity("SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b") + self.validate_identity("TRUNCATE TABLE t1 ON CLUSTER test_cluster") + self.validate_identity("TRUNCATE DATABASE db") + self.validate_identity("TRUNCATE DATABASE db ON CLUSTER test_cluster") self.validate_identity( "SELECT id, quantileGK(100, 0.95)(reading) OVER (PARTITION BY id ORDER BY id RANGE BETWEEN 30000 PRECEDING AND CURRENT ROW) AS window FROM table" ) - - self.validate_identity( - "SELECT $1$foo$1$", - "SELECT 'foo'", - ) - self.validate_identity( - "SELECT * FROM table LIMIT 1, 2 BY a, b", - "SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b", - ) self.validate_identity( "SELECT * FROM table LIMIT 1 BY CONCAT(datalayerVariantNo, datalayerProductId, warehouse)" ) @@ -133,10 +113,6 @@ class TestClickhouse(Validator): self.validate_identity( "SELECT sum(1) AS impressions, (arrayJoin(arrayZip(cities, browsers)) AS t).1 AS city, t.2 AS browser FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities, ['Firefox', 'Chrome', 'Chrome'] AS browsers) GROUP BY 2, 3" ) - self.validate_identity( - "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ['Istanbul', 'Berlin']", - "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ('Istanbul', 'Berlin')", - ) self.validate_identity( 'SELECT CAST(tuple(1 AS "a", 2 AS "b", 3.0 AS "c").2 AS Nullable(String))' ) @@ -155,12 +131,43 @@ class TestClickhouse(Validator): self.validate_identity( "CREATE MATERIALIZED VIEW test_view (id UInt8) TO db.table1 AS SELECT * FROM test_data" ) - self.validate_identity("TRUNCATE TABLE t1 ON CLUSTER test_cluster") - self.validate_identity("TRUNCATE DATABASE db") - self.validate_identity("TRUNCATE DATABASE db ON CLUSTER test_cluster") self.validate_identity( "CREATE TABLE t (foo String CODEC(LZ4HC(9), ZSTD, DELTA), size String ALIAS formatReadableSize(size_bytes), INDEX idx1 a TYPE bloom_filter(0.001) GRANULARITY 1, INDEX idx2 a TYPE set(100) GRANULARITY 2, INDEX idx3 a TYPE minmax GRANULARITY 3)" ) + self.validate_identity( + "SELECT $1$foo$1$", + "SELECT 'foo'", + ) + self.validate_identity( + "SELECT * FROM table LIMIT 1, 2 BY a, b", + "SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b", + ) + self.validate_identity( + "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ['Istanbul', 'Berlin']", + "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ('Istanbul', 'Berlin')", + ) + + self.validate_all( + "SELECT * FROM x PREWHERE y = 1 WHERE z = 2", + write={ + "": "SELECT * FROM x WHERE z = 2", + "clickhouse": "SELECT * FROM x PREWHERE y = 1 WHERE z = 2", + }, + ) + self.validate_all( + "SELECT * FROM x AS prewhere", + read={ + "clickhouse": "SELECT * FROM x AS prewhere", + "duckdb": "SELECT * FROM x prewhere", + }, + ) + self.validate_all( + "SELECT a, b FROM (SELECT * FROM x) AS t", + read={ + "clickhouse": "SELECT a, b FROM (SELECT * FROM x) AS t", + "duckdb": "SELECT a, b FROM (SELECT * FROM x) AS t(a, b)", + }, + ) self.validate_all( "SELECT arrayJoin([1,2,3])", write={ @@ -880,3 +887,26 @@ LIFETIME(MIN 0 MAX 0)""", for creatable in ("DATABASE", "TABLE", "VIEW", "DICTIONARY", "FUNCTION"): with self.subTest(f"Test DROP {creatable} ON CLUSTER"): self.validate_identity(f"DROP {creatable} test ON CLUSTER test_cluster") + + def test_datetime_funcs(self): + # Each datetime func has an alias that is roundtripped to the original name e.g. (DATE_SUB, DATESUB) -> DATE_SUB + datetime_funcs = (("DATE_SUB", "DATESUB"), ("DATE_ADD", "DATEADD")) + + # 2-arg functions of type (date, unit) + for func in (*datetime_funcs, ("TIMESTAMP_ADD", "TIMESTAMPADD")): + func_name = func[0] + for func_alias in func: + self.validate_identity( + f"""SELECT {func_alias}(date, INTERVAL '3' YEAR)""", + f"""SELECT {func_name}(date, INTERVAL '3' YEAR)""", + ) + + # 3-arg functions of type (unit, value, date) + for func in (*datetime_funcs, ("DATE_DIFF", "DATEDIFF"), ("TIMESTAMP_SUB", "TIMESTAMPSUB")): + func_name = func[0] + for func_alias in func: + with self.subTest(f"Test 3-arg date-time function {func_alias}"): + self.validate_identity( + f"SELECT {func_alias}(SECOND, 1, bar)", + f"SELECT {func_name}(SECOND, 1, bar)", + ) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 9ef3b86..471830f 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -1,4 +1,4 @@ -from sqlglot import transpile +from sqlglot import exp, transpile from sqlglot.errors import ParseError from tests.dialects.test_dialect import Validator @@ -25,6 +25,7 @@ class TestDatabricks(Validator): self.validate_identity("CREATE FUNCTION a AS b") self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1") self.validate_identity("CREATE TABLE foo (x DATE GENERATED ALWAYS AS (CAST(y AS DATE)))") + self.validate_identity("TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', address)") self.validate_identity( "CREATE TABLE IF NOT EXISTS db.table (a TIMESTAMP, b BOOLEAN GENERATED ALWAYS AS (NOT a IS NULL)) USING DELTA" ) @@ -37,22 +38,26 @@ class TestDatabricks(Validator): self.validate_identity( "SELECT * FROM sales UNPIVOT EXCLUDE NULLS (sales FOR quarter IN (q1 AS `Jan-Mar`))" ) - self.validate_identity( "CREATE FUNCTION add_one(x INT) RETURNS INT LANGUAGE PYTHON AS $$def add_one(x):\n return x+1$$" ) - self.validate_identity( "CREATE FUNCTION add_one(x INT) RETURNS INT LANGUAGE PYTHON AS $FOO$def add_one(x):\n return x+1$FOO$" ) - - self.validate_identity("TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', address)") self.validate_identity( "TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', city LIKE 'LA')" ) self.validate_identity( "COPY INTO target FROM `s3://link` FILEFORMAT = AVRO VALIDATE = ALL FILES = ('file1', 'file2') FORMAT_OPTIONS ('opt1'='true', 'opt2'='test') COPY_OPTIONS ('mergeSchema'='true')" ) + self.validate_identity( + "DATE_DIFF(day, created_at, current_date())", + "DATEDIFF(DAY, created_at, CURRENT_DATE)", + ).args["unit"].assert_is(exp.Var) + self.validate_identity( + r'SELECT r"\\foo.bar\"', + r"SELECT '\\\\foo.bar\\'", + ) self.validate_all( "CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))", @@ -67,7 +72,6 @@ class TestDatabricks(Validator): "teradata": "CREATE TABLE t1 AS (SELECT c FROM t2) WITH DATA", }, ) - self.validate_all( "SELECT X'1A2B'", read={ @@ -96,33 +100,30 @@ class TestDatabricks(Validator): # https://docs.databricks.com/sql/language-manual/functions/colonsign.html def test_json(self): + self.validate_identity("SELECT c1:price, c1:price.foo, c1:price.bar[1]") self.validate_identity( - """SELECT c1 : price FROM VALUES ('{ "price": 5 }') AS T(c1)""", - """SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""", + """SELECT c1:item[1].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""" ) self.validate_identity( - """SELECT c1:['price'] FROM VALUES('{ "price": 5 }') AS T(c1)""", - """SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""", + """SELECT c1:item[*].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""" ) self.validate_identity( - """SELECT c1:item[1].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", - """SELECT GET_JSON_OBJECT(c1, '$.item[1].price') FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + """SELECT FROM_JSON(c1:item[*].price, 'ARRAY')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""" ) self.validate_identity( - """SELECT c1:item[*].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", - """SELECT GET_JSON_OBJECT(c1, '$.item[*].price') FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + """SELECT INLINE(FROM_JSON(c1:item[*], 'ARRAY>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""" ) self.validate_identity( - """SELECT from_json(c1:item[*].price, 'ARRAY')[0] FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", - """SELECT FROM_JSON(GET_JSON_OBJECT(c1, '$.item[*].price'), 'ARRAY')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + """SELECT c1:['price'] FROM VALUES ('{ "price": 5 }') AS T(c1)""", + """SELECT c1:price FROM VALUES ('{ "price": 5 }') AS T(c1)""", ) self.validate_identity( - """SELECT inline(from_json(c1:item[*], 'ARRAY>')) FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", - """SELECT INLINE(FROM_JSON(GET_JSON_OBJECT(c1, '$.item[*]'), 'ARRAY>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + """SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""", + """SELECT c1:price FROM VALUES ('{ "price": 5 }') AS T(c1)""", ) self.validate_identity( - "SELECT c1 : price", - "SELECT GET_JSON_OBJECT(c1, '$.price')", + """SELECT raw:`zip code`, raw:`fb:testid`, raw:store['bicycle'], raw:store["zip code"]""", + """SELECT raw:["zip code"], raw:["fb:testid"], raw:store.bicycle, raw:store["zip code"]""", ) def test_datediff(self): diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index aaeb7b0..c0afb2f 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -102,14 +102,10 @@ class TestDialect(Validator): lowercase_mysql = Dialect.get_or_raise("mysql, normalization_strategy = lowercase") self.assertEqual(lowercase_mysql.normalization_strategy.value, "LOWERCASE") - with self.assertRaises(ValueError) as cm: + with self.assertRaises(AttributeError) as cm: Dialect.get_or_raise("mysql, normalization_strategy") - self.assertEqual( - str(cm.exception), - "Invalid dialect format: 'mysql, normalization_strategy'. " - "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'.", - ) + self.assertEqual(str(cm.exception), "'bool' object has no attribute 'upper'") with self.assertRaises(ValueError) as cm: Dialect.get_or_raise("myqsl") @@ -121,6 +117,18 @@ class TestDialect(Validator): self.assertEqual(str(cm.exception), "Unknown dialect 'asdfjasodiufjsd'.") + oracle_with_settings = Dialect.get_or_raise( + "oracle, normalization_strategy = lowercase, version = 19.5" + ) + self.assertEqual(oracle_with_settings.normalization_strategy.value, "LOWERCASE") + self.assertEqual(oracle_with_settings.settings, {"version": "19.5"}) + + bool_settings = Dialect.get_or_raise("oracle, s1=TruE, s2=1, s3=FaLse, s4=0, s5=nonbool") + self.assertEqual( + bool_settings.settings, + {"s1": True, "s2": True, "s3": False, "s4": False, "s5": "nonbool"}, + ) + def test_compare_dialects(self): bigquery_class = Dialect["bigquery"] bigquery_object = BigQuery() @@ -1150,7 +1158,6 @@ class TestDialect(Validator): write={ "bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", - "oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", diff --git a/tests/dialects/test_doris.py b/tests/dialects/test_doris.py index 8180d05..99076ba 100644 --- a/tests/dialects/test_doris.py +++ b/tests/dialects/test_doris.py @@ -56,6 +56,34 @@ class TestDoris(Validator): "postgres": "SELECT STRING_AGG('aa', ',')", }, ) + self.validate_all( + "SELECT LAG(1, 1, NULL) OVER (ORDER BY 1)", + read={ + "doris": "SELECT LAG(1, 1, NULL) OVER (ORDER BY 1)", + "postgres": "SELECT LAG(1) OVER (ORDER BY 1)", + }, + ) + self.validate_all( + "SELECT LAG(1, 2, NULL) OVER (ORDER BY 1)", + read={ + "doris": "SELECT LAG(1, 2, NULL) OVER (ORDER BY 1)", + "postgres": "SELECT LAG(1, 2) OVER (ORDER BY 1)", + }, + ) + self.validate_all( + "SELECT LEAD(1, 1, NULL) OVER (ORDER BY 1)", + read={ + "doris": "SELECT LEAD(1, 1, NULL) OVER (ORDER BY 1)", + "postgres": "SELECT LEAD(1) OVER (ORDER BY 1)", + }, + ) + self.validate_all( + "SELECT LEAD(1, 2, NULL) OVER (ORDER BY 1)", + read={ + "doris": "SELECT LEAD(1, 2, NULL) OVER (ORDER BY 1)", + "postgres": "SELECT LEAD(1, 2) OVER (ORDER BY 1)", + }, + ) def test_identity(self): self.validate_identity("COALECSE(a, b, c, d)") diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 2bde478..e0b0131 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -18,6 +18,13 @@ class TestDuckDB(Validator): "WITH _data AS (SELECT [STRUCT(1 AS a, 2 AS b), STRUCT(2 AS a, 3 AS b)] AS col) SELECT col.b FROM _data, UNNEST(_data.col) AS col WHERE col.a = 1", ) + self.validate_all( + """SELECT CASE WHEN JSON_VALID('{"x: 1}') THEN '{"x: 1}' ELSE NULL END""", + read={ + "duckdb": """SELECT CASE WHEN JSON_VALID('{"x: 1}') THEN '{"x: 1}' ELSE NULL END""", + "snowflake": """SELECT TRY_PARSE_JSON('{"x: 1}')""", + }, + ) self.validate_all( "SELECT straight_join", write={ @@ -786,6 +793,8 @@ class TestDuckDB(Validator): }, ) + self.validate_identity("SELECT LENGTH(foo)") + def test_array_index(self): with self.assertLogs(helper_logger) as cm: self.validate_all( @@ -847,7 +856,7 @@ class TestDuckDB(Validator): read={"bigquery": "SELECT DATE(PARSE_DATE('%m/%d/%Y', '05/06/2020'))"}, ) self.validate_all( - "SELECT CAST('2020-01-01' AS DATE) + INTERVAL (-1) DAY", + "SELECT CAST('2020-01-01' AS DATE) + INTERVAL '-1' DAY", read={"mysql": "SELECT DATE '2020-01-01' + INTERVAL -1 DAY"}, ) self.validate_all( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 280ebbf..bfdb2a6 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -117,6 +117,7 @@ class TestMySQL(Validator): ) def test_identity(self): + self.validate_identity("SELECT CAST(COALESCE(`id`, 'NULL') AS CHAR CHARACTER SET binary)") self.validate_identity("SELECT e.* FROM e STRAIGHT_JOIN p ON e.x = p.y") self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1") self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')") diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 7cc4d72..1d9fd99 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -1,5 +1,4 @@ from sqlglot import exp, UnsupportedError -from sqlglot.dialects.oracle import eliminate_join_marks from tests.dialects.test_dialect import Validator @@ -10,7 +9,7 @@ class TestOracle(Validator): self.validate_all( "SELECT CONNECT_BY_ROOT x y", write={ - "": "SELECT CONNECT_BY_ROOT(x) AS y", + "": "SELECT CONNECT_BY_ROOT x AS y", "oracle": "SELECT CONNECT_BY_ROOT x AS y", }, ) @@ -87,9 +86,9 @@ class TestOracle(Validator): "SELECT DISTINCT col1, col2 FROM table", ) self.validate_identity( - "SELECT * FROM T ORDER BY I OFFSET nvl(:variable1, 10) ROWS FETCH NEXT nvl(:variable2, 10) ROWS ONLY", - "SELECT * FROM T ORDER BY I OFFSET COALESCE(:variable1, 10) ROWS FETCH NEXT COALESCE(:variable2, 10) ROWS ONLY", + "SELECT * FROM T ORDER BY I OFFSET NVL(:variable1, 10) ROWS FETCH NEXT NVL(:variable2, 10) ROWS ONLY", ) + self.validate_identity("NVL(x, y)").assert_is(exp.Anonymous) self.validate_identity( "SELECT * FROM t SAMPLE (.25)", "SELECT * FROM t SAMPLE (0.25)", @@ -190,13 +189,6 @@ class TestOracle(Validator): "spark": "SELECT CAST(NULL AS VARCHAR(2328)) AS COL1", }, ) - self.validate_all( - "NVL(NULL, 1)", - write={ - "": "COALESCE(NULL, 1)", - "oracle": "COALESCE(NULL, 1)", - }, - ) self.validate_all( "DATE '2022-01-01'", write={ @@ -245,6 +237,10 @@ class TestOracle(Validator): "duckdb": "SELECT CAST(STRPTIME('2024-12-12', '%Y-%m-%d') AS DATE)", }, ) + self.validate_identity( + """SELECT * FROM t ORDER BY a ASC NULLS LAST, b ASC NULLS FIRST, c DESC NULLS LAST, d DESC NULLS FIRST""", + """SELECT * FROM t ORDER BY a ASC, b ASC NULLS FIRST, c DESC NULLS LAST, d DESC""", + ) def test_join_marker(self): self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y") @@ -416,59 +412,6 @@ WHERE for query in (f"{body}{start}{connect}", f"{body}{connect}{start}"): self.validate_identity(query, pretty, pretty=True) - def test_eliminate_join_marks(self): - test_sql = [ - ( - "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) > 5", - "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y > 5", - ), - ( - "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) IS NULL", - "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y IS NULL", - ), - ( - "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y IS NULL", - "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T2.y IS NULL", - ), - ( - "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T1.Z > 4", - "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T1.Z > 4", - ), - ( - "SELECT * FROM table1, table2 WHERE table1.column = table2.column(+)", - "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column", - ), - ( - "SELECT * FROM table1, table2, table3, table4 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+) and table1.column = table4.column(+)", - "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column LEFT JOIN table4 ON table1.column = table4.column", - ), - ( - "SELECT * FROM table1, table2, table3 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+)", - "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column", - ), - ( - "SELECT table1.id, table2.cloumn1, table3.id FROM table1, table2, (SELECT tableInner1.id FROM tableInner1, tableInner2 WHERE tableInner1.id = tableInner2.id(+)) AS table3 WHERE table1.id = table2.id(+) and table1.id = table3.id(+)", - "SELECT table1.id, table2.cloumn1, table3.id FROM table1 LEFT JOIN table2 ON table1.id = table2.id LEFT JOIN (SELECT tableInner1.id FROM tableInner1 LEFT JOIN tableInner2 ON tableInner1.id = tableInner2.id) table3 ON table1.id = table3.id", - ), - # 2 join marks on one side of predicate - ( - "SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + table2.column2(+)", - "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + table2.column2", - ), - # join mark and expression - ( - "SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + 25", - "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + 25", - ), - ] - - for original, expected in test_sql: - with self.subTest(original): - self.assertEqual( - eliminate_join_marks(self.parse_one(original)).sql(dialect=self.dialect), - expected, - ) - def test_query_restrictions(self): for restriction in ("READ ONLY", "CHECK OPTION"): for constraint_name in (" CONSTRAINT name", ""): diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 071677d..816a283 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -8,6 +8,14 @@ class TestPostgres(Validator): dialect = "postgres" def test_postgres(self): + self.validate_all( + "x ? y", + write={ + "": "JSONB_CONTAINS(x, y)", + "postgres": "x ? y", + }, + ) + self.validate_identity("SHA384(x)") self.validate_identity( 'CREATE TABLE x (a TEXT COLLATE "de_DE")', "CREATE TABLE x (a TEXT COLLATE de_DE)" @@ -67,10 +75,6 @@ class TestPostgres(Validator): self.validate_identity("EXEC AS myfunc @id = 123", check_command_warning=True) self.validate_identity("SELECT CURRENT_USER") self.validate_identity("SELECT * FROM ONLY t1") - self.validate_identity( - "SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]", - "SELECT ARRAY[1, 2] @> ARRAY[1, 2, 3]", - ) self.validate_identity( """UPDATE "x" SET "y" = CAST('0 days 60.000000 seconds' AS INTERVAL) WHERE "x"."id" IN (2, 3)""" ) @@ -127,6 +131,14 @@ class TestPostgres(Validator): "pg_catalog.PG_TABLE_IS_VISIBLE(c.oid) " "ORDER BY 2, 3" ) + self.validate_identity( + "/*+ some comment*/ SELECT b.foo, b.bar FROM baz AS b", + "/* + some comment */ SELECT b.foo, b.bar FROM baz AS b", + ) + self.validate_identity( + "SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]", + "SELECT ARRAY[1, 2] @> ARRAY[1, 2, 3]", + ) self.validate_identity( "SELECT ARRAY[]::INT[] AS foo", "SELECT CAST(ARRAY[] AS INT[]) AS foo", diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index ebb270a..dbe3abc 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -581,6 +581,13 @@ class TestPresto(Validator): ) def test_presto(self): + self.assertEqual( + exp.func("md5", exp.func("concat", exp.cast("x", "text"), exp.Literal.string("s"))).sql( + dialect="presto" + ), + "LOWER(TO_HEX(MD5(TO_UTF8(CONCAT(CAST(x AS VARCHAR), CAST('s' AS VARCHAR))))))", + ) + with self.assertLogs(helper_logger): self.validate_all( "SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table", @@ -1192,3 +1199,18 @@ MATCH_RECOGNIZE ( "starrocks": "SIGN(x)", }, ) + + def test_json_vs_row_extract(self): + for dialect in ("trino", "presto"): + s = parse_one('SELECT col:x:y."special string"', read="snowflake") + + dialect_json_extract_setting = f"{dialect}, variant_extract_is_json_extract=True" + dialect_row_access_setting = f"{dialect}, variant_extract_is_json_extract=False" + + # By default, Snowflake VARIANT will generate JSON_EXTRACT() in Presto/Trino + json_extract_result = """SELECT JSON_EXTRACT(col, '$.x.y["special string"]')""" + self.assertEqual(s.sql(dialect), json_extract_result) + self.assertEqual(s.sql(dialect_json_extract_setting), json_extract_result) + + # If the setting is overriden to False, then generate ROW access (dot notation) + self.assertEqual(s.sql(dialect_row_access_setting), 'SELECT col.x.y."special string"') diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 69793c7..c4e7073 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -28,7 +28,7 @@ class TestRedshift(Validator): """SELECT JSON_EXTRACT_PATH_TEXT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', 'farm', 'barn', 'color')""", write={ "bigquery": """SELECT JSON_EXTRACT_SCALAR('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', '$.farm.barn.color')""", - "databricks": """SELECT GET_JSON_OBJECT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', '$.farm.barn.color')""", + "databricks": """SELECT '{ "farm": {"barn": { "color": "red", "feed stocked": true }}}':farm.barn.color""", "duckdb": """SELECT '{ "farm": {"barn": { "color": "red", "feed stocked": true }}}' ->> '$.farm.barn.color'""", "postgres": """SELECT JSON_EXTRACT_PATH_TEXT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', 'farm', 'barn', 'color')""", "presto": """SELECT JSON_EXTRACT_SCALAR('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', '$.farm.barn.color')""", @@ -228,7 +228,7 @@ class TestRedshift(Validator): "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", - "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) _t WHERE _row_number = 1", + "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) _t WHERE _row_number = 1", "presto": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "redshift": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", "snowflake": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", @@ -259,6 +259,12 @@ class TestRedshift(Validator): "postgres": "COALESCE(a, b, c, d)", }, ) + + self.validate_identity( + "DATEDIFF(days, a, b)", + "DATEDIFF(DAY, a, b)", + ) + self.validate_all( "DATEDIFF('day', a, b)", write={ @@ -300,6 +306,14 @@ class TestRedshift(Validator): }, ) + self.validate_all( + "SELECT EXTRACT(EPOCH FROM CURRENT_DATE)", + write={ + "snowflake": "SELECT DATE_PART(EPOCH, CURRENT_DATE)", + "redshift": "SELECT EXTRACT(EPOCH FROM CURRENT_DATE)", + }, + ) + def test_identity(self): self.validate_identity("LISTAGG(DISTINCT foo, ', ')") self.validate_identity("CREATE MATERIALIZED VIEW orders AUTO REFRESH YES AS SELECT 1") diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 1286436..88b2148 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -11,6 +11,12 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + self.assertEqual( + # Ensures we don't fail when generating ParseJSON with the `safe` arg set to `True` + self.validate_identity("""SELECT TRY_PARSE_JSON('{"x: 1}')""").sql(), + """SELECT PARSE_JSON('{"x: 1}')""", + ) + self.validate_identity( "transform(x, a int -> a + a + 1)", "TRANSFORM(x, a -> CAST(a AS INT) + CAST(a AS INT) + 1)", @@ -49,6 +55,8 @@ WHERE )""", ) + self.validate_identity("SELECT CAST([1, 2, 3] AS VECTOR(FLOAT, 3))") + self.validate_identity("SELECT CONNECT_BY_ROOT test AS test_column_alias") self.validate_identity("SELECT number").selects[0].assert_is(exp.Column) self.validate_identity("INTERVAL '4 years, 5 months, 3 hours'") self.validate_identity("ALTER TABLE table1 CLUSTER BY (name DESC)") @@ -182,18 +190,6 @@ WHERE """SELECT PARSE_JSON('{"food":{"fruit":"banana"}}'):food.fruit::VARCHAR""", """SELECT CAST(GET_PATH(PARSE_JSON('{"food":{"fruit":"banana"}}'), 'food.fruit') AS VARCHAR)""", ) - self.validate_identity( - "SELECT * FROM foo at", - "SELECT * FROM foo AS at", - ) - self.validate_identity( - "SELECT * FROM foo before", - "SELECT * FROM foo AS before", - ) - self.validate_identity( - "SELECT * FROM foo at (col)", - "SELECT * FROM foo AS at(col)", - ) self.validate_identity( "SELECT * FROM unnest(x) with ordinality", "SELECT * FROM TABLE(FLATTEN(INPUT => x)) AS _u(seq, key, path, index, value, this)", @@ -337,7 +333,7 @@ WHERE """SELECT PARSE_JSON('{"fruit":"banana"}'):fruit""", write={ "bigquery": """SELECT JSON_EXTRACT(PARSE_JSON('{"fruit":"banana"}'), '$.fruit')""", - "databricks": """SELECT GET_JSON_OBJECT('{"fruit":"banana"}', '$.fruit')""", + "databricks": """SELECT '{"fruit":"banana"}':fruit""", "duckdb": """SELECT JSON('{"fruit":"banana"}') -> '$.fruit'""", "mysql": """SELECT JSON_EXTRACT('{"fruit":"banana"}', '$.fruit')""", "presto": """SELECT JSON_EXTRACT(JSON_PARSE('{"fruit":"banana"}'), '$.fruit')""", @@ -1196,6 +1192,17 @@ WHERE "SELECT oldt.*, newt.* FROM my_table BEFORE (STATEMENT => '8e5d0ca9-005e-44e6-b858-a8f5b37c5726') AS oldt FULL OUTER JOIN my_table AT (STATEMENT => '8e5d0ca9-005e-44e6-b858-a8f5b37c5726') AS newt ON oldt.id = newt.id WHERE oldt.id IS NULL OR newt.id IS NULL", ) + # Make sure that the historical data keywords can still be used as aliases + for historical_data_prefix in ("AT", "BEFORE", "END", "CHANGES"): + for schema_suffix in ("", "(col)"): + with self.subTest( + f"Testing historical data prefix alias: {historical_data_prefix}{schema_suffix}" + ): + self.validate_identity( + f"SELECT * FROM foo {historical_data_prefix}{schema_suffix}", + f"SELECT * FROM foo AS {historical_data_prefix}{schema_suffix}", + ) + def test_ddl(self): for constraint_prefix in ("WITH ", ""): with self.subTest(f"Constraint prefix: {constraint_prefix}"): @@ -1216,6 +1223,7 @@ WHERE "CREATE TABLE t (id INT TAG (key1='value_1', key2='value_2'))", ) + self.validate_identity("CREATE SECURE VIEW table1 AS (SELECT a FROM table2)") self.validate_identity( """create external table et2( col1 date as (parse_json(metadata$external_table_partition):COL1::date), @@ -1240,6 +1248,9 @@ WHERE self.validate_identity( "CREATE OR REPLACE TAG IF NOT EXISTS cost_center COMMENT='cost_center tag'" ).this.assert_is(exp.Identifier) + self.validate_identity( + "CREATE DYNAMIC TABLE product (pre_tax_profit, taxes, after_tax_profit) TARGET_LAG='20 minutes' WAREHOUSE=mywh AS SELECT revenue - cost, (revenue - cost) * tax_rate, (revenue - cost) * (1.0 - tax_rate) FROM staging_table" + ) self.validate_identity( "ALTER TABLE db_name.schmaName.tblName ADD COLUMN COLUMN_1 VARCHAR NOT NULL TAG (key1='value_1')" ) @@ -2021,3 +2032,15 @@ SINGLE = TRUE""", self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c") self.validate_identity("ALTER TABLE foo UNSET DATA_RETENTION_TIME_IN_DAYS, CHANGE_TRACKING") + + def test_from_changes(self): + self.validate_identity( + """SELECT C1 FROM t1 CHANGES (INFORMATION => APPEND_ONLY) AT (STREAM => 's1') END (TIMESTAMP => $ts2)""" + ) + self.validate_identity( + """SELECT C1 FROM t1 CHANGES (INFORMATION => APPEND_ONLY) BEFORE (STATEMENT => 'STMT_ID') END (TIMESTAMP => $ts2)""" + ) + self.validate_identity( + """SELECT 1 FROM some_table CHANGES (INFORMATION => APPEND_ONLY) AT (TIMESTAMP => TO_TIMESTAMP_TZ('2024-07-01 00:00:00+00:00')) END (TIMESTAMP => TO_TIMESTAMP_TZ('2024-07-01 14:28:59.999999+00:00'))""", + """SELECT 1 FROM some_table CHANGES (INFORMATION => APPEND_ONLY) AT (TIMESTAMP => CAST('2024-07-01 00:00:00+00:00' AS TIMESTAMPTZ)) END (TIMESTAMP => CAST('2024-07-01 14:28:59.999999+00:00' AS TIMESTAMPTZ))""", + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index bff91bf..4e62b32 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -245,7 +245,7 @@ TBLPROPERTIES ( self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)") self.validate_identity("REFRESH TABLE a.b.c") - self.validate_identity("INTERVAL -86 DAYS") + self.validate_identity("INTERVAL '-86' DAYS") self.validate_identity("TRIM(' SparkSQL ')") self.validate_identity("TRIM(BOTH 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')") @@ -801,3 +801,10 @@ TBLPROPERTIES ( self.assertEqual(query.sql(name), with_modifiers) else: self.assertEqual(query.sql(name), without_modifiers) + + def test_schema_binding_options(self): + for schema_binding in ("BINDING", "COMPENSATION", "TYPE EVOLUTION", "EVOLUTION"): + with self.subTest(f"Test roundtrip of VIEW schema binding {schema_binding}"): + self.validate_identity( + f"CREATE VIEW emp_v WITH SCHEMA {schema_binding} AS SELECT * FROM emp" + ) diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 74d5f88..3945ca3 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -5,6 +5,13 @@ class TestTeradata(Validator): dialect = "teradata" def test_teradata(self): + self.validate_all( + "RANDOM(l, u)", + write={ + "": "(u - l) * RAND() + l", + "teradata": "RANDOM(l, u)", + }, + ) self.validate_identity("TO_NUMBER(expr, fmt, nlsparam)") self.validate_identity("SELECT TOP 10 * FROM tbl") self.validate_identity("SELECT * FROM tbl SAMPLE 5") @@ -212,6 +219,8 @@ class TestTeradata(Validator): ) def test_time(self): + self.validate_identity("CAST(CURRENT_TIMESTAMP(6) AS TIMESTAMP WITH TIME ZONE)") + self.validate_all( "CURRENT_TIMESTAMP", read={ diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 7455650..11d60e7 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1,4 +1,4 @@ -from sqlglot import exp, parse +from sqlglot import exp, parse, parse_one from tests.dialects.test_dialect import Validator from sqlglot.errors import ParseError from sqlglot.optimizer.annotate_types import annotate_types @@ -8,19 +8,14 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): - self.assertEqual( - annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"), - "SELECT 1 WHERE EXISTS(SELECT 1)", - ) + # https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN + # tsql allows .. which means use the default schema + self.validate_identity("SELECT * FROM a..b") self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c") self.validate_identity("DROP view a.b.c", "DROP VIEW b.c") self.validate_identity("ROUND(x, 1, 0)") self.validate_identity("EXEC MyProc @id=7, @name='Lochristi'", check_command_warning=True) - # https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN - # tsql allows .. which means use the default schema - self.validate_identity("SELECT * FROM a..b") - self.validate_identity("SELECT TRIM(' test ') AS Result") self.validate_identity("SELECT TRIM('.,! ' FROM ' # test .') AS Result") self.validate_identity("SELECT * FROM t TABLESAMPLE (10 PERCENT)") @@ -36,9 +31,22 @@ class TestTSQL(Validator): self.validate_identity("1 AND true", "1 <> 0 AND (1 = 1)") self.validate_identity("CAST(x AS int) OR y", "CAST(x AS INTEGER) <> 0 OR y <> 0") self.validate_identity("TRUNCATE TABLE t1 WITH (PARTITIONS(1, 2 TO 5, 10 TO 20, 84))") + self.validate_identity( + "CREATE CLUSTERED INDEX [IX_OfficeTagDetail_TagDetailID] ON [dbo].[OfficeTagDetail]([TagDetailID] ASC)" + ) + self.validate_identity( + "CREATE INDEX [x] ON [y]([z] ASC) WITH (allow_page_locks=on) ON X([y])" + ) + self.validate_identity( + "CREATE INDEX [x] ON [y]([z] ASC) WITH (allow_page_locks=on) ON PRIMARY" + ) self.validate_identity( "COPY INTO test_1 FROM 'path' WITH (FORMAT_NAME = test, FILE_TYPE = 'CSV', CREDENTIAL = (IDENTITY='Shared Access Signature', SECRET='token'), FIELDTERMINATOR = ';', ROWTERMINATOR = '0X0A', ENCODING = 'UTF8', DATEFORMAT = 'ymd', MAXERRORS = 10, ERRORFILE = 'errorsfolder', IDENTITY_INSERT = 'ON')" ) + self.assertEqual( + annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"), + "SELECT 1 WHERE EXISTS(SELECT 1)", + ) self.validate_all( "SELECT IIF(cond <> 0, 'True', 'False')", @@ -1868,3 +1876,25 @@ FROM OPENJSON(@json) WITH ( "DECLARE vendor_cursor CURSOR FOR SELECT VendorID, Name FROM Purchasing.Vendor WHERE PreferredVendorStatus = 1 ORDER BY VendorID", check_command_warning=True, ) + + def test_scope_resolution_op(self): + # we still want to support :: casting shorthand for tsql + self.validate_identity("x::int", "CAST(x AS INTEGER)") + self.validate_identity("x::varchar", "CAST(x AS VARCHAR)") + self.validate_identity("x::varchar(MAX)", "CAST(x AS VARCHAR(MAX))") + + for lhs, rhs in ( + ("", "FOO(a, b)"), + ("bar", "baZ(1, 2)"), + ("LOGIN", "EricKurjan"), + ("GEOGRAPHY", "Point(latitude, longitude, 4326)"), + ( + "GEOGRAPHY", + "STGeomFromText('POLYGON((-122.358 47.653 , -122.348 47.649, -122.348 47.658, -122.358 47.658, -122.358 47.653))', 4326)", + ), + ): + with self.subTest(f"Scope resolution, LHS: {lhs}, RHS: {rhs}"): + expr = self.validate_identity(f"{lhs}::{rhs}") + base_sql = expr.sql() + self.assertEqual(base_sql, f"SCOPE_RESOLUTION({lhs + ', ' if lhs else ''}{rhs})") + self.assertEqual(parse_one(base_sql).sql("tsql"), f"{lhs}::{rhs}") diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 4dc4aa1..433c23d 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -115,7 +115,7 @@ ARRAY(foo, time) ARRAY(LENGTH(waiter_name) > 0) ARRAY_CONTAINS(x, 1) x.EXTRACT(1) -EXTRACT(x FROM y) +EXTRACT(X FROM y) EXTRACT(DATE FROM y) EXTRACT(WEEK(monday) FROM created_at) CONCAT_WS('-', 'a', 'b') @@ -733,6 +733,8 @@ SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z SELECT ((SELECT 1) + 1) SELECT * FROM project.dataset.INFORMATION_SCHEMA.TABLES SELECT CAST(x AS INT) /* comment */ FROM foo +SELECT c /* c1 /* c2 */ c3 */ +SELECT c /* c1 /* c2 /* c3 */ */ */ SELECT c /* c1 */ AS alias /* c2 */ SELECT a /* x */, b /* x */ SELECT a /* x */ /* y */ /* z */, b /* k */ /* m */ @@ -873,3 +875,5 @@ SELECT copy SELECT rollup SELECT unnest SELECT * FROM a STRAIGHT_JOIN b +SELECT COUNT(DISTINCT "foo bar") FROM (SELECT 1 AS "foo bar") AS t +SELECT vector diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index ea96fe5..836bcf2 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -480,6 +480,18 @@ SELECT COALESCE(t1.a, t2.a) AS a FROM t1 AS t1 JOIN t2 AS t2 ON t1.a = t2.a; WITH m(a) AS (SELECT 1), n(b) AS (SELECT 1) SELECT * FROM m JOIN n AS foo(a) USING (a); WITH m AS (SELECT 1 AS a), n AS (SELECT 1 AS b) SELECT COALESCE(m.a, foo.a) AS a FROM m AS m JOIN n AS foo(a) ON m.a = foo.a; +# title: coalesce the USING clause's columns (3 joins, 2 join columns) +WITH t1 AS (SELECT 'x' AS id, DATE '2024-01-01' AS foo, 000 AS value), t2 AS (SELECT 'x' AS id, DATE '2024-02-02' AS foo, 123 AS value), t3 AS (SELECT 'x' AS id, DATE '2024-02-02' AS foo, 456 AS value) SELECT * FROM t1 FULL OUTER JOIN t2 USING(id, foo) FULL OUTER JOIN t3 USING(id, foo); +WITH t1 AS (SELECT 'x' AS id, CAST('2024-01-01' AS DATE) AS foo, 000 AS value), t2 AS (SELECT 'x' AS id, CAST('2024-02-02' AS DATE) AS foo, 123 AS value), t3 AS (SELECT 'x' AS id, CAST('2024-02-02' AS DATE) AS foo, 456 AS value) SELECT COALESCE(t1.id, t2.id, t3.id) AS id, COALESCE(t1.foo, t2.foo, t3.foo) AS foo, t1.value AS value, t2.value AS value, t3.value AS value FROM t1 AS t1 FULL OUTER JOIN t2 AS t2 ON t1.id = t2.id AND t1.foo = t2.foo FULL OUTER JOIN t3 AS t3 ON COALESCE(t1.id, t2.id) = t3.id AND COALESCE(t1.foo, t2.foo) = t3.foo; + +# title: coalesce the USING clause's columns (3 joins, 3 join columns) +WITH t1 AS (SELECT 'x' AS id, CAST('2024-01-01' AS DATE) AS foo, 000 AS value), t2 AS (SELECT 'x' AS id, CAST('2024-02-02' AS DATE) AS foo, 123 AS value), t3 AS (SELECT 'x' AS id, CAST('2024-02-02' AS DATE) AS foo, 456 AS value) SELECT * FROM t1 FULL OUTER JOIN t2 USING (id, foo, value) FULL OUTER JOIN t3 USING (id, foo, value); +WITH t1 AS (SELECT 'x' AS id, CAST('2024-01-01' AS DATE) AS foo, 000 AS value), t2 AS (SELECT 'x' AS id, CAST('2024-02-02' AS DATE) AS foo, 123 AS value), t3 AS (SELECT 'x' AS id, CAST('2024-02-02' AS DATE) AS foo, 456 AS value) SELECT COALESCE(t1.id, t2.id, t3.id) AS id, COALESCE(t1.foo, t2.foo, t3.foo) AS foo, COALESCE(t1.value, t2.value, t3.value) AS value FROM t1 AS t1 FULL OUTER JOIN t2 AS t2 ON t1.id = t2.id AND t1.foo = t2.foo AND t1.value = t2.value FULL OUTER JOIN t3 AS t3 ON COALESCE(t1.id, t2.id) = t3.id AND COALESCE(t1.foo, t2.foo) = t3.foo AND COALESCE(t1.value, t2.value) = t3.value; + +# title: coalesce the USING clause's columns (4 joins, 2 join columns) +WITH t1 AS (SELECT 'x' AS id, CAST('2024-01-01' AS DATE) AS foo, 000 AS value), t2 AS (SELECT 'x' AS id, CAST('2024-02-02' AS DATE) AS foo, 123 AS value), t3 AS (SELECT 'x' AS id, CAST('2024-02-02' AS DATE) AS foo, 456 AS value), t4 AS (SELECT 'x' AS id, CAST('2024-03-03' AS DATE) AS foo, 789 AS value) SELECT * FROM t1 FULL OUTER JOIN t2 USING (id, foo) FULL OUTER JOIN t3 USING (id, foo) FULL OUTER JOIN t4 USING (id, foo); +WITH t1 AS (SELECT 'x' AS id, CAST('2024-01-01' AS DATE) AS foo, 000 AS value), t2 AS (SELECT 'x' AS id, CAST('2024-02-02' AS DATE) AS foo, 123 AS value), t3 AS (SELECT 'x' AS id, CAST('2024-02-02' AS DATE) AS foo, 456 AS value), t4 AS (SELECT 'x' AS id, CAST('2024-03-03' AS DATE) AS foo, 789 AS value) SELECT COALESCE(t1.id, t2.id, t3.id, t4.id) AS id, COALESCE(t1.foo, t2.foo, t3.foo, t4.foo) AS foo, t1.value AS value, t2.value AS value, t3.value AS value, t4.value AS value FROM t1 AS t1 FULL OUTER JOIN t2 AS t2 ON t1.id = t2.id AND t1.foo = t2.foo FULL OUTER JOIN t3 AS t3 ON COALESCE(t1.id, t2.id) = t3.id AND COALESCE(t1.foo, t2.foo) = t3.foo FULL OUTER JOIN t4 AS t4 ON COALESCE(t1.id, t2.id, t3.id) = t4.id AND COALESCE(t1.foo, t2.foo, t3.foo) = t4.foo; + -------------------------------------- -- Hint with table reference -------------------------------------- diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index c131643..ed7a689 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -375,7 +375,7 @@ order by SELECT "n1"."n_name" AS "supp_nation", "n2"."n_name" AS "cust_nation", - EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE)) AS "l_year", + EXTRACT(YEAR FROM CAST("lineitem"."l_shipdate" AS DATE)) AS "l_year", SUM("lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" )) AS "revenue" @@ -407,7 +407,7 @@ JOIN "nation" AS "n2" GROUP BY "n1"."n_name", "n2"."n_name", - EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE)) + EXTRACT(YEAR FROM CAST("lineitem"."l_shipdate" AS DATE)) ORDER BY "supp_nation", "cust_nation", @@ -425,7 +425,7 @@ select from ( select - extract(year from cast(o_orderdate as date)) as o_year, + extract(YEAR from cast(o_orderdate as date)) as o_year, l_extendedprice * (1 - l_discount) as volume, n2.n_name as nation from @@ -454,7 +454,7 @@ group by order by o_year; SELECT - EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year", + EXTRACT(YEAR FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year", SUM( CASE WHEN "n2"."n_name" = 'BRAZIL' @@ -486,7 +486,7 @@ JOIN "region" AS "region" WHERE "part"."p_type" = 'ECONOMY ANODIZED STEEL' GROUP BY - EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) + EXTRACT(YEAR FROM CAST("orders"."o_orderdate" AS DATE)) ORDER BY "o_year"; @@ -527,7 +527,7 @@ order by o_year desc; SELECT "nation"."n_name" AS "nation", - EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year", + EXTRACT(YEAR FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year", SUM( "lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" @@ -549,7 +549,7 @@ WHERE "part"."p_name" LIKE '%green%' GROUP BY "nation"."n_name", - EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) + EXTRACT(YEAR FROM CAST("orders"."o_orderdate" AS DATE)) ORDER BY "nation", "o_year" DESC; diff --git a/tests/test_executor.py b/tests/test_executor.py index 317b930..e80fb1e 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -14,6 +14,8 @@ from sqlglot.errors import ExecuteError from sqlglot.executor import execute from sqlglot.executor.python import Python from sqlglot.executor.table import Table, ensure_tables +from sqlglot.optimizer import optimize +from sqlglot.planner import Plan from tests.helpers import ( FIXTURES_DIR, SKIP_INTEGRATION, @@ -862,3 +864,18 @@ class TestExecutor(unittest.TestCase): result = execute("SELECT x FROM t", dialect="duckdb", tables=tables) self.assertEqual(result.columns, ("x",)) self.assertEqual(result.rows, [([1, 2, 3],)]) + + def test_agg_order(self): + plan = Plan( + optimize(""" + SELECT + AVG(bill_length_mm) AS avg_bill_length, + AVG(bill_depth_mm) AS avg_bill_depth + FROM penguins + """) + ) + + assert [agg.alias for agg in plan.root.aggregations] == [ + "avg_bill_length", + "avg_bill_depth", + ] diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 1395b24..b3617ee 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -1011,12 +1011,18 @@ FROM foo""", "ALTER TABLE t1 RENAME TO t2", ) - def test_is_negative(self): - self.assertTrue(parse_one("-1").is_negative) - self.assertTrue(parse_one("- 1.0").is_negative) - self.assertTrue(exp.Literal.number("-1").is_negative) - self.assertFalse(parse_one("1").is_negative) - self.assertFalse(parse_one("x").is_negative) + def test_to_py(self): + self.assertEqual(parse_one("- -1").to_py(), 1) + self.assertIs(parse_one("TRUE").to_py(), True) + self.assertIs(parse_one("1").to_py(), 1) + self.assertIs(parse_one("'1'").to_py(), "1") + self.assertIs(parse_one("null").to_py(), None) + + with self.assertRaises(ValueError): + parse_one("x").to_py() + + def test_is_int(self): + self.assertTrue(parse_one("- -1").is_int) def test_is_star(self): assert parse_one("*").is_star diff --git a/tests/test_jsonpath.py b/tests/test_jsonpath.py index 4daf3c1..c939c52 100644 --- a/tests/test_jsonpath.py +++ b/tests/test_jsonpath.py @@ -2,8 +2,9 @@ import json import os import unittest -from sqlglot import exp, jsonpath +from sqlglot import exp from sqlglot.errors import ParseError, TokenError +from sqlglot.jsonpath import parse from tests.helpers import FIXTURES_DIR @@ -25,7 +26,7 @@ class TestJsonpath(unittest.TestCase): exp.JSONPathSelector(this=exp.JSONPathScript(this="@.x)")), ] self.assertEqual( - jsonpath.parse("$.*.a[0]['x'][*, 'y', 1].z[?(@.a == 'b'), 1:][1:5][1,?@.a][(@.x)]"), + parse("$.*.a[0]['x'][*, 'y', 1].z[?(@.a == 'b'), 1:][1:5][1,?@.a][(@.x)]"), exp.JSONPath(expressions=expected_expressions), ) @@ -36,7 +37,7 @@ class TestJsonpath(unittest.TestCase): ("$[((@.length-1))]", "$[((@.length-1))]"), ): with self.subTest(f"{selector} -> {expected}"): - self.assertEqual(jsonpath.parse(selector).sql(), f"'{expected}'") + self.assertEqual(parse(selector).sql(), f"'{expected}'") def test_cts_file(self): with open(os.path.join(FIXTURES_DIR, "jsonpath", "cts.json")) as file: @@ -131,9 +132,9 @@ class TestJsonpath(unittest.TestCase): with self.subTest(f"{selector.strip()} /* {test['name']} */"): if test.get("invalid_selector"): try: - jsonpath.parse(selector) + parse(selector) except (ParseError, TokenError): pass else: - path = jsonpath.parse(selector) + path = parse(selector) self.assertEqual(path.sql(), f"'{overrides.get(selector, selector)}'") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 81b9731..604a364 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -341,6 +341,25 @@ class TestOptimizer(unittest.TestCase): "WITH tbl1 AS (SELECT STRUCT(1 AS `f0`, 2 AS f1) AS col) SELECT tbl1.col.`f0` AS `f0`, tbl1.col.f1 AS f1 FROM tbl1", ) + # can't coalesce USING columns because they don't exist in every already-joined table + self.assertEqual( + optimizer.qualify_columns.qualify_columns( + parse_one( + "SELECT id, dt, v FROM (SELECT t1.id, t1.dt, sum(coalesce(t2.v, 0)) AS v FROM t1 AS t1 LEFT JOIN lkp AS lkp USING (id) LEFT JOIN t2 AS t2 USING (other_id, dt, common) WHERE t1.id > 10 GROUP BY 1, 2) AS _q_0", + dialect="bigquery", + ), + schema=MappingSchema( + schema={ + "t1": {"id": "int64", "dt": "date", "common": "int64"}, + "lkp": {"id": "int64", "other_id": "int64", "common": "int64"}, + "t2": {"other_id": "int64", "dt": "date", "v": "int64", "common": "int64"}, + }, + dialect="bigquery", + ), + ).sql(dialect="bigquery"), + "SELECT _q_0.id AS id, _q_0.dt AS dt, _q_0.v AS v FROM (SELECT t1.id AS id, t1.dt AS dt, sum(coalesce(t2.v, 0)) AS v FROM t1 AS t1 LEFT JOIN lkp AS lkp ON t1.id = lkp.id LEFT JOIN t2 AS t2 ON lkp.other_id = t2.other_id AND t1.dt = t2.dt AND COALESCE(t1.common, lkp.common) = t2.common WHERE t1.id > 10 GROUP BY t1.id, t1.dt) AS _q_0", + ) + self.check_file( "qualify_columns", qualify_columns, @@ -473,15 +492,35 @@ SELECT :with,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expr 'SELECT "x"."a" + 1 AS "d", "x"."a" + 1 + 1 AS "e" FROM "x" AS "x" WHERE ("x"."a" + 2) > 1 GROUP BY "x"."a" + 1 + 1', ) + unused_schema = {"l": {"c": "int"}} self.assertEqual( optimizer.qualify_columns.qualify_columns( parse_one("SELECT CAST(x AS INT) AS y FROM z AS z"), - schema={"l": {"c": "int"}}, + schema=unused_schema, infer_schema=False, ).sql(), "SELECT CAST(x AS INT) AS y FROM z AS z", ) + # BigQuery expands overlapping alias only for GROUP BY + HAVING + sql = "WITH data AS (SELECT 1 AS id, 2 AS my_id, 'a' AS name, 'b' AS full_name) SELECT id AS my_id, CONCAT(id, name) AS full_name FROM data WHERE my_id = 1 GROUP BY my_id, full_name HAVING my_id = 1" + self.assertEqual( + optimizer.qualify_columns.qualify_columns( + parse_one(sql, dialect="bigquery"), + schema=MappingSchema(schema=unused_schema, dialect="bigquery"), + ).sql(), + "WITH data AS (SELECT 1 AS id, 2 AS my_id, 'a' AS name, 'b' AS full_name) SELECT data.id AS my_id, CONCAT(data.id, data.name) AS full_name FROM data WHERE data.my_id = 1 GROUP BY data.id, CONCAT(data.id, data.name) HAVING data.id = 1", + ) + + # Clickhouse expands overlapping alias across the entire query + self.assertEqual( + optimizer.qualify_columns.qualify_columns( + parse_one(sql, dialect="clickhouse"), + schema=MappingSchema(schema=unused_schema, dialect="clickhouse"), + ).sql(), + "WITH data AS (SELECT 1 AS id, 2 AS my_id, 'a' AS name, 'b' AS full_name) SELECT data.id AS my_id, CONCAT(data.id, data.name) AS full_name FROM data WHERE data.id = 1 GROUP BY data.id, CONCAT(data.id, data.name) HAVING data.id = 1", + ) + def test_optimize_joins(self): self.check_file( "optimize_joins", diff --git a/tests/test_parser.py b/tests/test_parser.py index d6849c3..f360b43 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -14,6 +14,8 @@ class TestParser(unittest.TestCase): parse_one("") def test_parse_into(self): + self.assertIsInstance(parse_one("select * from t", into=exp.Select), exp.Select) + self.assertIsInstance(parse_one("select * from t limit 5", into=exp.Select), exp.Select) self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join) self.assertIsInstance(parse_one("int", into=exp.DataType), exp.DataType) self.assertIsInstance(parse_one("array", into=exp.DataType), exp.DataType) @@ -102,6 +104,13 @@ class TestParser(unittest.TestCase): def test_float(self): self.assertEqual(parse_one(".2"), parse_one("0.2")) + def test_unnest(self): + unnest_sql = "UNNEST(foo)" + expr = parse_one(unnest_sql) + self.assertIsInstance(expr, exp.Unnest) + self.assertIsInstance(expr.expressions, list) + self.assertEqual(expr.sql(), unnest_sql) + def test_unnest_projection(self): expr = parse_one("SELECT foo IN UNNEST(bla) AS bar") self.assertIsInstance(expr.selects[0], exp.Alias) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 73d6705..e7d596c 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -3,6 +3,7 @@ import unittest from sqlglot import parse_one from sqlglot.transforms import ( eliminate_distinct_on, + eliminate_join_marks, eliminate_qualify, remove_precision_parameterized_types, unalias_group, @@ -12,9 +13,11 @@ from sqlglot.transforms import ( class TestTransforms(unittest.TestCase): maxDiff = None - def validate(self, transform, sql, target): - with self.subTest(sql): - self.assertEqual(parse_one(sql).transform(transform).sql(), target) + def validate(self, transform, sql, target, dialect=None): + with self.subTest(f"{dialect} - {sql}"): + self.assertEqual( + parse_one(sql, dialect=dialect).transform(transform).sql(dialect=dialect), target + ) def test_unalias_group(self): self.validate( @@ -138,3 +141,76 @@ class TestTransforms(unittest.TestCase): "SELECT CAST(1 AS DECIMAL(10, 2)), CAST('13' AS VARCHAR(10))", "SELECT CAST(1 AS DECIMAL), CAST('13' AS VARCHAR)", ) + + def test_eliminate_join_marks(self): + for dialect in ("oracle", "redshift"): + self.validate( + eliminate_join_marks, + "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) > 5", + "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y > 5", + dialect, + ) + self.validate( + eliminate_join_marks, + "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x (+) = T2.x and T2.y > 5", + "SELECT T1.d, T2.c FROM T2 LEFT JOIN T1 ON T1.x = T2.x WHERE T2.y > 5", + dialect, + ) + self.validate( + eliminate_join_marks, + "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) IS NULL", + "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y IS NULL", + dialect, + ) + self.validate( + eliminate_join_marks, + "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y IS NULL", + "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T2.y IS NULL", + dialect, + ) + self.validate( + eliminate_join_marks, + "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T1.Z > 4", + "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T1.Z > 4", + dialect, + ) + self.validate( + eliminate_join_marks, + "SELECT * FROM table1, table2 WHERE table1.col = table2.col(+)", + "SELECT * FROM table1 LEFT JOIN table2 ON table1.col = table2.col", + dialect, + ) + self.validate( + eliminate_join_marks, + "SELECT * FROM table1, table2, table3, table4 WHERE table1.col = table2.col(+) and table2.col >= table3.col(+) and table1.col = table4.col(+)", + "SELECT * FROM table1 LEFT JOIN table2 ON table1.col = table2.col LEFT JOIN table3 ON table2.col >= table3.col LEFT JOIN table4 ON table1.col = table4.col", + dialect, + ) + self.validate( + eliminate_join_marks, + "SELECT * FROM table1, table2, table3 WHERE table1.col = table2.col(+) and table2.col >= table3.col(+)", + "SELECT * FROM table1 LEFT JOIN table2 ON table1.col = table2.col LEFT JOIN table3 ON table2.col >= table3.col", + dialect, + ) + # 2 join marks on one side of predicate + self.validate( + eliminate_join_marks, + "SELECT * FROM table1, table2 WHERE table1.col = table2.col1(+) + table2.col2(+)", + "SELECT * FROM table1 LEFT JOIN table2 ON table1.col = table2.col1 + table2.col2", + dialect, + ) + # join mark and expression + self.validate( + eliminate_join_marks, + "SELECT * FROM table1, table2 WHERE table1.col = table2.col1(+) + 25", + "SELECT * FROM table1 LEFT JOIN table2 ON table1.col = table2.col1 + 25", + dialect, + ) + + alias = "AS " if dialect != "oracle" else "" + self.validate( + eliminate_join_marks, + "SELECT table1.id, table2.cloumn1, table3.id FROM table1, table2, (SELECT tableInner1.id FROM tableInner1, tableInner2 WHERE tableInner1.id = tableInner2.id(+)) AS table3 WHERE table1.id = table2.id(+) and table1.id = table3.id(+)", + f"SELECT table1.id, table2.cloumn1, table3.id FROM table1 LEFT JOIN table2 ON table1.id = table2.id LEFT JOIN (SELECT tableInner1.id FROM tableInner1 LEFT JOIN tableInner2 ON tableInner1.id = tableInner2.id) {alias}table3 ON table1.id = table3.id", + dialect, + ) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index dea9985..b5e069a 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -111,6 +111,10 @@ class TestTranspile(unittest.TestCase): self.validate("SELECT a\r\nFROM b", "SELECT a FROM b") def test_comments(self): + self.validate( + "select /* asfd /* asdf */ asdf */ 1", + "/* asfd /* asdf */ asdf */ SELECT 1", + ) self.validate( "SELECT c /* foo */ AS alias", "SELECT c AS alias /* foo */", @@ -552,7 +556,7 @@ FROM x""", ) self.validate( - """SELECT X FROM catalog.db.table WHERE Y + """SELECT X FROM catalog.db.table WHERE Y -- AND Z""", """SELECT X FROM catalog.db.table WHERE Y AND Z""", @@ -585,24 +589,24 @@ FROM x""", def test_extract(self): self.validate( "EXTRACT(day FROM '2020-01-01'::TIMESTAMP)", - "EXTRACT(day FROM CAST('2020-01-01' AS TIMESTAMP))", + "EXTRACT(DAY FROM CAST('2020-01-01' AS TIMESTAMP))", ) self.validate( "EXTRACT(timezone FROM '2020-01-01'::TIMESTAMP)", - "EXTRACT(timezone FROM CAST('2020-01-01' AS TIMESTAMP))", + "EXTRACT(TIMEZONE FROM CAST('2020-01-01' AS TIMESTAMP))", ) self.validate( "EXTRACT(year FROM '2020-01-01'::TIMESTAMP WITH TIME ZONE)", - "EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMPTZ))", + "EXTRACT(YEAR FROM CAST('2020-01-01' AS TIMESTAMPTZ))", ) self.validate( "extract(month from '2021-01-31'::timestamp without time zone)", - "EXTRACT(month FROM CAST('2021-01-31' AS TIMESTAMP))", + "EXTRACT(MONTH FROM CAST('2021-01-31' AS TIMESTAMP))", ) - self.validate("extract(week from current_date + 2)", "EXTRACT(week FROM CURRENT_DATE + 2)") + self.validate("extract(week from current_date + 2)", "EXTRACT(WEEK FROM CURRENT_DATE + 2)") self.validate( "EXTRACT(minute FROM datetime1 - datetime2)", - "EXTRACT(minute FROM datetime1 - datetime2)", + "EXTRACT(MINUTE FROM datetime1 - datetime2)", ) def test_if(self): -- cgit v1.2.3