diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dataframe/integration/dataframe_validator.py | 5 | ||||
-rw-r--r-- | tests/dataframe/unit/test_column.py | 3 | ||||
-rw-r--r-- | tests/dataframe/unit/test_functions.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_databricks.py | 33 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 22 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 12 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 94 | ||||
-rw-r--r-- | tests/fixtures/optimizer/qualify_columns.sql | 17 | ||||
-rw-r--r-- | tests/test_build.py | 36 | ||||
-rw-r--r-- | tests/test_expressions.py | 6 |
10 files changed, 223 insertions, 9 deletions
diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py index 6c4642f..4a89c78 100644 --- a/tests/dataframe/integration/dataframe_validator.py +++ b/tests/dataframe/integration/dataframe_validator.py @@ -1,3 +1,4 @@ +import sys import typing as t import unittest import warnings @@ -9,7 +10,9 @@ if t.TYPE_CHECKING: from pyspark.sql import DataFrame as SparkDataFrame -@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set") +@unittest.skipIf( + SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set" +) class DataFrameValidator(unittest.TestCase): spark = None sqlglot = None diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py index df0ebff..977971e 100644 --- a/tests/dataframe/unit/test_column.py +++ b/tests/dataframe/unit/test_column.py @@ -146,7 +146,8 @@ class TestDataframeColumn(unittest.TestCase): F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(), ) self.assertEqual( - "cola BETWEEN CAST('2022-01-01 01:01:01' AS TIMESTAMP) " "AND CAST('2022-03-01 01:01:01' AS TIMESTAMP)", + "cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) " + "AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)", F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(), ) diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 97753bd..eadbb93 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -30,7 +30,7 @@ class TestFunctions(unittest.TestCase): test_date = SF.lit(datetime.date(2022, 1, 1)) self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) test_datetime = SF.lit(datetime.datetime(2022, 1, 1, 1, 1, 1)) - self.assertEqual("CAST('2022-01-01 01:01:01' AS TIMESTAMP)", test_datetime.sql()) + self.assertEqual("CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP)", test_datetime.sql()) test_dict = SF.lit({"cola": 1, "colb": "test"}) self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) @@ -52,7 +52,7 @@ class TestFunctions(unittest.TestCase): test_date = SF.col(datetime.date(2022, 1, 1)) self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) test_datetime = SF.col(datetime.datetime(2022, 1, 1, 1, 1, 1)) - self.assertEqual("CAST('2022-01-01 01:01:01' AS TIMESTAMP)", test_datetime.sql()) + self.assertEqual("CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP)", test_datetime.sql()) test_dict = SF.col({"cola": 1, "colb": "test"}) self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py new file mode 100644 index 0000000..e242e73 --- /dev/null +++ b/tests/dialects/test_databricks.py @@ -0,0 +1,33 @@ +from tests.dialects.test_dialect import Validator + + +class TestDatabricks(Validator): + dialect = "databricks" + + def test_datediff(self): + self.validate_all( + "SELECT DATEDIFF(year, 'start', 'end')", + write={ + "tsql": "SELECT DATEDIFF(year, 'start', 'end')", + "databricks": "SELECT DATEDIFF(year, 'start', 'end')", + }, + ) + + def test_add_date(self): + self.validate_all( + "SELECT DATEADD(year, 1, '2020-01-01')", + write={ + "tsql": "SELECT DATEADD(year, 1, '2020-01-01')", + "databricks": "SELECT DATEADD(year, 1, '2020-01-01')", + }, + ) + self.validate_all( + "SELECT DATEDIFF('end', 'start')", write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"} + ) + self.validate_all( + "SELECT DATE_ADD('2020-01-01', 1)", + write={ + "tsql": "SELECT DATEADD(DAY, 1, '2020-01-01')", + "databricks": "SELECT DATEADD(DAY, 1, '2020-01-01')", + }, + ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 5d1cf13..3b837df 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -82,6 +82,28 @@ class TestDialect(Validator): }, ) self.validate_all( + "CAST(a AS BINARY(4))", + read={ + "presto": "CAST(a AS VARBINARY(4))", + "sqlite": "CAST(a AS VARBINARY(4))", + }, + write={ + "bigquery": "CAST(a AS BINARY(4))", + "clickhouse": "CAST(a AS BINARY(4))", + "duckdb": "CAST(a AS BINARY(4))", + "mysql": "CAST(a AS BINARY(4))", + "hive": "CAST(a AS BINARY(4))", + "oracle": "CAST(a AS BLOB(4))", + "postgres": "CAST(a AS BYTEA(4))", + "presto": "CAST(a AS VARBINARY(4))", + "redshift": "CAST(a AS VARBYTE(4))", + "snowflake": "CAST(a AS BINARY(4))", + "sqlite": "CAST(a AS BLOB(4))", + "spark": "CAST(a AS BINARY(4))", + "starrocks": "CAST(a AS BINARY(4))", + }, + ) + self.validate_all( "CAST(MAP('a', '1') AS MAP(TEXT, TEXT))", write={ "clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))", diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 159b643..fea2311 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -293,6 +293,18 @@ class TestSnowflake(Validator): "CREATE TABLE a (x DATE, y BIGINT) WITH (PARTITION BY (x), integration='q', auto_refresh=TRUE, file_format=(type = parquet))" ) self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x") + self.validate_all( + "CREATE OR REPLACE TRANSIENT TABLE a (id INT)", + read={ + "postgres": "CREATE OR REPLACE TRANSIENT TABLE a (id INT)", + "snowflake": "CREATE OR REPLACE TRANSIENT TABLE a (id INT)", + }, + write={ + "postgres": "CREATE OR REPLACE TABLE a (id INT)", + "mysql": "CREATE OR REPLACE TABLE a (id INT)", + "snowflake": "CREATE OR REPLACE TRANSIENT TABLE a (id INT)", + }, + ) def test_user_defined_functions(self): self.validate_all( diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 2a20163..d22a9c2 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -260,6 +260,20 @@ class TestTSQL(Validator): "spark": "CAST(x AS INT)", }, ) + self.validate_all( + "SELECT CONVERT(VARCHAR(10), testdb.dbo.test.x, 120) y FROM testdb.dbo.test", + write={ + "mysql": "SELECT CAST(TIME_TO_STR(testdb.dbo.test.x, '%Y-%m-%d %H:%M:%S') AS VARCHAR(10)) AS y FROM testdb.dbo.test", + "spark": "SELECT CAST(DATE_FORMAT(testdb.dbo.test.x, 'yyyy-MM-dd HH:mm:ss') AS VARCHAR(10)) AS y FROM testdb.dbo.test", + }, + ) + self.validate_all( + "SELECT CONVERT(VARCHAR(10), y.x) z FROM testdb.dbo.test y", + write={ + "mysql": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y", + "spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y", + }, + ) def test_add_date(self): self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')") @@ -267,7 +281,10 @@ class TestTSQL(Validator): "SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"} ) self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"}) - self.validate_all("SELECT DATEADD(wk, 1, '2017/08/25')", write={"spark": "SELECT DATE_ADD('2017/08/25', 7)"}) + self.validate_all( + "SELECT DATEADD(wk, 1, '2017/08/25')", + write={"spark": "SELECT DATE_ADD('2017/08/25', 7)", "databricks": "SELECT DATEADD(week, 1, '2017/08/25')"}, + ) def test_date_diff(self): self.validate_identity("SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')") @@ -279,11 +296,19 @@ class TestTSQL(Validator): }, ) self.validate_all( - "SELECT DATEDIFF(month, 'start','end')", - write={"spark": "SELECT MONTHS_BETWEEN('end', 'start')", "tsql": "SELECT DATEDIFF(month, 'start', 'end')"}, + "SELECT DATEDIFF(mm, 'start','end')", + write={ + "spark": "SELECT MONTHS_BETWEEN('end', 'start')", + "tsql": "SELECT DATEDIFF(month, 'start', 'end')", + "databricks": "SELECT DATEDIFF(month, 'start', 'end')", + }, ) self.validate_all( - "SELECT DATEDIFF(quarter, 'start', 'end')", write={"spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3"} + "SELECT DATEDIFF(quarter, 'start', 'end')", + write={ + "spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3", + "databricks": "SELECT DATEDIFF(quarter, 'start', 'end')", + }, ) def test_iif(self): @@ -294,3 +319,64 @@ class TestTSQL(Validator): "spark": "SELECT IF(cond, 'True', 'False')", }, ) + + def test_lateral_subquery(self): + self.validate_all( + "SELECT x.a, x.b, t.v, t.y FROM x CROSS APPLY (SELECT v, y FROM t) t(v, y)", + write={ + "spark": "SELECT x.a, x.b, t.v, t.y FROM x JOIN LATERAL (SELECT v, y FROM t) AS t(v, y)", + }, + ) + self.validate_all( + "SELECT x.a, x.b, t.v, t.y FROM x OUTER APPLY (SELECT v, y FROM t) t(v, y)", + write={ + "spark": "SELECT x.a, x.b, t.v, t.y FROM x LEFT JOIN LATERAL (SELECT v, y FROM t) AS t(v, y)", + }, + ) + + def test_lateral_table_valued_function(self): + self.validate_all( + "SELECT t.x, y.z FROM x CROSS APPLY tvfTest(t.x)y(z)", + write={ + "spark": "SELECT t.x, y.z FROM x JOIN LATERAL TVFTEST(t.x) y AS z", + }, + ) + self.validate_all( + "SELECT t.x, y.z FROM x OUTER APPLY tvfTest(t.x)y(z)", + write={ + "spark": "SELECT t.x, y.z FROM x LEFT JOIN LATERAL TVFTEST(t.x) y AS z", + }, + ) + + def test_top(self): + self.validate_all( + "SELECT TOP 3 * FROM A", + write={ + "spark": "SELECT * FROM A LIMIT 3", + }, + ) + self.validate_all( + "SELECT TOP (3) * FROM A", + write={ + "spark": "SELECT * FROM A LIMIT 3", + }, + ) + + def test_format(self): + self.validate_identity("SELECT FORMAT('01-01-1991', 'd.mm.yyyy')") + self.validate_identity("SELECT FORMAT(12345, '###.###.###')") + self.validate_identity("SELECT FORMAT(1234567, 'f')") + self.validate_all( + "SELECT FORMAT(1000000.01,'###,###.###')", + write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"}, + ) + self.validate_all("SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}) + self.validate_all( + "SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')", + write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"}, + ) + self.validate_all( + "SELECT FORMAT(date_col, 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"} + ) + self.validate_all("SELECT FORMAT(date_col, 'm')", write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"}) + self.validate_all("SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}) diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 858f232..a958c08 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -31,6 +31,23 @@ SELECT x.a + x.b AS "_col_0" FROM x AS x; SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a; SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a; +SELECT SUM(a) AS c FROM x HAVING SUM(a) > 3; +SELECT SUM(x.a) AS c FROM x AS x HAVING SUM(x.a) > 3; + +SELECT SUM(a) AS a FROM x HAVING SUM(a) > 3; +SELECT SUM(x.a) AS a FROM x AS x HAVING SUM(x.a) > 3; + +SELECT SUM(a) AS c FROM x HAVING c > 3; +SELECT SUM(x.a) AS c FROM x AS x HAVING c > 3; + +# execute: false +SELECT SUM(a) AS a FROM x HAVING a > 3; +SELECT SUM(x.a) AS a FROM x AS x HAVING a > 3; + +# execute: false +SELECT SUM(a) AS c FROM x HAVING SUM(c) > 3; +SELECT SUM(x.a) AS c FROM x AS x HAVING SUM(c) > 3; + SELECT a AS j, b FROM x ORDER BY j; SELECT x.a AS j, x.b AS b FROM x AS x ORDER BY j; diff --git a/tests/test_build.py b/tests/test_build.py index f51996d..b7b6865 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -154,6 +154,42 @@ class TestBuild(unittest.TestCase): "SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) AS aliased ON a = b", ), ( + lambda: select("x", "y", "z") + .from_("merged_df") + .join("vte_diagnosis_df", using=["patient_id", "encounter_id"]), + "SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)", + ), + ( + lambda: select("x", "y", "z") + .from_("merged_df") + .join("vte_diagnosis_df", using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")]), + "SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)", + ), + ( + lambda: parse_one("JOIN x", into=exp.Join).on("y = 1", "z = 1"), + "JOIN x ON y = 1 AND z = 1", + ), + ( + lambda: parse_one("JOIN x", into=exp.Join).on("y = 1"), + "JOIN x ON y = 1", + ), + ( + lambda: parse_one("JOIN x", into=exp.Join).using("bar", "bob"), + "JOIN x USING (bar, bob)", + ), + ( + lambda: parse_one("JOIN x", into=exp.Join).using("bar"), + "JOIN x USING (bar)", + ), + ( + lambda: select("x").from_("foo").join("bla", using="bob"), + "SELECT x FROM foo JOIN bla USING (bob)", + ), + ( + lambda: select("x").from_("foo").join("bla", using="bob"), + "SELECT x FROM foo JOIN bla USING (bob)", + ), + ( lambda: select("x", "COUNT(y)").from_("tbl").group_by("x").having("COUNT(y) > 0"), "SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 0", ), diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 9af59d9..adfd329 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -538,7 +538,11 @@ class TestExpressions(unittest.TestCase): ((1, "2", None), "(1, '2', NULL)"), ([1, "2", None], "ARRAY(1, '2', NULL)"), ({"x": None}, "MAP('x', NULL)"), - (datetime.datetime(2022, 10, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01 01:01:01')"), + (datetime.datetime(2022, 10, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')"), + ( + datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc), + "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')", + ), (datetime.date(2022, 10, 1), "DATE_STR_TO_DATE('2022-10-01')"), ]: with self.subTest(value): |