diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dialects/test_databricks.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 22 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 15 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 32 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_teradata.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 47 | ||||
-rw-r--r-- | tests/test_transforms.py | 8 |
9 files changed, 118 insertions, 32 deletions
diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index f13d0f2..d06e0f1 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -5,6 +5,12 @@ class TestDatabricks(Validator): dialect = "databricks" def test_databricks(self): + self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO HOUR)") + self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO MINUTE)") + self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO SECOND)") + self.validate_identity("SELECT CAST('23:00:00' AS INTERVAL HOUR TO MINUTE)") + self.validate_identity("SELECT CAST('23:00:00' AS INTERVAL HOUR TO SECOND)") + self.validate_identity("SELECT CAST('23:00:00' AS INTERVAL MINUTE TO SECOND)") self.validate_identity("CREATE TABLE target SHALLOW CLONE source") self.validate_identity("INSERT INTO a REPLACE WHERE cond VALUES (1), (2)") self.validate_identity("SELECT c1 : price") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 6a41218..47e1ec7 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -282,14 +282,16 @@ class TestDialect(Validator): "starrocks": "CAST(a AS DATETIME)", "redshift": "CAST(a AS TIMESTAMP)", "doris": "CAST(a AS DATETIME)", + "mysql": "CAST(a AS DATETIME)", }, ) self.validate_all( "CAST(a AS TIMESTAMPTZ)", write={ - "starrocks": "CAST(a AS DATETIME)", + "starrocks": "TIMESTAMP(a)", "redshift": "CAST(a AS TIMESTAMP WITH TIME ZONE)", "doris": "CAST(a AS DATETIME)", + "mysql": "TIMESTAMP(a)", }, ) self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"}) @@ -870,7 +872,7 @@ class TestDialect(Validator): "TIMESTAMP '2022-01-01'", write={ "drill": "CAST('2022-01-01' AS TIMESTAMP)", - "mysql": "CAST('2022-01-01' AS TIMESTAMP)", + "mysql": "CAST('2022-01-01' AS DATETIME)", "starrocks": "CAST('2022-01-01' AS DATETIME)", "hive": "CAST('2022-01-01' AS TIMESTAMP)", "doris": "CAST('2022-01-01' AS DATETIME)", diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index fc63f9f..1f1c2e9 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -6,6 +6,19 @@ class TestMySQL(Validator): dialect = "mysql" def test_ddl(self): + int_types = {"BIGINT", "INT", "MEDIUMINT", "SMALLINT", "TINYINT"} + + for t in int_types: + self.validate_identity(f"CREATE TABLE t (id {t} UNSIGNED)") + self.validate_identity(f"CREATE TABLE t (id {t}(10) UNSIGNED)") + + self.validate_all( + "CREATE TABLE t (id INT UNSIGNED)", + write={ + "duckdb": "CREATE TABLE t (id UINTEGER)", + }, + ) + self.validate_identity("CREATE TABLE foo (id BIGINT)") self.validate_identity("CREATE TABLE 00f (1d BIGINT)") self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10") @@ -83,6 +96,12 @@ class TestMySQL(Validator): "mysql": "CREATE TABLE IF NOT EXISTS industry_info (a BIGINT(20) NOT NULL AUTO_INCREMENT, b BIGINT(20) NOT NULL, c VARCHAR(1000), PRIMARY KEY (a), UNIQUE d (b), INDEX e (b))", }, ) + self.validate_all( + "CREATE TABLE test (ts TIMESTAMP, ts_tz TIMESTAMPTZ, ts_ltz TIMESTAMPLTZ)", + write={ + "mysql": "CREATE TABLE test (ts DATETIME, ts_tz TIMESTAMP, ts_ltz TIMESTAMP)", + }, + ) def test_identity(self): self.validate_identity( @@ -202,6 +221,9 @@ class TestMySQL(Validator): "spark": "CAST(x AS BLOB) + CAST(y AS BLOB)", }, ) + self.validate_all("CAST(x AS TIMESTAMP)", write={"mysql": "CAST(x AS DATETIME)"}) + self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"mysql": "TIMESTAMP(x)"}) + self.validate_all("CAST(x AS TIMESTAMPLTZ)", write={"mysql": "TIMESTAMP(x)"}) def test_canonical_functions(self): self.validate_identity("SELECT LEFT('str', 2)", "SELECT LEFT('str', 2)") diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 8740aca..21196b7 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -126,12 +126,17 @@ class TestPostgres(Validator): ) def test_postgres(self): - self.validate_identity("x @@ y") - expr = parse_one("SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)") unnest = expr.args["joins"][0].this.this unnest.assert_is(exp.Unnest) + alter_table_only = """ALTER TABLE ONLY "Album" ADD CONSTRAINT "FK_AlbumArtistId" FOREIGN KEY ("ArtistId") REFERENCES "Artist" ("ArtistId") ON DELETE NO ACTION ON UPDATE NO ACTION""" + expr = parse_one(alter_table_only) + + self.assertIsInstance(expr, exp.AlterTable) + self.assertEqual(expr.sql(dialect="postgres"), alter_table_only) + + self.validate_identity("x @@ y") self.validate_identity("CAST(x AS MONEY)") self.validate_identity("CAST(x AS INT4RANGE)") self.validate_identity("CAST(x AS INT4MULTIRANGE)") @@ -619,6 +624,12 @@ class TestPostgres(Validator): "snowflake": "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b", }, ) + self.validate_all( + "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)", + write={ + "postgres": "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b WHEN NOT MATCHED THEN INSERT (a, b) VALUES (y.a, y.b)", + }, + ) self.validate_all( "x / y ^ z", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 245adf3..aea8b69 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -171,22 +171,22 @@ class TestRedshift(Validator): self.validate_all( "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", write={ - "bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", - "databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", - "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", - "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", - "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1", - "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", - "presto": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", - "redshift": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1", - "snowflake": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1", - "spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", - "sqlite": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", - "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1", - "tableau": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", - "teradata": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", - "trino": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", - "tsql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", + "bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) _t WHERE _row_number = 1", + "presto": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "redshift": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "snowflake": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "sqlite": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "tableau": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "teradata": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "trino": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "tsql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", }, ) self.validate_all( diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index a892b0f..becb66a 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -240,6 +240,14 @@ TBLPROPERTIES ( self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')") self.validate_identity("SPLIT(str, pattern, lim)") self.validate_identity( + "SELECT REGEXP_REPLACE('100-200', r'([^0-9])', '')", + "SELECT REGEXP_REPLACE('100-200', '([^0-9])', '')", + ) + self.validate_identity( + "SELECT REGEXP_REPLACE('100-200', R'([^0-9])', '')", + "SELECT REGEXP_REPLACE('100-200', '([^0-9])', '')", + ) + self.validate_identity( "SELECT STR_TO_MAP('a:1,b:2,c:3')", "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", ) diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 32bdc71..f3615ff 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -5,6 +5,12 @@ class TestTeradata(Validator): dialect = "teradata" def test_teradata(self): + self.validate_identity("SELECT * FROM tbl SAMPLE 5") + self.validate_identity( + "SELECT * FROM tbl SAMPLE 0.33, .25, .1", + "SELECT * FROM tbl SAMPLE 0.33, 0.25, 0.1", + ) + self.validate_all( "DATABASE tduser", read={ diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index c27b7fa..acf8b79 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -368,6 +368,14 @@ class TestTSQL(Validator): }, ) + self.validate_all( + "CAST(x as UNIQUEIDENTIFIER)", + write={ + "spark": "CAST(x AS STRING)", + "tsql": "CAST(x AS UNIQUEIDENTIFIER)", + }, + ) + def test_types_date(self): self.validate_all( "CAST(x as DATE)", @@ -427,14 +435,6 @@ class TestTSQL(Validator): ) self.validate_all( - "CAST(x as UNIQUEIDENTIFIER)", - write={ - "spark": "CAST(x AS STRING)", - "tsql": "CAST(x AS UNIQUEIDENTIFIER)", - }, - ) - - self.validate_all( "CAST(x as VARBINARY)", write={ "spark": "CAST(x AS BINARY)", @@ -447,6 +447,37 @@ class TestTSQL(Validator): write={"tsql": "CAST(x AS BIT)"}, ) + self.validate_all("a = TRUE", write={"tsql": "a = 1"}) + + self.validate_all("a != FALSE", write={"tsql": "a <> 0"}) + + self.validate_all("a IS TRUE", write={"tsql": "a = 1"}) + + self.validate_all("a IS NOT FALSE", write={"tsql": "NOT a = 0"}) + + self.validate_all( + "CASE WHEN a IN (TRUE) THEN 'y' ELSE 'n' END", + write={"tsql": "CASE WHEN a IN (1) THEN 'y' ELSE 'n' END"}, + ) + + self.validate_all( + "CASE WHEN a NOT IN (FALSE) THEN 'y' ELSE 'n' END", + write={"tsql": "CASE WHEN NOT a IN (0) THEN 'y' ELSE 'n' END"}, + ) + + self.validate_all("SELECT TRUE, FALSE", write={"tsql": "SELECT 1, 0"}) + + self.validate_all("SELECT TRUE AS a, FALSE AS b", write={"tsql": "SELECT 1 AS a, 0 AS b"}) + + self.validate_all( + "SELECT 1 FROM a WHERE TRUE", write={"tsql": "SELECT 1 FROM a WHERE (1 = 1)"} + ) + + self.validate_all( + "CASE WHEN TRUE THEN 'y' WHEN FALSE THEN 'n' ELSE NULL END", + write={"tsql": "CASE WHEN (1 = 1) THEN 'y' WHEN (1 = 0) THEN 'n' ELSE NULL END"}, + ) + def test_ddl(self): self.validate_all( "CREATE TABLE tbl (id INTEGER IDENTITY PRIMARY KEY)", diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 2109f53..8f14ae4 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -52,17 +52,17 @@ class TestTransforms(unittest.TestCase): self.validate( eliminate_distinct_on, "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", - "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1", + "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", ) self.validate( eliminate_distinct_on, "SELECT DISTINCT ON (a) a, b FROM x", - "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a) AS _row_number FROM x) WHERE _row_number = 1", + "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a) AS _row_number FROM x) AS _t WHERE _row_number = 1", ) self.validate( eliminate_distinct_on, "SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC", - "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1", + "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", ) self.validate( eliminate_distinct_on, @@ -72,7 +72,7 @@ class TestTransforms(unittest.TestCase): self.validate( eliminate_distinct_on, "SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC", - "SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE _row_number_2 = 1", + "SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) AS _t WHERE _row_number_2 = 1", ) def test_eliminate_qualify(self): |