diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dialects/test_bigquery.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 17 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 126 | ||||
-rw-r--r-- | tests/dialects/test_starrocks.py | 5 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 5 | ||||
-rw-r--r-- | tests/fixtures/optimizer/qualify_columns.sql | 55 | ||||
-rw-r--r-- | tests/test_lineage.py | 81 | ||||
-rw-r--r-- | tests/test_optimizer.py | 4 |
11 files changed, 280 insertions, 27 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index ec16dba..dbe4401 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -316,6 +316,14 @@ LANGUAGE js AS ) self.validate_all( + "SELECT DATE_SUB(DATE '2008-12-25', INTERVAL 5 DAY)", + write={ + "bigquery": "SELECT DATE_SUB(CAST('2008-12-25' AS DATE), INTERVAL '5' DAY)", + "duckdb": "SELECT CAST('2008-12-25' AS DATE) - INTERVAL '5' DAY", + "snowflake": "SELECT DATEADD(DAY, '5' * -1, CAST('2008-12-25' AS DATE))", + }, + ) + self.validate_all( "EDIT_DISTANCE(col1, col2, max_distance => 3)", write={ "bigquery": "EDIT_DISTANCE(col1, col2, max_distance => 3)", diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 19b3ce3..d3d363e 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -155,6 +155,10 @@ class TestClickhouse(Validator): "CREATE TABLE t (foo String CODEC(LZ4HC(9), ZSTD, DELTA), size String ALIAS formatReadableSize(size_bytes), INDEX idx1 a TYPE bloom_filter(0.001) GRANULARITY 1, INDEX idx2 a TYPE set(100) GRANULARITY 2, INDEX idx3 a TYPE minmax GRANULARITY 3)" ) self.validate_identity( + "INSERT INTO tab VALUES ({'key1': 1, 'key2': 10}), ({'key1': 2, 'key2': 20}), ({'key1': 3, 'key2': 30})", + "INSERT INTO tab VALUES (map('key1', 1, 'key2', 10)), (map('key1', 2, 'key2', 20)), (map('key1', 3, 'key2', 30))", + ) + self.validate_identity( "SELECT (toUInt8('1') + toUInt8('2')) IS NOT NULL", "SELECT NOT ((toUInt8('1') + toUInt8('2')) IS NULL)", ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 465c231..5eb89f3 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -18,6 +18,7 @@ class TestMySQL(Validator): self.validate_identity("CREATE TABLE foo (a BIGINT, UNIQUE (b) USING BTREE)") self.validate_identity("CREATE TABLE foo (id BIGINT)") self.validate_identity("CREATE TABLE 00f (1d BIGINT)") + self.validate_identity("CREATE TABLE temp (id SERIAL PRIMARY KEY)") self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10") self.validate_identity("DELETE FROM t WHERE a <= 10 LIMIT 10") self.validate_identity("CREATE TABLE foo (a BIGINT, INDEX USING BTREE (b))") diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 66ded23..acdb2d4 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -72,6 +72,9 @@ class TestPostgres(Validator): self.validate_identity("SELECT CURRENT_USER") self.validate_identity("SELECT * FROM ONLY t1") self.validate_identity( + "SELECT * FROM t WHERE some_column >= CURRENT_DATE + INTERVAL '1 day 1 hour' AND some_another_column IS TRUE" + ) + self.validate_identity( """UPDATE "x" SET "y" = CAST('0 days 60.000000 seconds' AS INTERVAL) WHERE "x"."id" IN (2, 3)""" ) self.validate_identity( @@ -1289,3 +1292,17 @@ CROSS JOIN JSON_ARRAY_ELEMENTS(CAST(JSON_EXTRACT_PATH(tbox, 'boxes') AS JSON)) A "clickhouse": UnsupportedError, }, ) + + def test_xmlelement(self): + self.validate_identity("SELECT XMLELEMENT(NAME foo)") + self.validate_identity("SELECT XMLELEMENT(NAME foo, XMLATTRIBUTES('xyz' AS bar))") + self.validate_identity("SELECT XMLELEMENT(NAME test, XMLATTRIBUTES(a, b)) FROM test") + self.validate_identity( + "SELECT XMLELEMENT(NAME foo, XMLATTRIBUTES(CURRENT_DATE AS bar), 'cont', 'ent')" + ) + self.validate_identity( + """SELECT XMLELEMENT(NAME "foo$bar", XMLATTRIBUTES('xyz' AS "a&b"))""" + ) + self.validate_identity( + "SELECT XMLELEMENT(NAME foo, XMLATTRIBUTES('xyz' AS bar), XMLELEMENT(NAME abc), XMLCOMMENT('test'), XMLELEMENT(NAME xyz))" + ) diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 971d81b..4a70859 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -320,6 +320,7 @@ class TestRedshift(Validator): ) def test_identity(self): + self.validate_identity("SELECT CAST(value AS FLOAT(8))") self.validate_identity("1 div", "1 AS div") self.validate_identity("LISTAGG(DISTINCT foo, ', ')") self.validate_identity("CREATE MATERIALIZED VIEW orders AUTO REFRESH YES AS SELECT 1") diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index a11c21a..1d55f35 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -21,27 +21,6 @@ class TestSnowflake(Validator): expr.selects[0].assert_is(exp.AggFunc) self.assertEqual(expr.sql(dialect="snowflake"), "SELECT APPROX_TOP_K(C4, 3, 5) FROM t") - self.assertEqual( - exp.select(exp.Explode(this=exp.column("x")).as_("y", quoted=True)).sql( - "snowflake", pretty=True - ), - """SELECT - IFF(_u.pos = _u_2.pos_2, _u_2."y", NULL) AS "y" -FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, ( - GREATEST(ARRAY_SIZE(x)) - 1 -) + 1))) AS _u(seq, key, path, index, pos, this) -CROSS JOIN TABLE(FLATTEN(INPUT => x)) AS _u_2(seq, key, path, pos_2, "y", this) -WHERE - _u.pos = _u_2.pos_2 - OR ( - _u.pos > ( - ARRAY_SIZE(x) - 1 - ) AND _u_2.pos_2 = ( - ARRAY_SIZE(x) - 1 - ) - )""", - ) - self.validate_identity("exclude := [foo]") self.validate_identity("SELECT CAST([1, 2, 3] AS VECTOR(FLOAT, 3))") self.validate_identity("SELECT CONNECT_BY_ROOT test AS test_column_alias") @@ -976,12 +955,15 @@ WHERE "snowflake": "EDITDISTANCE(col1, col2, 3)", }, ) - self.validate_identity("SELECT BITOR(a, b) FROM table") - - self.validate_identity("SELECT BIT_OR(a, b) FROM table", "SELECT BITOR(a, b) FROM table") - - # Test BITOR with three arguments, padding on the left - self.validate_identity("SELECT BITOR(a, b, 'LEFT') FROM table_name") + self.validate_identity("SELECT BITOR(a, b)") + self.validate_identity("SELECT BIT_OR(a, b)", "SELECT BITOR(a, b)") + self.validate_identity("SELECT BITOR(a, b, 'LEFT')") + self.validate_identity("SELECT BITXOR(a, b, 'LEFT')") + self.validate_identity("SELECT BIT_XOR(a, b)", "SELECT BITXOR(a, b)") + self.validate_identity("SELECT BIT_XOR(a, b, 'LEFT')", "SELECT BITXOR(a, b, 'LEFT')") + self.validate_identity("SELECT BITSHIFTLEFT(a, 1)") + self.validate_identity("SELECT BIT_SHIFTLEFT(a, 1)", "SELECT BITSHIFTLEFT(a, 1)") + self.validate_identity("SELECT BIT_SHIFTRIGHT(a, 1)", "SELECT BITSHIFTRIGHT(a, 1)") def test_null_treatment(self): self.validate_all( @@ -1600,6 +1582,27 @@ WHERE ) def test_flatten(self): + self.assertEqual( + exp.select(exp.Explode(this=exp.column("x")).as_("y", quoted=True)).sql( + "snowflake", pretty=True + ), + """SELECT + IFF(_u.pos = _u_2.pos_2, _u_2."y", NULL) AS "y" +FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, ( + GREATEST(ARRAY_SIZE(x)) - 1 +) + 1))) AS _u(seq, key, path, index, pos, this) +CROSS JOIN TABLE(FLATTEN(INPUT => x)) AS _u_2(seq, key, path, pos_2, "y", this) +WHERE + _u.pos = _u_2.pos_2 + OR ( + _u.pos > ( + ARRAY_SIZE(x) - 1 + ) AND _u_2.pos_2 = ( + ARRAY_SIZE(x) - 1 + ) + )""", + ) + self.validate_all( """ select @@ -1624,6 +1627,75 @@ FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS }, pretty=True, ) + self.validate_all( + """ + SELECT + uc.user_id, + uc.start_ts AS ts, + CASE + WHEN uc.start_ts::DATE >= '2023-01-01' AND uc.country_code IN ('US') AND uc.user_id NOT IN ( + SELECT DISTINCT + _id + FROM + users, + LATERAL FLATTEN(INPUT => PARSE_JSON(flags)) datasource + WHERE datasource.value:name = 'something' + ) + THEN 'Sample1' + ELSE 'Sample2' + END AS entity + FROM user_countries AS uc + LEFT JOIN ( + SELECT user_id, MAX(IFF(service_entity IS NULL,1,0)) AS le_null + FROM accepted_user_agreements + GROUP BY 1 + ) AS aua + ON uc.user_id = aua.user_id + """, + write={ + "snowflake": """SELECT + uc.user_id, + uc.start_ts AS ts, + CASE + WHEN CAST(uc.start_ts AS DATE) >= '2023-01-01' + AND uc.country_code IN ('US') + AND uc.user_id <> ALL ( + SELECT DISTINCT + _id + FROM users, LATERAL IFF(_u.pos = _u_2.pos_2, _u_2.entity, NULL) AS datasource(SEQ, KEY, PATH, INDEX, VALUE, THIS) + WHERE + GET_PATH(datasource.value, 'name') = 'something' + ) + THEN 'Sample1' + ELSE 'Sample2' + END AS entity +FROM user_countries AS uc +LEFT JOIN ( + SELECT + user_id, + MAX(IFF(service_entity IS NULL, 1, 0)) AS le_null + FROM accepted_user_agreements + GROUP BY + 1 +) AS aua + ON uc.user_id = aua.user_id +CROSS JOIN TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, ( + GREATEST(ARRAY_SIZE(INPUT => PARSE_JSON(flags))) - 1 +) + 1))) AS _u(seq, key, path, index, pos, this) +CROSS JOIN TABLE(FLATTEN(INPUT => PARSE_JSON(flags))) AS _u_2(seq, key, path, pos_2, entity, this) +WHERE + _u.pos = _u_2.pos_2 + OR ( + _u.pos > ( + ARRAY_SIZE(INPUT => PARSE_JSON(flags)) - 1 + ) + AND _u_2.pos_2 = ( + ARRAY_SIZE(INPUT => PARSE_JSON(flags)) - 1 + ) + )""", + }, + pretty=True, + ) # All examples from https://docs.snowflake.com/en/sql-reference/functions/flatten.html#syntax self.validate_all( diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py index bf72485..1b7360d 100644 --- a/tests/dialects/test_starrocks.py +++ b/tests/dialects/test_starrocks.py @@ -18,6 +18,8 @@ class TestStarrocks(Validator): "DISTRIBUTED BY HASH (col1) PROPERTIES ('replication_num'='1')", "PRIMARY KEY (col1) DISTRIBUTED BY HASH (col1)", "DUPLICATE KEY (col1, col2) DISTRIBUTED BY HASH (col1)", + "UNIQUE KEY (col1, col2) PARTITION BY RANGE (col1) (START ('2024-01-01') END ('2024-01-31') EVERY (INTERVAL 1 DAY)) DISTRIBUTED BY HASH (col1)", + "UNIQUE KEY (col1, col2) PARTITION BY RANGE (col1, col2) (START ('1') END ('10') EVERY (1), START ('10') END ('100') EVERY (10)) DISTRIBUTED BY HASH (col1)", ] for properties in ddl_sqls: @@ -31,6 +33,9 @@ class TestStarrocks(Validator): self.validate_identity( "CREATE TABLE foo (col0 DECIMAL(9, 1), col1 DECIMAL32(9, 1), col2 DECIMAL64(18, 10), col3 DECIMAL128(38, 10)) DISTRIBUTED BY HASH (col1) BUCKETS 1" ) + self.validate_identity( + "CREATE TABLE foo (col1 LARGEINT) DISTRIBUTED BY HASH (col1) BUCKETS 1" + ) def test_identity(self): self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 6136599..e8cd696 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1579,6 +1579,11 @@ WHERE }, ) + self.validate_identity( + "SELECT DATEADD(DAY, DATEDIFF(DAY, -3, GETDATE()), '08:00:00')", + "SELECT DATEADD(DAY, DATEDIFF(DAY, CAST('1899-12-29' AS DATETIME2), CAST(GETDATE() AS DATETIME2)), '08:00:00')", + ) + def test_lateral_subquery(self): self.validate_all( "SELECT x.a, x.b, t.v, t.y FROM x CROSS APPLY (SELECT v, y FROM t) t(v, y)", diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 2640145..ecb6eee 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -191,6 +191,10 @@ SELECT SOME_UDF(data).* FROM t; SELECT SOME_UDF(t.data).* FROM t AS t; # execute: false +SELECT p.* FROM p UNION ALL SELECT p2.* FROM p2; +SELECT p.* FROM p AS p UNION ALL SELECT p2.* FROM p2 AS p2; + +# execute: false # allow_partial_qualification: true # validate_qualify_columns: false SELECT a + 1 AS i, missing_column FROM x; @@ -201,6 +205,30 @@ SELECT x.a + 1 AS i, missing_column AS missing_column FROM x AS x; SELECT s, arr1, arr2 FROM arrays_test LEFT ARRAY JOIN arr1, arrays_test.arr2; SELECT arrays_test.s AS s, arrays_test.arr1 AS arr1, arrays_test.arr2 AS arr2 FROM arrays_test AS arrays_test LEFT ARRAY JOIN arrays_test.arr1, arrays_test.arr2; +# execute: false +# dialect: snowflake +WITH employees AS ( + SELECT * + FROM (VALUES ('President', 1, NULL), + ('Vice President Engineering', 10, 1), + ('Programmer', 100, 10), + ('QA Engineer', 101, 10), + ('Vice President HR', 20, 1), + ('Health Insurance Analyst', 200, 20) + ) AS t(title, employee_ID, manager_ID) +) +SELECT + employee_ID, + manager_ID, + title, + level +FROM employees +START WITH title = 'President' +CONNECT BY manager_ID = PRIOR employee_id +ORDER BY + employee_ID NULLS LAST; +WITH EMPLOYEES AS (SELECT T.TITLE AS TITLE, T.EMPLOYEE_ID AS EMPLOYEE_ID, T.MANAGER_ID AS MANAGER_ID FROM (VALUES ('President', 1, NULL), ('Vice President Engineering', 10, 1), ('Programmer', 100, 10), ('QA Engineer', 101, 10), ('Vice President HR', 20, 1), ('Health Insurance Analyst', 200, 20)) AS T(TITLE, EMPLOYEE_ID, MANAGER_ID)) SELECT EMPLOYEES.EMPLOYEE_ID AS EMPLOYEE_ID, EMPLOYEES.MANAGER_ID AS MANAGER_ID, EMPLOYEES.TITLE AS TITLE, EMPLOYEES.LEVEL AS LEVEL FROM EMPLOYEES AS EMPLOYEES START WITH EMPLOYEES.TITLE = 'President' CONNECT BY EMPLOYEES.MANAGER_ID = PRIOR EMPLOYEES.EMPLOYEE_ID ORDER BY EMPLOYEE_ID; + -------------------------------------- -- Derived tables -------------------------------------- @@ -727,3 +755,30 @@ SELECT y.b AS b FROM ((SELECT x.a AS a FROM x AS x) AS _q_0 INNER JOIN y AS y ON SELECT a, c FROM x TABLESAMPLE SYSTEM (10 ROWS) CROSS JOIN y TABLESAMPLE SYSTEM (10 ROWS); SELECT x.a AS a, y.c AS c FROM x AS x TABLESAMPLE SYSTEM (10 ROWS) CROSS JOIN y AS y TABLESAMPLE SYSTEM (10 ROWS); + +-------------------------------------- +-- Snowflake allows column alias to be used in almost all clauses +-------------------------------------- +# title: Snowflake column alias in JOIN +# dialect: snowflake +# execute: false +SELECT x.a AS foo FROM x JOIN y ON foo = y.b; +SELECT X.A AS FOO FROM X AS X JOIN Y AS Y ON X.A = Y.B; + +# title: Snowflake column alias in QUALIFY +# dialect: snowflake +# execute: false +SELECT x.a AS foo FROM x QUALIFY foo = 1; +SELECT X.A AS FOO FROM X AS X QUALIFY X.A = 1; + +# title: Snowflake column alias in GROUP BY +# dialect: snowflake +# execute: false +SELECT x.a AS foo FROM x GROUP BY foo = 1; +SELECT X.A AS FOO FROM X AS X GROUP BY X.A = 1; + +# title: Snowflake column alias in WHERE +# dialect: snowflake +# execute: false +SELECT x.a AS foo FROM x WHERE foo = 1; +SELECT X.A AS FOO FROM X AS X WHERE X.A = 1;
\ No newline at end of file diff --git a/tests/test_lineage.py b/tests/test_lineage.py index 036f146..095ee80 100644 --- a/tests/test_lineage.py +++ b/tests/test_lineage.py @@ -495,3 +495,84 @@ class TestLineage(unittest.TestCase): self.assertEqual(len(node.downstream), 1) self.assertEqual(len(node.downstream[0].downstream), 1) self.assertEqual(node.downstream[0].downstream[0].name, "t1.x") + + def test_pivot_without_alias(self) -> None: + sql = """ + SELECT + a as other_a + FROM (select value,category from sample_data) + PIVOT ( + sum(value) + FOR category IN ('a', 'b') + ); + """ + node = lineage("other_a", sql) + + self.assertEqual(node.downstream[0].name, "_q_0.value") + self.assertEqual(node.downstream[0].downstream[0].name, "sample_data.value") + + def test_pivot_with_alias(self) -> None: + sql = """ + SELECT + cat_a_s as other_as + FROM sample_data + PIVOT ( + sum(value) as s, max(price) + FOR category IN ('a' as cat_a, 'b') + ) + """ + node = lineage("other_as", sql) + + self.assertEqual(len(node.downstream), 1) + self.assertEqual(node.downstream[0].name, "sample_data.value") + + def test_pivot_with_cte(self) -> None: + sql = """ + WITH t as ( + SELECT + a as other_a + FROM sample_data + PIVOT ( + sum(value) + FOR category IN ('a', 'b') + ) + ) + select other_a from t + """ + node = lineage("other_a", sql) + + self.assertEqual(node.downstream[0].name, "t.other_a") + self.assertEqual(node.downstream[0].reference_node_name, "t") + self.assertEqual(node.downstream[0].downstream[0].name, "sample_data.value") + + def test_pivot_with_implicit_column_of_pivoted_source(self) -> None: + sql = """ + SELECT empid + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ( + '2023_Q1', + '2023_Q2', + '2023_Q3')) + ORDER BY empid; + """ + node = lineage("empid", sql) + + self.assertEqual(node.downstream[0].name, "quarterly_sales.empid") + + def test_pivot_with_implicit_column_of_pivoted_source_and_cte(self) -> None: + sql = """ + WITH t as ( + SELECT empid + FROM quarterly_sales + PIVOT(SUM(amount) FOR quarter IN ( + '2023_Q1', + '2023_Q2', + '2023_Q3')) + ) + select empid from t + """ + node = lineage("empid", sql) + + self.assertEqual(node.downstream[0].name, "t.empid") + self.assertEqual(node.downstream[0].reference_node_name, "t") + self.assertEqual(node.downstream[0].downstream[0].name, "quarterly_sales.empid") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 4a41e4a..7f2ed0d 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -551,6 +551,10 @@ class TestOptimizer(unittest.TestCase): SELECT :with,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expression,SELECT :expressions,2,:distinct,True,:alias, AS cte,CTE :this,SELECT :expressions,WINDOW :this,ROW(),:partition_by,y,:over,OVER,:from,FROM ((SELECT :expressions,1):limit,LIMIT :expression,10),:alias, AS cte2,:expressions,STAR,a + 1,a DIV 1,FILTER("B",LAMBDA :this,x + y,:expressions,x,y),:from,FROM (z AS z:joins,JOIN :this,z,:kind,CROSS) AS f(a),:joins,JOIN :this,a.b.c.d.e.f.g,:side,LEFT,:using,n,:order,ORDER :expressions,ORDERED :this,1,:nulls_first,True """.strip(), ) + self.assertEqual( + optimizer.simplify.gen(parse_one("select item_id /* description */"), comments=True), + "SELECT :expressions,item_id /* description */", + ) def test_unnest_subqueries(self): self.check_file("unnest_subqueries", optimizer.unnest_subqueries.unnest_subqueries) |