summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_snowflake.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/dialects/test_snowflake.py245
1 files changed, 209 insertions, 36 deletions
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 1286436..6cde86b 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -11,17 +11,10 @@ class TestSnowflake(Validator):
dialect = "snowflake"
def test_snowflake(self):
- self.validate_identity(
- "transform(x, a int -> a + a + 1)",
- "TRANSFORM(x, a -> CAST(a AS INT) + CAST(a AS INT) + 1)",
- )
-
- self.validate_all(
- "ARRAY_CONSTRUCT_COMPACT(1, null, 2)",
- write={
- "spark": "ARRAY_COMPACT(ARRAY(1, NULL, 2))",
- "snowflake": "ARRAY_CONSTRUCT_COMPACT(1, NULL, 2)",
- },
+ self.assertEqual(
+ # Ensures we don't fail when generating ParseJSON with the `safe` arg set to `True`
+ self.validate_identity("""SELECT TRY_PARSE_JSON('{"x: 1}')""").sql(),
+ """SELECT PARSE_JSON('{"x: 1}')""",
)
expr = parse_one("SELECT APPROX_TOP_K(C4, 3, 5) FROM t")
@@ -49,6 +42,9 @@ WHERE
)""",
)
+ self.validate_identity("exclude := [foo]")
+ self.validate_identity("SELECT CAST([1, 2, 3] AS VECTOR(FLOAT, 3))")
+ self.validate_identity("SELECT CONNECT_BY_ROOT test AS test_column_alias")
self.validate_identity("SELECT number").selects[0].assert_is(exp.Column)
self.validate_identity("INTERVAL '4 years, 5 months, 3 hours'")
self.validate_identity("ALTER TABLE table1 CLUSTER BY (name DESC)")
@@ -84,7 +80,6 @@ WHERE
self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT foo FROM IDENTIFIER('x')")
self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT IDENTIFIER('foo') FROM x")
self.validate_identity("INITCAP('iqamqinterestedqinqthisqtopic', 'q')")
- self.validate_identity("CAST(x AS GEOMETRY)")
self.validate_identity("OBJECT_CONSTRUCT(*)")
self.validate_identity("SELECT CAST('2021-01-01' AS DATE) + INTERVAL '1 DAY'")
self.validate_identity("SELECT HLL(*)")
@@ -101,6 +96,22 @@ WHERE
self.validate_identity("ALTER TABLE a SWAP WITH b")
self.validate_identity("SELECT MATCH_CONDITION")
self.validate_identity("SELECT * REPLACE (CAST(col AS TEXT) AS scol) FROM t")
+ self.validate_identity("1 /* /* */")
+ self.validate_identity(
+ "SELECT * FROM table AT (TIMESTAMP => '2024-07-24') UNPIVOT(a FOR b IN (c)) AS pivot_table"
+ )
+ self.validate_identity(
+ "SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR quarter IN ('2023_Q1', '2023_Q2', '2023_Q3', '2023_Q4', '2024_Q1') DEFAULT ON NULL (0)) ORDER BY empid"
+ )
+ self.validate_identity(
+ "SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR quarter IN (SELECT DISTINCT quarter FROM ad_campaign_types_by_quarter WHERE television = TRUE ORDER BY quarter)) ORDER BY empid"
+ )
+ self.validate_identity(
+ "SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR quarter IN (ANY ORDER BY quarter)) ORDER BY empid"
+ )
+ self.validate_identity(
+ "SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR quarter IN (ANY)) ORDER BY empid"
+ )
self.validate_identity(
"MERGE INTO my_db AS ids USING (SELECT new_id FROM my_model WHERE NOT col IS NULL) AS new_ids ON ids.type = new_ids.type AND ids.source = new_ids.source WHEN NOT MATCHED THEN INSERT VALUES (new_ids.new_id)"
)
@@ -114,6 +125,30 @@ WHERE
"SELECT * FROM DATA AS DATA_L ASOF JOIN DATA AS DATA_R MATCH_CONDITION (DATA_L.VAL > DATA_R.VAL) ON DATA_L.ID = DATA_R.ID"
)
self.validate_identity(
+ "CAST(x AS GEOGRAPHY)",
+ "TO_GEOGRAPHY(x)",
+ )
+ self.validate_identity(
+ "CAST(x AS GEOMETRY)",
+ "TO_GEOMETRY(x)",
+ )
+ self.validate_identity(
+ "transform(x, a int -> a + a + 1)",
+ "TRANSFORM(x, a -> CAST(a AS INT) + CAST(a AS INT) + 1)",
+ )
+ self.validate_identity(
+ "SELECT * FROM s WHERE c NOT IN (1, 2, 3)",
+ "SELECT * FROM s WHERE NOT c IN (1, 2, 3)",
+ )
+ self.validate_identity(
+ "SELECT * FROM s WHERE c NOT IN (SELECT * FROM t)",
+ "SELECT * FROM s WHERE c <> ALL (SELECT * FROM t)",
+ )
+ self.validate_identity(
+ "SELECT * FROM t1 INNER JOIN t2 USING (t1.col)",
+ "SELECT * FROM t1 INNER JOIN t2 USING (col)",
+ )
+ self.validate_identity(
"CURRENT_TIMESTAMP - INTERVAL '1 w' AND (1 = 1)",
"CURRENT_TIMESTAMP() - INTERVAL '1 WEEK' AND (1 = 1)",
)
@@ -183,18 +218,6 @@ WHERE
"""SELECT CAST(GET_PATH(PARSE_JSON('{"food":{"fruit":"banana"}}'), 'food.fruit') AS VARCHAR)""",
)
self.validate_identity(
- "SELECT * FROM foo at",
- "SELECT * FROM foo AS at",
- )
- self.validate_identity(
- "SELECT * FROM foo before",
- "SELECT * FROM foo AS before",
- )
- self.validate_identity(
- "SELECT * FROM foo at (col)",
- "SELECT * FROM foo AS at(col)",
- )
- self.validate_identity(
"SELECT * FROM unnest(x) with ordinality",
"SELECT * FROM TABLE(FLATTEN(INPUT => x)) AS _u(seq, key, path, index, value, this)",
)
@@ -283,6 +306,13 @@ WHERE
)
self.validate_all(
+ "ARRAY_CONSTRUCT_COMPACT(1, null, 2)",
+ write={
+ "spark": "ARRAY_COMPACT(ARRAY(1, NULL, 2))",
+ "snowflake": "ARRAY_CONSTRUCT_COMPACT(1, NULL, 2)",
+ },
+ )
+ self.validate_all(
"OBJECT_CONSTRUCT_KEEP_NULL('key_1', 'one', 'key_2', NULL)",
read={
"bigquery": "JSON_OBJECT(['key_1', 'key_2'], ['one', NULL])",
@@ -337,7 +367,7 @@ WHERE
"""SELECT PARSE_JSON('{"fruit":"banana"}'):fruit""",
write={
"bigquery": """SELECT JSON_EXTRACT(PARSE_JSON('{"fruit":"banana"}'), '$.fruit')""",
- "databricks": """SELECT GET_JSON_OBJECT('{"fruit":"banana"}', '$.fruit')""",
+ "databricks": """SELECT '{"fruit":"banana"}':fruit""",
"duckdb": """SELECT JSON('{"fruit":"banana"}') -> '$.fruit'""",
"mysql": """SELECT JSON_EXTRACT('{"fruit":"banana"}', '$.fruit')""",
"presto": """SELECT JSON_EXTRACT(JSON_PARSE('{"fruit":"banana"}'), '$.fruit')""",
@@ -572,12 +602,12 @@ WHERE
self.validate_all(
"DIV0(foo, bar)",
write={
- "snowflake": "IFF(bar = 0, 0, foo / bar)",
- "sqlite": "IIF(bar = 0, 0, CAST(foo AS REAL) / bar)",
- "presto": "IF(bar = 0, 0, CAST(foo AS DOUBLE) / bar)",
- "spark": "IF(bar = 0, 0, foo / bar)",
- "hive": "IF(bar = 0, 0, foo / bar)",
- "duckdb": "CASE WHEN bar = 0 THEN 0 ELSE foo / bar END",
+ "snowflake": "IFF(bar = 0 AND NOT foo IS NULL, 0, foo / bar)",
+ "sqlite": "IIF(bar = 0 AND NOT foo IS NULL, 0, CAST(foo AS REAL) / bar)",
+ "presto": "IF(bar = 0 AND NOT foo IS NULL, 0, CAST(foo AS DOUBLE) / bar)",
+ "spark": "IF(bar = 0 AND NOT foo IS NULL, 0, foo / bar)",
+ "hive": "IF(bar = 0 AND NOT foo IS NULL, 0, foo / bar)",
+ "duckdb": "CASE WHEN bar = 0 AND NOT foo IS NULL THEN 0 ELSE foo / bar END",
},
)
self.validate_all(
@@ -725,6 +755,8 @@ WHERE
write={
"spark": "SELECT COLLECT_LIST(DISTINCT a)",
"snowflake": "SELECT ARRAY_AGG(DISTINCT a)",
+ "duckdb": "SELECT ARRAY_AGG(DISTINCT a) FILTER(WHERE a IS NOT NULL)",
+ "presto": "SELECT ARRAY_AGG(DISTINCT a) FILTER(WHERE a IS NOT NULL)",
},
)
self.validate_all(
@@ -831,6 +863,71 @@ WHERE
},
)
+ self.validate_all(
+ "SELECT OBJECT_INSERT(OBJECT_INSERT(OBJECT_INSERT(OBJECT_CONSTRUCT('key5', 'value5'), 'key1', 5), 'key2', 2.2), 'key3', 'value3')",
+ write={
+ "snowflake": "SELECT OBJECT_INSERT(OBJECT_INSERT(OBJECT_INSERT(OBJECT_CONSTRUCT('key5', 'value5'), 'key1', 5), 'key2', 2.2), 'key3', 'value3')",
+ "duckdb": "SELECT STRUCT_INSERT(STRUCT_INSERT(STRUCT_INSERT({'key5': 'value5'}, key1 := 5), key2 := 2.2), key3 := 'value3')",
+ },
+ )
+
+ self.validate_all(
+ "SELECT OBJECT_INSERT(OBJECT_INSERT(OBJECT_INSERT(OBJECT_CONSTRUCT(), 'key1', 5), 'key2', 2.2), 'key3', 'value3')",
+ write={
+ "snowflake": "SELECT OBJECT_INSERT(OBJECT_INSERT(OBJECT_INSERT(OBJECT_CONSTRUCT(), 'key1', 5), 'key2', 2.2), 'key3', 'value3')",
+ "duckdb": "SELECT STRUCT_INSERT(STRUCT_INSERT(STRUCT_PACK(key1 := 5), key2 := 2.2), key3 := 'value3')",
+ },
+ )
+
+ self.validate_identity(
+ """SELECT ARRAY_CONSTRUCT('foo')::VARIANT[0]""",
+ """SELECT CAST(['foo'] AS VARIANT)[0]""",
+ )
+
+ self.validate_all(
+ "SELECT CONVERT_TIMEZONE('America/New_York', '2024-08-06 09:10:00.000')",
+ write={
+ "snowflake": "SELECT CONVERT_TIMEZONE('America/New_York', '2024-08-06 09:10:00.000')",
+ "spark": "SELECT CONVERT_TIMEZONE('America/New_York', '2024-08-06 09:10:00.000')",
+ "databricks": "SELECT CONVERT_TIMEZONE('America/New_York', '2024-08-06 09:10:00.000')",
+ "redshift": "SELECT CONVERT_TIMEZONE('America/New_York', '2024-08-06 09:10:00.000')",
+ },
+ )
+
+ self.validate_all(
+ "SELECT CONVERT_TIMEZONE('America/Los_Angeles', 'America/New_York', '2024-08-06 09:10:00.000')",
+ write={
+ "snowflake": "SELECT CONVERT_TIMEZONE('America/Los_Angeles', 'America/New_York', '2024-08-06 09:10:00.000')",
+ "spark": "SELECT CONVERT_TIMEZONE('America/Los_Angeles', 'America/New_York', '2024-08-06 09:10:00.000')",
+ "databricks": "SELECT CONVERT_TIMEZONE('America/Los_Angeles', 'America/New_York', '2024-08-06 09:10:00.000')",
+ "redshift": "SELECT CONVERT_TIMEZONE('America/Los_Angeles', 'America/New_York', '2024-08-06 09:10:00.000')",
+ "mysql": "SELECT CONVERT_TZ('2024-08-06 09:10:00.000', 'America/Los_Angeles', 'America/New_York')",
+ "duckdb": "SELECT CAST('2024-08-06 09:10:00.000' AS TIMESTAMP) AT TIME ZONE 'America/Los_Angeles' AT TIME ZONE 'America/New_York'",
+ },
+ )
+
+ self.validate_identity(
+ "SELECT UUID_STRING(), UUID_STRING('fe971b24-9572-4005-b22f-351e9c09274d', 'foo')"
+ )
+
+ self.validate_all(
+ "UUID_STRING('fe971b24-9572-4005-b22f-351e9c09274d', 'foo')",
+ read={
+ "snowflake": "UUID_STRING('fe971b24-9572-4005-b22f-351e9c09274d', 'foo')",
+ },
+ write={
+ "hive": "UUID()",
+ "spark2": "UUID()",
+ "spark": "UUID()",
+ "databricks": "UUID()",
+ "duckdb": "UUID()",
+ "presto": "UUID()",
+ "trino": "UUID()",
+ "postgres": "GEN_RANDOM_UUID()",
+ "bigquery": "GENERATE_UUID()",
+ },
+ )
+
def test_null_treatment(self):
self.validate_all(
r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1",
@@ -903,6 +1000,11 @@ WHERE
"SELECT * FROM @foo/bar (FILE_FORMAT => ds_sandbox.test.my_csv_format, PATTERN => 'test') AS bla",
)
+ self.validate_identity(
+ "SELECT * FROM @test.public.thing/location/somefile.csv( FILE_FORMAT => 'fmt' )",
+ "SELECT * FROM @test.public.thing/location/somefile.csv (FILE_FORMAT => 'fmt')",
+ )
+
def test_sample(self):
self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)")
self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)")
@@ -1196,6 +1298,17 @@ WHERE
"SELECT oldt.*, newt.* FROM my_table BEFORE (STATEMENT => '8e5d0ca9-005e-44e6-b858-a8f5b37c5726') AS oldt FULL OUTER JOIN my_table AT (STATEMENT => '8e5d0ca9-005e-44e6-b858-a8f5b37c5726') AS newt ON oldt.id = newt.id WHERE oldt.id IS NULL OR newt.id IS NULL",
)
+ # Make sure that the historical data keywords can still be used as aliases
+ for historical_data_prefix in ("AT", "BEFORE", "END", "CHANGES"):
+ for schema_suffix in ("", "(col)"):
+ with self.subTest(
+ f"Testing historical data prefix alias: {historical_data_prefix}{schema_suffix}"
+ ):
+ self.validate_identity(
+ f"SELECT * FROM foo {historical_data_prefix}{schema_suffix}",
+ f"SELECT * FROM foo AS {historical_data_prefix}{schema_suffix}",
+ )
+
def test_ddl(self):
for constraint_prefix in ("WITH ", ""):
with self.subTest(f"Constraint prefix: {constraint_prefix}"):
@@ -1216,6 +1329,7 @@ WHERE
"CREATE TABLE t (id INT TAG (key1='value_1', key2='value_2'))",
)
+ self.validate_identity("CREATE SECURE VIEW table1 AS (SELECT a FROM table2)")
self.validate_identity(
"""create external table et2(
col1 date as (parse_json(metadata$external_table_partition):COL1::date),
@@ -1241,6 +1355,9 @@ WHERE
"CREATE OR REPLACE TAG IF NOT EXISTS cost_center COMMENT='cost_center tag'"
).this.assert_is(exp.Identifier)
self.validate_identity(
+ "CREATE DYNAMIC TABLE product (pre_tax_profit, taxes, after_tax_profit) TARGET_LAG='20 minutes' WAREHOUSE=mywh AS SELECT revenue - cost, (revenue - cost) * tax_rate, (revenue - cost) * (1.0 - tax_rate) FROM staging_table"
+ )
+ self.validate_identity(
"ALTER TABLE db_name.schmaName.tblName ADD COLUMN COLUMN_1 VARCHAR NOT NULL TAG (key1='value_1')"
)
self.validate_identity(
@@ -1330,6 +1447,12 @@ WHERE
},
)
+ self.assertIsNotNone(
+ self.validate_identity("CREATE TABLE foo (bar INT AS (foo))").find(
+ exp.TransformColumnConstraint
+ )
+ )
+
def test_user_defined_functions(self):
self.validate_all(
"CREATE FUNCTION a(x DATE, y BIGINT) RETURNS ARRAY LANGUAGE JAVASCRIPT AS $$ SELECT 1 $$",
@@ -1608,16 +1731,27 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattene
"REGEXP_SUBSTR(subject, pattern)",
read={
"bigquery": "REGEXP_EXTRACT(subject, pattern)",
+ "snowflake": "REGEXP_EXTRACT(subject, pattern)",
+ },
+ write={
+ "bigquery": "REGEXP_EXTRACT(subject, pattern)",
+ "snowflake": "REGEXP_SUBSTR(subject, pattern)",
+ },
+ )
+ self.validate_all(
+ "REGEXP_SUBSTR(subject, pattern, 1, 1, 'c', 1)",
+ read={
"hive": "REGEXP_EXTRACT(subject, pattern)",
- "presto": "REGEXP_EXTRACT(subject, pattern)",
+ "spark2": "REGEXP_EXTRACT(subject, pattern)",
"spark": "REGEXP_EXTRACT(subject, pattern)",
+ "databricks": "REGEXP_EXTRACT(subject, pattern)",
},
write={
- "bigquery": "REGEXP_EXTRACT(subject, pattern)",
"hive": "REGEXP_EXTRACT(subject, pattern)",
- "presto": "REGEXP_EXTRACT(subject, pattern)",
- "snowflake": "REGEXP_SUBSTR(subject, pattern)",
+ "spark2": "REGEXP_EXTRACT(subject, pattern)",
"spark": "REGEXP_EXTRACT(subject, pattern)",
+ "databricks": "REGEXP_EXTRACT(subject, pattern)",
+ "snowflake": "REGEXP_SUBSTR(subject, pattern, 1, 1, 'c', 1)",
},
)
self.validate_all(
@@ -1885,7 +2019,7 @@ STORAGE_ALLOWED_LOCATIONS=('s3://mybucket1/path1/', 's3://mybucket2/path2/')""",
def test_swap(self):
ast = parse_one("ALTER TABLE a SWAP WITH b", read="snowflake")
- assert isinstance(ast, exp.AlterTable)
+ assert isinstance(ast, exp.Alter)
assert isinstance(ast.args["actions"][0], exp.SwapTable)
def test_try_cast(self):
@@ -2005,6 +2139,16 @@ SINGLE = TRUE""",
self.validate_identity("SELECT t.$23:a.b", "SELECT GET_PATH(t.$23, 'a.b')")
self.validate_identity("SELECT t.$17:a[0].b[0].c", "SELECT GET_PATH(t.$17, 'a[0].b[0].c')")
+ self.validate_all(
+ """
+ SELECT col:"customer's department"
+ """,
+ write={
+ "snowflake": """SELECT GET_PATH(col, '["customer\\'s department"]')""",
+ "postgres": "SELECT JSON_EXTRACT_PATH(col, 'customer''s department')",
+ },
+ )
+
def test_alter_set_unset(self):
self.validate_identity("ALTER TABLE tbl SET DATA_RETENTION_TIME_IN_DAYS=1")
self.validate_identity("ALTER TABLE tbl SET DEFAULT_DDL_COLLATION='test'")
@@ -2021,3 +2165,32 @@ SINGLE = TRUE""",
self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c")
self.validate_identity("ALTER TABLE foo UNSET DATA_RETENTION_TIME_IN_DAYS, CHANGE_TRACKING")
+
+ def test_from_changes(self):
+ self.validate_identity(
+ """SELECT C1 FROM t1 CHANGES (INFORMATION => APPEND_ONLY) AT (STREAM => 's1') END (TIMESTAMP => $ts2)"""
+ )
+ self.validate_identity(
+ """SELECT C1 FROM t1 CHANGES (INFORMATION => APPEND_ONLY) BEFORE (STATEMENT => 'STMT_ID') END (TIMESTAMP => $ts2)"""
+ )
+ self.validate_identity(
+ """SELECT 1 FROM some_table CHANGES (INFORMATION => APPEND_ONLY) AT (TIMESTAMP => TO_TIMESTAMP_TZ('2024-07-01 00:00:00+00:00')) END (TIMESTAMP => TO_TIMESTAMP_TZ('2024-07-01 14:28:59.999999+00:00'))""",
+ """SELECT 1 FROM some_table CHANGES (INFORMATION => APPEND_ONLY) AT (TIMESTAMP => CAST('2024-07-01 00:00:00+00:00' AS TIMESTAMPTZ)) END (TIMESTAMP => CAST('2024-07-01 14:28:59.999999+00:00' AS TIMESTAMPTZ))""",
+ )
+
+ def test_grant(self):
+ grant_cmds = [
+ "GRANT SELECT ON FUTURE TABLES IN DATABASE d1 TO ROLE r1",
+ "GRANT INSERT, DELETE ON FUTURE TABLES IN SCHEMA d1.s1 TO ROLE r2",
+ "GRANT SELECT ON ALL TABLES IN SCHEMA mydb.myschema to ROLE analyst",
+ "GRANT SELECT, INSERT ON FUTURE TABLES IN SCHEMA mydb.myschema TO ROLE role1",
+ "GRANT CREATE MATERIALIZED VIEW ON SCHEMA mydb.myschema TO DATABASE ROLE mydb.dr1",
+ ]
+
+ for sql in grant_cmds:
+ with self.subTest(f"Testing Snowflake's GRANT command statement: {sql}"):
+ self.validate_identity(sql, check_command_warning=True)
+
+ self.validate_identity(
+ "GRANT ALL PRIVILEGES ON FUNCTION mydb.myschema.ADD5(number) TO ROLE analyst"
+ )