summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_spark.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:43 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-07 11:39:43 +0000
commit341eb1a6bdf0dd5b015e5140d3b068c6fd3f4d87 (patch)
tree61fb7eca2238fb5d41d3906f4af41de03abd25ea /tests/dialects/test_spark.py
parentAdding upstream version 17.12.0. (diff)
downloadsqlglot-341eb1a6bdf0dd5b015e5140d3b068c6fd3f4d87.tar.xz
sqlglot-341eb1a6bdf0dd5b015e5140d3b068c6fd3f4d87.zip
Adding upstream version 18.2.0.upstream/18.2.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests/dialects/test_spark.py')
-rw-r--r--tests/dialects/test_spark.py57
1 files changed, 55 insertions, 2 deletions
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index 2afa868..a892b0f 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -1,5 +1,6 @@
from unittest import mock
+from sqlglot import exp, parse_one
from tests.dialects.test_dialect import Validator
@@ -224,6 +225,10 @@ TBLPROPERTIES (
)
def test_spark(self):
+ expr = parse_one("any_value(col, true)", read="spark")
+ self.assertIsInstance(expr.args.get("ignore_nulls"), exp.Boolean)
+ self.assertEqual(expr.sql(dialect="spark"), "ANY_VALUE(col, TRUE)")
+
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)")
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)")
self.validate_identity("REFRESH table a.b.c")
@@ -234,8 +239,46 @@ TBLPROPERTIES (
self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')")
self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')")
self.validate_identity("SPLIT(str, pattern, lim)")
+ self.validate_identity(
+ "SELECT STR_TO_MAP('a:1,b:2,c:3')",
+ "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')",
+ )
self.validate_all(
+ "foo.bar",
+ read={
+ "": "STRUCT_EXTRACT(foo, bar)",
+ },
+ )
+ self.validate_all(
+ "MAP(1, 2, 3, 4)",
+ write={
+ "spark": "MAP(1, 2, 3, 4)",
+ "trino": "MAP(ARRAY[1, 3], ARRAY[2, 4])",
+ },
+ )
+ self.validate_all(
+ "MAP()",
+ read={
+ "spark": "MAP()",
+ "trino": "MAP()",
+ },
+ write={
+ "trino": "MAP(ARRAY[], ARRAY[])",
+ },
+ )
+ self.validate_all(
+ "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')",
+ read={
+ "presto": "SELECT SPLIT_TO_MAP('a:1,b:2,c:3', ',', ':')",
+ "spark": "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')",
+ },
+ write={
+ "presto": "SELECT SPLIT_TO_MAP('a:1,b:2,c:3', ',', ':')",
+ "spark": "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')",
+ },
+ )
+ self.validate_all(
"SELECT DATEDIFF(month, CAST('1996-10-30' AS TIMESTAMP), CAST('1997-02-28 10:30:00' AS TIMESTAMP))",
read={
"duckdb": "SELECT DATEDIFF('month', CAST('1996-10-30' AS TIMESTAMP), CAST('1997-02-28 10:30:00' AS TIMESTAMP))",
@@ -399,7 +442,7 @@ TBLPROPERTIES (
"ARRAY(0, 1, 2)",
write={
"bigquery": "[0, 1, 2]",
- "duckdb": "LIST_VALUE(0, 1, 2)",
+ "duckdb": "[0, 1, 2]",
"presto": "ARRAY[0, 1, 2]",
"hive": "ARRAY(0, 1, 2)",
"spark": "ARRAY(0, 1, 2)",
@@ -466,7 +509,7 @@ TBLPROPERTIES (
self.validate_all(
"MAP_FROM_ARRAYS(ARRAY(1), c)",
write={
- "duckdb": "MAP(LIST_VALUE(1), c)",
+ "duckdb": "MAP([1], c)",
"presto": "MAP(ARRAY[1], c)",
"hive": "MAP(ARRAY(1), c)",
"spark": "MAP_FROM_ARRAYS(ARRAY(1), c)",
@@ -522,3 +565,13 @@ TBLPROPERTIES (
self.validate_identity(
"SELECT TRANSFORM(zip_code, name, age) USING 'cat' FROM person WHERE zip_code > 94500"
)
+
+ def test_insert_cte(self):
+ self.validate_all(
+ "INSERT OVERWRITE TABLE table WITH cte AS (SELECT cola FROM other_table) SELECT cola FROM cte",
+ write={
+ "spark": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte",
+ "spark2": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte",
+ "databricks": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte",
+ },
+ )