diff options
Diffstat (limited to 'tests/dialects')
-rw-r--r-- | tests/dialects/test_clickhouse.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 54 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 23 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 35 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 11 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 15 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_sqlite.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 14 |
11 files changed, 167 insertions, 16 deletions
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index d206bb1..40a3a04 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -5,6 +5,7 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): + self.validate_identity("SELECT match('abc', '([a-z]+)')") self.validate_identity("dictGet(x, 'y')") self.validate_identity("SELECT * FROM x FINAL") self.validate_identity("SELECT * FROM x AS y FINAL") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 0805e9c..3558d62 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -290,6 +290,60 @@ class TestDialect(Validator): read={"postgres": "INET '127.0.0.1/32'"}, ) + def test_decode(self): + self.validate_identity("DECODE(bin, charset)") + + self.validate_all( + "SELECT DECODE(a, 1, 'one')", + write={ + "": "SELECT CASE WHEN a = 1 THEN 'one' END", + "oracle": "SELECT CASE WHEN a = 1 THEN 'one' END", + "redshift": "SELECT CASE WHEN a = 1 THEN 'one' END", + "snowflake": "SELECT CASE WHEN a = 1 THEN 'one' END", + "spark": "SELECT CASE WHEN a = 1 THEN 'one' END", + }, + ) + self.validate_all( + "SELECT DECODE(a, 1, 'one', 'default')", + write={ + "": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END", + "oracle": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END", + "redshift": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END", + "snowflake": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END", + "spark": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END", + }, + ) + self.validate_all( + "SELECT DECODE(a, NULL, 'null')", + write={ + "": "SELECT CASE WHEN a IS NULL THEN 'null' END", + "oracle": "SELECT CASE WHEN a IS NULL THEN 'null' END", + "redshift": "SELECT CASE WHEN a IS NULL THEN 'null' END", + "snowflake": "SELECT CASE WHEN a IS NULL THEN 'null' END", + "spark": "SELECT CASE WHEN a IS NULL THEN 'null' END", + }, + ) + self.validate_all( + "SELECT DECODE(a, b, c)", + write={ + "": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END", + "oracle": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END", + "redshift": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END", + "snowflake": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END", + "spark": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END", + }, + ) + self.validate_all( + "SELECT DECODE(tbl.col, 'some_string', 'foo')", + write={ + "": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END", + "oracle": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END", + "redshift": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END", + "snowflake": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END", + "spark": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END", + }, + ) + def test_if_null(self): self.validate_all( "SELECT IFNULL(1, NULL) FROM foo", diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index a15e6b4..245d82a 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -74,6 +74,17 @@ class TestDuckDB(Validator): "spark": "TO_TIMESTAMP(x, 'M/d/yy h:mm a')", }, ) + self.validate_all( + "CAST(start AS TIMESTAMPTZ) AT TIME ZONE 'America/New_York'", + read={ + "snowflake": "CONVERT_TIMEZONE('America/New_York', CAST(start AS TIMESTAMPTZ))", + }, + write={ + "bigquery": "TIMESTAMP(DATETIME(CAST(start AS TIMESTAMPTZ), 'America/New_York'))", + "duckdb": "CAST(start AS TIMESTAMPTZ) AT TIME ZONE 'America/New_York'", + "snowflake": "CONVERT_TIMEZONE('America/New_York', CAST(start AS TIMESTAMPTZ))", + }, + ) def test_sample(self): self.validate_all( @@ -421,6 +432,18 @@ class TestDuckDB(Validator): "snowflake": "CAST(COL AS ARRAY)", }, ) + self.validate_all( + "CAST([STRUCT_PACK(a := 1)] AS STRUCT(a BIGINT)[])", + write={ + "duckdb": "CAST(LIST_VALUE({'a': 1}) AS STRUCT(a BIGINT)[])", + }, + ) + self.validate_all( + "CAST([[STRUCT_PACK(a := 1)]] AS STRUCT(a BIGINT)[][])", + write={ + "duckdb": "CAST(LIST_VALUE(LIST_VALUE({'a': 1})) AS STRUCT(a BIGINT)[][])", + }, + ) def test_bool_or(self): self.validate_all( diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 0161f1e..1a83575 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -689,3 +689,11 @@ class TestHive(Validator): self.validate_identity("'\\z'") self.validate_identity("'\\\z'") self.validate_identity("'\\\\z'") + + def test_data_type(self): + self.validate_all( + "CAST(a AS BIT)", + write={ + "hive": "CAST(a AS BOOLEAN)", + }, + ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 5059d05..f618728 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -204,6 +204,41 @@ class TestMySQL(Validator): }, ) + def test_match_against(self): + self.validate_all( + "MATCH(col1, col2, col3) AGAINST('abc')", + read={ + "": "MATCH(col1, col2, col3) AGAINST('abc')", + "mysql": "MATCH(col1, col2, col3) AGAINST('abc')", + }, + write={ + "": "MATCH(col1, col2, col3) AGAINST('abc')", + "mysql": "MATCH(col1, col2, col3) AGAINST('abc')", + }, + ) + self.validate_all( + "MATCH(col1, col2) AGAINST('abc' IN NATURAL LANGUAGE MODE)", + write={"mysql": "MATCH(col1, col2) AGAINST('abc' IN NATURAL LANGUAGE MODE)"}, + ) + self.validate_all( + "MATCH(col1, col2) AGAINST('abc' IN NATURAL LANGUAGE MODE WITH QUERY EXPANSION)", + write={ + "mysql": "MATCH(col1, col2) AGAINST('abc' IN NATURAL LANGUAGE MODE WITH QUERY EXPANSION)" + }, + ) + self.validate_all( + "MATCH(col1, col2) AGAINST('abc' IN BOOLEAN MODE)", + write={"mysql": "MATCH(col1, col2) AGAINST('abc' IN BOOLEAN MODE)"}, + ) + self.validate_all( + "MATCH(col1, col2) AGAINST('abc' WITH QUERY EXPANSION)", + write={"mysql": "MATCH(col1, col2) AGAINST('abc' WITH QUERY EXPANSION)"}, + ) + self.validate_all( + "MATCH(a.b) AGAINST('abc')", + write={"mysql": "MATCH(a.b) AGAINST('abc')"}, + ) + def test_mysql(self): self.validate_all( "SELECT a FROM tbl FOR UPDATE", diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 1762e7a..1007899 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -539,6 +539,16 @@ class TestPresto(Validator): "hive": "SELECT a, b, c, d, SUM(y) FROM z GROUP BY d, GROUPING SETS ((b, c)), CUBE (a), ROLLUP (a)", }, ) + self.validate_all( + "JSON_FORMAT(x)", + read={ + "spark": "TO_JSON(x)", + }, + write={ + "presto": "JSON_FORMAT(x)", + "spark": "TO_JSON(x)", + }, + ) def test_encode_decode(self): self.validate_all( diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index ff730f8..0933051 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -69,9 +69,11 @@ class TestRedshift(Validator): self.validate_all( "DECODE(x, a, b, c, d)", write={ - "": "MATCHES(x, a, b, c, d)", - "oracle": "DECODE(x, a, b, c, d)", - "snowflake": "DECODE(x, a, b, c, d)", + "": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END", + "oracle": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END", + "redshift": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END", + "snowflake": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END", + "spark": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END", }, ) self.validate_all( @@ -97,9 +99,6 @@ class TestRedshift(Validator): ) def test_identity(self): - self.validate_identity( - "SELECT DECODE(COL1, 'replace_this', 'with_this', 'replace_that', 'with_that')" - ) self.validate_identity("CAST('bla' AS SUPER)") self.validate_identity("CREATE TABLE real1 (realcol REAL)") self.validate_identity("CAST('foo' AS HLLSKETCH)") diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 940fa50..eb423a5 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -6,6 +6,11 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + self.validate_identity("SELECT HLL(*)") + self.validate_identity("SELECT HLL(a)") + self.validate_identity("SELECT HLL(DISTINCT t.a)") + self.validate_identity("SELECT HLL(a, b, c)") + self.validate_identity("SELECT HLL(DISTINCT a, b, c)") self.validate_identity("$x") self.validate_identity("SELECT REGEXP_LIKE(a, b, c)") self.validate_identity("PUT file:///dir/tmp.csv @%table") @@ -376,14 +381,10 @@ class TestSnowflake(Validator): }, ) self.validate_all( - "DECODE(x, a, b, c, d)", - read={ - "": "MATCHES(x, a, b, c, d)", - }, + "DECODE(x, a, b, c, d, e)", write={ - "": "MATCHES(x, a, b, c, d)", - "oracle": "DECODE(x, a, b, c, d)", - "snowflake": "DECODE(x, a, b, c, d)", + "": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d ELSE e END", + "snowflake": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d ELSE e END", }, ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index b12f272..0da2931 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -345,3 +345,13 @@ TBLPROPERTIES ( "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a", write={"spark": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"}, ) + + def test_current_user(self): + self.validate_all( + "CURRENT_USER", + write={"spark": "CURRENT_USER()"}, + ) + self.validate_all( + "CURRENT_USER()", + write={"spark": "CURRENT_USER()"}, + ) diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index fd9e52b..583d5be 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -12,7 +12,7 @@ class TestSQLite(Validator): self.validate_identity("INSERT OR ROLLBACK INTO foo (x, y) VALUES (1, 2)") self.validate_all( - "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)", + "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)", write={"sqlite": "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)"}, ) self.validate_all( diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 60867be..d9ee4ae 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -547,11 +547,11 @@ WHERE def test_string(self): self.validate_all( "SELECT N'test'", - write={"spark": "SELECT N'test'"}, + write={"spark": "SELECT 'test'"}, ) self.validate_all( "SELECT n'test'", - write={"spark": "SELECT N'test'"}, + write={"spark": "SELECT 'test'"}, ) self.validate_all( "SELECT '''test'''", @@ -621,3 +621,13 @@ WHERE "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME ALL AS alias""", }, ) + + def test_current_user(self): + self.validate_all( + "SUSER_NAME()", + write={"spark": "CURRENT_USER()"}, + ) + self.validate_all( + "SUSER_SNAME()", + write={"spark": "CURRENT_USER()"}, + ) |