summaryrefslogtreecommitdiffstats
path: root/tests/dialects
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-07-13 11:11:42 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-07-13 11:11:42 +0000
commit721d458d4c24741ccbc5519b7ca39234a1a21ff6 (patch)
treeb9f72e1d00aba012f06cdf7b0d75ec5e53640eaf /tests/dialects
parentAdding upstream version 25.1.0. (diff)
downloadsqlglot-721d458d4c24741ccbc5519b7ca39234a1a21ff6.tar.xz
sqlglot-721d458d4c24741ccbc5519b7ca39234a1a21ff6.zip
Adding upstream version 25.5.1.upstream/25.5.1
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/dialects')
-rw-r--r--tests/dialects/test_bigquery.py55
-rw-r--r--tests/dialects/test_clickhouse.py96
-rw-r--r--tests/dialects/test_databricks.py41
-rw-r--r--tests/dialects/test_dialect.py21
-rw-r--r--tests/dialects/test_doris.py28
-rw-r--r--tests/dialects/test_duckdb.py11
-rw-r--r--tests/dialects/test_mysql.py1
-rw-r--r--tests/dialects/test_oracle.py71
-rw-r--r--tests/dialects/test_postgres.py20
-rw-r--r--tests/dialects/test_presto.py22
-rw-r--r--tests/dialects/test_redshift.py18
-rw-r--r--tests/dialects/test_snowflake.py49
-rw-r--r--tests/dialects/test_spark.py9
-rw-r--r--tests/dialects/test_teradata.py9
-rw-r--r--tests/dialects/test_tsql.py48
15 files changed, 345 insertions, 154 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index ae8ed16..803ac11 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -103,6 +103,7 @@ LANGUAGE js AS
select_with_quoted_udf = self.validate_identity("SELECT `p.d.UdF`(data) FROM `p.d.t`")
self.assertEqual(select_with_quoted_udf.selects[0].name, "p.d.UdF")
+ self.validate_identity("CAST(x AS STRUCT<list ARRAY<INT64>>)")
self.validate_identity("assert.true(1 = 1)")
self.validate_identity("SELECT ARRAY_TO_STRING(list, '--') AS text")
self.validate_identity("SELECT jsondoc['some_key']")
@@ -294,6 +295,20 @@ LANGUAGE js AS
)
self.validate_all(
+ "SAFE_CAST(some_date AS DATE FORMAT 'DD MONTH YYYY')",
+ write={
+ "bigquery": "SAFE_CAST(some_date AS DATE FORMAT 'DD MONTH YYYY')",
+ "duckdb": "CAST(TRY_STRPTIME(some_date, '%d %B %Y') AS DATE)",
+ },
+ )
+ self.validate_all(
+ "SAFE_CAST(some_date AS DATE FORMAT 'YYYY-MM-DD') AS some_date",
+ write={
+ "bigquery": "SAFE_CAST(some_date AS DATE FORMAT 'YYYY-MM-DD') AS some_date",
+ "duckdb": "CAST(TRY_STRPTIME(some_date, '%Y-%m-%d') AS DATE) AS some_date",
+ },
+ )
+ self.validate_all(
"SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS h, UNNEST(h.t3) AS s",
write={
"bigquery": "SELECT t.c1, h.c2, s.c3 FROM t1 AS t, UNNEST(t.t2) AS h, UNNEST(h.t3) AS s",
@@ -1345,6 +1360,46 @@ WHERE
"bigquery": "SELECT CAST(x AS DATETIME)",
},
)
+ self.validate_all(
+ "SELECT TIME(foo, 'America/Los_Angeles')",
+ write={
+ "duckdb": "SELECT CAST(CAST(foo AS TIMESTAMPTZ) AT TIME ZONE 'America/Los_Angeles' AS TIME)",
+ "bigquery": "SELECT TIME(foo, 'America/Los_Angeles')",
+ },
+ )
+ self.validate_all(
+ "SELECT DATETIME('2020-01-01')",
+ write={
+ "duckdb": "SELECT CAST('2020-01-01' AS TIMESTAMP)",
+ "bigquery": "SELECT DATETIME('2020-01-01')",
+ },
+ )
+ self.validate_all(
+ "SELECT DATETIME('2020-01-01', TIME '23:59:59')",
+ write={
+ "duckdb": "SELECT CAST(CAST('2020-01-01' AS DATE) + CAST('23:59:59' AS TIME) AS TIMESTAMP)",
+ "bigquery": "SELECT DATETIME('2020-01-01', CAST('23:59:59' AS TIME))",
+ },
+ )
+ self.validate_all(
+ "SELECT DATETIME('2020-01-01', 'America/Los_Angeles')",
+ write={
+ "duckdb": "SELECT CAST(CAST('2020-01-01' AS TIMESTAMPTZ) AT TIME ZONE 'America/Los_Angeles' AS TIMESTAMP)",
+ "bigquery": "SELECT DATETIME('2020-01-01', 'America/Los_Angeles')",
+ },
+ )
+ self.validate_all(
+ "SELECT LENGTH(foo)",
+ read={
+ "bigquery": "SELECT LENGTH(foo)",
+ "snowflake": "SELECT LENGTH(foo)",
+ },
+ write={
+ "duckdb": "SELECT CASE TYPEOF(foo) WHEN 'VARCHAR' THEN LENGTH(CAST(foo AS TEXT)) WHEN 'BLOB' THEN OCTET_LENGTH(CAST(foo AS BLOB)) END",
+ "snowflake": "SELECT LENGTH(foo)",
+ "": "SELECT LENGTH(foo)",
+ },
+ )
def test_errors(self):
with self.assertRaises(TokenError):
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index 72634a8..ef84d48 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -7,23 +7,6 @@ class TestClickhouse(Validator):
dialect = "clickhouse"
def test_clickhouse(self):
- self.validate_all(
- "SELECT * FROM x PREWHERE y = 1 WHERE z = 2",
- write={
- "": "SELECT * FROM x WHERE z = 2",
- "clickhouse": "SELECT * FROM x PREWHERE y = 1 WHERE z = 2",
- },
- )
- self.validate_all(
- "SELECT * FROM x AS prewhere",
- read={
- "clickhouse": "SELECT * FROM x AS prewhere",
- "duckdb": "SELECT * FROM x prewhere",
- },
- )
-
- self.validate_identity("SELECT * FROM x LIMIT 1 UNION ALL SELECT * FROM y")
-
string_types = [
"BLOB",
"LONGBLOB",
@@ -42,6 +25,9 @@ class TestClickhouse(Validator):
self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)")
self.assertIsNone(expr._meta)
+ self.validate_identity("SELECT EXTRACT(YEAR FROM toDateTime('2023-02-01'))")
+ self.validate_identity("extract(haystack, pattern)")
+ self.validate_identity("SELECT * FROM x LIMIT 1 UNION ALL SELECT * FROM y")
self.validate_identity("SELECT CAST(x AS Tuple(String, Array(Nullable(Float64))))")
self.validate_identity("countIf(x, y)")
self.validate_identity("x = y")
@@ -94,18 +80,12 @@ class TestClickhouse(Validator):
self.validate_identity("""SELECT JSONExtractString('{"x": {"y": 1}}', 'x', 'y')""")
self.validate_identity("SELECT * FROM table LIMIT 1 BY a, b")
self.validate_identity("SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b")
+ self.validate_identity("TRUNCATE TABLE t1 ON CLUSTER test_cluster")
+ self.validate_identity("TRUNCATE DATABASE db")
+ self.validate_identity("TRUNCATE DATABASE db ON CLUSTER test_cluster")
self.validate_identity(
"SELECT id, quantileGK(100, 0.95)(reading) OVER (PARTITION BY id ORDER BY id RANGE BETWEEN 30000 PRECEDING AND CURRENT ROW) AS window FROM table"
)
-
- self.validate_identity(
- "SELECT $1$foo$1$",
- "SELECT 'foo'",
- )
- self.validate_identity(
- "SELECT * FROM table LIMIT 1, 2 BY a, b",
- "SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b",
- )
self.validate_identity(
"SELECT * FROM table LIMIT 1 BY CONCAT(datalayerVariantNo, datalayerProductId, warehouse)"
)
@@ -134,10 +114,6 @@ class TestClickhouse(Validator):
"SELECT sum(1) AS impressions, (arrayJoin(arrayZip(cities, browsers)) AS t).1 AS city, t.2 AS browser FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities, ['Firefox', 'Chrome', 'Chrome'] AS browsers) GROUP BY 2, 3"
)
self.validate_identity(
- "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ['Istanbul', 'Berlin']",
- "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ('Istanbul', 'Berlin')",
- )
- self.validate_identity(
'SELECT CAST(tuple(1 AS "a", 2 AS "b", 3.0 AS "c").2 AS Nullable(String))'
)
self.validate_identity(
@@ -155,12 +131,43 @@ class TestClickhouse(Validator):
self.validate_identity(
"CREATE MATERIALIZED VIEW test_view (id UInt8) TO db.table1 AS SELECT * FROM test_data"
)
- self.validate_identity("TRUNCATE TABLE t1 ON CLUSTER test_cluster")
- self.validate_identity("TRUNCATE DATABASE db")
- self.validate_identity("TRUNCATE DATABASE db ON CLUSTER test_cluster")
self.validate_identity(
"CREATE TABLE t (foo String CODEC(LZ4HC(9), ZSTD, DELTA), size String ALIAS formatReadableSize(size_bytes), INDEX idx1 a TYPE bloom_filter(0.001) GRANULARITY 1, INDEX idx2 a TYPE set(100) GRANULARITY 2, INDEX idx3 a TYPE minmax GRANULARITY 3)"
)
+ self.validate_identity(
+ "SELECT $1$foo$1$",
+ "SELECT 'foo'",
+ )
+ self.validate_identity(
+ "SELECT * FROM table LIMIT 1, 2 BY a, b",
+ "SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b",
+ )
+ self.validate_identity(
+ "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ['Istanbul', 'Berlin']",
+ "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ('Istanbul', 'Berlin')",
+ )
+
+ self.validate_all(
+ "SELECT * FROM x PREWHERE y = 1 WHERE z = 2",
+ write={
+ "": "SELECT * FROM x WHERE z = 2",
+ "clickhouse": "SELECT * FROM x PREWHERE y = 1 WHERE z = 2",
+ },
+ )
+ self.validate_all(
+ "SELECT * FROM x AS prewhere",
+ read={
+ "clickhouse": "SELECT * FROM x AS prewhere",
+ "duckdb": "SELECT * FROM x prewhere",
+ },
+ )
+ self.validate_all(
+ "SELECT a, b FROM (SELECT * FROM x) AS t",
+ read={
+ "clickhouse": "SELECT a, b FROM (SELECT * FROM x) AS t",
+ "duckdb": "SELECT a, b FROM (SELECT * FROM x) AS t(a, b)",
+ },
+ )
self.validate_all(
"SELECT arrayJoin([1,2,3])",
write={
@@ -880,3 +887,26 @@ LIFETIME(MIN 0 MAX 0)""",
for creatable in ("DATABASE", "TABLE", "VIEW", "DICTIONARY", "FUNCTION"):
with self.subTest(f"Test DROP {creatable} ON CLUSTER"):
self.validate_identity(f"DROP {creatable} test ON CLUSTER test_cluster")
+
+ def test_datetime_funcs(self):
+ # Each datetime func has an alias that is roundtripped to the original name e.g. (DATE_SUB, DATESUB) -> DATE_SUB
+ datetime_funcs = (("DATE_SUB", "DATESUB"), ("DATE_ADD", "DATEADD"))
+
+ # 2-arg functions of type <func>(date, unit)
+ for func in (*datetime_funcs, ("TIMESTAMP_ADD", "TIMESTAMPADD")):
+ func_name = func[0]
+ for func_alias in func:
+ self.validate_identity(
+ f"""SELECT {func_alias}(date, INTERVAL '3' YEAR)""",
+ f"""SELECT {func_name}(date, INTERVAL '3' YEAR)""",
+ )
+
+ # 3-arg functions of type <func>(unit, value, date)
+ for func in (*datetime_funcs, ("DATE_DIFF", "DATEDIFF"), ("TIMESTAMP_SUB", "TIMESTAMPSUB")):
+ func_name = func[0]
+ for func_alias in func:
+ with self.subTest(f"Test 3-arg date-time function {func_alias}"):
+ self.validate_identity(
+ f"SELECT {func_alias}(SECOND, 1, bar)",
+ f"SELECT {func_name}(SECOND, 1, bar)",
+ )
diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py
index 9ef3b86..471830f 100644
--- a/tests/dialects/test_databricks.py
+++ b/tests/dialects/test_databricks.py
@@ -1,4 +1,4 @@
-from sqlglot import transpile
+from sqlglot import exp, transpile
from sqlglot.errors import ParseError
from tests.dialects.test_dialect import Validator
@@ -25,6 +25,7 @@ class TestDatabricks(Validator):
self.validate_identity("CREATE FUNCTION a AS b")
self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1")
self.validate_identity("CREATE TABLE foo (x DATE GENERATED ALWAYS AS (CAST(y AS DATE)))")
+ self.validate_identity("TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', address)")
self.validate_identity(
"CREATE TABLE IF NOT EXISTS db.table (a TIMESTAMP, b BOOLEAN GENERATED ALWAYS AS (NOT a IS NULL)) USING DELTA"
)
@@ -37,22 +38,26 @@ class TestDatabricks(Validator):
self.validate_identity(
"SELECT * FROM sales UNPIVOT EXCLUDE NULLS (sales FOR quarter IN (q1 AS `Jan-Mar`))"
)
-
self.validate_identity(
"CREATE FUNCTION add_one(x INT) RETURNS INT LANGUAGE PYTHON AS $$def add_one(x):\n return x+1$$"
)
-
self.validate_identity(
"CREATE FUNCTION add_one(x INT) RETURNS INT LANGUAGE PYTHON AS $FOO$def add_one(x):\n return x+1$FOO$"
)
-
- self.validate_identity("TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', address)")
self.validate_identity(
"TRUNCATE TABLE t1 PARTITION(age = 10, name = 'test', city LIKE 'LA')"
)
self.validate_identity(
"COPY INTO target FROM `s3://link` FILEFORMAT = AVRO VALIDATE = ALL FILES = ('file1', 'file2') FORMAT_OPTIONS ('opt1'='true', 'opt2'='test') COPY_OPTIONS ('mergeSchema'='true')"
)
+ self.validate_identity(
+ "DATE_DIFF(day, created_at, current_date())",
+ "DATEDIFF(DAY, created_at, CURRENT_DATE)",
+ ).args["unit"].assert_is(exp.Var)
+ self.validate_identity(
+ r'SELECT r"\\foo.bar\"',
+ r"SELECT '\\\\foo.bar\\'",
+ )
self.validate_all(
"CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))",
@@ -67,7 +72,6 @@ class TestDatabricks(Validator):
"teradata": "CREATE TABLE t1 AS (SELECT c FROM t2) WITH DATA",
},
)
-
self.validate_all(
"SELECT X'1A2B'",
read={
@@ -96,33 +100,30 @@ class TestDatabricks(Validator):
# https://docs.databricks.com/sql/language-manual/functions/colonsign.html
def test_json(self):
+ self.validate_identity("SELECT c1:price, c1:price.foo, c1:price.bar[1]")
self.validate_identity(
- """SELECT c1 : price FROM VALUES ('{ "price": 5 }') AS T(c1)""",
- """SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""",
+ """SELECT c1:item[1].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)"""
)
self.validate_identity(
- """SELECT c1:['price'] FROM VALUES('{ "price": 5 }') AS T(c1)""",
- """SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""",
+ """SELECT c1:item[*].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)"""
)
self.validate_identity(
- """SELECT c1:item[1].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- """SELECT GET_JSON_OBJECT(c1, '$.item[1].price') FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
+ """SELECT FROM_JSON(c1:item[*].price, 'ARRAY<DOUBLE>')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)"""
)
self.validate_identity(
- """SELECT c1:item[*].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- """SELECT GET_JSON_OBJECT(c1, '$.item[*].price') FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
+ """SELECT INLINE(FROM_JSON(c1:item[*], 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)"""
)
self.validate_identity(
- """SELECT from_json(c1:item[*].price, 'ARRAY<DOUBLE>')[0] FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- """SELECT FROM_JSON(GET_JSON_OBJECT(c1, '$.item[*].price'), 'ARRAY<DOUBLE>')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
+ """SELECT c1:['price'] FROM VALUES ('{ "price": 5 }') AS T(c1)""",
+ """SELECT c1:price FROM VALUES ('{ "price": 5 }') AS T(c1)""",
)
self.validate_identity(
- """SELECT inline(from_json(c1:item[*], 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
- """SELECT INLINE(FROM_JSON(GET_JSON_OBJECT(c1, '$.item[*]'), 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""",
+ """SELECT GET_JSON_OBJECT(c1, '$.price') FROM VALUES ('{ "price": 5 }') AS T(c1)""",
+ """SELECT c1:price FROM VALUES ('{ "price": 5 }') AS T(c1)""",
)
self.validate_identity(
- "SELECT c1 : price",
- "SELECT GET_JSON_OBJECT(c1, '$.price')",
+ """SELECT raw:`zip code`, raw:`fb:testid`, raw:store['bicycle'], raw:store["zip code"]""",
+ """SELECT raw:["zip code"], raw:["fb:testid"], raw:store.bicycle, raw:store["zip code"]""",
)
def test_datediff(self):
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index aaeb7b0..c0afb2f 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -102,14 +102,10 @@ class TestDialect(Validator):
lowercase_mysql = Dialect.get_or_raise("mysql, normalization_strategy = lowercase")
self.assertEqual(lowercase_mysql.normalization_strategy.value, "LOWERCASE")
- with self.assertRaises(ValueError) as cm:
+ with self.assertRaises(AttributeError) as cm:
Dialect.get_or_raise("mysql, normalization_strategy")
- self.assertEqual(
- str(cm.exception),
- "Invalid dialect format: 'mysql, normalization_strategy'. "
- "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'.",
- )
+ self.assertEqual(str(cm.exception), "'bool' object has no attribute 'upper'")
with self.assertRaises(ValueError) as cm:
Dialect.get_or_raise("myqsl")
@@ -121,6 +117,18 @@ class TestDialect(Validator):
self.assertEqual(str(cm.exception), "Unknown dialect 'asdfjasodiufjsd'.")
+ oracle_with_settings = Dialect.get_or_raise(
+ "oracle, normalization_strategy = lowercase, version = 19.5"
+ )
+ self.assertEqual(oracle_with_settings.normalization_strategy.value, "LOWERCASE")
+ self.assertEqual(oracle_with_settings.settings, {"version": "19.5"})
+
+ bool_settings = Dialect.get_or_raise("oracle, s1=TruE, s2=1, s3=FaLse, s4=0, s5=nonbool")
+ self.assertEqual(
+ bool_settings.settings,
+ {"s1": True, "s2": True, "s3": False, "s4": False, "s5": "nonbool"},
+ )
+
def test_compare_dialects(self):
bigquery_class = Dialect["bigquery"]
bigquery_object = BigQuery()
@@ -1150,7 +1158,6 @@ class TestDialect(Validator):
write={
"bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
"duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST",
- "oracle": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
"presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC, lname NULLS FIRST",
"hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
"spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
diff --git a/tests/dialects/test_doris.py b/tests/dialects/test_doris.py
index 8180d05..99076ba 100644
--- a/tests/dialects/test_doris.py
+++ b/tests/dialects/test_doris.py
@@ -56,6 +56,34 @@ class TestDoris(Validator):
"postgres": "SELECT STRING_AGG('aa', ',')",
},
)
+ self.validate_all(
+ "SELECT LAG(1, 1, NULL) OVER (ORDER BY 1)",
+ read={
+ "doris": "SELECT LAG(1, 1, NULL) OVER (ORDER BY 1)",
+ "postgres": "SELECT LAG(1) OVER (ORDER BY 1)",
+ },
+ )
+ self.validate_all(
+ "SELECT LAG(1, 2, NULL) OVER (ORDER BY 1)",
+ read={
+ "doris": "SELECT LAG(1, 2, NULL) OVER (ORDER BY 1)",
+ "postgres": "SELECT LAG(1, 2) OVER (ORDER BY 1)",
+ },
+ )
+ self.validate_all(
+ "SELECT LEAD(1, 1, NULL) OVER (ORDER BY 1)",
+ read={
+ "doris": "SELECT LEAD(1, 1, NULL) OVER (ORDER BY 1)",
+ "postgres": "SELECT LEAD(1) OVER (ORDER BY 1)",
+ },
+ )
+ self.validate_all(
+ "SELECT LEAD(1, 2, NULL) OVER (ORDER BY 1)",
+ read={
+ "doris": "SELECT LEAD(1, 2, NULL) OVER (ORDER BY 1)",
+ "postgres": "SELECT LEAD(1, 2) OVER (ORDER BY 1)",
+ },
+ )
def test_identity(self):
self.validate_identity("COALECSE(a, b, c, d)")
diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py
index 2bde478..e0b0131 100644
--- a/tests/dialects/test_duckdb.py
+++ b/tests/dialects/test_duckdb.py
@@ -19,6 +19,13 @@ class TestDuckDB(Validator):
)
self.validate_all(
+ """SELECT CASE WHEN JSON_VALID('{"x: 1}') THEN '{"x: 1}' ELSE NULL END""",
+ read={
+ "duckdb": """SELECT CASE WHEN JSON_VALID('{"x: 1}') THEN '{"x: 1}' ELSE NULL END""",
+ "snowflake": """SELECT TRY_PARSE_JSON('{"x: 1}')""",
+ },
+ )
+ self.validate_all(
"SELECT straight_join",
write={
"duckdb": "SELECT straight_join",
@@ -786,6 +793,8 @@ class TestDuckDB(Validator):
},
)
+ self.validate_identity("SELECT LENGTH(foo)")
+
def test_array_index(self):
with self.assertLogs(helper_logger) as cm:
self.validate_all(
@@ -847,7 +856,7 @@ class TestDuckDB(Validator):
read={"bigquery": "SELECT DATE(PARSE_DATE('%m/%d/%Y', '05/06/2020'))"},
)
self.validate_all(
- "SELECT CAST('2020-01-01' AS DATE) + INTERVAL (-1) DAY",
+ "SELECT CAST('2020-01-01' AS DATE) + INTERVAL '-1' DAY",
read={"mysql": "SELECT DATE '2020-01-01' + INTERVAL -1 DAY"},
)
self.validate_all(
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index 280ebbf..bfdb2a6 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -117,6 +117,7 @@ class TestMySQL(Validator):
)
def test_identity(self):
+ self.validate_identity("SELECT CAST(COALESCE(`id`, 'NULL') AS CHAR CHARACTER SET binary)")
self.validate_identity("SELECT e.* FROM e STRAIGHT_JOIN p ON e.x = p.y")
self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1")
self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')")
diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py
index 7cc4d72..1d9fd99 100644
--- a/tests/dialects/test_oracle.py
+++ b/tests/dialects/test_oracle.py
@@ -1,5 +1,4 @@
from sqlglot import exp, UnsupportedError
-from sqlglot.dialects.oracle import eliminate_join_marks
from tests.dialects.test_dialect import Validator
@@ -10,7 +9,7 @@ class TestOracle(Validator):
self.validate_all(
"SELECT CONNECT_BY_ROOT x y",
write={
- "": "SELECT CONNECT_BY_ROOT(x) AS y",
+ "": "SELECT CONNECT_BY_ROOT x AS y",
"oracle": "SELECT CONNECT_BY_ROOT x AS y",
},
)
@@ -87,9 +86,9 @@ class TestOracle(Validator):
"SELECT DISTINCT col1, col2 FROM table",
)
self.validate_identity(
- "SELECT * FROM T ORDER BY I OFFSET nvl(:variable1, 10) ROWS FETCH NEXT nvl(:variable2, 10) ROWS ONLY",
- "SELECT * FROM T ORDER BY I OFFSET COALESCE(:variable1, 10) ROWS FETCH NEXT COALESCE(:variable2, 10) ROWS ONLY",
+ "SELECT * FROM T ORDER BY I OFFSET NVL(:variable1, 10) ROWS FETCH NEXT NVL(:variable2, 10) ROWS ONLY",
)
+ self.validate_identity("NVL(x, y)").assert_is(exp.Anonymous)
self.validate_identity(
"SELECT * FROM t SAMPLE (.25)",
"SELECT * FROM t SAMPLE (0.25)",
@@ -191,13 +190,6 @@ class TestOracle(Validator):
},
)
self.validate_all(
- "NVL(NULL, 1)",
- write={
- "": "COALESCE(NULL, 1)",
- "oracle": "COALESCE(NULL, 1)",
- },
- )
- self.validate_all(
"DATE '2022-01-01'",
write={
"": "DATE_STR_TO_DATE('2022-01-01')",
@@ -245,6 +237,10 @@ class TestOracle(Validator):
"duckdb": "SELECT CAST(STRPTIME('2024-12-12', '%Y-%m-%d') AS DATE)",
},
)
+ self.validate_identity(
+ """SELECT * FROM t ORDER BY a ASC NULLS LAST, b ASC NULLS FIRST, c DESC NULLS LAST, d DESC NULLS FIRST""",
+ """SELECT * FROM t ORDER BY a ASC, b ASC NULLS FIRST, c DESC NULLS LAST, d DESC""",
+ )
def test_join_marker(self):
self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y")
@@ -416,59 +412,6 @@ WHERE
for query in (f"{body}{start}{connect}", f"{body}{connect}{start}"):
self.validate_identity(query, pretty, pretty=True)
- def test_eliminate_join_marks(self):
- test_sql = [
- (
- "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) > 5",
- "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y > 5",
- ),
- (
- "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y (+) IS NULL",
- "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x AND T2.y IS NULL",
- ),
- (
- "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T2.y IS NULL",
- "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T2.y IS NULL",
- ),
- (
- "SELECT T1.d, T2.c FROM T1, T2 WHERE T1.x = T2.x (+) and T1.Z > 4",
- "SELECT T1.d, T2.c FROM T1 LEFT JOIN T2 ON T1.x = T2.x WHERE T1.Z > 4",
- ),
- (
- "SELECT * FROM table1, table2 WHERE table1.column = table2.column(+)",
- "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column",
- ),
- (
- "SELECT * FROM table1, table2, table3, table4 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+) and table1.column = table4.column(+)",
- "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column LEFT JOIN table4 ON table1.column = table4.column",
- ),
- (
- "SELECT * FROM table1, table2, table3 WHERE table1.column = table2.column(+) and table2.column >= table3.column(+)",
- "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column LEFT JOIN table3 ON table2.column >= table3.column",
- ),
- (
- "SELECT table1.id, table2.cloumn1, table3.id FROM table1, table2, (SELECT tableInner1.id FROM tableInner1, tableInner2 WHERE tableInner1.id = tableInner2.id(+)) AS table3 WHERE table1.id = table2.id(+) and table1.id = table3.id(+)",
- "SELECT table1.id, table2.cloumn1, table3.id FROM table1 LEFT JOIN table2 ON table1.id = table2.id LEFT JOIN (SELECT tableInner1.id FROM tableInner1 LEFT JOIN tableInner2 ON tableInner1.id = tableInner2.id) table3 ON table1.id = table3.id",
- ),
- # 2 join marks on one side of predicate
- (
- "SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + table2.column2(+)",
- "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + table2.column2",
- ),
- # join mark and expression
- (
- "SELECT * FROM table1, table2 WHERE table1.column = table2.column1(+) + 25",
- "SELECT * FROM table1 LEFT JOIN table2 ON table1.column = table2.column1 + 25",
- ),
- ]
-
- for original, expected in test_sql:
- with self.subTest(original):
- self.assertEqual(
- eliminate_join_marks(self.parse_one(original)).sql(dialect=self.dialect),
- expected,
- )
-
def test_query_restrictions(self):
for restriction in ("READ ONLY", "CHECK OPTION"):
for constraint_name in (" CONSTRAINT name", ""):
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 071677d..816a283 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -8,6 +8,14 @@ class TestPostgres(Validator):
dialect = "postgres"
def test_postgres(self):
+ self.validate_all(
+ "x ? y",
+ write={
+ "": "JSONB_CONTAINS(x, y)",
+ "postgres": "x ? y",
+ },
+ )
+
self.validate_identity("SHA384(x)")
self.validate_identity(
'CREATE TABLE x (a TEXT COLLATE "de_DE")', "CREATE TABLE x (a TEXT COLLATE de_DE)"
@@ -68,10 +76,6 @@ class TestPostgres(Validator):
self.validate_identity("SELECT CURRENT_USER")
self.validate_identity("SELECT * FROM ONLY t1")
self.validate_identity(
- "SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]",
- "SELECT ARRAY[1, 2] @> ARRAY[1, 2, 3]",
- )
- self.validate_identity(
"""UPDATE "x" SET "y" = CAST('0 days 60.000000 seconds' AS INTERVAL) WHERE "x"."id" IN (2, 3)"""
)
self.validate_identity(
@@ -128,6 +132,14 @@ class TestPostgres(Validator):
"ORDER BY 2, 3"
)
self.validate_identity(
+ "/*+ some comment*/ SELECT b.foo, b.bar FROM baz AS b",
+ "/* + some comment */ SELECT b.foo, b.bar FROM baz AS b",
+ )
+ self.validate_identity(
+ "SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]",
+ "SELECT ARRAY[1, 2] @> ARRAY[1, 2, 3]",
+ )
+ self.validate_identity(
"SELECT ARRAY[]::INT[] AS foo",
"SELECT CAST(ARRAY[] AS INT[]) AS foo",
)
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index ebb270a..dbe3abc 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -581,6 +581,13 @@ class TestPresto(Validator):
)
def test_presto(self):
+ self.assertEqual(
+ exp.func("md5", exp.func("concat", exp.cast("x", "text"), exp.Literal.string("s"))).sql(
+ dialect="presto"
+ ),
+ "LOWER(TO_HEX(MD5(TO_UTF8(CONCAT(CAST(x AS VARCHAR), CAST('s' AS VARCHAR))))))",
+ )
+
with self.assertLogs(helper_logger):
self.validate_all(
"SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table",
@@ -1192,3 +1199,18 @@ MATCH_RECOGNIZE (
"starrocks": "SIGN(x)",
},
)
+
+ def test_json_vs_row_extract(self):
+ for dialect in ("trino", "presto"):
+ s = parse_one('SELECT col:x:y."special string"', read="snowflake")
+
+ dialect_json_extract_setting = f"{dialect}, variant_extract_is_json_extract=True"
+ dialect_row_access_setting = f"{dialect}, variant_extract_is_json_extract=False"
+
+ # By default, Snowflake VARIANT will generate JSON_EXTRACT() in Presto/Trino
+ json_extract_result = """SELECT JSON_EXTRACT(col, '$.x.y["special string"]')"""
+ self.assertEqual(s.sql(dialect), json_extract_result)
+ self.assertEqual(s.sql(dialect_json_extract_setting), json_extract_result)
+
+ # If the setting is overriden to False, then generate ROW access (dot notation)
+ self.assertEqual(s.sql(dialect_row_access_setting), 'SELECT col.x.y."special string"')
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index 69793c7..c4e7073 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -28,7 +28,7 @@ class TestRedshift(Validator):
"""SELECT JSON_EXTRACT_PATH_TEXT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', 'farm', 'barn', 'color')""",
write={
"bigquery": """SELECT JSON_EXTRACT_SCALAR('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', '$.farm.barn.color')""",
- "databricks": """SELECT GET_JSON_OBJECT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', '$.farm.barn.color')""",
+ "databricks": """SELECT '{ "farm": {"barn": { "color": "red", "feed stocked": true }}}':farm.barn.color""",
"duckdb": """SELECT '{ "farm": {"barn": { "color": "red", "feed stocked": true }}}' ->> '$.farm.barn.color'""",
"postgres": """SELECT JSON_EXTRACT_PATH_TEXT('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', 'farm', 'barn', 'color')""",
"presto": """SELECT JSON_EXTRACT_SCALAR('{ "farm": {"barn": { "color": "red", "feed stocked": true }}}', '$.farm.barn.color')""",
@@ -228,7 +228,7 @@ class TestRedshift(Validator):
"drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
"hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
"mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
- "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) _t WHERE _row_number = 1",
+ "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) _t WHERE _row_number = 1",
"presto": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
"redshift": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
"snowflake": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
@@ -259,6 +259,12 @@ class TestRedshift(Validator):
"postgres": "COALESCE(a, b, c, d)",
},
)
+
+ self.validate_identity(
+ "DATEDIFF(days, a, b)",
+ "DATEDIFF(DAY, a, b)",
+ )
+
self.validate_all(
"DATEDIFF('day', a, b)",
write={
@@ -300,6 +306,14 @@ class TestRedshift(Validator):
},
)
+ self.validate_all(
+ "SELECT EXTRACT(EPOCH FROM CURRENT_DATE)",
+ write={
+ "snowflake": "SELECT DATE_PART(EPOCH, CURRENT_DATE)",
+ "redshift": "SELECT EXTRACT(EPOCH FROM CURRENT_DATE)",
+ },
+ )
+
def test_identity(self):
self.validate_identity("LISTAGG(DISTINCT foo, ', ')")
self.validate_identity("CREATE MATERIALIZED VIEW orders AUTO REFRESH YES AS SELECT 1")
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 1286436..88b2148 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -11,6 +11,12 @@ class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
+ self.assertEqual(
+ # Ensures we don't fail when generating ParseJSON with the `safe` arg set to `True`
+ self.validate_identity("""SELECT TRY_PARSE_JSON('{"x: 1}')""").sql(),
+ """SELECT PARSE_JSON('{"x: 1}')""",
+ )
+
self.validate_identity(
"transform(x, a int -> a + a + 1)",
"TRANSFORM(x, a -> CAST(a AS INT) + CAST(a AS INT) + 1)",
@@ -49,6 +55,8 @@ WHERE
)""",
)
+ self.validate_identity("SELECT CAST([1, 2, 3] AS VECTOR(FLOAT, 3))")
+ self.validate_identity("SELECT CONNECT_BY_ROOT test AS test_column_alias")
self.validate_identity("SELECT number").selects[0].assert_is(exp.Column)
self.validate_identity("INTERVAL '4 years, 5 months, 3 hours'")
self.validate_identity("ALTER TABLE table1 CLUSTER BY (name DESC)")
@@ -183,18 +191,6 @@ WHERE
"""SELECT CAST(GET_PATH(PARSE_JSON('{"food":{"fruit":"banana"}}'), 'food.fruit') AS VARCHAR)""",
)
self.validate_identity(
- "SELECT * FROM foo at",
- "SELECT * FROM foo AS at",
- )
- self.validate_identity(
- "SELECT * FROM foo before",
- "SELECT * FROM foo AS before",
- )
- self.validate_identity(
- "SELECT * FROM foo at (col)",
- "SELECT * FROM foo AS at(col)",
- )
- self.validate_identity(
"SELECT * FROM unnest(x) with ordinality",
"SELECT * FROM TABLE(FLATTEN(INPUT => x)) AS _u(seq, key, path, index, value, this)",
)
@@ -337,7 +333,7 @@ WHERE
"""SELECT PARSE_JSON('{"fruit":"banana"}'):fruit""",
write={
"bigquery": """SELECT JSON_EXTRACT(PARSE_JSON('{"fruit":"banana"}'), '$.fruit')""",
- "databricks": """SELECT GET_JSON_OBJECT('{"fruit":"banana"}', '$.fruit')""",
+ "databricks": """SELECT '{"fruit":"banana"}':fruit""",
"duckdb": """SELECT JSON('{"fruit":"banana"}') -> '$.fruit'""",
"mysql": """SELECT JSON_EXTRACT('{"fruit":"banana"}', '$.fruit')""",
"presto": """SELECT JSON_EXTRACT(JSON_PARSE('{"fruit":"banana"}'), '$.fruit')""",
@@ -1196,6 +1192,17 @@ WHERE
"SELECT oldt.*, newt.* FROM my_table BEFORE (STATEMENT => '8e5d0ca9-005e-44e6-b858-a8f5b37c5726') AS oldt FULL OUTER JOIN my_table AT (STATEMENT => '8e5d0ca9-005e-44e6-b858-a8f5b37c5726') AS newt ON oldt.id = newt.id WHERE oldt.id IS NULL OR newt.id IS NULL",
)
+ # Make sure that the historical data keywords can still be used as aliases
+ for historical_data_prefix in ("AT", "BEFORE", "END", "CHANGES"):
+ for schema_suffix in ("", "(col)"):
+ with self.subTest(
+ f"Testing historical data prefix alias: {historical_data_prefix}{schema_suffix}"
+ ):
+ self.validate_identity(
+ f"SELECT * FROM foo {historical_data_prefix}{schema_suffix}",
+ f"SELECT * FROM foo AS {historical_data_prefix}{schema_suffix}",
+ )
+
def test_ddl(self):
for constraint_prefix in ("WITH ", ""):
with self.subTest(f"Constraint prefix: {constraint_prefix}"):
@@ -1216,6 +1223,7 @@ WHERE
"CREATE TABLE t (id INT TAG (key1='value_1', key2='value_2'))",
)
+ self.validate_identity("CREATE SECURE VIEW table1 AS (SELECT a FROM table2)")
self.validate_identity(
"""create external table et2(
col1 date as (parse_json(metadata$external_table_partition):COL1::date),
@@ -1241,6 +1249,9 @@ WHERE
"CREATE OR REPLACE TAG IF NOT EXISTS cost_center COMMENT='cost_center tag'"
).this.assert_is(exp.Identifier)
self.validate_identity(
+ "CREATE DYNAMIC TABLE product (pre_tax_profit, taxes, after_tax_profit) TARGET_LAG='20 minutes' WAREHOUSE=mywh AS SELECT revenue - cost, (revenue - cost) * tax_rate, (revenue - cost) * (1.0 - tax_rate) FROM staging_table"
+ )
+ self.validate_identity(
"ALTER TABLE db_name.schmaName.tblName ADD COLUMN COLUMN_1 VARCHAR NOT NULL TAG (key1='value_1')"
)
self.validate_identity(
@@ -2021,3 +2032,15 @@ SINGLE = TRUE""",
self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c")
self.validate_identity("ALTER TABLE foo UNSET DATA_RETENTION_TIME_IN_DAYS, CHANGE_TRACKING")
+
+ def test_from_changes(self):
+ self.validate_identity(
+ """SELECT C1 FROM t1 CHANGES (INFORMATION => APPEND_ONLY) AT (STREAM => 's1') END (TIMESTAMP => $ts2)"""
+ )
+ self.validate_identity(
+ """SELECT C1 FROM t1 CHANGES (INFORMATION => APPEND_ONLY) BEFORE (STATEMENT => 'STMT_ID') END (TIMESTAMP => $ts2)"""
+ )
+ self.validate_identity(
+ """SELECT 1 FROM some_table CHANGES (INFORMATION => APPEND_ONLY) AT (TIMESTAMP => TO_TIMESTAMP_TZ('2024-07-01 00:00:00+00:00')) END (TIMESTAMP => TO_TIMESTAMP_TZ('2024-07-01 14:28:59.999999+00:00'))""",
+ """SELECT 1 FROM some_table CHANGES (INFORMATION => APPEND_ONLY) AT (TIMESTAMP => CAST('2024-07-01 00:00:00+00:00' AS TIMESTAMPTZ)) END (TIMESTAMP => CAST('2024-07-01 14:28:59.999999+00:00' AS TIMESTAMPTZ))""",
+ )
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index bff91bf..4e62b32 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -245,7 +245,7 @@ TBLPROPERTIES (
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)")
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)")
self.validate_identity("REFRESH TABLE a.b.c")
- self.validate_identity("INTERVAL -86 DAYS")
+ self.validate_identity("INTERVAL '-86' DAYS")
self.validate_identity("TRIM(' SparkSQL ')")
self.validate_identity("TRIM(BOTH 'SL' FROM 'SSparkSQLS')")
self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')")
@@ -801,3 +801,10 @@ TBLPROPERTIES (
self.assertEqual(query.sql(name), with_modifiers)
else:
self.assertEqual(query.sql(name), without_modifiers)
+
+ def test_schema_binding_options(self):
+ for schema_binding in ("BINDING", "COMPENSATION", "TYPE EVOLUTION", "EVOLUTION"):
+ with self.subTest(f"Test roundtrip of VIEW schema binding {schema_binding}"):
+ self.validate_identity(
+ f"CREATE VIEW emp_v WITH SCHEMA {schema_binding} AS SELECT * FROM emp"
+ )
diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py
index 74d5f88..3945ca3 100644
--- a/tests/dialects/test_teradata.py
+++ b/tests/dialects/test_teradata.py
@@ -5,6 +5,13 @@ class TestTeradata(Validator):
dialect = "teradata"
def test_teradata(self):
+ self.validate_all(
+ "RANDOM(l, u)",
+ write={
+ "": "(u - l) * RAND() + l",
+ "teradata": "RANDOM(l, u)",
+ },
+ )
self.validate_identity("TO_NUMBER(expr, fmt, nlsparam)")
self.validate_identity("SELECT TOP 10 * FROM tbl")
self.validate_identity("SELECT * FROM tbl SAMPLE 5")
@@ -212,6 +219,8 @@ class TestTeradata(Validator):
)
def test_time(self):
+ self.validate_identity("CAST(CURRENT_TIMESTAMP(6) AS TIMESTAMP WITH TIME ZONE)")
+
self.validate_all(
"CURRENT_TIMESTAMP",
read={
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index 7455650..11d60e7 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -1,4 +1,4 @@
-from sqlglot import exp, parse
+from sqlglot import exp, parse, parse_one
from tests.dialects.test_dialect import Validator
from sqlglot.errors import ParseError
from sqlglot.optimizer.annotate_types import annotate_types
@@ -8,19 +8,14 @@ class TestTSQL(Validator):
dialect = "tsql"
def test_tsql(self):
- self.assertEqual(
- annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"),
- "SELECT 1 WHERE EXISTS(SELECT 1)",
- )
+ # https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN
+ # tsql allows .. which means use the default schema
+ self.validate_identity("SELECT * FROM a..b")
self.validate_identity("CREATE view a.b.c", "CREATE VIEW b.c")
self.validate_identity("DROP view a.b.c", "DROP VIEW b.c")
self.validate_identity("ROUND(x, 1, 0)")
self.validate_identity("EXEC MyProc @id=7, @name='Lochristi'", check_command_warning=True)
- # https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN
- # tsql allows .. which means use the default schema
- self.validate_identity("SELECT * FROM a..b")
-
self.validate_identity("SELECT TRIM(' test ') AS Result")
self.validate_identity("SELECT TRIM('.,! ' FROM ' # test .') AS Result")
self.validate_identity("SELECT * FROM t TABLESAMPLE (10 PERCENT)")
@@ -37,8 +32,21 @@ class TestTSQL(Validator):
self.validate_identity("CAST(x AS int) OR y", "CAST(x AS INTEGER) <> 0 OR y <> 0")
self.validate_identity("TRUNCATE TABLE t1 WITH (PARTITIONS(1, 2 TO 5, 10 TO 20, 84))")
self.validate_identity(
+ "CREATE CLUSTERED INDEX [IX_OfficeTagDetail_TagDetailID] ON [dbo].[OfficeTagDetail]([TagDetailID] ASC)"
+ )
+ self.validate_identity(
+ "CREATE INDEX [x] ON [y]([z] ASC) WITH (allow_page_locks=on) ON X([y])"
+ )
+ self.validate_identity(
+ "CREATE INDEX [x] ON [y]([z] ASC) WITH (allow_page_locks=on) ON PRIMARY"
+ )
+ self.validate_identity(
"COPY INTO test_1 FROM 'path' WITH (FORMAT_NAME = test, FILE_TYPE = 'CSV', CREDENTIAL = (IDENTITY='Shared Access Signature', SECRET='token'), FIELDTERMINATOR = ';', ROWTERMINATOR = '0X0A', ENCODING = 'UTF8', DATEFORMAT = 'ymd', MAXERRORS = 10, ERRORFILE = 'errorsfolder', IDENTITY_INSERT = 'ON')"
)
+ self.assertEqual(
+ annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"),
+ "SELECT 1 WHERE EXISTS(SELECT 1)",
+ )
self.validate_all(
"SELECT IIF(cond <> 0, 'True', 'False')",
@@ -1868,3 +1876,25 @@ FROM OPENJSON(@json) WITH (
"DECLARE vendor_cursor CURSOR FOR SELECT VendorID, Name FROM Purchasing.Vendor WHERE PreferredVendorStatus = 1 ORDER BY VendorID",
check_command_warning=True,
)
+
+ def test_scope_resolution_op(self):
+ # we still want to support :: casting shorthand for tsql
+ self.validate_identity("x::int", "CAST(x AS INTEGER)")
+ self.validate_identity("x::varchar", "CAST(x AS VARCHAR)")
+ self.validate_identity("x::varchar(MAX)", "CAST(x AS VARCHAR(MAX))")
+
+ for lhs, rhs in (
+ ("", "FOO(a, b)"),
+ ("bar", "baZ(1, 2)"),
+ ("LOGIN", "EricKurjan"),
+ ("GEOGRAPHY", "Point(latitude, longitude, 4326)"),
+ (
+ "GEOGRAPHY",
+ "STGeomFromText('POLYGON((-122.358 47.653 , -122.348 47.649, -122.348 47.658, -122.358 47.658, -122.358 47.653))', 4326)",
+ ),
+ ):
+ with self.subTest(f"Scope resolution, LHS: {lhs}, RHS: {rhs}"):
+ expr = self.validate_identity(f"{lhs}::{rhs}")
+ base_sql = expr.sql()
+ self.assertEqual(base_sql, f"SCOPE_RESOLUTION({lhs + ', ' if lhs else ''}{rhs})")
+ self.assertEqual(parse_one(base_sql).sql("tsql"), f"{lhs}::{rhs}")