summaryrefslogtreecommitdiffstats
path: root/tests/dialects
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects')
-rw-r--r--tests/dialects/test_databricks.py33
-rw-r--r--tests/dialects/test_dialect.py22
-rw-r--r--tests/dialects/test_snowflake.py12
-rw-r--r--tests/dialects/test_tsql.py94
4 files changed, 157 insertions, 4 deletions
diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py
new file mode 100644
index 0000000..e242e73
--- /dev/null
+++ b/tests/dialects/test_databricks.py
@@ -0,0 +1,33 @@
+from tests.dialects.test_dialect import Validator
+
+
+class TestDatabricks(Validator):
+ dialect = "databricks"
+
+ def test_datediff(self):
+ self.validate_all(
+ "SELECT DATEDIFF(year, 'start', 'end')",
+ write={
+ "tsql": "SELECT DATEDIFF(year, 'start', 'end')",
+ "databricks": "SELECT DATEDIFF(year, 'start', 'end')",
+ },
+ )
+
+ def test_add_date(self):
+ self.validate_all(
+ "SELECT DATEADD(year, 1, '2020-01-01')",
+ write={
+ "tsql": "SELECT DATEADD(year, 1, '2020-01-01')",
+ "databricks": "SELECT DATEADD(year, 1, '2020-01-01')",
+ },
+ )
+ self.validate_all(
+ "SELECT DATEDIFF('end', 'start')", write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"}
+ )
+ self.validate_all(
+ "SELECT DATE_ADD('2020-01-01', 1)",
+ write={
+ "tsql": "SELECT DATEADD(DAY, 1, '2020-01-01')",
+ "databricks": "SELECT DATEADD(DAY, 1, '2020-01-01')",
+ },
+ )
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 5d1cf13..3b837df 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -82,6 +82,28 @@ class TestDialect(Validator):
},
)
self.validate_all(
+ "CAST(a AS BINARY(4))",
+ read={
+ "presto": "CAST(a AS VARBINARY(4))",
+ "sqlite": "CAST(a AS VARBINARY(4))",
+ },
+ write={
+ "bigquery": "CAST(a AS BINARY(4))",
+ "clickhouse": "CAST(a AS BINARY(4))",
+ "duckdb": "CAST(a AS BINARY(4))",
+ "mysql": "CAST(a AS BINARY(4))",
+ "hive": "CAST(a AS BINARY(4))",
+ "oracle": "CAST(a AS BLOB(4))",
+ "postgres": "CAST(a AS BYTEA(4))",
+ "presto": "CAST(a AS VARBINARY(4))",
+ "redshift": "CAST(a AS VARBYTE(4))",
+ "snowflake": "CAST(a AS BINARY(4))",
+ "sqlite": "CAST(a AS BLOB(4))",
+ "spark": "CAST(a AS BINARY(4))",
+ "starrocks": "CAST(a AS BINARY(4))",
+ },
+ )
+ self.validate_all(
"CAST(MAP('a', '1') AS MAP(TEXT, TEXT))",
write={
"clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))",
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 159b643..fea2311 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -293,6 +293,18 @@ class TestSnowflake(Validator):
"CREATE TABLE a (x DATE, y BIGINT) WITH (PARTITION BY (x), integration='q', auto_refresh=TRUE, file_format=(type = parquet))"
)
self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x")
+ self.validate_all(
+ "CREATE OR REPLACE TRANSIENT TABLE a (id INT)",
+ read={
+ "postgres": "CREATE OR REPLACE TRANSIENT TABLE a (id INT)",
+ "snowflake": "CREATE OR REPLACE TRANSIENT TABLE a (id INT)",
+ },
+ write={
+ "postgres": "CREATE OR REPLACE TABLE a (id INT)",
+ "mysql": "CREATE OR REPLACE TABLE a (id INT)",
+ "snowflake": "CREATE OR REPLACE TRANSIENT TABLE a (id INT)",
+ },
+ )
def test_user_defined_functions(self):
self.validate_all(
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index 2a20163..d22a9c2 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -260,6 +260,20 @@ class TestTSQL(Validator):
"spark": "CAST(x AS INT)",
},
)
+ self.validate_all(
+ "SELECT CONVERT(VARCHAR(10), testdb.dbo.test.x, 120) y FROM testdb.dbo.test",
+ write={
+ "mysql": "SELECT CAST(TIME_TO_STR(testdb.dbo.test.x, '%Y-%m-%d %H:%M:%S') AS VARCHAR(10)) AS y FROM testdb.dbo.test",
+ "spark": "SELECT CAST(DATE_FORMAT(testdb.dbo.test.x, 'yyyy-MM-dd HH:mm:ss') AS VARCHAR(10)) AS y FROM testdb.dbo.test",
+ },
+ )
+ self.validate_all(
+ "SELECT CONVERT(VARCHAR(10), y.x) z FROM testdb.dbo.test y",
+ write={
+ "mysql": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y",
+ "spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y",
+ },
+ )
def test_add_date(self):
self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')")
@@ -267,7 +281,10 @@ class TestTSQL(Validator):
"SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"}
)
self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"})
- self.validate_all("SELECT DATEADD(wk, 1, '2017/08/25')", write={"spark": "SELECT DATE_ADD('2017/08/25', 7)"})
+ self.validate_all(
+ "SELECT DATEADD(wk, 1, '2017/08/25')",
+ write={"spark": "SELECT DATE_ADD('2017/08/25', 7)", "databricks": "SELECT DATEADD(week, 1, '2017/08/25')"},
+ )
def test_date_diff(self):
self.validate_identity("SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')")
@@ -279,11 +296,19 @@ class TestTSQL(Validator):
},
)
self.validate_all(
- "SELECT DATEDIFF(month, 'start','end')",
- write={"spark": "SELECT MONTHS_BETWEEN('end', 'start')", "tsql": "SELECT DATEDIFF(month, 'start', 'end')"},
+ "SELECT DATEDIFF(mm, 'start','end')",
+ write={
+ "spark": "SELECT MONTHS_BETWEEN('end', 'start')",
+ "tsql": "SELECT DATEDIFF(month, 'start', 'end')",
+ "databricks": "SELECT DATEDIFF(month, 'start', 'end')",
+ },
)
self.validate_all(
- "SELECT DATEDIFF(quarter, 'start', 'end')", write={"spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3"}
+ "SELECT DATEDIFF(quarter, 'start', 'end')",
+ write={
+ "spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3",
+ "databricks": "SELECT DATEDIFF(quarter, 'start', 'end')",
+ },
)
def test_iif(self):
@@ -294,3 +319,64 @@ class TestTSQL(Validator):
"spark": "SELECT IF(cond, 'True', 'False')",
},
)
+
+ def test_lateral_subquery(self):
+ self.validate_all(
+ "SELECT x.a, x.b, t.v, t.y FROM x CROSS APPLY (SELECT v, y FROM t) t(v, y)",
+ write={
+ "spark": "SELECT x.a, x.b, t.v, t.y FROM x JOIN LATERAL (SELECT v, y FROM t) AS t(v, y)",
+ },
+ )
+ self.validate_all(
+ "SELECT x.a, x.b, t.v, t.y FROM x OUTER APPLY (SELECT v, y FROM t) t(v, y)",
+ write={
+ "spark": "SELECT x.a, x.b, t.v, t.y FROM x LEFT JOIN LATERAL (SELECT v, y FROM t) AS t(v, y)",
+ },
+ )
+
+ def test_lateral_table_valued_function(self):
+ self.validate_all(
+ "SELECT t.x, y.z FROM x CROSS APPLY tvfTest(t.x)y(z)",
+ write={
+ "spark": "SELECT t.x, y.z FROM x JOIN LATERAL TVFTEST(t.x) y AS z",
+ },
+ )
+ self.validate_all(
+ "SELECT t.x, y.z FROM x OUTER APPLY tvfTest(t.x)y(z)",
+ write={
+ "spark": "SELECT t.x, y.z FROM x LEFT JOIN LATERAL TVFTEST(t.x) y AS z",
+ },
+ )
+
+ def test_top(self):
+ self.validate_all(
+ "SELECT TOP 3 * FROM A",
+ write={
+ "spark": "SELECT * FROM A LIMIT 3",
+ },
+ )
+ self.validate_all(
+ "SELECT TOP (3) * FROM A",
+ write={
+ "spark": "SELECT * FROM A LIMIT 3",
+ },
+ )
+
+ def test_format(self):
+ self.validate_identity("SELECT FORMAT('01-01-1991', 'd.mm.yyyy')")
+ self.validate_identity("SELECT FORMAT(12345, '###.###.###')")
+ self.validate_identity("SELECT FORMAT(1234567, 'f')")
+ self.validate_all(
+ "SELECT FORMAT(1000000.01,'###,###.###')",
+ write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"},
+ )
+ self.validate_all("SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"})
+ self.validate_all(
+ "SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')",
+ write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"},
+ )
+ self.validate_all(
+ "SELECT FORMAT(date_col, 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"}
+ )
+ self.validate_all("SELECT FORMAT(date_col, 'm')", write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"})
+ self.validate_all("SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"})