summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/dialects/test_databricks.py6
-rw-r--r--tests/dialects/test_dialect.py6
-rw-r--r--tests/dialects/test_mysql.py22
-rw-r--r--tests/dialects/test_postgres.py15
-rw-r--r--tests/dialects/test_redshift.py32
-rw-r--r--tests/dialects/test_spark.py8
-rw-r--r--tests/dialects/test_teradata.py6
-rw-r--r--tests/dialects/test_tsql.py47
-rw-r--r--tests/test_transforms.py8
9 files changed, 118 insertions, 32 deletions
diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py
index f13d0f2..d06e0f1 100644
--- a/tests/dialects/test_databricks.py
+++ b/tests/dialects/test_databricks.py
@@ -5,6 +5,12 @@ class TestDatabricks(Validator):
dialect = "databricks"
def test_databricks(self):
+ self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO HOUR)")
+ self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO MINUTE)")
+ self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO SECOND)")
+ self.validate_identity("SELECT CAST('23:00:00' AS INTERVAL HOUR TO MINUTE)")
+ self.validate_identity("SELECT CAST('23:00:00' AS INTERVAL HOUR TO SECOND)")
+ self.validate_identity("SELECT CAST('23:00:00' AS INTERVAL MINUTE TO SECOND)")
self.validate_identity("CREATE TABLE target SHALLOW CLONE source")
self.validate_identity("INSERT INTO a REPLACE WHERE cond VALUES (1), (2)")
self.validate_identity("SELECT c1 : price")
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 6a41218..47e1ec7 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -282,14 +282,16 @@ class TestDialect(Validator):
"starrocks": "CAST(a AS DATETIME)",
"redshift": "CAST(a AS TIMESTAMP)",
"doris": "CAST(a AS DATETIME)",
+ "mysql": "CAST(a AS DATETIME)",
},
)
self.validate_all(
"CAST(a AS TIMESTAMPTZ)",
write={
- "starrocks": "CAST(a AS DATETIME)",
+ "starrocks": "TIMESTAMP(a)",
"redshift": "CAST(a AS TIMESTAMP WITH TIME ZONE)",
"doris": "CAST(a AS DATETIME)",
+ "mysql": "TIMESTAMP(a)",
},
)
self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"})
@@ -870,7 +872,7 @@ class TestDialect(Validator):
"TIMESTAMP '2022-01-01'",
write={
"drill": "CAST('2022-01-01' AS TIMESTAMP)",
- "mysql": "CAST('2022-01-01' AS TIMESTAMP)",
+ "mysql": "CAST('2022-01-01' AS DATETIME)",
"starrocks": "CAST('2022-01-01' AS DATETIME)",
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
"doris": "CAST('2022-01-01' AS DATETIME)",
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index fc63f9f..1f1c2e9 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -6,6 +6,19 @@ class TestMySQL(Validator):
dialect = "mysql"
def test_ddl(self):
+ int_types = {"BIGINT", "INT", "MEDIUMINT", "SMALLINT", "TINYINT"}
+
+ for t in int_types:
+ self.validate_identity(f"CREATE TABLE t (id {t} UNSIGNED)")
+ self.validate_identity(f"CREATE TABLE t (id {t}(10) UNSIGNED)")
+
+ self.validate_all(
+ "CREATE TABLE t (id INT UNSIGNED)",
+ write={
+ "duckdb": "CREATE TABLE t (id UINTEGER)",
+ },
+ )
+
self.validate_identity("CREATE TABLE foo (id BIGINT)")
self.validate_identity("CREATE TABLE 00f (1d BIGINT)")
self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10")
@@ -83,6 +96,12 @@ class TestMySQL(Validator):
"mysql": "CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE d (b), INDEX e (b))",
},
)
+ self.validate_all(
+ "CREATE TABLE test (ts TIMESTAMP, ts_tz TIMESTAMPTZ, ts_ltz TIMESTAMPLTZ)",
+ write={
+ "mysql": "CREATE TABLE test (ts DATETIME, ts_tz TIMESTAMP, ts_ltz TIMESTAMP)",
+ },
+ )
def test_identity(self):
self.validate_identity(
@@ -202,6 +221,9 @@ class TestMySQL(Validator):
"spark": "CAST(x AS BLOB) + CAST(y AS BLOB)",
},
)
+ self.validate_all("CAST(x AS TIMESTAMP)", write={"mysql": "CAST(x AS DATETIME)"})
+ self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"mysql": "TIMESTAMP(x)"})
+ self.validate_all("CAST(x AS TIMESTAMPLTZ)", write={"mysql": "TIMESTAMP(x)"})
def test_canonical_functions(self):
self.validate_identity("SELECT LEFT('str', 2)", "SELECT LEFT('str', 2)")
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 8740aca..21196b7 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -126,12 +126,17 @@ class TestPostgres(Validator):
)
def test_postgres(self):
- self.validate_identity("x @@ y")
-
expr = parse_one("SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)")
unnest = expr.args["joins"][0].this.this
unnest.assert_is(exp.Unnest)
+ alter_table_only = """ALTER TABLE ONLY "Album" ADD CONSTRAINT "FK_AlbumArtistId" FOREIGN KEY ("ArtistId") REFERENCES "Artist" ("ArtistId") ON DELETE NO ACTION ON UPDATE NO ACTION"""
+ expr = parse_one(alter_table_only)
+
+ self.assertIsInstance(expr, exp.AlterTable)
+ self.assertEqual(expr.sql(dialect="postgres"), alter_table_only)
+
+ self.validate_identity("x @@ y")
self.validate_identity("CAST(x AS MONEY)")
self.validate_identity("CAST(x AS INT4RANGE)")
self.validate_identity("CAST(x AS INT4MULTIRANGE)")
@@ -619,6 +624,12 @@ class TestPostgres(Validator):
"snowflake": "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b",
},
)
+ self.validate_all(
+ "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)",
+ write={
+ "postgres": "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)",
+ },
+ )
self.validate_all(
"x / y ^ z",
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index 245adf3..aea8b69 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -171,22 +171,22 @@ class TestRedshift(Validator):
self.validate_all(
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
write={
- "bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
- "databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
- "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
- "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
- "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1",
- "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
- "presto": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
- "redshift": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1",
- "snowflake": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1",
- "spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
- "sqlite": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
- "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1",
- "tableau": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
- "teradata": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
- "trino": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
- "tsql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1",
+ "bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) _t WHERE _row_number = 1",
+ "presto": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "redshift": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "snowflake": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "sqlite": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "tableau": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "teradata": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "trino": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
+ "tsql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1",
},
)
self.validate_all(
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index a892b0f..becb66a 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -240,6 +240,14 @@ TBLPROPERTIES (
self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')")
self.validate_identity("SPLIT(str, pattern, lim)")
self.validate_identity(
+ "SELECT REGEXP_REPLACE('100-200', r'([^0-9])', '')",
+ "SELECT REGEXP_REPLACE('100-200', '([^0-9])', '')",
+ )
+ self.validate_identity(
+ "SELECT REGEXP_REPLACE('100-200', R'([^0-9])', '')",
+ "SELECT REGEXP_REPLACE('100-200', '([^0-9])', '')",
+ )
+ self.validate_identity(
"SELECT STR_TO_MAP('a:1,b:2,c:3')",
"SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')",
)
diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py
index 32bdc71..f3615ff 100644
--- a/tests/dialects/test_teradata.py
+++ b/tests/dialects/test_teradata.py
@@ -5,6 +5,12 @@ class TestTeradata(Validator):
dialect = "teradata"
def test_teradata(self):
+ self.validate_identity("SELECT * FROM tbl SAMPLE 5")
+ self.validate_identity(
+ "SELECT * FROM tbl SAMPLE 0.33, .25, .1",
+ "SELECT * FROM tbl SAMPLE 0.33, 0.25, 0.1",
+ )
+
self.validate_all(
"DATABASE tduser",
read={
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index c27b7fa..acf8b79 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -368,6 +368,14 @@ class TestTSQL(Validator):
},
)
+ self.validate_all(
+ "CAST(x as UNIQUEIDENTIFIER)",
+ write={
+ "spark": "CAST(x AS STRING)",
+ "tsql": "CAST(x AS UNIQUEIDENTIFIER)",
+ },
+ )
+
def test_types_date(self):
self.validate_all(
"CAST(x as DATE)",
@@ -427,14 +435,6 @@ class TestTSQL(Validator):
)
self.validate_all(
- "CAST(x as UNIQUEIDENTIFIER)",
- write={
- "spark": "CAST(x AS STRING)",
- "tsql": "CAST(x AS UNIQUEIDENTIFIER)",
- },
- )
-
- self.validate_all(
"CAST(x as VARBINARY)",
write={
"spark": "CAST(x AS BINARY)",
@@ -447,6 +447,37 @@ class TestTSQL(Validator):
write={"tsql": "CAST(x AS BIT)"},
)
+ self.validate_all("a = TRUE", write={"tsql": "a = 1"})
+
+ self.validate_all("a != FALSE", write={"tsql": "a <> 0"})
+
+ self.validate_all("a IS TRUE", write={"tsql": "a = 1"})
+
+ self.validate_all("a IS NOT FALSE", write={"tsql": "NOT a = 0"})
+
+ self.validate_all(
+ "CASE WHEN a IN (TRUE) THEN 'y' ELSE 'n' END",
+ write={"tsql": "CASE WHEN a IN (1) THEN 'y' ELSE 'n' END"},
+ )
+
+ self.validate_all(
+ "CASE WHEN a NOT IN (FALSE) THEN 'y' ELSE 'n' END",
+ write={"tsql": "CASE WHEN NOT a IN (0) THEN 'y' ELSE 'n' END"},
+ )
+
+ self.validate_all("SELECT TRUE, FALSE", write={"tsql": "SELECT 1, 0"})
+
+ self.validate_all("SELECT TRUE AS a, FALSE AS b", write={"tsql": "SELECT 1 AS a, 0 AS b"})
+
+ self.validate_all(
+ "SELECT 1 FROM a WHERE TRUE", write={"tsql": "SELECT 1 FROM a WHERE (1 = 1)"}
+ )
+
+ self.validate_all(
+ "CASE WHEN TRUE THEN 'y' WHEN FALSE THEN 'n' ELSE NULL END",
+ write={"tsql": "CASE WHEN (1 = 1) THEN 'y' WHEN (1 = 0) THEN 'n' ELSE NULL END"},
+ )
+
def test_ddl(self):
self.validate_all(
"CREATE TABLE tbl (id INTEGER IDENTITY PRIMARY KEY)",
diff --git a/tests/test_transforms.py b/tests/test_transforms.py
index 2109f53..8f14ae4 100644
--- a/tests/test_transforms.py
+++ b/tests/test_transforms.py
@@ -52,17 +52,17 @@ class TestTransforms(unittest.TestCase):
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC",
- "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1",
+ "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a) a, b FROM x",
- "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a) AS _row_number FROM x) WHERE _row_number = 1",
+ "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a) AS _row_number FROM x) AS _t WHERE _row_number = 1",
)
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC",
- "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1",
+ "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1",
)
self.validate(
eliminate_distinct_on,
@@ -72,7 +72,7 @@ class TestTransforms(unittest.TestCase):
self.validate(
eliminate_distinct_on,
"SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC",
- "SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE _row_number_2 = 1",
+ "SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) AS _t WHERE _row_number_2 = 1",
)
def test_eliminate_qualify(self):