summaryrefslogtreecommitdiffstats
path: root/tests/dialects/test_snowflake.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/dialects/test_snowflake.py126
1 files changed, 99 insertions, 27 deletions
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index a11c21a..1d55f35 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -21,27 +21,6 @@ class TestSnowflake(Validator):
expr.selects[0].assert_is(exp.AggFunc)
self.assertEqual(expr.sql(dialect="snowflake"), "SELECT APPROX_TOP_K(C4, 3, 5) FROM t")
- self.assertEqual(
- exp.select(exp.Explode(this=exp.column("x")).as_("y", quoted=True)).sql(
- "snowflake", pretty=True
- ),
- """SELECT
- IFF(_u.pos = _u_2.pos_2, _u_2."y", NULL) AS "y"
-FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (
- GREATEST(ARRAY_SIZE(x)) - 1
-) + 1))) AS _u(seq, key, path, index, pos, this)
-CROSS JOIN TABLE(FLATTEN(INPUT => x)) AS _u_2(seq, key, path, pos_2, "y", this)
-WHERE
- _u.pos = _u_2.pos_2
- OR (
- _u.pos > (
- ARRAY_SIZE(x) - 1
- ) AND _u_2.pos_2 = (
- ARRAY_SIZE(x) - 1
- )
- )""",
- )
-
self.validate_identity("exclude := [foo]")
self.validate_identity("SELECT CAST([1, 2, 3] AS VECTOR(FLOAT, 3))")
self.validate_identity("SELECT CONNECT_BY_ROOT test AS test_column_alias")
@@ -976,12 +955,15 @@ WHERE
"snowflake": "EDITDISTANCE(col1, col2, 3)",
},
)
- self.validate_identity("SELECT BITOR(a, b) FROM table")
-
- self.validate_identity("SELECT BIT_OR(a, b) FROM table", "SELECT BITOR(a, b) FROM table")
-
- # Test BITOR with three arguments, padding on the left
- self.validate_identity("SELECT BITOR(a, b, 'LEFT') FROM table_name")
+ self.validate_identity("SELECT BITOR(a, b)")
+ self.validate_identity("SELECT BIT_OR(a, b)", "SELECT BITOR(a, b)")
+ self.validate_identity("SELECT BITOR(a, b, 'LEFT')")
+ self.validate_identity("SELECT BITXOR(a, b, 'LEFT')")
+ self.validate_identity("SELECT BIT_XOR(a, b)", "SELECT BITXOR(a, b)")
+ self.validate_identity("SELECT BIT_XOR(a, b, 'LEFT')", "SELECT BITXOR(a, b, 'LEFT')")
+ self.validate_identity("SELECT BITSHIFTLEFT(a, 1)")
+ self.validate_identity("SELECT BIT_SHIFTLEFT(a, 1)", "SELECT BITSHIFTLEFT(a, 1)")
+ self.validate_identity("SELECT BIT_SHIFTRIGHT(a, 1)", "SELECT BITSHIFTRIGHT(a, 1)")
def test_null_treatment(self):
self.validate_all(
@@ -1600,6 +1582,27 @@ WHERE
)
def test_flatten(self):
+ self.assertEqual(
+ exp.select(exp.Explode(this=exp.column("x")).as_("y", quoted=True)).sql(
+ "snowflake", pretty=True
+ ),
+ """SELECT
+ IFF(_u.pos = _u_2.pos_2, _u_2."y", NULL) AS "y"
+FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (
+ GREATEST(ARRAY_SIZE(x)) - 1
+) + 1))) AS _u(seq, key, path, index, pos, this)
+CROSS JOIN TABLE(FLATTEN(INPUT => x)) AS _u_2(seq, key, path, pos_2, "y", this)
+WHERE
+ _u.pos = _u_2.pos_2
+ OR (
+ _u.pos > (
+ ARRAY_SIZE(x) - 1
+ ) AND _u_2.pos_2 = (
+ ARRAY_SIZE(x) - 1
+ )
+ )""",
+ )
+
self.validate_all(
"""
select
@@ -1624,6 +1627,75 @@ FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS
},
pretty=True,
)
+ self.validate_all(
+ """
+ SELECT
+ uc.user_id,
+ uc.start_ts AS ts,
+ CASE
+ WHEN uc.start_ts::DATE >= '2023-01-01' AND uc.country_code IN ('US') AND uc.user_id NOT IN (
+ SELECT DISTINCT
+ _id
+ FROM
+ users,
+ LATERAL FLATTEN(INPUT => PARSE_JSON(flags)) datasource
+ WHERE datasource.value:name = 'something'
+ )
+ THEN 'Sample1'
+ ELSE 'Sample2'
+ END AS entity
+ FROM user_countries AS uc
+ LEFT JOIN (
+ SELECT user_id, MAX(IFF(service_entity IS NULL,1,0)) AS le_null
+ FROM accepted_user_agreements
+ GROUP BY 1
+ ) AS aua
+ ON uc.user_id = aua.user_id
+ """,
+ write={
+ "snowflake": """SELECT
+ uc.user_id,
+ uc.start_ts AS ts,
+ CASE
+ WHEN CAST(uc.start_ts AS DATE) >= '2023-01-01'
+ AND uc.country_code IN ('US')
+ AND uc.user_id <> ALL (
+ SELECT DISTINCT
+ _id
+ FROM users, LATERAL IFF(_u.pos = _u_2.pos_2, _u_2.entity, NULL) AS datasource(SEQ, KEY, PATH, INDEX, VALUE, THIS)
+ WHERE
+ GET_PATH(datasource.value, 'name') = 'something'
+ )
+ THEN 'Sample1'
+ ELSE 'Sample2'
+ END AS entity
+FROM user_countries AS uc
+LEFT JOIN (
+ SELECT
+ user_id,
+ MAX(IFF(service_entity IS NULL, 1, 0)) AS le_null
+ FROM accepted_user_agreements
+ GROUP BY
+ 1
+) AS aua
+ ON uc.user_id = aua.user_id
+CROSS JOIN TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (
+ GREATEST(ARRAY_SIZE(INPUT => PARSE_JSON(flags))) - 1
+) + 1))) AS _u(seq, key, path, index, pos, this)
+CROSS JOIN TABLE(FLATTEN(INPUT => PARSE_JSON(flags))) AS _u_2(seq, key, path, pos_2, entity, this)
+WHERE
+ _u.pos = _u_2.pos_2
+ OR (
+ _u.pos > (
+ ARRAY_SIZE(INPUT => PARSE_JSON(flags)) - 1
+ )
+ AND _u_2.pos_2 = (
+ ARRAY_SIZE(INPUT => PARSE_JSON(flags)) - 1
+ )
+ )""",
+ },
+ pretty=True,
+ )
# All examples from https://docs.snowflake.com/en/sql-reference/functions/flatten.html#syntax
self.validate_all(