summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_spark.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/dialects/test_spark.py')
-rw-r--r--tests/dialects/test_spark.py37
1 files changed, 33 insertions, 4 deletions
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index 1cf1ede..18f1fb7 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -1,6 +1,7 @@
from unittest import mock
from sqlglot import exp, parse_one
+from sqlglot.dialects.dialect import Dialects
from tests.dialects.test_dialect import Validator
@@ -245,13 +246,16 @@ TBLPROPERTIES (
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)")
self.validate_identity("REFRESH TABLE a.b.c")
self.validate_identity("INTERVAL -86 DAYS")
- self.validate_identity("SELECT UNIX_TIMESTAMP()")
self.validate_identity("TRIM(' SparkSQL ')")
self.validate_identity("TRIM(BOTH 'SL' FROM 'SSparkSQLS')")
self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')")
self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')")
self.validate_identity("SPLIT(str, pattern, lim)")
self.validate_identity(
+ "SELECT UNIX_TIMESTAMP()",
+ "SELECT UNIX_TIMESTAMP(CURRENT_TIMESTAMP())",
+ )
+ self.validate_identity(
"SELECT CAST('2023-01-01' AS TIMESTAMP) + INTERVAL 23 HOUR + 59 MINUTE + 59 SECONDS",
"SELECT CAST('2023-01-01' AS TIMESTAMP) + INTERVAL '23' HOUR + INTERVAL '59' MINUTE + INTERVAL '59' SECONDS",
)
@@ -281,6 +285,18 @@ TBLPROPERTIES (
)
self.validate_all(
+ "SELECT SPLIT('123|789', '\\\\|')",
+ read={
+ "duckdb": "SELECT STR_SPLIT_REGEX('123|789', '\\|')",
+ "presto": "SELECT REGEXP_SPLIT('123|789', '\\|')",
+ },
+ write={
+ "duckdb": "SELECT STR_SPLIT_REGEX('123|789', '\\|')",
+ "presto": "SELECT REGEXP_SPLIT('123|789', '\\|')",
+ "spark": "SELECT SPLIT('123|789', '\\\\|')",
+ },
+ )
+ self.validate_all(
"WITH tbl AS (SELECT 1 AS id, 'eggy' AS name UNION ALL SELECT NULL AS id, 'jake' AS name) SELECT COUNT(DISTINCT id, name) AS cnt FROM tbl",
write={
"clickhouse": "WITH tbl AS (SELECT 1 AS id, 'eggy' AS name UNION ALL SELECT NULL AS id, 'jake' AS name) SELECT COUNT(DISTINCT id, name) AS cnt FROM tbl",
@@ -366,7 +382,7 @@ TBLPROPERTIES (
"hive": "SELECT CAST(DATEDIFF(TO_DATE('2020-12-31'), TO_DATE('2020-01-01')) / 7 AS INT)",
"postgres": "SELECT CAST(EXTRACT(days FROM (CAST(CAST('2020-12-31' AS DATE) AS TIMESTAMP) - CAST(CAST('2020-01-01' AS DATE) AS TIMESTAMP))) / 7 AS BIGINT)",
"redshift": "SELECT DATEDIFF(WEEK, CAST('2020-01-01' AS DATE), CAST('2020-12-31' AS DATE))",
- "snowflake": "SELECT DATEDIFF(WEEK, CAST('2020-01-01' AS DATE), CAST('2020-12-31' AS DATE))",
+ "snowflake": "SELECT DATEDIFF(WEEK, TO_DATE('2020-01-01'), TO_DATE('2020-12-31'))",
"spark": "SELECT DATEDIFF(WEEK, TO_DATE('2020-01-01'), TO_DATE('2020-12-31'))",
},
)
@@ -644,10 +660,10 @@ TBLPROPERTIES (
"SELECT TRANSFORM(zip_code, name, age) USING 'cat' AS (a STRING, b STRING, c STRING) FROM person WHERE zip_code > 94511"
)
self.validate_identity(
- "SELECT TRANSFORM(name, age) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' NULL DEFINED AS 'NULL' USING 'cat' AS (name_age STRING) ROW FORMAT DELIMITED FIELDS TERMINATED BY '@' LINES TERMINATED BY '\n' NULL DEFINED AS 'NULL' FROM person"
+ "SELECT TRANSFORM(name, age) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\\n' NULL DEFINED AS 'NULL' USING 'cat' AS (name_age STRING) ROW FORMAT DELIMITED FIELDS TERMINATED BY '@' LINES TERMINATED BY '\\n' NULL DEFINED AS 'NULL' FROM person"
)
self.validate_identity(
- "SELECT TRANSFORM(zip_code, name, age) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES ('field.delim'='\t') USING 'cat' AS (a STRING, b STRING, c STRING) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES ('field.delim'='\t') FROM person WHERE zip_code > 94511"
+ "SELECT TRANSFORM(zip_code, name, age) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES ('field.delim'='\\t') USING 'cat' AS (a STRING, b STRING, c STRING) ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES ('field.delim'='\\t') FROM person WHERE zip_code > 94511"
)
self.validate_identity(
"SELECT TRANSFORM(zip_code, name, age) USING 'cat' FROM person WHERE zip_code > 94500"
@@ -720,3 +736,16 @@ TBLPROPERTIES (
"presto": "SELECT col, pos, IF(_u_2.pos_2 = _u_3.pos_3, _u_3.col_2) AS col_2, IF(_u_2.pos_2 = _u_3.pos_3, _u_3.pos_3) AS pos_3 FROM _u CROSS JOIN UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u_2(pos_2) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE _u_2.pos_2 = _u_3.pos_3 OR (_u_2.pos_2 > CARDINALITY(ARRAY[2, 3]) AND _u_3.pos_3 = CARDINALITY(ARRAY[2, 3]))",
},
)
+
+ def test_strip_modifiers(self):
+ without_modifiers = "SELECT * FROM t"
+ with_modifiers = f"{without_modifiers} CLUSTER BY y DISTRIBUTE BY x SORT BY z"
+ query = self.parse_one(with_modifiers)
+
+ for dialect in Dialects:
+ with self.subTest(f"Transpiling query with CLUSTER/DISTRIBUTE/SORT BY to {dialect}"):
+ name = dialect.value
+ if name in ("", "databricks", "hive", "spark", "spark2"):
+ self.assertEqual(query.sql(name), with_modifiers)
+ else:
+ self.assertEqual(query.sql(name), without_modifiers)