diff options
Diffstat (limited to 'tests/dialects')
-rw-r--r-- | tests/dialects/test_bigquery.py | 55 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 96 | ||||
-rw-r--r-- | tests/dialects/test_databricks.py | 41 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 21 | ||||
-rw-r--r-- | tests/dialects/test_doris.py | 28 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 11 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 71 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 20 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 22 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 18 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 49 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 9 | ||||
-rw-r--r-- | tests/dialects/test_teradata.py | 9 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 48 |
15 files changed, 345 insertions, 154 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index ae8ed16..803ac11 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -103,6 +103,7 @@ LANGUAGE js AS select_with_quoted_udf = self.validate_identity("SELECT `p.d.UdF`(data) FROM `p.d.t`") self.assertEqual(select_with_quoted_udf.selects[0].name, "p.d.UdF") + self.validate_identity("CAST(x AS STRUCT<list ARRAY<INT64>>)") self.validate_identity("assert.true(1 = 1)") self.validate_identity("SELECT ARRAY_TO_STRING(list, '--') AS text") self.validate_identity("SELECT jsondoc['some_key']") @@ -294,6 +295,20 @@ LANGUAGE js AS ) self.validate_all( + "SAFE_CAST(some_date AS DATE FORMAT 'DD MONTH YYYY')", + write={ + "bigquery": "SAFE_CAST(some_date AS DATE FORMAT 'DD MONTH YYYY')", + "duckdb": "CAST(TRY_STRPTIME(some_date, '%d %B %Y') AS DATE)", + }, + ) + self.validate_all( + "SAFE_CAST(some_date AS DATE FORMAT 'YYYY-MM-DD') AS some_date", + write={ + "bigquery": "SAFE_CAST(some_date AS DATE FORMAT 'YYYY-MM-DD') AS some_date", + "duckdb": "CAST(TRY_STRPTIME(some_date, '%Y-%m-%d') AS DATE) AS some_date", + }, + ) + self.validate_all( "SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS h, UNNEST(h.t3) AS s", write={ "bigquery": "SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS h, UNNEST(h.t3) AS s", @@ -1345,6 +1360,46 @@ WHERE "bigquery": "SELECT CAST(x AS DATETIME)", }, ) + self.validate_all( + "SELECT TIME(foo, 'America/Los_Angeles')", + write={ + "duckdb": "SELECT CAST(CAST(foo AS TIMESTAMPTZ) AT TIME ZONE 'America/Los_Angeles' AS TIME)", + "bigquery": "SELECT TIME(foo, 'America/Los_Angeles')", + }, + ) + self.validate_all( + "SELECT DATETIME('2020-01-01')", + write={ + "duckdb": "SELECT CAST('2020-01-01' AS TIMESTAMP)", + "bigquery": "SELECT DATETIME('2020-01-01')", + }, + ) + self.validate_all( + "SELECT DATETIME('2020-01-01', TIME '23:59:59')", + write={ + "duckdb": "SELECT CAST(CAST('2020-01-01' AS DATE) + CAST('23:59:59' AS TIME) AS TIMESTAMP)", + "bigquery": "SELECT DATETIME('2020-01-01', CAST('23:59:59' AS TIME))", + }, + ) + self.validate_all( + "SELECT DATETIME('2020-01-01', 'America/Los_Angeles')", + write={ + "duckdb": "SELECT CAST(CAST('2020-01-01' AS TIMESTAMPTZ) AT TIME ZONE 'America/Los_Angeles' AS TIMESTAMP)", + "bigquery": "SELECT DATETIME('2020-01-01', 'America/Los_Angeles')", + }, + ) + self.validate_all( + "SELECT LENGTH(foo)", + read={ + "bigquery": "SELECT LENGTH(foo)", + "snowflake": "SELECT LENGTH(foo)", + }, + write={ + "duckdb": "SELECT CASE TYPEOF(foo) WHEN 'VARCHAR' THEN LENGTH(CAST(foo AS TEXT)) WHEN 'BLOB' THEN OCTET_LENGTH(CAST(foo AS BLOB)) END", + "snowflake": "SELECT LENGTH(foo)", + "": "SELECT LENGTH(foo)", + }, + ) def test_errors(self): with self.assertRaises(TokenError): diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 72634a8..ef84d48 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -7,23 +7,6 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): - self.validate_all( - "SELECT * FROM x PREWHERE y = 1 WHERE z = 2", - write={ - "": "SELECT * FROM x WHERE z = 2", - "clickhouse": "SELECT * FROM x PREWHERE y = 1 WHERE z = 2", - }, - ) - self.validate_all( - "SELECT * FROM x AS prewhere", - read={ - "clickhouse": "SELECT * FROM x AS prewhere", - "duckdb": "SELECT * FROM x prewhere", - }, - ) - - self.validate_identity("SELECT * FROM x LIMIT 1 UNION ALL SELECT * FROM y") - string_types = [ "BLOB", "LONGBLOB", @@ -42,6 +25,9 @@ class TestClickhouse(Validator): self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)") self.assertIsNone(expr._meta) + self.validate_identity("SELECT EXTRACT(YEAR FROM toDateTime('2023-02-01'))") + self.validate_identity("extract(haystack, pattern)") + self.validate_identity("SELECT * FROM x LIMIT 1 UNION ALL SELECT * FROM y") self.validate_identity("SELECT CAST(x AS Tuple(String, Array(Nullable(Float64))))") self.validate_identity("countIf(x, y)") self.validate_identity("x = y") @@ -94,18 +80,12 @@ class TestClickhouse(Validator): self.validate_identity("""SELECT JSONExtractString('{"x": {"y": 1}}', 'x', 'y')""") self.validate_identity("SELECT * FROM table LIMIT 1 BY a, b") self.validate_identity("SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b") + self.validate_identity("TRUNCATE TABLE t1 ON CLUSTER test_cluster") + self.validate_identity("TRUNCATE DATABASE db") + self.validate_identity("TRUNCATE DATABASE db ON CLUSTER test_cluster") self.validate_identity( "SELECT id, quantileGK(100, 0.95)(reading) OVER (PARTITION BY id ORDER BY id RANGE BETWEEN 30000 PRECEDING AND CURRENT ROW) AS window FROM table" ) - - self.validate_identity( - "SELECT $1$foo$1$", - "SELECT 'foo'", - ) - self.validate_identity( - "SELECT * FROM table LIMIT 1, 2 BY a, b", - "SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b", - ) self.validate_identity( "SELECT * FROM table LIMIT 1 BY CONCAT(datalayerVariantNo, datalayerProductId, warehouse)" ) @@ -134,10 +114,6 @@ class TestClickhouse(Validator): "SELECT sum(1) AS impressions, (arrayJoin(arrayZip(cities, browsers)) AS t).1 AS city, t.2 AS browser FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities, ['Firefox', 'Chrome', 'Chrome'] AS browsers) GROUP BY 2, 3" ) self.validate_identity( - "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ['Istanbul', 'Berlin']", - "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ('Istanbul', 'Berlin')", - ) - self.validate_identity( 'SELECT CAST(tuple(1 AS "a", 2 AS "b", 3.0 AS "c").2 AS Nullable(String))' ) self.validate_identity( @@ -155,12 +131,43 @@ class TestClickhouse(Validator): self.validate_identity( "CREATE MATERIALIZED VIEW test_view (id UInt8) TO db.table1 AS SELECT * FROM test_data" ) - self.validate_identity("TRUNCATE TABLE t1 ON CLUSTER test_cluster") - self.validate_identity("TRUNCATE DATABASE db") - self.validate_identity("TRUNCATE DATABASE db ON CLUSTER test_cluster") self.validate_identity( "CREATE TABLE t (foo String CODEC(LZ4HC(9), ZSTD, DELTA), size String ALIAS formatReadableSize(size_bytes), INDEX idx1 a TYPE bloom_filter(0.001) GRANULARITY 1, INDEX idx2 a TYPE set(100) GRANULARITY 2, INDEX idx3 a TYPE minmax GRANULARITY 3)" ) + self.validate_identity( + "SELECT $1$foo$1$", + "SELECT 'foo'", + ) + self.validate_identity( + "SELECT * FROM table LIMIT 1, 2 BY a, b", + "SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b", + ) + self.validate_identity( + "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ['Istanbul', 'Berlin']", + "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ('Istanbul', 'Berlin')", + ) + + self.validate_all( + "SELECT * FROM x PREWHERE y = 1 WHERE z = 2", + write={ + "": "SELECT * FROM x WHERE z = 2", + "clickhouse": "SELECT * FROM x PREWHERE y = 1 WHERE z = 2", + }, + ) + self.validate_all( + "SELECT * FROM x AS prewhere", + read={ + "clickhouse": "SELECT * FROM x AS prewhere", + "duckdb": "SELECT * FROM x prewhere", + }, + ) + self.validate_all( + "SELECT a, b FROM (SELECT * FROM x) AS t", + read={ + "clickhouse": "SELECT a, b FROM (SELECT * FROM x) AS t", + "duckdb": "SELECT a, b FROM (SELECT * FROM x) AS t(a, b)", + }, + ) self.validate_all( "SELECT arrayJoin([1,2,3])", write={ @@ -880,3 +887,26 @@ LIFETIME(MIN 0 MAX 0)""", for creatable in ("DATABASE", "TABLE", "VIEW", "DICTIONARY", "FUNCTION"): with self.subTest(f"Test DROP {creatable} ON CLUSTER"): self.validate_identity(f"DROP {creatable} test ON CLUSTER test_cluster") + + def test_datetime_funcs(self): + # Each datetime func has an alias that is roundtripped to the original name e.g. (DATE_SUB, DATESUB) -> DATE_SUB + datetime_funcs = (("DATE_SUB", "DATESUB"), ("DATE_ADD", "DATEADD")) + + # 2-arg functions of type <func>(date, unit) + for func in (*datetime_funcs, ("TIMESTAMP_ADD", "TIMESTAMPADD")): + func_name = func[0] + for func_alias in func: + self.validate_identity( + f"""SELECT {func_alias}(date, INTERVAL '3' YEAR)""", + f"""SELECT {func_name}(date, INTERVAL '3' YEAR)""", + ) + + # 3-arg functions of type <func>(unit, value, date) + for func in (*datetime_funcs, ("DATE_DIFF", "DATEDIFF"), ("TIMESTAMP_SUB", "TIMESTAMPSUB")): + func_name = func[0] + for func_alias in func: + with self.subTest(f"Test 3-arg date-time function {func_alias}"): + self.validate_identity( + f"SELECT {func_alias}(SECOND, 1, bar)", + f"SELECT {func_name}(SECOND, 1, bar)", + ) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 9ef3b86..471830f 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -1,4 +1,4 @@ -from sqlglot import transpile +from sqlglot import exp, transpile from sqlglot.errors import ParseError from tests.dialects.test_dialect import Validator @@ -25,6 +25,7 @@ class TestDatabricks(Validator): self.validate_identity("CREATE FUNCTION a AS b") self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1") self.validate_identity("CREATE TABLE foo (x DATE GENERATED ALWAYS AS (CAST(y AS DATE)))") + self.validate_identity("TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', address)") self.validate_identity( "CREATE TABLE IF NOT EXISTS db.table (a TIMESTAMP, b BOOLEAN GENERATED ALWAYS AS (NOT a IS NULL)) USING DELTA" ) @@ -37,22 +38,26 @@ class TestDatabricks(Validator): self.validate_identity( "SELECT * FROM sales UNPIVOT EXCLUDE NULLS (sales FOR quarter IN (q1 AS `Jan-Mar`))" ) - self.validate_identity( "CREATE FUNCTION add_one(x INT) RETURNS INT LANGUAGE PYTHON AS $$def add_one(x):\n return x+1$$" ) - self.validate_identity( "CREATE FUNCTION add_one(x INT) RETURNS INT LANGUAGE PYTHON AS $FOO$def add_one(x):\n return x+1$FOO$" ) - - self.validate_identity("TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', address)") self.validate_identity( "TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', city LIKE 'LA')" ) self.validate_identity( "COPY INTO target FROM `s3://link` FILEFORMAT = AVRO VALIDATE = ALL FILES = ('file1', 'file2') FORMAT_OPTIONS ('opt1'='true', 'opt2'='test') COPY_OPTIONS ('mergeSchema'='true')" ) + self.validate_identity( + "DATE_DIFF(day, created_at, current_date())", + "DATEDIFF(DAY, created_at, CURRENT_DATE)", + ).args["unit"].assert_is(exp.Var) + self.validate_identity( + r'SELECT r"\\foo.bar\"', + r"SELECT '\\\\foo.bar\\'", + ) self.validate_all( "CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))", @@ -67,7 +72,6 @@ class TestDatabricks(Validator): "teradata": "CREATE TABLE t1 AS (SELECT c FROM t2) WITH DATA", }, ) - self.validate_all( "SELECT X'1A2B'", read={ @@ -96,33 +100,30 @@ class TestDatabricks(Validator): # https://docs.databricks.com/sql/language-manual/functions/colonsign.html def test_json(self): + self.validate_identity("SELECT c1:price, c1:price.foo, c1:price.bar[1]") self.validate_identity( - """SELECT c1 : price FROM VALUES ('{ "price": 5 }') AS T(c1)""", - """SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""", + """SELECT c1:item[1].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""" ) self.validate_identity( - """SELECT c1:['price'] FROM VALUES('{ "price": 5 }') AS T(c1)""", - """SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""", + """SELECT c1:item[*].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""" ) self.validate_identity( - """SELECT c1:item[1].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", - """SELECT GET_JSON_OBJECT(c1, '$.item[1].price') FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + """SELECT FROM_JSON(c1:item[*].price, 'ARRAY<DOUBLE>')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""" ) self.validate_identity( - """SELECT c1:item[*].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", - """SELECT GET_JSON_OBJECT(c1, '$.item[*].price') FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + """SELECT INLINE(FROM_JSON(c1:item[*], 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""" ) self.validate_identity( - """SELECT from_json(c1:item[*].price, 'ARRAY<DOUBLE>')[0] FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", - """SELECT FROM_JSON(GET_JSON_OBJECT(c1, '$.item[*].price'), 'ARRAY<DOUBLE>')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + """SELECT c1:['price'] FROM VALUES ('{ "price": 5 }') AS T(c1)""", + """SELECT c1:price FROM VALUES ('{ "price": 5 }') AS T(c1)""", ) self.validate_identity( - """SELECT inline(from_json(c1:item[*], 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", - """SELECT INLINE(FROM_JSON(GET_JSON_OBJECT(c1, '$.item[*]'), 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + """SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""", + """SELECT c1:price FROM VALUES ('{ "price": 5 }') AS T(c1)""", ) self.validate_identity( - "SELECT c1 : price", - "SELECT GET_JSON_OBJECT(c1, '$.price')", + """SELECT raw:`zip code`, raw:`fb:testid`, raw:store['bicycle'], raw:store["zip code"]""", + """SELECT raw:["zip code"], raw:["fb:testid"], raw:store.bicycle, raw:store["zip code"]""", ) def test_datediff(self): diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index aaeb7b0..c0afb2f 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -102,14 +102,10 @@ class TestDialect(Validator): lowercase_mysql = Dialect.get_or_raise("mysql, normalization_strategy = lowercase") self.assertEqual(lowercase_mysql.normalization_strategy.value, "LOWERCASE") - with self.assertRaises(ValueError) as cm: + with self.assertRaises(AttributeError) as cm: Dialect.get_or_raise("mysql, normalization_strategy") - self.assertEqual( - str(cm.exception), - "Invalid dialect format: 'mysql, normalization_strategy'. " - "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'.", - ) + self.assertEqual(str(cm.exception), "'bool' object has no attribute 'upper'") with self.assertRaises(ValueError) as cm: Dialect.get_or_raise("myqsl") @@ -121,6 +117,18 @@ class TestDialect(Validator): self.assertEqual(str(cm.exception), "Unknown dialect 'asdfjasodiufjsd'.") + oracle_with_settings = Dialect.get_or_raise( + "oracle, normalization_strategy = lowercase, version = 19.5" + ) + self.assertEqual(oracle_with_settings.normalization_strategy.value, "LOWERCASE") + self.assertEqual(oracle_with_settings.settings, {"version": "19.5"}) + + bool_settings = Dialect.get_or_raise("oracle, s1=TruE, s2=1, s3=FaLse, s4=0, s5=nonbool") + self.assertEqual( + bool_settings.settings, + {"s1": True, "s2": True, "s3": False, "s4": False, "s5": "nonbool"}, + ) + def test_compare_dialects(self): bigquery_class = Dialect["bigquery"] bigquery_object = BigQuery() @@ -1150,7 +1158,6 @@ class TestDialect(Validator): write={ "bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", - "oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST", "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", diff --git a/tests/dialects/test_doris.py b/tests/dialects/test_doris.py index 8180d05..99076ba 100644 --- a/tests/dialects/test_doris.py +++ b/tests/dialects/test_doris.py @@ -56,6 +56,34 @@ class TestDoris(Validator): "postgres": "SELECT STRING_AGG('aa', ',')", }, ) + self.validate_all( + "SELECT LAG(1, 1, NULL) OVER (ORDER BY 1)", + read={ + "doris": "SELECT LAG(1, 1, NULL) OVER (ORDER BY 1)", + "postgres": "SELECT LAG(1) OVER (ORDER BY 1)", + }, + ) + self.validate_all( + "SELECT LAG(1, 2, NULL) OVER (ORDER BY 1)", + read={ + "doris": "SELECT LAG(1, 2, NULL) OVER (ORDER BY 1)", + "postgres": "SELECT LAG(1, 2) OVER (ORDER BY 1)", + }, + ) + self.validate_all( + "SELECT LEAD(1, 1, NULL) OVER (ORDER BY 1)", + read={ + "doris": "SELECT LEAD(1, 1, NULL) OVER (ORDER BY 1)", + "postgres": "SELECT LEAD(1) OVER (ORDER BY 1)", + }, + ) + self.validate_all( + "SELECT LEAD(1, 2, NULL) OVER (ORDER BY 1)", + read={ + "doris": "SELECT LEAD(1, 2, NULL) OVER (ORDER BY 1)", + "postgres": "SELECT LEAD(1, 2) OVER (ORDER BY 1)", + }, + ) def test_identity(self): self.validate_identity("COALECSE(a, b, c, d)") diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 2bde478..e0b0131 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -19,6 +19,13 @@ class TestDuckDB(Validator): ) self.validate_all( + """SELECT CASE WHEN JSON_VALID('{"x: 1}') THEN '{"x: 1}' ELSE NULL END""", + read={ + "duckdb": """SELECT CASE WHEN JSON_VALID('{"x: 1}') THEN '{"x: 1}' ELSE NULL END""", + "snowflake": """SELECT TRY_PARSE_JSON('{"x: 1}')""", + }, + ) + self.validate_all( "SELECT straight_join", write={ "duckdb": "SELECT straight_join", @@ -786,6 +793,8 @@ class TestDuckDB(Validator): }, ) + self.validate_identity("SELECT LENGTH(foo)") + def test_array_index(self): with self.assertLogs(helper_logger) as cm: self.validate_all( @@ -847,7 +856,7 @@ class TestDuckDB(Validator): read={"bigquery": "SELECT DATE(PARSE_DATE('%m/%d/%Y', '05/06/2020'))"}, ) self.validate_all( - "SELECT CAST('2020-01-01' AS DATE) + INTERVAL (-1) DAY", + "SELECT CAST('2020-01-01' AS DATE) + INTERVAL '-1' DAY", read={"mysql": "SELECT DATE '2020-01-01' + INTERVAL -1 DAY"}, ) self.validate_all( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 280ebbf..bfdb2a6 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -117,6 +117,7 @@ class TestMySQL(Validator): ) def test_identity(self): + self.validate_identity("SELECT CAST(COALESCE(`id`, 'NULL') AS CHAR CHARACTER SET binary)") self.validate_identity("SELECT e.* FROM e STRAIGHT_JOIN p ON e.x = p.y") self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1") self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')") diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 7cc4d72..1d9fd99 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -1,5 +1,4 @@ from sqlglot import exp, UnsupportedError -from sqlglot.dialects.oracle import eliminate_join_marks from tests.dialects.test_dialect import Validator @@ -10,7 +9,7 @@ class TestOracle(Validator): self.validate_all( "SELECT CONNECT_BY_ROOT x y", write={ - "": "SELECT CONNECT_BY_ROOT(x) AS y", + "": "SELECT CONNECT_BY_ROOT x AS y", "oracle": "SELECT CONNECT_BY_ROOT x AS y", }, ) @@ -87,9 +86,9 @@ class TestOracle(Validator): "SELECT DISTINCT col1, col2 FROM table", ) self.validate_identity( - "SELECT * FROM T ORDER BY I OFFSET nvl(:variable1, 10) ROWS FETCH NEXT nvl(:variable2, 10) ROWS ONLY", - "SELECT * FROM T ORDER BY I OFFSET COALESCE(:variable1, 10) ROWS FETCH NEXT COALESCE(:variable2, 10) ROWS ONLY", + "SELECT * FROM T ORDER BY I OFFSET NVL(:variable1, 10) ROWS FETCH NEXT NVL(:variable2, 10) ROWS ONLY", ) + self.validate_identity("NVL(x, y)").assert_is(exp.Anonymous) self.validate_identity( "SELECT * FROM t SAMPLE (.25)", "SELECT * FROM t SAMPLE (0.25)", @@ -191,13 +190,6 @@ class TestOracle(Validator): }, ) self.validate_all( - "NVL(NULL, 1)", - write={ - "": "COALESCE(NULL, 1)", - "oracle": "COALESCE(NULL, 1)", - }, - ) - self.validate_all( "DATE '2022-01-01'", write={ "": "DATE_STR_TO_DATE('2022-01-01')", @@ -245,6 +237,10 @@ class TestOracle(Validator): "duckdb": "SELECT CAST(STRPTIME('2024-12-12', '%Y-%m-%d') AS DATE)", }, ) + self.validate_identity( + """SELECT * FROM t ORDER BY a ASC NULLS LAST, b ASC NULLS FIRST, c DESC NULLS LAST, d DESC NULLS FIRST""", + """SELECT * FROM t ORDER BY a ASC, b ASC NULLS FIRST, c DESC NULLS LAST, d DESC""", + ) def test_join_marker(self): self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y") @@ -416,59 +412,6 @@ WHERE for query in (f"{body}{start}{connect}", f"{body}{connect}{start}"): self.validate_identity(query, pretty, pretty=True) - def test_eliminate_join_marks(self): - test_sql = [ - ( - "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) > 5", - "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y > 5", - ), - ( - "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) IS NULL", - "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y IS NULL", - ), - ( - "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y IS NULL", - "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T2.y IS NULL", - ), - ( - "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T1.Z > 4", - "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T1.Z > 4", - ), - ( - "SELECT * FROM table1, table2 WHERE table1.column = table2.column(+)", - "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column", - ), - ( - "SELECT * FROM table1, table2, table3, table4 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+) and table1.column = table4.column(+)", - "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column LEFT JOIN table4 ON table1.column = table4.column", - ), - ( - "SELECT * FROM table1, table2, table3 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+)", - "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column", - ), - ( - "SELECT table1.id, table2.cloumn1, table3.id FROM table1, table2, (SELECT tableInner1.id FROM tableInner1, tableInner2 WHERE tableInner1.id = tableInner2.id(+)) AS table3 WHERE table1.id = table2.id(+) and table1.id = table3.id(+)", - "SELECT table1.id, table2.cloumn1, table3.id FROM table1 LEFT JOIN table2 ON table1.id = table2.id LEFT JOIN (SELECT tableInner1.id FROM tableInner1 LEFT JOIN tableInner2 ON tableInner1.id = tableInner2.id) table3 ON table1.id = table3.id", - ), - # 2 join marks on one side of predicate - ( - "SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + table2.column2(+)", - "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + table2.column2", - ), - # join mark and expression - ( - "SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + 25", - "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + 25", - ), - ] - - for original, expected in test_sql: - with self.subTest(original): - self.assertEqual( - eliminate_join_marks(self.parse_one(original)).sql(dialect=self.dialect), - expected, - ) - def test_query_restrictions(self): for restriction in ("READ ONLY", "CHECK OPTION"): for constraint_name in (" CONSTRAINT name", ""): diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 071677d..816a283 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -8,6 +8,14 @@ class TestPostgres(Validator): dialect = "postgres" def test_postgres(self): + self.validate_all( + "x ? y", + write={ + "": "JSONB_CONTAINS(x, y)", + "postgres": "x ? y", + }, + ) + self.validate_identity("SHA384(x)") self.validate_identity( 'CREATE TABLE x (a TEXT COLLATE "de_DE")', "CREATE TABLE x (a TEXT COLLATE de_DE)" @@ -68,10 +76,6 @@ class TestPostgres(Validator): self.validate_identity("SELECT CURRENT_USER") self.validate_identity("SELECT * FROM ONLY t1") self.validate_identity( - "SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]", - "SELECT ARRAY[1, 2] @> ARRAY[1, 2, 3]", - ) - self.validate_identity( """UPDATE "x" SET "y" = CAST('0 days 60.000000 seconds' AS INTERVAL) WHERE "x"."id" IN (2, 3)""" ) self.validate_identity( @@ -128,6 +132,14 @@ class TestPostgres(Validator): "ORDER BY 2, 3" ) self.validate_identity( + "/*+ some comment*/ SELECT b.foo, b.bar FROM baz AS b", + "/* + some comment */ SELECT b.foo, b.bar FROM baz AS b", + ) + self.validate_identity( + "SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]", + "SELECT ARRAY[1, 2] @> ARRAY[1, 2, 3]", + ) + self.validate_identity( "SELECT ARRAY[]::INT[] AS foo", "SELECT CAST(ARRAY[] AS INT[]) AS foo", ) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index ebb270a..dbe3abc 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -581,6 +581,13 @@ class TestPresto(Validator): ) def test_presto(self): + self.assertEqual( + exp.func("md5", exp.func("concat", exp.cast("x", "text"), exp.Literal.string("s"))).sql( + dialect="presto" + ), + "LOWER(TO_HEX(MD5(TO_UTF8(CONCAT(CAST(x AS VARCHAR), CAST('s' AS VARCHAR))))))", + ) + with self.assertLogs(helper_logger): self.validate_all( "SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table", @@ -1192,3 +1199,18 @@ MATCH_RECOGNIZE ( "starrocks": "SIGN(x)", }, ) + + def test_json_vs_row_extract(self): + for dialect in ("trino", "presto"): + s = parse_one('SELECT col:x:y."special string"', read="snowflake") + + dialect_json_extract_setting = f"{dialect}, variant_extract_is_json_extract=True" + dialect_row_access_setting = f"{dialect}, variant_extract_is_json_extract=False" + + # By default, Snowflake VARIANT will generate JSON_EXTRACT() in Presto/Trino + json_extract_result = """SELECT JSON_EXTRACT(col, '$.x.y["special string"]')""" + self.assertEqual(s.sql(dialect), json_extract_result) + self.assertEqual(s.sql(dialect_json_extract_setting), json_extract_result) + + # If the setting is overriden to False, then generate ROW access (dot notation) + self.assertEqual(s.sql(dialect_row_access_setting), 'SELECT col.x.y."special string"') diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 69793c7..c4e7073 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -28,7 +28,7 @@ class TestRedshift(Validator): """SELECT JSON_EXTRACT_PATH_TEXT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', 'farm', 'barn', 'color')""", write={ "bigquery": """SELECT JSON_EXTRACT_SCALAR('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', '$.farm.barn.color')""", - "databricks": """SELECT GET_JSON_OBJECT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', '$.farm.barn.color')""", + "databricks": """SELECT '{ "farm": {"barn": { "color": "red", "feed stocked": true }}}':farm.barn.color""", "duckdb": """SELECT '{ "farm": {"barn": { "color": "red", "feed stocked": true }}}' ->> '$.farm.barn.color'""", "postgres": """SELECT JSON_EXTRACT_PATH_TEXT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', 'farm', 'barn', 'color')""", "presto": """SELECT JSON_EXTRACT_SCALAR('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', '$.farm.barn.color')""", @@ -228,7 +228,7 @@ class TestRedshift(Validator): "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", - "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) _t WHERE _row_number = 1", + "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) _t WHERE _row_number = 1", "presto": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "redshift": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", "snowflake": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", @@ -259,6 +259,12 @@ class TestRedshift(Validator): "postgres": "COALESCE(a, b, c, d)", }, ) + + self.validate_identity( + "DATEDIFF(days, a, b)", + "DATEDIFF(DAY, a, b)", + ) + self.validate_all( "DATEDIFF('day', a, b)", write={ @@ -300,6 +306,14 @@ class TestRedshift(Validator): }, ) + self.validate_all( + "SELECT EXTRACT(EPOCH FROM CURRENT_DATE)", + write={ + "snowflake": "SELECT DATE_PART(EPOCH, CURRENT_DATE)", + "redshift": "SELECT EXTRACT(EPOCH FROM CURRENT_DATE)", + }, + ) + def test_identity(self): self.validate_identity("LISTAGG(DISTINCT foo, ', ')") self.validate_identity("CREATE MATERIALIZED VIEW orders AUTO REFRESH YES AS SELECT 1") diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 1286436..88b2148 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -11,6 +11,12 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + self.assertEqual( + # Ensures we don't fail when generating ParseJSON with the `safe` arg set to `True` + self.validate_identity("""SELECT TRY_PARSE_JSON('{"x: 1}')""").sql(), + """SELECT PARSE_JSON('{"x: 1}')""", + ) + self.validate_identity( "transform(x, a int -> a + a + 1)", "TRANSFORM(x, a -> CAST(a AS INT) + CAST(a AS INT) + 1)", @@ -49,6 +55,8 @@ WHERE )""", ) + self.validate_identity("SELECT CAST([1, 2, 3] AS VECTOR(FLOAT, 3))") + self.validate_identity("SELECT CONNECT_BY_ROOT test AS test_column_alias") self.validate_identity("SELECT number").selects[0].assert_is(exp.Column) self.validate_identity("INTERVAL '4 years, 5 months, 3 hours'") self.validate_identity("ALTER TABLE table1 CLUSTER BY (name DESC)") @@ -183,18 +191,6 @@ WHERE """SELECT CAST(GET_PATH(PARSE_JSON('{"food":{"fruit":"banana"}}'), 'food.fruit') AS VARCHAR)""", ) self.validate_identity( - "SELECT * FROM foo at", - "SELECT * FROM foo AS at", - ) - self.validate_identity( - "SELECT * FROM foo before", - "SELECT * FROM foo AS before", - ) - self.validate_identity( - "SELECT * FROM foo at (col)", - "SELECT * FROM foo AS at(col)", - ) - self.validate_identity( "SELECT * FROM unnest(x) with ordinality", "SELECT * FROM TABLE(FLATTEN(INPUT => x)) AS _u(seq, key, path, index, value, this)", ) @@ -337,7 +333,7 @@ WHERE """SELECT PARSE_JSON('{"fruit":"banana"}'):fruit""", write={ "bigquery": """SELECT JSON_EXTRACT(PARSE_JSON('{"fruit":"banana"}'), '$.fruit')""", - "databricks": """SELECT GET_JSON_OBJECT('{"fruit":"banana"}', '$.fruit')""", + "databricks": """SELECT '{"fruit":"banana"}':fruit""", "duckdb": """SELECT JSON('{"fruit":"banana"}') -> '$.fruit'""", "mysql": """SELECT JSON_EXTRACT('{"fruit":"banana"}', '$.fruit')""", "presto": """SELECT JSON_EXTRACT(JSON_PARSE('{"fruit":"banana"}'), '$.fruit')""", @@ -1196,6 +1192,17 @@ WHERE "SELECT oldt.*, newt.* FROM my_table BEFORE (STATEMENT => '8e5d0ca9-005e-44e6-b858-a8f5b37c5726') AS oldt FULL OUTER JOIN my_table AT (STATEMENT => '8e5d0ca9-005e-44e6-b858-a8f5b37c5726') AS newt ON oldt.id = newt.id WHERE oldt.id IS NULL OR newt.id IS NULL", ) + # Make sure that the historical data keywords can still be used as aliases + for historical_data_prefix in ("AT", "BEFORE", "END", "CHANGES"): + for schema_suffix in ("", "(col)"): + with self.subTest( + f"Testing historical data prefix alias: {historical_data_prefix}{schema_suffix}" + ): + self.validate_identity( + f"SELECT * FROM foo {historical_data_prefix}{schema_suffix}", + f"SELECT * FROM foo AS {historical_data_prefix}{schema_suffix}", + ) + def test_ddl(self): for constraint_prefix in ("WITH ", ""): with self.subTest(f"Constraint prefix: {constraint_prefix}"): @@ -1216,6 +1223,7 @@ WHERE "CREATE TABLE t (id INT TAG (key1='value_1', key2='value_2'))", ) + self.validate_identity("CREATE SECURE VIEW table1 AS (SELECT a FROM table2)") self.validate_identity( """create external table et2( col1 date as (parse_json(metadata$external_table_partition):COL1::date), @@ -1241,6 +1249,9 @@ WHERE "CREATE OR REPLACE TAG IF NOT EXISTS cost_center COMMENT='cost_center tag'" ).this.assert_is(exp.Identifier) self.validate_identity( + "CREATE DYNAMIC TABLE product (pre_tax_profit, taxes, after_tax_profit) TARGET_LAG='20 minutes' WAREHOUSE=mywh AS SELECT revenue - cost, (revenue - cost) * tax_rate, (revenue - cost) * (1.0 - tax_rate) FROM staging_table" + ) + self.validate_identity( "ALTER TABLE db_name.schmaName.tblName ADD COLUMN COLUMN_1 VARCHAR NOT NULL TAG (key1='value_1')" ) self.validate_identity( @@ -2021,3 +2032,15 @@ SINGLE = TRUE""", self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c") self.validate_identity("ALTER TABLE foo UNSET DATA_RETENTION_TIME_IN_DAYS, CHANGE_TRACKING") + + def test_from_changes(self): + self.validate_identity( + """SELECT C1 FROM t1 CHANGES (INFORMATION => APPEND_ONLY) AT (STREAM => 's1') END (TIMESTAMP => $ts2)""" + ) + self.validate_identity( + """SELECT C1 FROM t1 CHANGES (INFORMATION => APPEND_ONLY) BEFORE (STATEMENT => 'STMT_ID') END (TIMESTAMP => $ts2)""" + ) + self.validate_identity( + """SELECT 1 FROM some_table CHANGES (INFORMATION => APPEND_ONLY) AT (TIMESTAMP => TO_TIMESTAMP_TZ('2024-07-01 00:00:00+00:00')) END (TIMESTAMP => TO_TIMESTAMP_TZ('2024-07-01 14:28:59.999999+00:00'))""", + """SELECT 1 FROM some_table CHANGES (INFORMATION => APPEND_ONLY) AT (TIMESTAMP => CAST('2024-07-01 00:00:00+00:00' AS TIMESTAMPTZ)) END (TIMESTAMP => CAST('2024-07-01 14:28:59.999999+00:00' AS TIMESTAMPTZ))""", + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index bff91bf..4e62b32 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -245,7 +245,7 @@ TBLPROPERTIES ( self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)") self.validate_identity("REFRESH TABLE a.b.c") - self.validate_identity("INTERVAL -86 DAYS") + self.validate_identity("INTERVAL '-86' DAYS") self.validate_identity("TRIM(' SparkSQL ')") self.validate_identity("TRIM(BOTH 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')") @@ -801,3 +801,10 @@ TBLPROPERTIES ( self.assertEqual(query.sql(name), with_modifiers) else: self.assertEqual(query.sql(name), without_modifiers) + + def test_schema_binding_options(self): + for schema_binding in ("BINDING", "COMPENSATION", "TYPE EVOLUTION", "EVOLUTION"): + with self.subTest(f"Test roundtrip of VIEW schema binding {schema_binding}"): + self.validate_identity( + f"CREATE VIEW emp_v WITH SCHEMA {schema_binding} AS SELECT * FROM emp" + ) diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 74d5f88..3945ca3 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -5,6 +5,13 @@ class TestTeradata(Validator): dialect = "teradata" def test_teradata(self): + self.validate_all( + "RANDOM(l, u)", + write={ + "": "(u - l) * RAND() + l", + "teradata": "RANDOM(l, u)", + }, + ) self.validate_identity("TO_NUMBER(expr, fmt, nlsparam)") self.validate_identity("SELECT TOP 10 * FROM tbl") self.validate_identity("SELECT * FROM tbl SAMPLE 5") @@ -212,6 +219,8 @@ class TestTeradata(Validator): ) def test_time(self): + self.validate_identity("CAST(CURRENT_TIMESTAMP(6) AS TIMESTAMP WITH TIME ZONE)") + self.validate_all( "CURRENT_TIMESTAMP", read={ diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 7455650..11d60e7 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1,4 +1,4 @@ -from sqlglot import exp, parse +from sqlglot import exp, parse, parse_one from tests.dialects.test_dialect import Validator from sqlglot.errors import ParseError from sqlglot.optimizer.annotate_types import annotate_types @@ -8,19 +8,14 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): - self.assertEqual( - annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"), - "SELECT 1 WHERE EXISTS(SELECT 1)", - ) + # https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN + # tsql allows .. which means use the default schema + self.validate_identity("SELECT * FROM a..b") self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c") self.validate_identity("DROP view a.b.c", "DROP VIEW b.c") self.validate_identity("ROUND(x, 1, 0)") self.validate_identity("EXEC MyProc @id=7, @name='Lochristi'", check_command_warning=True) - # https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN - # tsql allows .. which means use the default schema - self.validate_identity("SELECT * FROM a..b") - self.validate_identity("SELECT TRIM(' test ') AS Result") self.validate_identity("SELECT TRIM('.,! ' FROM ' # test .') AS Result") self.validate_identity("SELECT * FROM t TABLESAMPLE (10 PERCENT)") @@ -37,8 +32,21 @@ class TestTSQL(Validator): self.validate_identity("CAST(x AS int) OR y", "CAST(x AS INTEGER) <> 0 OR y <> 0") self.validate_identity("TRUNCATE TABLE t1 WITH (PARTITIONS(1, 2 TO 5, 10 TO 20, 84))") self.validate_identity( + "CREATE CLUSTERED INDEX [IX_OfficeTagDetail_TagDetailID] ON [dbo].[OfficeTagDetail]([TagDetailID] ASC)" + ) + self.validate_identity( + "CREATE INDEX [x] ON [y]([z] ASC) WITH (allow_page_locks=on) ON X([y])" + ) + self.validate_identity( + "CREATE INDEX [x] ON [y]([z] ASC) WITH (allow_page_locks=on) ON PRIMARY" + ) + self.validate_identity( "COPY INTO test_1 FROM 'path' WITH (FORMAT_NAME = test, FILE_TYPE = 'CSV', CREDENTIAL = (IDENTITY='Shared Access Signature', SECRET='token'), FIELDTERMINATOR = ';', ROWTERMINATOR = '0X0A', ENCODING = 'UTF8', DATEFORMAT = 'ymd', MAXERRORS = 10, ERRORFILE = 'errorsfolder', IDENTITY_INSERT = 'ON')" ) + self.assertEqual( + annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"), + "SELECT 1 WHERE EXISTS(SELECT 1)", + ) self.validate_all( "SELECT IIF(cond <> 0, 'True', 'False')", @@ -1868,3 +1876,25 @@ FROM OPENJSON(@json) WITH ( "DECLARE vendor_cursor CURSOR FOR SELECT VendorID, Name FROM Purchasing.Vendor WHERE PreferredVendorStatus = 1 ORDER BY VendorID", check_command_warning=True, ) + + def test_scope_resolution_op(self): + # we still want to support :: casting shorthand for tsql + self.validate_identity("x::int", "CAST(x AS INTEGER)") + self.validate_identity("x::varchar", "CAST(x AS VARCHAR)") + self.validate_identity("x::varchar(MAX)", "CAST(x AS VARCHAR(MAX))") + + for lhs, rhs in ( + ("", "FOO(a, b)"), + ("bar", "baZ(1, 2)"), + ("LOGIN", "EricKurjan"), + ("GEOGRAPHY", "Point(latitude, longitude, 4326)"), + ( + "GEOGRAPHY", + "STGeomFromText('POLYGON((-122.358 47.653 , -122.348 47.649, -122.348 47.658, -122.358 47.658, -122.358 47.653))', 4326)", + ), + ): + with self.subTest(f"Scope resolution, LHS: {lhs}, RHS: {rhs}"): + expr = self.validate_identity(f"{lhs}::{rhs}") + base_sql = expr.sql() + self.assertEqual(base_sql, f"SCOPE_RESOLUTION({lhs + ', ' if lhs else ''}{rhs})") + self.assertEqual(parse_one(base_sql).sql("tsql"), f"{lhs}::{rhs}") |