summaryrefslogtreecommitdiffstats
path: root/tests/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects')
-rw-r--r--tests/dialects/test_bigquery.py105
-rw-r--r--tests/dialects/test_clickhouse.py53
-rw-r--r--tests/dialects/test_databricks.py1
-rw-r--r--tests/dialects/test_dialect.py116
-rw-r--r--tests/dialects/test_doris.py1
-rw-r--r--tests/dialects/test_duckdb.py50
-rw-r--r--tests/dialects/test_hive.py15
-rw-r--r--tests/dialects/test_mysql.py15
-rw-r--r--tests/dialects/test_oracle.py14
-rw-r--r--tests/dialects/test_postgres.py52
-rw-r--r--tests/dialects/test_presto.py86
-rw-r--r--tests/dialects/test_redshift.py40
-rw-r--r--tests/dialects/test_snowflake.py67
-rw-r--r--tests/dialects/test_spark.py57
-rw-r--r--tests/dialects/test_teradata.py12
-rw-r--r--tests/dialects/test_tsql.py111
16 files changed, 705 insertions, 90 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index 52f86bd..b776bdd 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -6,8 +6,36 @@ from tests.dialects.test_dialect import Validator
class TestBigQuery(Validator):
dialect = "bigquery"
+ maxDiff = None
def test_bigquery(self):
+ self.validate_identity("""SELECT JSON '"foo"' AS json_data""")
+ self.validate_identity("SELECT * FROM tbl FOR SYSTEM_TIME AS OF z")
+
+ self.validate_all(
+ """SELECT
+ `u`.`harness_user_email` AS `harness_user_email`,
+ `d`.`harness_user_id` AS `harness_user_id`,
+ `harness_account_id` AS `harness_account_id`
+FROM `analytics_staging`.`stg_mongodb__users` AS `u`, UNNEST(`u`.`harness_cluster_details`) AS `d`, UNNEST(`d`.`harness_account_ids`) AS `harness_account_id`
+WHERE
+ NOT `harness_account_id` IS NULL""",
+ read={
+ "": """
+ SELECT
+ "u"."harness_user_email" AS "harness_user_email",
+ "_q_0"."d"."harness_user_id" AS "harness_user_id",
+ "_q_1"."harness_account_id" AS "harness_account_id"
+ FROM
+ "analytics_staging"."stg_mongodb__users" AS "u",
+ UNNEST("u"."harness_cluster_details") AS "_q_0"("d"),
+ UNNEST("_q_0"."d"."harness_account_ids") AS "_q_1"("harness_account_id")
+ WHERE
+ NOT "_q_1"."harness_account_id" IS NULL
+ """
+ },
+ pretty=True,
+ )
with self.assertRaises(TokenError):
transpile("'\\'", read="bigquery")
@@ -57,6 +85,10 @@ class TestBigQuery(Validator):
self.validate_identity("SELECT * FROM my-table")
self.validate_identity("SELECT * FROM my-project.mydataset.mytable")
self.validate_identity("SELECT * FROM pro-ject_id.c.d CROSS JOIN foo-bar")
+ self.validate_identity("SELECT * FROM foo.bar.25", "SELECT * FROM foo.bar.`25`")
+ self.validate_identity("SELECT * FROM foo.bar.25_", "SELECT * FROM foo.bar.`25_`")
+ self.validate_identity("SELECT * FROM foo.bar.25x a", "SELECT * FROM foo.bar.`25x` AS a")
+ self.validate_identity("SELECT * FROM foo.bar.25ab c", "SELECT * FROM foo.bar.`25ab` AS c")
self.validate_identity("x <> ''")
self.validate_identity("DATE_TRUNC(col, WEEK(MONDAY))")
self.validate_identity("SELECT b'abc'")
@@ -105,6 +137,34 @@ class TestBigQuery(Validator):
self.validate_all('x <> """"""', write={"bigquery": "x <> ''"})
self.validate_all("x <> ''''''", write={"bigquery": "x <> ''"})
self.validate_all("CAST(x AS DATETIME)", read={"": "x::timestamp"})
+ self.validate_all(
+ "SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)",
+ write={
+ "bigquery": "SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)",
+ "databricks": "SELECT TIMESTAMPDIFF(MILLISECOND, '2023-01-01T05:00:00', '2023-01-01T00:00:00')",
+ },
+ ),
+ self.validate_all(
+ "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)",
+ write={
+ "bigquery": "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)",
+ "databricks": "SELECT TIMESTAMPADD(MILLISECOND, 1, '2023-01-01T00:00:00')",
+ },
+ ),
+ self.validate_all(
+ "SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)",
+ write={
+ "bigquery": "SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)",
+ "databricks": "SELECT TIMESTAMPADD(MILLISECOND, 1 * -1, '2023-01-01T00:00:00')",
+ },
+ ),
+ self.validate_all(
+ "SELECT DATETIME_TRUNC('2023-01-01T01:01:01', HOUR)",
+ write={
+ "bigquery": "SELECT DATETIME_TRUNC('2023-01-01T01:01:01', HOUR)",
+ "databricks": "SELECT DATE_TRUNC('HOUR', '2023-01-01T01:01:01')",
+ },
+ ),
self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"})
self.validate_all("CAST(x AS CHAR)", write={"bigquery": "CAST(x AS STRING)"})
self.validate_all("CAST(x AS NCHAR)", write={"bigquery": "CAST(x AS STRING)"})
@@ -141,6 +201,20 @@ class TestBigQuery(Validator):
},
)
self.validate_all(
+ "SHA256(x)",
+ write={
+ "bigquery": "SHA256(x)",
+ "spark2": "SHA2(x, 256)",
+ },
+ )
+ self.validate_all(
+ "SHA512(x)",
+ write={
+ "bigquery": "SHA512(x)",
+ "spark2": "SHA2(x, 512)",
+ },
+ )
+ self.validate_all(
"SELECT CAST('20201225' AS TIMESTAMP FORMAT 'YYYYMMDD' AT TIME ZONE 'America/New_York')",
write={"bigquery": "SELECT PARSE_TIMESTAMP('%Y%m%d', '20201225', 'America/New_York')"},
)
@@ -249,7 +323,7 @@ class TestBigQuery(Validator):
self.validate_all(
"r'x\\y'",
write={
- "bigquery": "'x\\\y'",
+ "bigquery": "'x\\\\y'",
"hive": "'x\\\\y'",
},
)
@@ -329,14 +403,14 @@ class TestBigQuery(Validator):
self.validate_all(
"[1, 2, 3]",
read={
- "duckdb": "LIST_VALUE(1, 2, 3)",
+ "duckdb": "[1, 2, 3]",
"presto": "ARRAY[1, 2, 3]",
"hive": "ARRAY(1, 2, 3)",
"spark": "ARRAY(1, 2, 3)",
},
write={
"bigquery": "[1, 2, 3]",
- "duckdb": "LIST_VALUE(1, 2, 3)",
+ "duckdb": "[1, 2, 3]",
"presto": "ARRAY[1, 2, 3]",
"hive": "ARRAY(1, 2, 3)",
"spark": "ARRAY(1, 2, 3)",
@@ -710,3 +784,28 @@ class TestBigQuery(Validator):
"WITH cte AS (SELECT 1 AS foo UNION ALL SELECT 2) SELECT foo FROM cte",
read={"postgres": "WITH cte(foo) AS (SELECT 1 UNION ALL SELECT 2) SELECT foo FROM cte"},
)
+
+ def test_json_object(self):
+ self.validate_identity("SELECT JSON_OBJECT() AS json_data")
+ self.validate_identity("SELECT JSON_OBJECT('foo', 10, 'bar', TRUE) AS json_data")
+ self.validate_identity("SELECT JSON_OBJECT('foo', 10, 'bar', ['a', 'b']) AS json_data")
+ self.validate_identity("SELECT JSON_OBJECT('a', 10, 'a', 'foo') AS json_data")
+ self.validate_identity(
+ "SELECT JSON_OBJECT(['a', 'b'], [10, NULL]) AS json_data",
+ "SELECT JSON_OBJECT('a', 10, 'b', NULL) AS json_data",
+ )
+ self.validate_identity(
+ """SELECT JSON_OBJECT(['a', 'b'], [JSON '10', JSON '"foo"']) AS json_data""",
+ """SELECT JSON_OBJECT('a', JSON '10', 'b', JSON '"foo"') AS json_data""",
+ )
+ self.validate_identity(
+ "SELECT JSON_OBJECT(['a', 'b'], [STRUCT(10 AS id, 'Red' AS color), STRUCT(20 AS id, 'Blue' AS color)]) AS json_data",
+ "SELECT JSON_OBJECT('a', STRUCT(10 AS id, 'Red' AS color), 'b', STRUCT(20 AS id, 'Blue' AS color)) AS json_data",
+ )
+ self.validate_identity(
+ "SELECT JSON_OBJECT(['a', 'b'], [TO_JSON(10), TO_JSON(['foo', 'bar'])]) AS json_data",
+ "SELECT JSON_OBJECT('a', TO_JSON(10), 'b', TO_JSON(['foo', 'bar'])) AS json_data",
+ )
+
+ with self.assertRaises(ParseError):
+ transpile("SELECT JSON_OBJECT('a', 1, 'b') AS json_data", read="bigquery")
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index 583be3e..ab2379d 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -6,6 +6,31 @@ class TestClickhouse(Validator):
dialect = "clickhouse"
def test_clickhouse(self):
+ self.validate_all(
+ "DATE_ADD('day', 1, x)",
+ read={
+ "clickhouse": "dateAdd(day, 1, x)",
+ "presto": "DATE_ADD('day', 1, x)",
+ },
+ write={
+ "clickhouse": "DATE_ADD('day', 1, x)",
+ "presto": "DATE_ADD('day', 1, x)",
+ "": "DATE_ADD(x, 1, 'day')",
+ },
+ )
+ self.validate_all(
+ "DATE_DIFF('day', a, b)",
+ read={
+ "clickhouse": "dateDiff('day', a, b)",
+ "presto": "DATE_DIFF('day', a, b)",
+ },
+ write={
+ "clickhouse": "DATE_DIFF('day', a, b)",
+ "presto": "DATE_DIFF('day', a, b)",
+ "": "DATEDIFF(b, a, day)",
+ },
+ )
+
expr = parse_one("count(x)")
self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)")
self.assertIsNone(expr._meta)
@@ -47,8 +72,10 @@ class TestClickhouse(Validator):
self.validate_identity("position(haystack, needle)")
self.validate_identity("position(haystack, needle, position)")
self.validate_identity("CAST(x AS DATETIME)")
+ self.validate_identity("CAST(x AS VARCHAR(255))", "CAST(x AS String)")
+ self.validate_identity("CAST(x AS BLOB)", "CAST(x AS String)")
self.validate_identity(
- 'SELECT CAST(tuple(1 AS "a", 2 AS "b", 3.0 AS "c").2 AS Nullable(TEXT))'
+ 'SELECT CAST(tuple(1 AS "a", 2 AS "b", 3.0 AS "c").2 AS Nullable(String))'
)
self.validate_identity(
"CREATE TABLE test (id UInt8) ENGINE=AggregatingMergeTree() ORDER BY tuple()"
@@ -95,11 +122,11 @@ class TestClickhouse(Validator):
},
)
self.validate_all(
- "CONCAT(CASE WHEN COALESCE(CAST(a AS TEXT), '') IS NULL THEN COALESCE(CAST(a AS TEXT), '') ELSE CAST(COALESCE(CAST(a AS TEXT), '') AS TEXT) END, CASE WHEN COALESCE(CAST(b AS TEXT), '') IS NULL THEN COALESCE(CAST(b AS TEXT), '') ELSE CAST(COALESCE(CAST(b AS TEXT), '') AS TEXT) END)",
+ "CONCAT(CASE WHEN COALESCE(CAST(a AS String), '') IS NULL THEN COALESCE(CAST(a AS String), '') ELSE CAST(COALESCE(CAST(a AS String), '') AS String) END, CASE WHEN COALESCE(CAST(b AS String), '') IS NULL THEN COALESCE(CAST(b AS String), '') ELSE CAST(COALESCE(CAST(b AS String), '') AS String) END)",
read={"postgres": "CONCAT(a, b)"},
)
self.validate_all(
- "CONCAT(CASE WHEN a IS NULL THEN a ELSE CAST(a AS TEXT) END, CASE WHEN b IS NULL THEN b ELSE CAST(b AS TEXT) END)",
+ "CONCAT(CASE WHEN a IS NULL THEN a ELSE CAST(a AS String) END, CASE WHEN b IS NULL THEN b ELSE CAST(b AS String) END)",
read={"mysql": "CONCAT(a, b)"},
)
self.validate_all(
@@ -233,7 +260,7 @@ class TestClickhouse(Validator):
self.validate_all(
"SELECT {abc: UInt32}, {b: String}, {c: DateTime},{d: Map(String, Array(UInt8))}, {e: Tuple(UInt8, String)}",
write={
- "clickhouse": "SELECT {abc: UInt32}, {b: TEXT}, {c: DATETIME}, {d: Map(TEXT, Array(UInt8))}, {e: Tuple(UInt8, String)}",
+ "clickhouse": "SELECT {abc: UInt32}, {b: String}, {c: DATETIME}, {d: Map(String, Array(UInt8))}, {e: Tuple(UInt8, String)}",
"": "SELECT :abc, :b, :c, :d, :e",
},
)
@@ -283,8 +310,8 @@ class TestClickhouse(Validator):
"clickhouse": """CREATE TABLE example1 (
timestamp DATETIME,
x UInt32 TTL now() + INTERVAL '1' MONTH,
- y TEXT TTL timestamp + INTERVAL '1' DAY,
- z TEXT
+ y String TTL timestamp + INTERVAL '1' DAY,
+ z String
)
ENGINE=MergeTree
ORDER BY tuple()""",
@@ -305,7 +332,7 @@ ORDER BY tuple()""",
"clickhouse": """CREATE TABLE test (
id UInt64,
timestamp DateTime64,
- data TEXT,
+ data String,
max_hits UInt64,
sum_hits UInt64
)
@@ -332,8 +359,8 @@ SET
""",
write={
"clickhouse": """CREATE TABLE test (
- id TEXT,
- data TEXT
+ id String,
+ data String
)
ENGINE=AggregatingMergeTree()
ORDER BY tuple()
@@ -416,7 +443,7 @@ WHERE
"clickhouse": """CREATE TABLE table_for_recompression (
d DATETIME,
key UInt64,
- value TEXT
+ value String
)
ENGINE=MergeTree()
ORDER BY tuple()
@@ -512,9 +539,9 @@ RANGE(MIN discount_start_date MAX discount_end_date)""",
""",
write={
"clickhouse": """CREATE DICTIONARY my_ip_trie_dictionary (
- prefix TEXT,
+ prefix String,
asn UInt32,
- cca2 TEXT DEFAULT '??'
+ cca2 String DEFAULT '??'
)
PRIMARY KEY (prefix)
SOURCE(CLICKHOUSE(
@@ -540,7 +567,7 @@ LIFETIME(MIN 0 MAX 3600)""",
write={
"clickhouse": """CREATE DICTIONARY polygons_test_dictionary (
key Array(Array(Array(Tuple(Float64, Float64)))),
- name TEXT
+ name String
)
PRIMARY KEY (key)
SOURCE(CLICKHOUSE(
diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py
index 38a7952..f13d0f2 100644
--- a/tests/dialects/test_databricks.py
+++ b/tests/dialects/test_databricks.py
@@ -5,6 +5,7 @@ class TestDatabricks(Validator):
dialect = "databricks"
def test_databricks(self):
+ self.validate_identity("CREATE TABLE target SHALLOW CLONE source")
self.validate_identity("INSERT INTO a REPLACE WHERE cond VALUES (1), (2)")
self.validate_identity("SELECT c1 : price")
self.validate_identity("CREATE FUNCTION a.b(x INT) RETURNS INT RETURN x + 1")
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 63f789f..6a41218 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -1,6 +1,13 @@
import unittest
-from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one
+from sqlglot import (
+ Dialect,
+ Dialects,
+ ErrorLevel,
+ ParseError,
+ UnsupportedError,
+ parse_one,
+)
from sqlglot.dialects import Hive
@@ -23,9 +30,10 @@ class Validator(unittest.TestCase):
Args:
sql (str): Main SQL expression
- dialect (str): dialect of `sql`
read (dict): Mapping of dialect -> SQL
write (dict): Mapping of dialect -> SQL
+ pretty (bool): prettify both read and write
+ identify (bool): quote identifiers in both read and write
"""
expression = self.parse_one(sql)
@@ -78,7 +86,7 @@ class TestDialect(Validator):
"CAST(a AS TEXT)",
write={
"bigquery": "CAST(a AS STRING)",
- "clickhouse": "CAST(a AS TEXT)",
+ "clickhouse": "CAST(a AS String)",
"drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS CHAR)",
@@ -116,7 +124,7 @@ class TestDialect(Validator):
"CAST(a AS VARBINARY(4))",
write={
"bigquery": "CAST(a AS BYTES)",
- "clickhouse": "CAST(a AS VARBINARY(4))",
+ "clickhouse": "CAST(a AS String)",
"duckdb": "CAST(a AS BLOB(4))",
"mysql": "CAST(a AS VARBINARY(4))",
"hive": "CAST(a AS BINARY(4))",
@@ -133,7 +141,7 @@ class TestDialect(Validator):
self.validate_all(
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
write={
- "clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))",
+ "clickhouse": "CAST(map('a', '1') AS Map(String, String))",
},
)
self.validate_all(
@@ -367,6 +375,60 @@ class TestDialect(Validator):
},
)
+ def test_nvl2(self):
+ self.validate_all(
+ "SELECT NVL2(a, b, c)",
+ write={
+ "": "SELECT NVL2(a, b, c)",
+ "bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "databricks": "SELECT NVL2(a, b, c)",
+ "doris": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "drill": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "duckdb": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "hive": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "mysql": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "oracle": "SELECT NVL2(a, b, c)",
+ "postgres": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "presto": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "redshift": "SELECT NVL2(a, b, c)",
+ "snowflake": "SELECT NVL2(a, b, c)",
+ "spark": "SELECT NVL2(a, b, c)",
+ "spark2": "SELECT NVL2(a, b, c)",
+ "sqlite": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "starrocks": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "teradata": "SELECT NVL2(a, b, c)",
+ "trino": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ "tsql": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END",
+ },
+ )
+ self.validate_all(
+ "SELECT NVL2(a, b)",
+ write={
+ "": "SELECT NVL2(a, b)",
+ "bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "databricks": "SELECT NVL2(a, b)",
+ "doris": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "drill": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "duckdb": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "hive": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "mysql": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "oracle": "SELECT NVL2(a, b)",
+ "postgres": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "presto": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "redshift": "SELECT NVL2(a, b)",
+ "snowflake": "SELECT NVL2(a, b)",
+ "spark": "SELECT NVL2(a, b)",
+ "spark2": "SELECT NVL2(a, b)",
+ "sqlite": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "starrocks": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "teradata": "SELECT NVL2(a, b)",
+ "trino": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ "tsql": "SELECT CASE WHEN NOT a IS NULL THEN b END",
+ },
+ )
+
def test_time(self):
self.validate_all(
"STR_TO_TIME(x, '%Y-%m-%dT%H:%M:%S')",
@@ -860,7 +922,7 @@ class TestDialect(Validator):
"ARRAY(0, 1, 2)",
write={
"bigquery": "[0, 1, 2]",
- "duckdb": "LIST_VALUE(0, 1, 2)",
+ "duckdb": "[0, 1, 2]",
"presto": "ARRAY[0, 1, 2]",
"spark": "ARRAY(0, 1, 2)",
},
@@ -879,7 +941,7 @@ class TestDialect(Validator):
"ARRAY_SUM(ARRAY(1, 2))",
write={
"trino": "REDUCE(ARRAY[1, 2], 0, (acc, x) -> acc + x, acc -> acc)",
- "duckdb": "LIST_SUM(LIST_VALUE(1, 2))",
+ "duckdb": "LIST_SUM([1, 2])",
"hive": "ARRAY_SUM(ARRAY(1, 2))",
"presto": "ARRAY_SUM(ARRAY[1, 2])",
"spark": "AGGREGATE(ARRAY(1, 2), 0, (acc, x) -> acc + x, acc -> acc)",
@@ -1403,27 +1465,27 @@ class TestDialect(Validator):
},
)
self.validate_all(
- "CREATE INDEX my_idx ON tbl (a, b)",
+ "CREATE INDEX my_idx ON tbl(a, b)",
read={
- "hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)",
- "sqlite": "CREATE INDEX my_idx ON tbl (a, b)",
+ "hive": "CREATE INDEX my_idx ON TABLE tbl(a, b)",
+ "sqlite": "CREATE INDEX my_idx ON tbl(a, b)",
},
write={
- "hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)",
- "postgres": "CREATE INDEX my_idx ON tbl (a NULLS FIRST, b NULLS FIRST)",
- "sqlite": "CREATE INDEX my_idx ON tbl (a, b)",
+ "hive": "CREATE INDEX my_idx ON TABLE tbl(a, b)",
+ "postgres": "CREATE INDEX my_idx ON tbl(a NULLS FIRST, b NULLS FIRST)",
+ "sqlite": "CREATE INDEX my_idx ON tbl(a, b)",
},
)
self.validate_all(
- "CREATE UNIQUE INDEX my_idx ON tbl (a, b)",
+ "CREATE UNIQUE INDEX my_idx ON tbl(a, b)",
read={
- "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)",
- "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)",
+ "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl(a, b)",
+ "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl(a, b)",
},
write={
- "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)",
- "postgres": "CREATE UNIQUE INDEX my_idx ON tbl (a NULLS FIRST, b NULLS FIRST)",
- "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)",
+ "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl(a, b)",
+ "postgres": "CREATE UNIQUE INDEX my_idx ON tbl(a NULLS FIRST, b NULLS FIRST)",
+ "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl(a, b)",
},
)
self.validate_all(
@@ -1710,3 +1772,19 @@ SELECT
"tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo",
},
)
+
+ def test_cast_to_user_defined_type(self):
+ self.validate_all(
+ "CAST(x AS some_udt)",
+ write={
+ "": "CAST(x AS some_udt)",
+ "oracle": "CAST(x AS some_udt)",
+ "postgres": "CAST(x AS some_udt)",
+ "presto": "CAST(x AS some_udt)",
+ "teradata": "CAST(x AS some_udt)",
+ "tsql": "CAST(x AS some_udt)",
+ },
+ )
+
+ with self.assertRaises(ParseError):
+ parse_one("CAST(x AS some_udt)", read="bigquery")
diff --git a/tests/dialects/test_doris.py b/tests/dialects/test_doris.py
index 63325a6..9591269 100644
--- a/tests/dialects/test_doris.py
+++ b/tests/dialects/test_doris.py
@@ -5,6 +5,7 @@ class TestDoris(Validator):
dialect = "doris"
def test_identity(self):
+ self.validate_identity("COALECSE(a, b, c, d)")
self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo")
self.validate_identity("SELECT APPROX_COUNT_DISTINCT(a) FROM x")
diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py
index c33c899..aca0d7a 100644
--- a/tests/dialects/test_duckdb.py
+++ b/tests/dialects/test_duckdb.py
@@ -6,6 +6,9 @@ class TestDuckDB(Validator):
dialect = "duckdb"
def test_duckdb(self):
+ self.validate_identity("[x.STRING_SPLIT(' ')[1] FOR x IN ['1', '2', 3] IF x.CONTAINS('1')]")
+ self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y")
+ self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x")
self.validate_identity("SELECT SUM(x) FILTER (x = 1)", "SELECT SUM(x) FILTER(WHERE x = 1)")
# https://github.com/duckdb/duckdb/releases/tag/v0.8.0
@@ -50,6 +53,7 @@ class TestDuckDB(Validator):
"SELECT * FROM (PIVOT Cities ON Year USING SUM(Population) GROUP BY Country) AS pivot_alias"
)
+ self.validate_identity("FROM x SELECT x UNION SELECT 1", "SELECT x FROM x UNION SELECT 1")
self.validate_all("FROM (FROM tbl)", write={"duckdb": "SELECT * FROM (SELECT * FROM tbl)"})
self.validate_all("FROM tbl", write={"duckdb": "SELECT * FROM tbl"})
self.validate_all("0b1010", write={"": "0 AS b1010"})
@@ -123,20 +127,20 @@ class TestDuckDB(Validator):
},
)
self.validate_all(
- "LIST_VALUE(0, 1, 2)",
+ "[0, 1, 2]",
read={
"spark": "ARRAY(0, 1, 2)",
},
write={
"bigquery": "[0, 1, 2]",
- "duckdb": "LIST_VALUE(0, 1, 2)",
+ "duckdb": "[0, 1, 2]",
"presto": "ARRAY[0, 1, 2]",
"spark": "ARRAY(0, 1, 2)",
},
)
self.validate_all(
"SELECT ARRAY_LENGTH([0], 1) AS x",
- write={"duckdb": "SELECT ARRAY_LENGTH(LIST_VALUE(0), 1) AS x"},
+ write={"duckdb": "SELECT ARRAY_LENGTH([0], 1) AS x"},
)
self.validate_all(
"REGEXP_MATCHES(x, y)",
@@ -178,18 +182,18 @@ class TestDuckDB(Validator):
"STRUCT_EXTRACT(x, 'abc')",
write={
"duckdb": "STRUCT_EXTRACT(x, 'abc')",
- "presto": 'x."abc"',
- "hive": "x.`abc`",
- "spark": "x.`abc`",
+ "presto": "x.abc",
+ "hive": "x.abc",
+ "spark": "x.abc",
},
)
self.validate_all(
"STRUCT_EXTRACT(STRUCT_EXTRACT(x, 'y'), 'abc')",
write={
"duckdb": "STRUCT_EXTRACT(STRUCT_EXTRACT(x, 'y'), 'abc')",
- "presto": 'x."y"."abc"',
- "hive": "x.`y`.`abc`",
- "spark": "x.`y`.`abc`",
+ "presto": "x.y.abc",
+ "hive": "x.y.abc",
+ "spark": "x.y.abc",
},
)
self.validate_all(
@@ -226,7 +230,7 @@ class TestDuckDB(Validator):
},
)
self.validate_all(
- "LIST_SUM(LIST_VALUE(1, 2))",
+ "LIST_SUM([1, 2])",
read={
"spark": "ARRAY_SUM(ARRAY(1, 2))",
},
@@ -304,14 +308,20 @@ class TestDuckDB(Validator):
},
)
self.validate_all(
- "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))",
+ "ARRAY_CONCAT([1, 2], [3, 4])",
+ read={
+ "bigquery": "ARRAY_CONCAT([1, 2], [3, 4])",
+ "postgres": "ARRAY_CAT(ARRAY[1, 2], ARRAY[3, 4])",
+ "snowflake": "ARRAY_CAT([1, 2], [3, 4])",
+ },
write={
- "duckdb": "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))",
- "presto": "CONCAT(ARRAY[1, 2], ARRAY[3, 4])",
+ "bigquery": "ARRAY_CONCAT([1, 2], [3, 4])",
+ "duckdb": "ARRAY_CONCAT([1, 2], [3, 4])",
"hive": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))",
- "spark": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))",
+ "postgres": "ARRAY_CAT(ARRAY[1, 2], ARRAY[3, 4])",
+ "presto": "CONCAT(ARRAY[1, 2], ARRAY[3, 4])",
"snowflake": "ARRAY_CAT([1, 2], [3, 4])",
- "bigquery": "ARRAY_CONCAT([1, 2], [3, 4])",
+ "spark": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))",
},
)
self.validate_all(
@@ -502,6 +512,10 @@ class TestDuckDB(Validator):
self.validate_identity("CAST(x AS INT128)")
self.validate_identity("CAST(x AS DOUBLE)")
self.validate_identity("CAST(x AS DECIMAL(15, 4))")
+ self.validate_identity("CAST(x AS STRUCT(number BIGINT))")
+ self.validate_identity(
+ "CAST(ROW(1, ROW(1)) AS STRUCT(number BIGINT, row STRUCT(number BIGINT)))"
+ )
self.validate_all("CAST(x AS NUMERIC(1, 2))", write={"duckdb": "CAST(x AS DECIMAL(1, 2))"})
self.validate_all("CAST(x AS HUGEINT)", write={"duckdb": "CAST(x AS INT128)"})
@@ -552,7 +566,7 @@ class TestDuckDB(Validator):
self.validate_all(
"cast([[1]] as int[][])",
write={
- "duckdb": "CAST(LIST_VALUE(LIST_VALUE(1)) AS INT[][])",
+ "duckdb": "CAST([[1]] AS INT[][])",
"spark": "CAST(ARRAY(ARRAY(1)) AS ARRAY<ARRAY<INT>>)",
},
)
@@ -587,13 +601,13 @@ class TestDuckDB(Validator):
self.validate_all(
"CAST([STRUCT_PACK(a := 1)] AS STRUCT(a BIGINT)[])",
write={
- "duckdb": "CAST(LIST_VALUE({'a': 1}) AS STRUCT(a BIGINT)[])",
+ "duckdb": "CAST([{'a': 1}] AS STRUCT(a BIGINT)[])",
},
)
self.validate_all(
"CAST([[STRUCT_PACK(a := 1)]] AS STRUCT(a BIGINT)[][])",
write={
- "duckdb": "CAST(LIST_VALUE(LIST_VALUE({'a': 1})) AS STRUCT(a BIGINT)[][])",
+ "duckdb": "CAST([[{'a': 1}]] AS STRUCT(a BIGINT)[][])",
},
)
diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py
index 4c463f7..70a05fd 100644
--- a/tests/dialects/test_hive.py
+++ b/tests/dialects/test_hive.py
@@ -390,6 +390,13 @@ class TestHive(Validator):
)
def test_hive(self):
+ self.validate_identity("SELECT * FROM my_table TIMESTAMP AS OF DATE_ADD(CURRENT_DATE, -1)")
+ self.validate_identity("SELECT * FROM my_table VERSION AS OF DATE_ADD(CURRENT_DATE, -1)")
+
+ self.validate_identity(
+ "SELECT ROW() OVER (DISTRIBUTE BY x SORT BY y)",
+ "SELECT ROW() OVER (PARTITION BY x ORDER BY y)",
+ )
self.validate_identity("SELECT transform")
self.validate_identity("SELECT * FROM test DISTRIBUTE BY y SORT BY x DESC ORDER BY l")
self.validate_identity(
@@ -591,7 +598,7 @@ class TestHive(Validator):
read={
"": "VAR_MAP(a, b, c, d)",
"clickhouse": "map(a, b, c, d)",
- "duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))",
+ "duckdb": "MAP([a, c], [b, d])",
"hive": "MAP(a, b, c, d)",
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
"spark": "MAP(a, b, c, d)",
@@ -599,7 +606,7 @@ class TestHive(Validator):
write={
"": "MAP(ARRAY(a, c), ARRAY(b, d))",
"clickhouse": "map(a, b, c, d)",
- "duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))",
+ "duckdb": "MAP([a, c], [b, d])",
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
"hive": "MAP(a, b, c, d)",
"spark": "MAP(a, b, c, d)",
@@ -609,7 +616,7 @@ class TestHive(Validator):
self.validate_all(
"MAP(a, b)",
write={
- "duckdb": "MAP(LIST_VALUE(a), LIST_VALUE(b))",
+ "duckdb": "MAP([a], [b])",
"presto": "MAP(ARRAY[a], ARRAY[b])",
"hive": "MAP(a, b)",
"spark": "MAP(a, b)",
@@ -717,9 +724,7 @@ class TestHive(Validator):
self.validate_identity("'\\\\n'")
self.validate_identity("''")
self.validate_identity("'\\\\'")
- self.validate_identity("'\z'")
self.validate_identity("'\\z'")
- self.validate_identity("'\\\z'")
self.validate_identity("'\\\\z'")
def test_data_type(self):
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index d60f09d..fc63f9f 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -7,12 +7,16 @@ class TestMySQL(Validator):
def test_ddl(self):
self.validate_identity("CREATE TABLE foo (id BIGINT)")
+ self.validate_identity("CREATE TABLE 00f (1d BIGINT)")
self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10")
self.validate_identity("DELETE FROM t WHERE a <= 10 LIMIT 10")
self.validate_identity("CREATE TABLE foo (a BIGINT, INDEX USING BTREE (b))")
self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))")
self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))")
self.validate_identity(
+ "UPDATE items SET items.price = 0 WHERE items.id >= 5 ORDER BY items.id LIMIT 10"
+ )
+ self.validate_identity(
"CREATE TABLE foo (a BIGINT, INDEX b USING HASH (c) COMMENT 'd' VISIBLE ENGINE_ATTRIBUTE = 'e' WITH PARSER foo)"
)
self.validate_identity(
@@ -81,6 +85,9 @@ class TestMySQL(Validator):
)
def test_identity(self):
+ self.validate_identity(
+ "SELECT * FROM x ORDER BY BINARY a", "SELECT * FROM x ORDER BY CAST(a AS BINARY)"
+ )
self.validate_identity("SELECT 1 XOR 0")
self.validate_identity("SELECT 1 && 0", "SELECT 1 AND 0")
self.validate_identity("SELECT /*+ BKA(t1) NO_BKA(t2) */ * FROM t1 INNER JOIN t2")
@@ -171,8 +178,12 @@ class TestMySQL(Validator):
self.validate_identity(
"SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000"
)
+ self.validate_identity("INTERVAL '1' YEAR")
+ self.validate_identity("DATE_ADD(x, INTERVAL 1 YEAR)")
def test_types(self):
+ self.validate_identity("CAST(x AS MEDIUMINT) + CAST(y AS YEAR(4))")
+
self.validate_all(
"CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT)",
read={
@@ -353,6 +364,7 @@ class TestMySQL(Validator):
write={
"": "MATCH(col1, col2, col3) AGAINST('abc')",
"mysql": "MATCH(col1, col2, col3) AGAINST('abc')",
+ "postgres": "(col1 @@ 'abc' OR col2 @@ 'abc' OR col3 @@ 'abc')", # not quite correct because it's not ts_query
},
)
self.validate_all(
@@ -818,3 +830,6 @@ COMMENT='客户账户表'"""
cmd = self.parse_one("SET x = 1, y = 2")
self.assertEqual(len(cmd.expressions), 2)
+
+ def test_json_object(self):
+ self.validate_identity("SELECT JSON_OBJECT('id', 87, 'name', 'carrot')")
diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py
index 0c3b09f..01a9ca3 100644
--- a/tests/dialects/test_oracle.py
+++ b/tests/dialects/test_oracle.py
@@ -6,6 +6,11 @@ class TestOracle(Validator):
dialect = "oracle"
def test_oracle(self):
+ self.validate_identity("SELECT JSON_OBJECT('name': first_name || ' ' || last_name) FROM t")
+ self.validate_identity("COALESCE(c1, c2, c3)")
+ self.validate_identity("SELECT * FROM TABLE(foo)")
+ self.validate_identity("SELECT a$x#b")
+ self.validate_identity("SELECT :OBJECT")
self.validate_identity("SELECT * FROM t FOR UPDATE")
self.validate_identity("SELECT * FROM t FOR UPDATE WAIT 5")
self.validate_identity("SELECT * FROM t FOR UPDATE NOWAIT")
@@ -21,6 +26,9 @@ class TestOracle(Validator):
self.validate_identity("SELECT * FROM table_name SAMPLE (25) s")
self.validate_identity("SELECT * FROM V$SESSION")
self.validate_identity(
+ "SELECT COUNT(1) INTO V_Temp FROM TABLE(CAST(somelist AS data_list)) WHERE col LIKE '%contact'"
+ )
+ self.validate_identity(
"SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name"
)
self.validate_identity(
@@ -28,12 +36,16 @@ class TestOracle(Validator):
'OVER (PARTITION BY department_id) AS "Worst", MAX(salary) KEEP (DENSE_RANK LAST ORDER BY commission_pct) '
'OVER (PARTITION BY department_id) AS "Best" FROM employees ORDER BY department_id, salary, last_name'
)
+ self.validate_identity(
+ "SELECT UNIQUE col1, col2 FROM table",
+ "SELECT DISTINCT col1, col2 FROM table",
+ )
self.validate_all(
"NVL(NULL, 1)",
write={
"": "COALESCE(NULL, 1)",
- "oracle": "NVL(NULL, 1)",
+ "oracle": "COALESCE(NULL, 1)",
},
)
self.validate_all(
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index a7719a9..8740aca 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -126,6 +126,8 @@ class TestPostgres(Validator):
)
def test_postgres(self):
+ self.validate_identity("x @@ y")
+
expr = parse_one("SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)")
unnest = expr.args["joins"][0].this.this
unnest.assert_is(exp.Unnest)
@@ -535,6 +537,54 @@ class TestPostgres(Validator):
write={"postgres": "CAST(x AS CSTRING)"},
)
self.validate_all(
+ "x::oid",
+ write={"postgres": "CAST(x AS OID)"},
+ )
+ self.validate_all(
+ "x::regclass",
+ write={"postgres": "CAST(x AS REGCLASS)"},
+ )
+ self.validate_all(
+ "x::regcollation",
+ write={"postgres": "CAST(x AS REGCOLLATION)"},
+ )
+ self.validate_all(
+ "x::regconfig",
+ write={"postgres": "CAST(x AS REGCONFIG)"},
+ )
+ self.validate_all(
+ "x::regdictionary",
+ write={"postgres": "CAST(x AS REGDICTIONARY)"},
+ )
+ self.validate_all(
+ "x::regnamespace",
+ write={"postgres": "CAST(x AS REGNAMESPACE)"},
+ )
+ self.validate_all(
+ "x::regoper",
+ write={"postgres": "CAST(x AS REGOPER)"},
+ )
+ self.validate_all(
+ "x::regoperator",
+ write={"postgres": "CAST(x AS REGOPERATOR)"},
+ )
+ self.validate_all(
+ "x::regproc",
+ write={"postgres": "CAST(x AS REGPROC)"},
+ )
+ self.validate_all(
+ "x::regprocedure",
+ write={"postgres": "CAST(x AS REGPROCEDURE)"},
+ )
+ self.validate_all(
+ "x::regrole",
+ write={"postgres": "CAST(x AS REGROLE)"},
+ )
+ self.validate_all(
+ "x::regtype",
+ write={"postgres": "CAST(x AS REGTYPE)"},
+ )
+ self.validate_all(
"TRIM(BOTH 'as' FROM 'as string as')",
write={
"postgres": "TRIM(BOTH 'as' FROM 'as string as')",
@@ -606,7 +656,7 @@ class TestPostgres(Validator):
"a || b",
write={
"": "a || b",
- "clickhouse": "CONCAT(CAST(a AS TEXT), CAST(b AS TEXT))",
+ "clickhouse": "CONCAT(CAST(a AS String), CAST(b AS String))",
"duckdb": "a || b",
"postgres": "a || b",
"presto": "CONCAT(CAST(a AS VARCHAR), CAST(b AS VARCHAR))",
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index 5091540..dbca5b3 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -88,7 +88,7 @@ class TestPresto(Validator):
"CAST(ARRAY[1, 2] AS ARRAY(BIGINT))",
write={
"bigquery": "CAST([1, 2] AS ARRAY<INT64>)",
- "duckdb": "CAST(LIST_VALUE(1, 2) AS BIGINT[])",
+ "duckdb": "CAST([1, 2] AS BIGINT[])",
"presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))",
"spark": "CAST(ARRAY(1, 2) AS ARRAY<BIGINT>)",
"snowflake": "CAST([1, 2] AS ARRAY)",
@@ -98,7 +98,7 @@ class TestPresto(Validator):
"CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INT,INT))",
write={
"bigquery": "CAST(MAP([1], [1]) AS MAP<INT64, INT64>)",
- "duckdb": "CAST(MAP(LIST_VALUE(1), LIST_VALUE(1)) AS MAP(INT, INT))",
+ "duckdb": "CAST(MAP([1], [1]) AS MAP(INT, INT))",
"presto": "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INTEGER, INTEGER))",
"hive": "CAST(MAP(1, 1) AS MAP<INT, INT>)",
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY(1), ARRAY(1)) AS MAP<INT, INT>)",
@@ -109,7 +109,7 @@ class TestPresto(Validator):
"CAST(MAP(ARRAY['a','b','c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INT)))",
write={
"bigquery": "CAST(MAP(['a', 'b', 'c'], [[1], [2], [3]]) AS MAP<STRING, ARRAY<INT64>>)",
- "duckdb": "CAST(MAP(LIST_VALUE('a', 'b', 'c'), LIST_VALUE(LIST_VALUE(1), LIST_VALUE(2), LIST_VALUE(3))) AS MAP(TEXT, INT[]))",
+ "duckdb": "CAST(MAP(['a', 'b', 'c'], [[1], [2], [3]]) AS MAP(TEXT, INT[]))",
"presto": "CAST(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INTEGER)))",
"hive": "CAST(MAP('a', ARRAY(1), 'b', ARRAY(2), 'c', ARRAY(3)) AS MAP<STRING, ARRAY<INT>>)",
"spark": "CAST(MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))) AS MAP<STRING, ARRAY<INT>>)",
@@ -138,6 +138,13 @@ class TestPresto(Validator):
def test_regex(self):
self.validate_all(
+ "REGEXP_REPLACE('abcd', '[ab]')",
+ write={
+ "presto": "REGEXP_REPLACE('abcd', '[ab]', '')",
+ "spark": "REGEXP_REPLACE('abcd', '[ab]', '')",
+ },
+ )
+ self.validate_all(
"REGEXP_LIKE(a, 'x')",
write={
"duckdb": "REGEXP_MATCHES(a, 'x')",
@@ -289,6 +296,13 @@ class TestPresto(Validator):
},
)
self.validate_all(
+ "DATE_ADD('DAY', 1 * -1, x)",
+ write={
+ "presto": "DATE_ADD('DAY', 1 * -1, x)",
+ },
+ read={"mysql": "DATE_SUB(x, INTERVAL 1 DAY)"},
+ )
+ self.validate_all(
"NOW()",
write={
"presto": "CURRENT_TIMESTAMP",
@@ -339,6 +353,11 @@ class TestPresto(Validator):
"presto": "SELECT CAST('2012-10-31 00:00' AS TIMESTAMP) AT TIME ZONE 'America/Sao_Paulo'",
},
)
+ self.validate_all(
+ "CAST(x AS TIMESTAMP)",
+ write={"presto": "CAST(x AS TIMESTAMP)"},
+ read={"mysql": "CAST(x AS DATETIME)", "clickhouse": "CAST(x AS DATETIME64)"},
+ )
def test_ddl(self):
self.validate_all(
@@ -480,6 +499,13 @@ class TestPresto(Validator):
@mock.patch("sqlglot.helper.logger")
def test_presto(self, logger):
+ self.validate_identity(
+ "SELECT * FROM example.testdb.customer_orders FOR VERSION AS OF 8954597067493422955"
+ )
+ self.validate_identity(
+ "SELECT * FROM example.testdb.customer_orders FOR TIMESTAMP AS OF CAST('2022-03-23 09:59:29.803 Europe/Vienna' AS TIMESTAMP)"
+ )
+
self.validate_identity("SELECT * FROM x OFFSET 1 LIMIT 1")
self.validate_identity("SELECT * FROM x OFFSET 1 FETCH FIRST 1 ROWS ONLY")
self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)")
@@ -487,8 +513,58 @@ class TestPresto(Validator):
self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
self.validate_identity("APPROX_PERCENTILE(a, b, c, d)")
+ self.validate_identity(
+ "SELECT SPLIT_TO_MAP('a:1;b:2;a:3', ';', ':', (k, v1, v2) -> CONCAT(v1, v2))"
+ )
self.validate_all(
+ "SELECT ROW(1, 2)",
+ read={
+ "spark": "SELECT STRUCT(1, 2)",
+ },
+ write={
+ "presto": "SELECT ROW(1, 2)",
+ "spark": "SELECT STRUCT(1, 2)",
+ },
+ )
+ self.validate_all(
+ "ARBITRARY(x)",
+ read={
+ "bigquery": "ANY_VALUE(x)",
+ "clickhouse": "any(x)",
+ "databricks": "ANY_VALUE(x)",
+ "doris": "ANY_VALUE(x)",
+ "drill": "ANY_VALUE(x)",
+ "duckdb": "ANY_VALUE(x)",
+ "hive": "FIRST(x)",
+ "mysql": "ANY_VALUE(x)",
+ "oracle": "ANY_VALUE(x)",
+ "redshift": "ANY_VALUE(x)",
+ "snowflake": "ANY_VALUE(x)",
+ "spark": "ANY_VALUE(x)",
+ "spark2": "FIRST(x)",
+ },
+ write={
+ "bigquery": "ANY_VALUE(x)",
+ "clickhouse": "any(x)",
+ "databricks": "ANY_VALUE(x)",
+ "doris": "ANY_VALUE(x)",
+ "drill": "ANY_VALUE(x)",
+ "duckdb": "ANY_VALUE(x)",
+ "hive": "FIRST(x)",
+ "mysql": "ANY_VALUE(x)",
+ "oracle": "ANY_VALUE(x)",
+ "postgres": "MAX(x)",
+ "presto": "ARBITRARY(x)",
+ "redshift": "ANY_VALUE(x)",
+ "snowflake": "ANY_VALUE(x)",
+ "spark": "ANY_VALUE(x)",
+ "spark2": "FIRST(x)",
+ "sqlite": "MAX(x)",
+ "tsql": "MAX(x)",
+ },
+ )
+ self.validate_all(
"STARTS_WITH('abc', 'a')",
read={"spark": "STARTSWITH('abc', 'a')"},
write={
@@ -596,7 +672,7 @@ class TestPresto(Validator):
"SELECT ARRAY[1, 2]",
write={
"bigquery": "SELECT [1, 2]",
- "duckdb": "SELECT LIST_VALUE(1, 2)",
+ "duckdb": "SELECT [1, 2]",
"presto": "SELECT ARRAY[1, 2]",
"spark": "SELECT ARRAY(1, 2)",
},
@@ -748,7 +824,7 @@ class TestPresto(Validator):
self.validate_all(
"""JSON_FORMAT(JSON '"x"')""",
write={
- "bigquery": """TO_JSON_STRING(CAST('"x"' AS JSON))""",
+ "bigquery": """TO_JSON_STRING(JSON '"x"')""",
"duckdb": """CAST(TO_JSON(CAST('"x"' AS JSON)) AS TEXT)""",
"presto": """JSON_FORMAT(CAST('"x"' AS JSON))""",
"spark": """REGEXP_EXTRACT(TO_JSON(FROM_JSON('["x"]', SCHEMA_OF_JSON('["x"]'))), '^.(.*).$', 1)""",
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index 3af27d4..245adf3 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -6,6 +6,14 @@ class TestRedshift(Validator):
def test_redshift(self):
self.validate_all(
+ "x ~* 'pat'",
+ write={
+ "redshift": "x ~* 'pat'",
+ "snowflake": "REGEXP_LIKE(x, 'pat', 'i')",
+ },
+ )
+
+ self.validate_all(
"SELECT CAST('01:03:05.124' AS TIME(2) WITH TIME ZONE)",
read={
"postgres": "SELECT CAST('01:03:05.124' AS TIMETZ(2))",
@@ -163,22 +171,22 @@ class TestRedshift(Validator):
self.validate_all(
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
write={
- "bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
- "databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
- "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
- "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
- "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1",
- "oracle": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
- "presto": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
- "redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
- "snowflake": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1',
- "spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1",
- "sqlite": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
- "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1",
- "tableau": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
- "teradata": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
- "trino": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
- "tsql": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1',
+ "bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
+ "databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
+ "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
+ "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
+ "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1",
+ "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
+ "presto": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
+ "redshift": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1",
+ "snowflake": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1",
+ "spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
+ "sqlite": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
+ "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1",
+ "tableau": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
+ "teradata": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
+ "trino": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
+ "tsql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
},
)
self.validate_all(
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 3053d47..30a1f03 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -8,6 +8,35 @@ class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
+ self.validate_identity(
+ 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage'
+ )
+
+ self.validate_all(
+ "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
+ read={
+ "oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
+ },
+ write={
+ "oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
+ "snowflake": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d",
+ },
+ )
+ self.validate_all(
+ "SELECT INSERT(a, 0, 0, 'b')",
+ read={
+ "mysql": "SELECT INSERT(a, 0, 0, 'b')",
+ "snowflake": "SELECT INSERT(a, 0, 0, 'b')",
+ "tsql": "SELECT STUFF(a, 0, 0, 'b')",
+ },
+ write={
+ "mysql": "SELECT INSERT(a, 0, 0, 'b')",
+ "snowflake": "SELECT INSERT(a, 0, 0, 'b')",
+ "tsql": "SELECT STUFF(a, 0, 0, 'b')",
+ },
+ )
+
+ self.validate_identity("LISTAGG(data['some_field'], ',')")
self.validate_identity("WEEKOFYEAR(tstamp)")
self.validate_identity("SELECT SUM(amount) FROM mytable GROUP BY ALL")
self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT foo FROM IDENTIFIER('x')")
@@ -383,12 +412,6 @@ class TestSnowflake(Validator):
},
)
self.validate_all(
- "SELECT NVL2(a, b, c)",
- write={
- "snowflake": "SELECT NVL2(a, b, c)",
- },
- )
- self.validate_all(
"SELECT $$a$$",
write={
"snowflake": "SELECT 'a'",
@@ -598,7 +621,7 @@ class TestSnowflake(Validator):
write={
"snowflake": "[0, 1, 2]",
"bigquery": "[0, 1, 2]",
- "duckdb": "LIST_VALUE(0, 1, 2)",
+ "duckdb": "[0, 1, 2]",
"presto": "ARRAY[0, 1, 2]",
"spark": "ARRAY(0, 1, 2)",
},
@@ -1011,3 +1034,33 @@ MATCH_RECOGNIZE (
)""",
pretty=True,
)
+
+ def test_show(self):
+ # Parsed as Command
+ self.validate_identity("SHOW COLUMNS IN TABLE dt_test")
+ self.validate_identity("SHOW TABLES LIKE 'line%' IN tpch.public")
+
+ ast = parse_one("SHOW TABLES HISTORY IN tpch.public")
+ self.assertIsInstance(ast, exp.Command)
+
+ # Parsed as Show
+ self.validate_identity("SHOW PRIMARY KEYS")
+ self.validate_identity("SHOW PRIMARY KEYS IN ACCOUNT")
+ self.validate_identity("SHOW PRIMARY KEYS IN DATABASE")
+ self.validate_identity("SHOW PRIMARY KEYS IN DATABASE foo")
+ self.validate_identity("SHOW PRIMARY KEYS IN TABLE")
+ self.validate_identity("SHOW PRIMARY KEYS IN TABLE foo")
+ self.validate_identity(
+ 'SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."customers"',
+ 'SHOW PRIMARY KEYS IN TABLE "TEST"."PUBLIC"."customers"',
+ )
+ self.validate_identity(
+ 'SHOW TERSE PRIMARY KEYS IN "TEST"."PUBLIC"."customers"',
+ 'SHOW PRIMARY KEYS IN TABLE "TEST"."PUBLIC"."customers"',
+ )
+
+ ast = parse_one('SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."customers"', read="snowflake")
+ table = ast.find(exp.Table)
+
+ self.assertIsNotNone(table)
+ self.assertEqual(table.sql(dialect="snowflake"), '"TEST"."PUBLIC"."customers"')
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index 2afa868..a892b0f 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -1,5 +1,6 @@
from unittest import mock
+from sqlglot import exp, parse_one
from tests.dialects.test_dialect import Validator
@@ -224,6 +225,10 @@ TBLPROPERTIES (
)
def test_spark(self):
+ expr = parse_one("any_value(col, true)", read="spark")
+ self.assertIsInstance(expr.args.get("ignore_nulls"), exp.Boolean)
+ self.assertEqual(expr.sql(dialect="spark"), "ANY_VALUE(col, TRUE)")
+
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)")
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)")
self.validate_identity("REFRESH table a.b.c")
@@ -234,8 +239,46 @@ TBLPROPERTIES (
self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')")
self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')")
self.validate_identity("SPLIT(str, pattern, lim)")
+ self.validate_identity(
+ "SELECT STR_TO_MAP('a:1,b:2,c:3')",
+ "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')",
+ )
self.validate_all(
+ "foo.bar",
+ read={
+ "": "STRUCT_EXTRACT(foo, bar)",
+ },
+ )
+ self.validate_all(
+ "MAP(1, 2, 3, 4)",
+ write={
+ "spark": "MAP(1, 2, 3, 4)",
+ "trino": "MAP(ARRAY[1, 3], ARRAY[2, 4])",
+ },
+ )
+ self.validate_all(
+ "MAP()",
+ read={
+ "spark": "MAP()",
+ "trino": "MAP()",
+ },
+ write={
+ "trino": "MAP(ARRAY[], ARRAY[])",
+ },
+ )
+ self.validate_all(
+ "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')",
+ read={
+ "presto": "SELECT SPLIT_TO_MAP('a:1,b:2,c:3', ',', ':')",
+ "spark": "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')",
+ },
+ write={
+ "presto": "SELECT SPLIT_TO_MAP('a:1,b:2,c:3', ',', ':')",
+ "spark": "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')",
+ },
+ )
+ self.validate_all(
"SELECT DATEDIFF(month, CAST('1996-10-30' AS TIMESTAMP), CAST('1997-02-28 10:30:00' AS TIMESTAMP))",
read={
"duckdb": "SELECT DATEDIFF('month', CAST('1996-10-30' AS TIMESTAMP), CAST('1997-02-28 10:30:00' AS TIMESTAMP))",
@@ -399,7 +442,7 @@ TBLPROPERTIES (
"ARRAY(0, 1, 2)",
write={
"bigquery": "[0, 1, 2]",
- "duckdb": "LIST_VALUE(0, 1, 2)",
+ "duckdb": "[0, 1, 2]",
"presto": "ARRAY[0, 1, 2]",
"hive": "ARRAY(0, 1, 2)",
"spark": "ARRAY(0, 1, 2)",
@@ -466,7 +509,7 @@ TBLPROPERTIES (
self.validate_all(
"MAP_FROM_ARRAYS(ARRAY(1), c)",
write={
- "duckdb": "MAP(LIST_VALUE(1), c)",
+ "duckdb": "MAP([1], c)",
"presto": "MAP(ARRAY[1], c)",
"hive": "MAP(ARRAY(1), c)",
"spark": "MAP_FROM_ARRAYS(ARRAY(1), c)",
@@ -522,3 +565,13 @@ TBLPROPERTIES (
self.validate_identity(
"SELECT TRANSFORM(zip_code, name, age) USING 'cat' FROM person WHERE zip_code > 94500"
)
+
+ def test_insert_cte(self):
+ self.validate_all(
+ "INSERT OVERWRITE TABLE table WITH cte AS (SELECT cola FROM other_table) SELECT cola FROM cte",
+ write={
+ "spark": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte",
+ "spark2": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte",
+ "databricks": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte",
+ },
+ )
diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py
index 4d32241..32bdc71 100644
--- a/tests/dialects/test_teradata.py
+++ b/tests/dialects/test_teradata.py
@@ -4,6 +4,18 @@ from tests.dialects.test_dialect import Validator
class TestTeradata(Validator):
dialect = "teradata"
+ def test_teradata(self):
+ self.validate_all(
+ "DATABASE tduser",
+ read={
+ "databricks": "USE tduser",
+ },
+ write={
+ "databricks": "USE tduser",
+ "teradata": "DATABASE tduser",
+ },
+ )
+
def test_translate(self):
self.validate_all(
"TRANSLATE(x USING LATIN_TO_UNICODE)",
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index f43b41b..c27b7fa 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -6,10 +6,55 @@ class TestTSQL(Validator):
dialect = "tsql"
def test_tsql(self):
+ self.validate_all(
+ "CREATE TABLE x ( A INTEGER NOT NULL, B INTEGER NULL )",
+ write={
+ "tsql": "CREATE TABLE x (A INTEGER NOT NULL, B INTEGER NULL)",
+ "hive": "CREATE TABLE x (A INT NOT NULL, B INT)",
+ },
+ )
+
+ self.validate_identity(
+ 'CREATE TABLE x (CONSTRAINT "pk_mytable" UNIQUE NONCLUSTERED (a DESC)) ON b (c)'
+ )
+
+ self.validate_identity(
+ """
+ CREATE TABLE x(
+ [zip_cd] [varchar](5) NULL NOT FOR REPLICATION
+ CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED
+ ([zip_cd_mkey] ASC)
+ WITH (PAD_INDEX = ON, STATISTICS_NORECOMPUTE = OFF) ON [PRIMARY]
+ ) ON [PRIMARY]
+ """,
+ 'CREATE TABLE x ("zip_cd" VARCHAR(5) NULL NOT FOR REPLICATION CONSTRAINT "pk_mytable" PRIMARY KEY CLUSTERED ("zip_cd_mkey") WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON "PRIMARY") ON "PRIMARY"',
+ )
+
+ self.validate_identity(
+ "CREATE TABLE tbl (a AS (x + 1) PERSISTED, b AS (y + 2), c AS (y / 3) PERSISTED NOT NULL)"
+ )
+
+ self.validate_identity(
+ "CREATE TABLE [db].[tbl]([a] [int])", 'CREATE TABLE "db"."tbl" ("a" INTEGER)'
+ )
+
projection = parse_one("SELECT a = 1", read="tsql").selects[0]
projection.assert_is(exp.Alias)
projection.args["alias"].assert_is(exp.Identifier)
+ self.validate_all(
+ "IF OBJECT_ID('tempdb.dbo.#TempTableName', 'U') IS NOT NULL DROP TABLE #TempTableName",
+ write={
+ "tsql": "DROP TABLE IF EXISTS #TempTableName",
+ "spark": "DROP TABLE IF EXISTS TempTableName",
+ },
+ )
+
+ self.validate_identity(
+ "MERGE INTO mytable WITH (HOLDLOCK) AS T USING mytable_merge AS S "
+ "ON (T.user_id = S.user_id) WHEN NOT MATCHED THEN INSERT (c1, c2) VALUES (S.c1, S.c2)"
+ )
+ self.validate_identity("UPDATE STATISTICS x")
self.validate_identity("UPDATE x SET y = 1 OUTPUT x.a, x.b INTO @y FROM y")
self.validate_identity("UPDATE x SET y = 1 OUTPUT x.a, x.b FROM y")
self.validate_identity("INSERT INTO x (y) OUTPUT x.a, x.b INTO l SELECT * FROM z")
@@ -397,8 +442,68 @@ class TestTSQL(Validator):
},
)
+ self.validate_all(
+ "CAST(x AS BOOLEAN)",
+ write={"tsql": "CAST(x AS BIT)"},
+ )
+
def test_ddl(self):
self.validate_all(
+ "CREATE TABLE tbl (id INTEGER IDENTITY PRIMARY KEY)",
+ read={
+ "mysql": "CREATE TABLE tbl (id INT AUTO_INCREMENT PRIMARY KEY)",
+ "tsql": "CREATE TABLE tbl (id INTEGER IDENTITY PRIMARY KEY)",
+ },
+ )
+ self.validate_all(
+ "CREATE TABLE tbl (id INTEGER NOT NULL IDENTITY(10, 1) PRIMARY KEY)",
+ read={
+ "postgres": "CREATE TABLE tbl (id INT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10) PRIMARY KEY)",
+ "tsql": "CREATE TABLE tbl (id INTEGER NOT NULL IDENTITY(10, 1) PRIMARY KEY)",
+ },
+ )
+ self.validate_all(
+ "IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id('db.tbl') AND name = 'idx') EXEC('CREATE INDEX idx ON db.tbl')",
+ read={
+ "": "CREATE INDEX IF NOT EXISTS idx ON db.tbl",
+ },
+ )
+
+ self.validate_all(
+ "IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = 'foo') EXEC('CREATE SCHEMA foo')",
+ read={
+ "": "CREATE SCHEMA IF NOT EXISTS foo",
+ },
+ )
+ self.validate_all(
+ "IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'foo') EXEC('CREATE TABLE foo (a INTEGER)')",
+ read={
+ "": "CREATE TABLE IF NOT EXISTS foo (a INTEGER)",
+ },
+ )
+
+ self.validate_all(
+ "CREATE OR ALTER VIEW a.b AS SELECT 1",
+ read={
+ "": "CREATE OR REPLACE VIEW a.b AS SELECT 1",
+ },
+ write={
+ "tsql": "CREATE OR ALTER VIEW a.b AS SELECT 1",
+ },
+ )
+
+ self.validate_all(
+ "ALTER TABLE a ADD b INTEGER, c INTEGER",
+ read={
+ "": "ALTER TABLE a ADD COLUMN b INT, ADD COLUMN c INT",
+ },
+ write={
+ "": "ALTER TABLE a ADD COLUMN b INT, ADD COLUMN c INT",
+ "tsql": "ALTER TABLE a ADD b INTEGER, c INTEGER",
+ },
+ )
+
+ self.validate_all(
"CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))",
write={
"spark": "CREATE TEMPORARY TABLE mytemp (a INT, b CHAR(2), c TIMESTAMP, d FLOAT)",
@@ -898,6 +1003,9 @@ WHERE
)
def test_iif(self):
+ self.validate_identity(
+ "SELECT IF(cond, 'True', 'False')", "SELECT IIF(cond, 'True', 'False')"
+ )
self.validate_identity("SELECT IIF(cond, 'True', 'False')")
self.validate_all(
"SELECT IIF(cond, 'True', 'False');",
@@ -961,9 +1069,12 @@ WHERE
)
def test_format(self):
+ self.validate_identity("SELECT FORMAT(foo, 'dddd', 'de-CH')")
+ self.validate_identity("SELECT FORMAT(EndOfDayRate, 'N', 'en-us')")
self.validate_identity("SELECT FORMAT('01-01-1991', 'd.mm.yyyy')")
self.validate_identity("SELECT FORMAT(12345, '###.###.###')")
self.validate_identity("SELECT FORMAT(1234567, 'f')")
+
self.validate_all(
"SELECT FORMAT(1000000.01,'###,###.###')",
write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"},