summaryrefslogtreecommitdiffstats
path: root/tests/dialects
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/dialects/test_clickhouse.py1
-rw-r--r--tests/dialects/test_dialect.py54
-rw-r--r--tests/dialects/test_duckdb.py23
-rw-r--r--tests/dialects/test_hive.py8
-rw-r--r--tests/dialects/test_mysql.py35
-rw-r--r--tests/dialects/test_presto.py10
-rw-r--r--tests/dialects/test_redshift.py11
-rw-r--r--tests/dialects/test_snowflake.py15
-rw-r--r--tests/dialects/test_spark.py10
-rw-r--r--tests/dialects/test_sqlite.py2
-rw-r--r--tests/dialects/test_tsql.py14
11 files changed, 167 insertions, 16 deletions
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index d206bb1..40a3a04 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -5,6 +5,7 @@ class TestClickhouse(Validator):
dialect = "clickhouse"
def test_clickhouse(self):
+ self.validate_identity("SELECT match('abc', '([a-z]+)')")
self.validate_identity("dictGet(x, 'y')")
self.validate_identity("SELECT * FROM x FINAL")
self.validate_identity("SELECT * FROM x AS y FINAL")
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 0805e9c..3558d62 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -290,6 +290,60 @@ class TestDialect(Validator):
read={"postgres": "INET '127.0.0.1/32'"},
)
+ def test_decode(self):
+ self.validate_identity("DECODE(bin, charset)")
+
+ self.validate_all(
+ "SELECT DECODE(a, 1, 'one')",
+ write={
+ "": "SELECT CASE WHEN a = 1 THEN 'one' END",
+ "oracle": "SELECT CASE WHEN a = 1 THEN 'one' END",
+ "redshift": "SELECT CASE WHEN a = 1 THEN 'one' END",
+ "snowflake": "SELECT CASE WHEN a = 1 THEN 'one' END",
+ "spark": "SELECT CASE WHEN a = 1 THEN 'one' END",
+ },
+ )
+ self.validate_all(
+ "SELECT DECODE(a, 1, 'one', 'default')",
+ write={
+ "": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END",
+ "oracle": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END",
+ "redshift": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END",
+ "snowflake": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END",
+ "spark": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END",
+ },
+ )
+ self.validate_all(
+ "SELECT DECODE(a, NULL, 'null')",
+ write={
+ "": "SELECT CASE WHEN a IS NULL THEN 'null' END",
+ "oracle": "SELECT CASE WHEN a IS NULL THEN 'null' END",
+ "redshift": "SELECT CASE WHEN a IS NULL THEN 'null' END",
+ "snowflake": "SELECT CASE WHEN a IS NULL THEN 'null' END",
+ "spark": "SELECT CASE WHEN a IS NULL THEN 'null' END",
+ },
+ )
+ self.validate_all(
+ "SELECT DECODE(a, b, c)",
+ write={
+ "": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END",
+ "oracle": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END",
+ "redshift": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END",
+ "snowflake": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END",
+ "spark": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END",
+ },
+ )
+ self.validate_all(
+ "SELECT DECODE(tbl.col, 'some_string', 'foo')",
+ write={
+ "": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END",
+ "oracle": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END",
+ "redshift": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END",
+ "snowflake": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END",
+ "spark": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END",
+ },
+ )
+
def test_if_null(self):
self.validate_all(
"SELECT IFNULL(1, NULL) FROM foo",
diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py
index a15e6b4..245d82a 100644
--- a/tests/dialects/test_duckdb.py
+++ b/tests/dialects/test_duckdb.py
@@ -74,6 +74,17 @@ class TestDuckDB(Validator):
"spark": "TO_TIMESTAMP(x, 'M/d/yy h:mm a')",
},
)
+ self.validate_all(
+ "CAST(start AS TIMESTAMPTZ) AT TIME ZONE 'America/New_York'",
+ read={
+ "snowflake": "CONVERT_TIMEZONE('America/New_York', CAST(start AS TIMESTAMPTZ))",
+ },
+ write={
+ "bigquery": "TIMESTAMP(DATETIME(CAST(start AS TIMESTAMPTZ), 'America/New_York'))",
+ "duckdb": "CAST(start AS TIMESTAMPTZ) AT TIME ZONE 'America/New_York'",
+ "snowflake": "CONVERT_TIMEZONE('America/New_York', CAST(start AS TIMESTAMPTZ))",
+ },
+ )
def test_sample(self):
self.validate_all(
@@ -421,6 +432,18 @@ class TestDuckDB(Validator):
"snowflake": "CAST(COL AS ARRAY)",
},
)
+ self.validate_all(
+ "CAST([STRUCT_PACK(a := 1)] AS STRUCT(a BIGINT)[])",
+ write={
+ "duckdb": "CAST(LIST_VALUE({'a': 1}) AS STRUCT(a BIGINT)[])",
+ },
+ )
+ self.validate_all(
+ "CAST([[STRUCT_PACK(a := 1)]] AS STRUCT(a BIGINT)[][])",
+ write={
+ "duckdb": "CAST(LIST_VALUE(LIST_VALUE({'a': 1})) AS STRUCT(a BIGINT)[][])",
+ },
+ )
def test_bool_or(self):
self.validate_all(
diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py
index 0161f1e..1a83575 100644
--- a/tests/dialects/test_hive.py
+++ b/tests/dialects/test_hive.py
@@ -689,3 +689,11 @@ class TestHive(Validator):
self.validate_identity("'\\z'")
self.validate_identity("'\\\z'")
self.validate_identity("'\\\\z'")
+
+ def test_data_type(self):
+ self.validate_all(
+ "CAST(a AS BIT)",
+ write={
+ "hive": "CAST(a AS BOOLEAN)",
+ },
+ )
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index 5059d05..f618728 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -204,6 +204,41 @@ class TestMySQL(Validator):
},
)
+ def test_match_against(self):
+ self.validate_all(
+ "MATCH(col1, col2, col3) AGAINST('abc')",
+ read={
+ "": "MATCH(col1, col2, col3) AGAINST('abc')",
+ "mysql": "MATCH(col1, col2, col3) AGAINST('abc')",
+ },
+ write={
+ "": "MATCH(col1, col2, col3) AGAINST('abc')",
+ "mysql": "MATCH(col1, col2, col3) AGAINST('abc')",
+ },
+ )
+ self.validate_all(
+ "MATCH(col1, col2) AGAINST('abc' IN NATURAL LANGUAGE MODE)",
+ write={"mysql": "MATCH(col1, col2) AGAINST('abc' IN NATURAL LANGUAGE MODE)"},
+ )
+ self.validate_all(
+ "MATCH(col1, col2) AGAINST('abc' IN NATURAL LANGUAGE MODE WITH QUERY EXPANSION)",
+ write={
+ "mysql": "MATCH(col1, col2) AGAINST('abc' IN NATURAL LANGUAGE MODE WITH QUERY EXPANSION)"
+ },
+ )
+ self.validate_all(
+ "MATCH(col1, col2) AGAINST('abc' IN BOOLEAN MODE)",
+ write={"mysql": "MATCH(col1, col2) AGAINST('abc' IN BOOLEAN MODE)"},
+ )
+ self.validate_all(
+ "MATCH(col1, col2) AGAINST('abc' WITH QUERY EXPANSION)",
+ write={"mysql": "MATCH(col1, col2) AGAINST('abc' WITH QUERY EXPANSION)"},
+ )
+ self.validate_all(
+ "MATCH(a.b) AGAINST('abc')",
+ write={"mysql": "MATCH(a.b) AGAINST('abc')"},
+ )
+
def test_mysql(self):
self.validate_all(
"SELECT a FROM tbl FOR UPDATE",
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index 1762e7a..1007899 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -539,6 +539,16 @@ class TestPresto(Validator):
"hive": "SELECT a, b, c, d, SUM(y) FROM z GROUP BY d, GROUPING SETS ((b, c)), CUBE (a), ROLLUP (a)",
},
)
+ self.validate_all(
+ "JSON_FORMAT(x)",
+ read={
+ "spark": "TO_JSON(x)",
+ },
+ write={
+ "presto": "JSON_FORMAT(x)",
+ "spark": "TO_JSON(x)",
+ },
+ )
def test_encode_decode(self):
self.validate_all(
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index ff730f8..0933051 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -69,9 +69,11 @@ class TestRedshift(Validator):
self.validate_all(
"DECODE(x, a, b, c, d)",
write={
- "": "MATCHES(x, a, b, c, d)",
- "oracle": "DECODE(x, a, b, c, d)",
- "snowflake": "DECODE(x, a, b, c, d)",
+ "": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END",
+ "oracle": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END",
+ "redshift": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END",
+ "snowflake": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END",
+ "spark": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END",
},
)
self.validate_all(
@@ -97,9 +99,6 @@ class TestRedshift(Validator):
)
def test_identity(self):
- self.validate_identity(
- "SELECT DECODE(COL1, 'replace_this', 'with_this', 'replace_that', 'with_that')"
- )
self.validate_identity("CAST('bla' AS SUPER)")
self.validate_identity("CREATE TABLE real1 (realcol REAL)")
self.validate_identity("CAST('foo' AS HLLSKETCH)")
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 940fa50..eb423a5 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -6,6 +6,11 @@ class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
+ self.validate_identity("SELECT HLL(*)")
+ self.validate_identity("SELECT HLL(a)")
+ self.validate_identity("SELECT HLL(DISTINCT t.a)")
+ self.validate_identity("SELECT HLL(a, b, c)")
+ self.validate_identity("SELECT HLL(DISTINCT a, b, c)")
self.validate_identity("$x")
self.validate_identity("SELECT REGEXP_LIKE(a, b, c)")
self.validate_identity("PUT file:///dir/tmp.csv @%table")
@@ -376,14 +381,10 @@ class TestSnowflake(Validator):
},
)
self.validate_all(
- "DECODE(x, a, b, c, d)",
- read={
- "": "MATCHES(x, a, b, c, d)",
- },
+ "DECODE(x, a, b, c, d, e)",
write={
- "": "MATCHES(x, a, b, c, d)",
- "oracle": "DECODE(x, a, b, c, d)",
- "snowflake": "DECODE(x, a, b, c, d)",
+ "": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d ELSE e END",
+ "snowflake": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d ELSE e END",
},
)
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index b12f272..0da2931 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -345,3 +345,13 @@ TBLPROPERTIES (
"SELECT a, LOGICAL_OR(b) FROM table GROUP BY a",
write={"spark": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"},
)
+
+ def test_current_user(self):
+ self.validate_all(
+ "CURRENT_USER",
+ write={"spark": "CURRENT_USER()"},
+ )
+ self.validate_all(
+ "CURRENT_USER()",
+ write={"spark": "CURRENT_USER()"},
+ )
diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py
index fd9e52b..583d5be 100644
--- a/tests/dialects/test_sqlite.py
+++ b/tests/dialects/test_sqlite.py
@@ -12,7 +12,7 @@ class TestSQLite(Validator):
self.validate_identity("INSERT OR ROLLBACK INTO foo (x, y) VALUES (1, 2)")
self.validate_all(
- "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)",
+ "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)",
write={"sqlite": "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)"},
)
self.validate_all(
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index 60867be..d9ee4ae 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -547,11 +547,11 @@ WHERE
def test_string(self):
self.validate_all(
"SELECT N'test'",
- write={"spark": "SELECT N'test'"},
+ write={"spark": "SELECT 'test'"},
)
self.validate_all(
"SELECT n'test'",
- write={"spark": "SELECT N'test'"},
+ write={"spark": "SELECT 'test'"},
)
self.validate_all(
"SELECT '''test'''",
@@ -621,3 +621,13 @@ WHERE
"tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME ALL AS alias""",
},
)
+
+ def test_current_user(self):
+ self.validate_all(
+ "SUSER_NAME()",
+ write={"spark": "CURRENT_USER()"},
+ )
+ self.validate_all(
+ "SUSER_SNAME()",
+ write={"spark": "CURRENT_USER()"},
+ )