diff options
Diffstat (limited to 'tests/dialects/test_spark.py')
-rw-r--r-- | tests/dialects/test_spark.py | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 1cf1ede..18f1fb7 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -1,6 +1,7 @@ from unittest import mock from sqlglot import exp, parse_one +from sqlglot.dialects.dialect import Dialects from tests.dialects.test_dialect import Validator @@ -245,13 +246,16 @@ TBLPROPERTIES ( self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)") self.validate_identity("REFRESH TABLE a.b.c") self.validate_identity("INTERVAL -86 DAYS") - self.validate_identity("SELECT UNIX_TIMESTAMP()") self.validate_identity("TRIM(' SparkSQL ')") self.validate_identity("TRIM(BOTH 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')") self.validate_identity("SPLIT(str, pattern, lim)") self.validate_identity( + "SELECT UNIX_TIMESTAMP()", + "SELECT UNIX_TIMESTAMP(CURRENT_TIMESTAMP())", + ) + self.validate_identity( "SELECT CAST('2023-01-01' AS TIMESTAMP) + INTERVAL 23 HOUR + 59 MINUTE + 59 SECONDS", "SELECT CAST('2023-01-01' AS TIMESTAMP) + INTERVAL '23' HOUR + INTERVAL '59' MINUTE + INTERVAL '59' SECONDS", ) @@ -281,6 +285,18 @@ TBLPROPERTIES ( ) self.validate_all( + "SELECT SPLIT('123|789', '\\\\|')", + read={ + "duckdb": "SELECT STR_SPLIT_REGEX('123|789', '\\|')", + "presto": "SELECT REGEXP_SPLIT('123|789', '\\|')", + }, + write={ + "duckdb": "SELECT STR_SPLIT_REGEX('123|789', '\\|')", + "presto": "SELECT REGEXP_SPLIT('123|789', '\\|')", + "spark": "SELECT SPLIT('123|789', '\\\\|')", + }, + ) + self.validate_all( "WITH tbl AS (SELECT 1 AS id, 'eggy' AS name UNION ALL SELECT NULL AS id, 'jake' AS name) SELECT COUNT(DISTINCT id, name) AS cnt FROM tbl", write={ "clickhouse": "WITH tbl AS (SELECT 1 AS id, 'eggy' AS name UNION ALL SELECT NULL AS id, 'jake' AS name) SELECT COUNT(DISTINCT id, name) AS cnt FROM tbl", @@ -366,7 +382,7 @@ TBLPROPERTIES ( "hive": "SELECT CAST(DATEDIFF(TO_DATE('2020-12-31'), TO_DATE('2020-01-01')) / 7 AS INT)", "postgres": "SELECT CAST(EXTRACT(days FROM (CAST(CAST('2020-12-31' AS DATE) AS TIMESTAMP) - CAST(CAST('2020-01-01' AS DATE) AS TIMESTAMP))) / 7 AS BIGINT)", "redshift": "SELECT DATEDIFF(WEEK, CAST('2020-01-01' AS DATE), CAST('2020-12-31' AS DATE))", - "snowflake": "SELECT DATEDIFF(WEEK, CAST('2020-01-01' AS DATE), CAST('2020-12-31' AS DATE))", + "snowflake": "SELECT DATEDIFF(WEEK, TO_DATE('2020-01-01'), TO_DATE('2020-12-31'))", "spark": "SELECT DATEDIFF(WEEK, TO_DATE('2020-01-01'), TO_DATE('2020-12-31'))", }, ) @@ -644,10 +660,10 @@ TBLPROPERTIES ( "SELECT TRANSFORM(zip_code, name, age) USING 'cat' AS (a STRING, b STRING, c STRING) FROM person WHERE zip_code > 94511" ) self.validate_identity( - "SELECT TRANSFORM(name, age) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' NULL DEFINED AS 'NULL' USING 'cat' AS (name_age STRING) ROW FORMAT DELIMITED FIELDS TERMINATED BY '@' LINES TERMINATED BY '\n' NULL DEFINED AS 'NULL' FROM person" + "SELECT TRANSFORM(name, age) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\\n' NULL DEFINED AS 'NULL' USING 'cat' AS (name_age STRING) ROW FORMAT DELIMITED FIELDS TERMINATED BY '@' LINES TERMINATED BY '\\n' NULL DEFINED AS 'NULL' FROM person" ) self.validate_identity( - "SELECT TRANSFORM(zip_code, name, age) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES ('field.delim'='\t') USING 'cat' AS (a STRING, b STRING, c STRING) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES ('field.delim'='\t') FROM person WHERE zip_code > 94511" + "SELECT TRANSFORM(zip_code, name, age) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES ('field.delim'='\\t') USING 'cat' AS (a STRING, b STRING, c STRING) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES ('field.delim'='\\t') FROM person WHERE zip_code > 94511" ) self.validate_identity( "SELECT TRANSFORM(zip_code, name, age) USING 'cat' FROM person WHERE zip_code > 94500" @@ -720,3 +736,16 @@ TBLPROPERTIES ( "presto": "SELECT col, pos, IF(_u_2.pos_2 = _u_3.pos_3, _u_3.col_2) AS col_2, IF(_u_2.pos_2 = _u_3.pos_3, _u_3.pos_3) AS pos_3 FROM _u CROSS JOIN UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u_2(pos_2) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE _u_2.pos_2 = _u_3.pos_3 OR (_u_2.pos_2 > CARDINALITY(ARRAY[2, 3]) AND _u_3.pos_3 = CARDINALITY(ARRAY[2, 3]))", }, ) + + def test_strip_modifiers(self): + without_modifiers = "SELECT * FROM t" + with_modifiers = f"{without_modifiers} CLUSTER BY y DISTRIBUTE BY x SORT BY z" + query = self.parse_one(with_modifiers) + + for dialect in Dialects: + with self.subTest(f"Transpiling query with CLUSTER/DISTRIBUTE/SORT BY to {dialect}"): + name = dialect.value + if name in ("", "databricks", "hive", "spark", "spark2"): + self.assertEqual(query.sql(name), with_modifiers) + else: + self.assertEqual(query.sql(name), without_modifiers) |