diff options
Diffstat (limited to 'tests/dialects/test_spark.py')
-rw-r--r-- | tests/dialects/test_spark.py | 57 |
1 files changed, 55 insertions, 2 deletions
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 2afa868..a892b0f 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -1,5 +1,6 @@ from unittest import mock +from sqlglot import exp, parse_one from tests.dialects.test_dialect import Validator @@ -224,6 +225,10 @@ TBLPROPERTIES ( ) def test_spark(self): + expr = parse_one("any_value(col, true)", read="spark") + self.assertIsInstance(expr.args.get("ignore_nulls"), exp.Boolean) + self.assertEqual(expr.sql(dialect="spark"), "ANY_VALUE(col, TRUE)") + self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)") self.validate_identity("REFRESH table a.b.c") @@ -234,8 +239,46 @@ TBLPROPERTIES ( self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')") self.validate_identity("SPLIT(str, pattern, lim)") + self.validate_identity( + "SELECT STR_TO_MAP('a:1,b:2,c:3')", + "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", + ) self.validate_all( + "foo.bar", + read={ + "": "STRUCT_EXTRACT(foo, bar)", + }, + ) + self.validate_all( + "MAP(1, 2, 3, 4)", + write={ + "spark": "MAP(1, 2, 3, 4)", + "trino": "MAP(ARRAY[1, 3], ARRAY[2, 4])", + }, + ) + self.validate_all( + "MAP()", + read={ + "spark": "MAP()", + "trino": "MAP()", + }, + write={ + "trino": "MAP(ARRAY[], ARRAY[])", + }, + ) + self.validate_all( + "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", + read={ + "presto": "SELECT SPLIT_TO_MAP('a:1,b:2,c:3', ',', ':')", + "spark": "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", + }, + write={ + "presto": "SELECT SPLIT_TO_MAP('a:1,b:2,c:3', ',', ':')", + "spark": "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", + }, + ) + self.validate_all( "SELECT DATEDIFF(month, CAST('1996-10-30' AS TIMESTAMP), CAST('1997-02-28 10:30:00' AS TIMESTAMP))", read={ "duckdb": "SELECT DATEDIFF('month', CAST('1996-10-30' AS TIMESTAMP), CAST('1997-02-28 10:30:00' AS TIMESTAMP))", @@ -399,7 +442,7 @@ TBLPROPERTIES ( "ARRAY(0, 1, 2)", write={ "bigquery": "[0, 1, 2]", - "duckdb": "LIST_VALUE(0, 1, 2)", + "duckdb": "[0, 1, 2]", "presto": "ARRAY[0, 1, 2]", "hive": "ARRAY(0, 1, 2)", "spark": "ARRAY(0, 1, 2)", @@ -466,7 +509,7 @@ TBLPROPERTIES ( self.validate_all( "MAP_FROM_ARRAYS(ARRAY(1), c)", write={ - "duckdb": "MAP(LIST_VALUE(1), c)", + "duckdb": "MAP([1], c)", "presto": "MAP(ARRAY[1], c)", "hive": "MAP(ARRAY(1), c)", "spark": "MAP_FROM_ARRAYS(ARRAY(1), c)", @@ -522,3 +565,13 @@ TBLPROPERTIES ( self.validate_identity( "SELECT TRANSFORM(zip_code, name, age) USING 'cat' FROM person WHERE zip_code > 94500" ) + + def test_insert_cte(self): + self.validate_all( + "INSERT OVERWRITE TABLE table WITH cte AS (SELECT cola FROM other_table) SELECT cola FROM cte", + write={ + "spark": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", + "spark2": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", + "databricks": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", + }, + ) |