diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dialects/test_bigquery.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 40 | ||||
-rw-r--r-- | tests/dialects/test_databricks.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 51 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 37 | ||||
-rw-r--r-- | tests/dialects/test_starrocks.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_trino.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 52 | ||||
-rw-r--r-- | tests/fixtures/optimizer/annotate_functions.sql | 122 | ||||
-rw-r--r-- | tests/fixtures/optimizer/merge_subqueries.sql | 18 | ||||
-rw-r--r-- | tests/fixtures/optimizer/qualify_columns.sql | 11 | ||||
-rw-r--r-- | tests/fixtures/optimizer/qualify_tables.sql | 3 | ||||
-rw-r--r-- | tests/fixtures/optimizer/simplify.sql | 3 | ||||
-rw-r--r-- | tests/test_build.py | 4 | ||||
-rw-r--r-- | tests/test_diff.py | 38 | ||||
-rw-r--r-- | tests/test_expressions.py | 6 | ||||
-rw-r--r-- | tests/test_optimizer.py | 22 |
22 files changed, 431 insertions, 12 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index d854165..3b317bc 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -107,6 +107,7 @@ LANGUAGE js AS select_with_quoted_udf = self.validate_identity("SELECT `p.d.UdF`(data) FROM `p.d.t`") self.assertEqual(select_with_quoted_udf.selects[0].name, "p.d.UdF") + self.validate_identity("SELECT ARRAY_CONCAT([1])") self.validate_identity("SELECT * FROM READ_CSV('bla.csv')") self.validate_identity("CAST(x AS STRUCT<list ARRAY<INT64>>)") self.validate_identity("assert.true(1 = 1)") diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index b4fc587..56ff06f 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -2,6 +2,7 @@ from datetime import date from sqlglot import exp, parse_one from sqlglot.dialects import ClickHouse from sqlglot.expressions import convert +from sqlglot.optimizer import traverse_scope from tests.dialects.test_dialect import Validator from sqlglot.errors import ErrorLevel @@ -28,6 +29,7 @@ class TestClickhouse(Validator): self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)") self.assertIsNone(expr._meta) + self.validate_identity("CAST(1 AS Bool)") self.validate_identity("SELECT toString(CHAR(104.1, 101, 108.9, 108.9, 111, 32))") self.validate_identity("@macro").assert_is(exp.Parameter).this.assert_is(exp.Var) self.validate_identity("SELECT toFloat(like)") @@ -420,11 +422,6 @@ class TestClickhouse(Validator): " GROUP BY loyalty ORDER BY loyalty ASC" }, ) - self.validate_identity("SELECT s, arr FROM arrays_test ARRAY JOIN arr") - self.validate_identity("SELECT s, arr, a FROM arrays_test LEFT ARRAY JOIN arr AS a") - self.validate_identity( - "SELECT s, arr_external FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external" - ) self.validate_all( "SELECT quantile(0.5)(a)", read={"duckdb": "SELECT quantile(a, 0.5)"}, @@ -1100,3 +1097,36 @@ LIFETIME(MIN 0 MAX 0)""", def test_grant(self): self.validate_identity("GRANT SELECT(x, y) ON db.table TO john WITH GRANT OPTION") self.validate_identity("GRANT INSERT(x, y) ON db.table TO john") + + def test_array_join(self): + expr = self.validate_identity( + "SELECT * FROM arrays_test ARRAY JOIN arr1, arrays_test.arr2 AS foo, ['a', 'b', 'c'] AS elem" + ) + joins = expr.args["joins"] + self.assertEqual(len(joins), 1) + + join = joins[0] + self.assertEqual(join.kind, "ARRAY") + self.assertIsInstance(join.this, exp.Column) + + self.assertEqual(len(join.expressions), 2) + self.assertIsInstance(join.expressions[0], exp.Alias) + self.assertIsInstance(join.expressions[0].this, exp.Column) + + self.assertIsInstance(join.expressions[1], exp.Alias) + self.assertIsInstance(join.expressions[1].this, exp.Array) + + self.validate_identity("SELECT s, arr FROM arrays_test ARRAY JOIN arr") + self.validate_identity("SELECT s, arr, a FROM arrays_test LEFT ARRAY JOIN arr AS a") + self.validate_identity( + "SELECT s, arr_external FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external" + ) + self.validate_identity( + "SELECT * FROM arrays_test ARRAY JOIN [1, 2, 3] AS arr_external1, ['a', 'b', 'c'] AS arr_external2, splitByString(',', 'asd,qwerty,zxc') AS arr_external3" + ) + + def test_traverse_scope(self): + sql = "SELECT * FROM t FINAL" + scopes = traverse_scope(parse_one(sql, dialect=self.dialect)) + self.assertEqual(len(scopes), 1) + self.assertEqual(set(scopes[0].sources), {"t"}) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 65e8d5d..f7ec756 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -7,6 +7,7 @@ class TestDatabricks(Validator): dialect = "databricks" def test_databricks(self): + self.validate_identity("SELECT t.current_time FROM t") self.validate_identity("ALTER TABLE labels ADD COLUMN label_score FLOAT") self.validate_identity("DESCRIBE HISTORY a.b") self.validate_identity("DESCRIBE history.tbl") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 96ce600..84a4ff3 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1762,6 +1762,7 @@ class TestDialect(Validator): self.validate_all( "LEVENSHTEIN(col1, col2)", write={ + "bigquery": "EDIT_DISTANCE(col1, col2)", "duckdb": "LEVENSHTEIN(col1, col2)", "drill": "LEVENSHTEIN_DISTANCE(col1, col2)", "presto": "LEVENSHTEIN_DISTANCE(col1, col2)", @@ -1772,6 +1773,7 @@ class TestDialect(Validator): self.validate_all( "LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))", write={ + "bigquery": "EDIT_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))", "duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", "drill": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))", "presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))", diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 6b58934..1f8fb81 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -256,6 +256,9 @@ class TestDuckDB(Validator): parse_one("a // b", read="duckdb").assert_is(exp.IntDiv).sql(dialect="duckdb"), "a // b" ) + self.validate_identity("SELECT UNNEST([1, 2])").selects[0].assert_is(exp.UDTF) + self.validate_identity("'red' IN flags").args["field"].assert_is(exp.Column) + self.validate_identity("'red' IN tbl.flags") self.validate_identity("CREATE TABLE tbl1 (u UNION(num INT, str TEXT))") self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y") self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x") diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 136ea60..e40a85a 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -187,6 +187,7 @@ class TestHive(Validator): "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b", write={ "presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y) AS t(a) CROSS JOIN UNNEST(z) AS u(b)", + "duckdb": "SELECT a, b FROM x CROSS JOIN UNNEST(y) AS t(a) CROSS JOIN UNNEST(z) AS u(b)", "hive": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b", "spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b", }, @@ -195,6 +196,7 @@ class TestHive(Validator): "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", write={ "presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)", + "duckdb": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)", "hive": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", "spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", }, @@ -211,6 +213,7 @@ class TestHive(Validator): "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", write={ "presto": "SELECT a FROM x CROSS JOIN UNNEST(ARRAY[y]) AS t(a)", + "duckdb": "SELECT a FROM x CROSS JOIN UNNEST([y]) AS t(a)", "hive": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", "spark": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", }, diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 0e593ef..bd0d6c3 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -139,6 +139,7 @@ class TestMySQL(Validator): ) def test_identity(self): + self.validate_identity("SELECT HIGH_PRIORITY STRAIGHT_JOIN SQL_CALC_FOUND_ROWS * FROM t") self.validate_identity("SELECT CAST(COALESCE(`id`, 'NULL') AS CHAR CHARACTER SET binary)") self.validate_identity("SELECT e.* FROM e STRAIGHT_JOIN p ON e.x = p.y") self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1") @@ -1305,3 +1306,12 @@ COMMENT='客户账户表'""" for sql in grant_cmds: with self.subTest(f"Testing MySQL's GRANT command statement: {sql}"): self.validate_identity(sql, check_command_warning=True) + + def test_explain(self): + self.validate_identity( + "EXPLAIN ANALYZE SELECT * FROM t", "DESCRIBE ANALYZE SELECT * FROM t" + ) + + expression = self.parse_one("EXPLAIN ANALYZE SELECT * FROM t") + self.assertIsInstance(expression, exp.Describe) + self.assertEqual(expression.text("style"), "ANALYZE") diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index d2bbedc..36ce5d0 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -329,6 +329,57 @@ class TestOracle(Validator): ) self.validate_identity("INSERT /*+ APPEND */ INTO IAP_TBL (id, col1) VALUES (2, 'test2')") self.validate_identity("INSERT /*+ APPEND_VALUES */ INTO dest_table VALUES (i, 'Value')") + self.validate_identity( + "SELECT /*+ LEADING(departments employees) USE_NL(employees) */ * FROM employees JOIN departments ON employees.department_id = departments.department_id", + """SELECT /*+ LEADING(departments employees) + USE_NL(employees) */ + * +FROM employees +JOIN departments + ON employees.department_id = departments.department_id""", + pretty=True, + ) + self.validate_identity( + "SELECT /*+ USE_NL(bbbbbbbbbbbbbbbbbbbbbbbb) LEADING(aaaaaaaaaaaaaaaaaaaaaaaa bbbbbbbbbbbbbbbbbbbbbbbb cccccccccccccccccccccccc dddddddddddddddddddddddd) INDEX(cccccccccccccccccccccccc) */ * FROM aaaaaaaaaaaaaaaaaaaaaaaa JOIN bbbbbbbbbbbbbbbbbbbbbbbb ON aaaaaaaaaaaaaaaaaaaaaaaa.id = bbbbbbbbbbbbbbbbbbbbbbbb.a_id JOIN cccccccccccccccccccccccc ON bbbbbbbbbbbbbbbbbbbbbbbb.id = cccccccccccccccccccccccc.b_id JOIN dddddddddddddddddddddddd ON cccccccccccccccccccccccc.id = dddddddddddddddddddddddd.c_id", + ) + self.validate_identity( + "SELECT /*+ USE_NL(bbbbbbbbbbbbbbbbbbbbbbbb) LEADING(aaaaaaaaaaaaaaaaaaaaaaaa bbbbbbbbbbbbbbbbbbbbbbbb cccccccccccccccccccccccc dddddddddddddddddddddddd) INDEX(cccccccccccccccccccccccc) */ * FROM aaaaaaaaaaaaaaaaaaaaaaaa JOIN bbbbbbbbbbbbbbbbbbbbbbbb ON aaaaaaaaaaaaaaaaaaaaaaaa.id = bbbbbbbbbbbbbbbbbbbbbbbb.a_id JOIN cccccccccccccccccccccccc ON bbbbbbbbbbbbbbbbbbbbbbbb.id = cccccccccccccccccccccccc.b_id JOIN dddddddddddddddddddddddd ON cccccccccccccccccccccccc.id = dddddddddddddddddddddddd.c_id", + """SELECT /*+ USE_NL(bbbbbbbbbbbbbbbbbbbbbbbb) + LEADING( + aaaaaaaaaaaaaaaaaaaaaaaa + bbbbbbbbbbbbbbbbbbbbbbbb + cccccccccccccccccccccccc + dddddddddddddddddddddddd + ) + INDEX(cccccccccccccccccccccccc) */ + * +FROM aaaaaaaaaaaaaaaaaaaaaaaa +JOIN bbbbbbbbbbbbbbbbbbbbbbbb + ON aaaaaaaaaaaaaaaaaaaaaaaa.id = bbbbbbbbbbbbbbbbbbbbbbbb.a_id +JOIN cccccccccccccccccccccccc + ON bbbbbbbbbbbbbbbbbbbbbbbb.id = cccccccccccccccccccccccc.b_id +JOIN dddddddddddddddddddddddd + ON cccccccccccccccccccccccc.id = dddddddddddddddddddddddd.c_id""", + pretty=True, + ) + # Test that parsing error with keywords like select where etc falls back + self.validate_identity( + "SELECT /*+ LEADING(departments employees) USE_NL(employees) select where group by is order by */ * FROM employees JOIN departments ON employees.department_id = departments.department_id", + """SELECT /*+ LEADING(departments employees) USE_NL(employees) select where group by is order by */ + * +FROM employees +JOIN departments + ON employees.department_id = departments.department_id""", + pretty=True, + ) + # Test that parsing error with , inside hint function falls back + self.validate_identity( + "SELECT /*+ LEADING(departments, employees) */ * FROM employees JOIN departments ON employees.department_id = departments.department_id" + ) + # Test that parsing error with keyword inside hint function falls back + self.validate_identity( + "SELECT /*+ LEADING(departments select) */ * FROM employees JOIN departments ON employees.department_id = departments.department_id" + ) def test_xml_table(self): self.validate_identity("XMLTABLE('x')") diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 3d5fbfe..4c10a45 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -14,6 +14,16 @@ class TestPresto(Validator): self.validate_identity("CAST(x AS HYPERLOGLOG)") self.validate_all( + "CAST(x AS BOOLEAN)", + read={ + "tsql": "CAST(x AS BIT)", + }, + write={ + "presto": "CAST(x AS BOOLEAN)", + "tsql": "CAST(x AS BIT)", + }, + ) + self.validate_all( "SELECT FROM_ISO8601_TIMESTAMP('2020-05-11T11:15:05')", write={ "duckdb": "SELECT CAST('2020-05-11T11:15:05' AS TIMESTAMPTZ)", diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 6cde86b..409a5a6 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -927,6 +927,29 @@ WHERE "bigquery": "GENERATE_UUID()", }, ) + self.validate_identity("TRY_TO_TIMESTAMP(foo)").assert_is(exp.Anonymous) + self.validate_identity("TRY_TO_TIMESTAMP('12345')").assert_is(exp.Anonymous) + self.validate_all( + "SELECT TRY_TO_TIMESTAMP('2024-01-15 12:30:00.000')", + write={ + "snowflake": "SELECT TRY_CAST('2024-01-15 12:30:00.000' AS TIMESTAMP)", + "duckdb": "SELECT TRY_CAST('2024-01-15 12:30:00.000' AS TIMESTAMP)", + }, + ) + self.validate_all( + "SELECT TRY_TO_TIMESTAMP('invalid')", + write={ + "snowflake": "SELECT TRY_CAST('invalid' AS TIMESTAMP)", + "duckdb": "SELECT TRY_CAST('invalid' AS TIMESTAMP)", + }, + ) + self.validate_all( + "SELECT TRY_TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/DD/yyyy hh24:mi:ss')", + write={ + "snowflake": "SELECT TRY_TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/DD/yyyy hh24:mi:ss')", + "duckdb": "SELECT CAST(TRY_STRPTIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S') AS TIMESTAMP)", + }, + ) def test_null_treatment(self): self.validate_all( @@ -1085,6 +1108,20 @@ WHERE "spark": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1 PERCENT)", }, ) + self.validate_all( + "TO_DOUBLE(expr)", + write={ + "snowflake": "TO_DOUBLE(expr)", + "duckdb": "CAST(expr AS DOUBLE)", + }, + ) + self.validate_all( + "TO_DOUBLE(expr, fmt)", + write={ + "snowflake": "TO_DOUBLE(expr, fmt)", + "duckdb": UnsupportedError, + }, + ) def test_timestamps(self): self.validate_identity("SELECT CAST('12:00:00' AS TIME)") diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py index 44c54a6..1edd7c6 100644 --- a/tests/dialects/test_starrocks.py +++ b/tests/dialects/test_starrocks.py @@ -5,6 +5,9 @@ from tests.dialects.test_dialect import Validator class TestStarrocks(Validator): dialect = "starrocks" + def test_starrocks(self): + self.validate_identity("ALTER TABLE a SWAP WITH b") + def test_ddl(self): ddl_sqls = [ "DISTRIBUTED BY HASH (col1) BUCKETS 1", diff --git a/tests/dialects/test_trino.py b/tests/dialects/test_trino.py index 8c73ec1..33a0229 100644 --- a/tests/dialects/test_trino.py +++ b/tests/dialects/test_trino.py @@ -9,6 +9,9 @@ class TestTrino(Validator): self.validate_identity("JSON_QUERY(content, 'lax $.HY.*')") self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITH UNCONDITIONAL WRAPPER)") self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITHOUT CONDITIONAL WRAPPER)") + self.validate_identity( + "SELECT LISTAGG(DISTINCT col, ',') WITHIN GROUP (ORDER BY col ASC) FROM tbl" + ) def test_trim(self): self.validate_identity("SELECT TRIM('!' FROM '!foo!')") diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 9be6fcd..042891a 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1,6 +1,6 @@ from sqlglot import exp, parse, parse_one from tests.dialects.test_dialect import Validator -from sqlglot.errors import ParseError +from sqlglot.errors import ParseError, UnsupportedError from sqlglot.optimizer.annotate_types import annotate_types @@ -1001,6 +1001,17 @@ class TestTSQL(Validator): ) self.validate_identity("CREATE PROC foo AS SELECT BAR() AS baz") self.validate_identity("CREATE PROCEDURE foo AS SELECT BAR() AS baz") + + self.validate_identity("CREATE PROCEDURE foo WITH ENCRYPTION AS SELECT 1") + self.validate_identity("CREATE PROCEDURE foo WITH RECOMPILE AS SELECT 1") + self.validate_identity("CREATE PROCEDURE foo WITH SCHEMABINDING AS SELECT 1") + self.validate_identity("CREATE PROCEDURE foo WITH NATIVE_COMPILATION AS SELECT 1") + self.validate_identity("CREATE PROCEDURE foo WITH EXECUTE AS OWNER AS SELECT 1") + self.validate_identity("CREATE PROCEDURE foo WITH EXECUTE AS 'username' AS SELECT 1") + self.validate_identity( + "CREATE PROCEDURE foo WITH EXECUTE AS OWNER, SCHEMABINDING, NATIVE_COMPILATION AS SELECT 1" + ) + self.validate_identity("CREATE FUNCTION foo(@bar INTEGER) RETURNS TABLE AS RETURN SELECT 1") self.validate_identity("CREATE FUNCTION dbo.ISOweek(@DATE DATETIME2) RETURNS INTEGER") @@ -1059,6 +1070,7 @@ WHERE CREATE procedure [TRANSF].[SP_Merge_Sales_Real] @Loadid INTEGER ,@NumberOfRows INTEGER + WITH EXECUTE AS OWNER, SCHEMABINDING, NATIVE_COMPILATION AS BEGIN SET XACT_ABORT ON; @@ -1074,7 +1086,7 @@ WHERE """ expected_sqls = [ - "CREATE PROCEDURE [TRANSF].[SP_Merge_Sales_Real] @Loadid INTEGER, @NumberOfRows INTEGER AS BEGIN SET XACT_ABORT ON", + "CREATE PROCEDURE [TRANSF].[SP_Merge_Sales_Real] @Loadid INTEGER, @NumberOfRows INTEGER WITH EXECUTE AS OWNER, SCHEMABINDING, NATIVE_COMPILATION AS BEGIN SET XACT_ABORT ON", "DECLARE @DWH_DateCreated AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)", "DECLARE @DWH_DateModified AS DATETIME2 = CONVERT(DATETIME2, GETDATE(), 104)", "DECLARE @DWH_IdUserCreated AS INTEGER = SUSER_ID(CURRENT_USER())", @@ -2014,3 +2026,39 @@ FROM OPENJSON(@json) WITH ( self.validate_identity( "GRANT EXECUTE ON TestProc TO User2 AS TesterRole", check_command_warning=True ) + + def test_parsename(self): + for i in range(4): + with self.subTest("Testing PARSENAME <-> SPLIT_PART"): + self.validate_all( + f"SELECT PARSENAME('1.2.3', {i})", + read={ + "spark": f"SELECT SPLIT_PART('1.2.3', '.', {4 - i})", + "databricks": f"SELECT SPLIT_PART('1.2.3', '.', {4 - i})", + }, + write={ + "spark": f"SELECT SPLIT_PART('1.2.3', '.', {4 - i})", + "databricks": f"SELECT SPLIT_PART('1.2.3', '.', {4 - i})", + "tsql": f"SELECT PARSENAME('1.2.3', {i})", + }, + ) + + # Test non-dot delimiter + self.validate_all( + "SELECT SPLIT_PART('1,2,3', ',', 1)", + write={ + "spark": "SELECT SPLIT_PART('1,2,3', ',', 1)", + "databricks": "SELECT SPLIT_PART('1,2,3', ',', 1)", + "tsql": UnsupportedError, + }, + ) + + # Test column-type parameters + self.validate_all( + "WITH t AS (SELECT 'a.b.c' AS value, 1 AS idx) SELECT SPLIT_PART(value, '.', idx) FROM t", + write={ + "spark": "WITH t AS (SELECT 'a.b.c' AS value, 1 AS idx) SELECT SPLIT_PART(value, '.', idx) FROM t", + "databricks": "WITH t AS (SELECT 'a.b.c' AS value, 1 AS idx) SELECT SPLIT_PART(value, '.', idx) FROM t", + "tsql": UnsupportedError, + }, + ) diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql index 8aa77d4..1f59a5a 100644 --- a/tests/fixtures/optimizer/annotate_functions.sql +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -186,4 +186,124 @@ DOUBLE; # dialect: bigquery EXP(tbl.bignum_col); -BIGDECIMAL;
\ No newline at end of file +BIGDECIMAL; + +# dialect: bigquery +CONCAT(tbl.str_col, tbl.str_col); +STRING; + +# dialect: bigquery +CONCAT(tbl.bin_col, tbl.bin_col); +BINARY; + +# dialect: bigquery +LEFT(tbl.str_col, 1); +STRING; + +# dialect: bigquery +LEFT(tbl.bin_col, 1); +BINARY; + +# dialect: bigquery +RIGHT(tbl.str_col, 1); +STRING; + +# dialect: bigquery +RIGHT(tbl.bin_col, 1); +BINARY; + +# dialect: bigquery +LOWER(tbl.str_col); +STRING; + +# dialect: bigquery +LOWER(tbl.bin_col); +BINARY; + +# dialect: bigquery +UPPER(tbl.str_col); +STRING; + +# dialect: bigquery +UPPER(tbl.bin_col); +BINARY; + +# dialect: bigquery +LPAD(tbl.str_col, 1, tbl.str_col); +STRING; + +# dialect: bigquery +LPAD(tbl.bin_col, 1, tbl.bin_col); +BINARY; + +# dialect: bigquery +RPAD(tbl.str_col, 1, tbl.str_col); +STRING; + +# dialect: bigquery +RPAD(tbl.bin_col, 1, tbl.bin_col); +BINARY; + +# dialect: bigquery +LTRIM(tbl.str_col); +STRING; + +# dialect: bigquery +LTRIM(tbl.bin_col, tbl.bin_col); +BINARY; + +# dialect: bigquery +RTRIM(tbl.str_col); +STRING; + +# dialect: bigquery +RTRIM(tbl.bin_col, tbl.bin_col); +BINARY; + +# dialect: bigquery +TRIM(tbl.str_col); +STRING; + +# dialect: bigquery +TRIM(tbl.bin_col, tbl.bin_col); +BINARY; + +# dialect: bigquery +REGEXP_EXTRACT(tbl.str_col, pattern); +STRING; + +# dialect: bigquery +REGEXP_EXTRACT(tbl.bin_col, pattern); +BINARY; + +# dialect: bigquery +REGEXP_REPLACE(tbl.str_col, pattern, replacement); +STRING; + +# dialect: bigquery +REGEXP_REPLACE(tbl.bin_col, pattern, replacement); +BINARY; + +# dialect: bigquery +REPEAT(tbl.str_col, 1); +STRING; + +# dialect: bigquery +REPEAT(tbl.bin_col, 1); +BINARY; + +# dialect: bigquery +SUBSTRING(tbl.str_col, 1); +STRING; + +# dialect: bigquery +SUBSTRING(tbl.bin_col, 1); +BINARY; + +# dialect: bigquery +SPLIT(tbl.str_col, delim); +ARRAY<STRING>; + +# dialect: bigquery +SPLIT(tbl.bin_col, delim); +ARRAY<BINARY>; diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql index ce5a435..e39e7d1 100644 --- a/tests/fixtures/optimizer/merge_subqueries.sql +++ b/tests/fixtures/optimizer/merge_subqueries.sql @@ -446,3 +446,21 @@ SELECT 1 AS a; WITH q AS (SELECT x.a AS a FROM x AS x ORDER BY x.a) SELECT q.a AS a FROM q AS q UNION ALL SELECT 1 AS a; +# title: Consecutive inner - outer conflicting names +WITH tbl AS (select 1 as id) +SELECT + id +FROM ( + SELECT OTBL.id + FROM ( + SELECT OTBL.id + FROM ( + SELECT OTBL.id + FROM tbl AS OTBL + LEFT OUTER JOIN tbl AS ITBL ON OTBL.id = ITBL.id + ) AS OTBL + LEFT OUTER JOIN tbl AS ITBL ON OTBL.id = ITBL.id + ) AS OTBL + LEFT OUTER JOIN tbl AS ITBL ON OTBL.id = ITBL.id +) AS ITBL; +WITH tbl AS (SELECT 1 AS id) SELECT OTBL.id AS id FROM tbl AS OTBL LEFT OUTER JOIN tbl AS ITBL_2 ON OTBL.id = ITBL_2.id LEFT OUTER JOIN tbl AS ITBL_3 ON OTBL.id = ITBL_3.id LEFT OUTER JOIN tbl AS ITBL ON OTBL.id = ITBL.id; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 68c0caa..7c901ce 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -190,6 +190,17 @@ SELECT x._col_0 AS _col_0, x._col_1 AS _col_1 FROM (VALUES (1, 2)) AS x(_col_0, SELECT SOME_UDF(data).* FROM t; SELECT SOME_UDF(t.data).* FROM t AS t; +# execute: false +# allow_partial_qualification: true +# validate_qualify_columns: false +SELECT a + 1 AS i, missing_column FROM x; +SELECT x.a + 1 AS i, missing_column AS missing_column FROM x AS x; + +# execute: false +# dialect: clickhouse +SELECT s, arr1, arr2 FROM arrays_test LEFT ARRAY JOIN arr1, arrays_test.arr2; +SELECT arrays_test.s AS s, arrays_test.arr1 AS arr1, arrays_test.arr2 AS arr2 FROM arrays_test AS arrays_test LEFT ARRAY JOIN arrays_test.arr1, arrays_test.arr2; + -------------------------------------- -- Derived tables -------------------------------------- diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql index 61d0b96..49e07fa 100644 --- a/tests/fixtures/optimizer/qualify_tables.sql +++ b/tests/fixtures/optimizer/qualify_tables.sql @@ -184,3 +184,6 @@ COPY INTO (SELECT * FROM c.db.x AS x) TO 'data' WITH (FORMAT 'CSV'); # title: tablesample SELECT 1 FROM x TABLESAMPLE SYSTEM (10 PERCENT) CROSS JOIN y TABLESAMPLE SYSTEM (10 PERCENT); SELECT 1 FROM c.db.x AS x TABLESAMPLE SYSTEM (10 PERCENT) CROSS JOIN c.db.y AS y TABLESAMPLE SYSTEM (10 PERCENT); + +WITH cte_tbl AS (SELECT 1 AS col2) UPDATE y SET col1 = (SELECT * FROM x) WHERE EXISTS(SELECT 1 FROM cte_tbl); +WITH cte_tbl AS (SELECT 1 AS col2) UPDATE c.db.y SET col1 = (SELECT * FROM c.db.x AS x) WHERE EXISTS(SELECT 1 FROM cte_tbl AS cte_tbl); diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index fa2dc79..1842e55 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -140,6 +140,9 @@ TRUE; COALESCE(x, y) <> ALL (SELECT z FROM w); COALESCE(x, y) <> ALL (SELECT z FROM w); +SELECT NOT (2 <> ALL (SELECT 2 UNION ALL SELECT 3)); +SELECT 2 = ANY(SELECT 2 UNION ALL SELECT 3); + -------------------------------------- -- Absorption -------------------------------------- diff --git a/tests/test_build.py b/tests/test_build.py index 5d383ad..92a7b2e 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -801,6 +801,10 @@ class TestBuild(unittest.TestCase): ), "MERGE INTO target_table AS target USING source_table AS source ON target.id = source.id WHEN MATCHED THEN UPDATE SET target.name = source.name RETURNING target.*", ), + ( + lambda: exp.union("SELECT 1", "SELECT 2", "SELECT 3", "SELECT 4"), + "SELECT 1 UNION SELECT 2 UNION SELECT 3 UNION SELECT 4", + ), ]: with self.subTest(sql): self.assertEqual(expression().sql(dialect[0] if dialect else None), sql) diff --git a/tests/test_diff.py b/tests/test_diff.py index edd3b26..f0e0747 100644 --- a/tests/test_diff.py +++ b/tests/test_diff.py @@ -240,5 +240,43 @@ class TestDiff(unittest.TestCase): self.assertEqual(["WARNING:sqlglot:Dummy warning"], cm.output) + def test_non_expression_leaf_delta(self): + expr_src = parse_one("SELECT a UNION SELECT b") + expr_tgt = parse_one("SELECT a UNION ALL SELECT b") + + self._validate_delta_only( + diff_delta_only(expr_src, expr_tgt), + [ + Update(source=expr_src, target=expr_tgt), + ], + ) + + expr_src = parse_one("SELECT a FROM t ORDER BY b ASC") + expr_tgt = parse_one("SELECT a FROM t ORDER BY b DESC") + + self._validate_delta_only( + diff_delta_only(expr_src, expr_tgt), + [ + Update( + source=expr_src.find(exp.Order).expressions[0], + target=expr_tgt.find(exp.Order).expressions[0], + ), + ], + ) + + expr_src = parse_one("SELECT a, b FROM t ORDER BY c ASC") + expr_tgt = parse_one("SELECT b, a FROM t ORDER BY c DESC") + + self._validate_delta_only( + diff_delta_only(expr_src, expr_tgt), + [ + Update( + source=expr_src.find(exp.Order).expressions[0], + target=expr_tgt.find(exp.Order).expressions[0], + ), + Move(parse_one("a")), + ], + ) + def _validate_delta_only(self, actual_delta, expected_delta): self.assertEqual(set(actual_delta), set(expected_delta)) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 1c88952..8ff117a 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -935,15 +935,13 @@ FROM foo""", def test_to_interval(self): self.assertEqual(exp.to_interval("1day").sql(), "INTERVAL '1' DAY") self.assertEqual(exp.to_interval(" 5 months").sql(), "INTERVAL '5' MONTHS") - with self.assertRaises(ValueError): - exp.to_interval("bla") + self.assertEqual(exp.to_interval("-2 day").sql(), "INTERVAL '-2' DAY") self.assertEqual(exp.to_interval(exp.Literal.string("1day")).sql(), "INTERVAL '1' DAY") + self.assertEqual(exp.to_interval(exp.Literal.string("-2 day")).sql(), "INTERVAL '-2' DAY") self.assertEqual( exp.to_interval(exp.Literal.string(" 5 months")).sql(), "INTERVAL '5' MONTHS" ) - with self.assertRaises(ValueError): - exp.to_interval(exp.Literal.string("bla")) def test_to_table(self): table_only = exp.to_table("table_name") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 2c2015b..9313285 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -602,6 +602,16 @@ SELECT :with,WITH :expressions,CTE :this,UNION :this,SELECT :expressions,1,:expr "WITH data AS (SELECT 1 AS id, 2 AS my_id, 'a' AS name, 'b' AS full_name) SELECT data.id AS my_id, CONCAT(data.id, data.name) AS full_name FROM data WHERE data.id = 1 GROUP BY data.id, CONCAT(data.id, data.name) HAVING data.id = 1", ) + # Edge case: BigQuery shouldn't expand aliases in complex expressions + sql = "WITH data AS (SELECT 1 AS id) SELECT FUNC(id) AS id FROM data GROUP BY FUNC(id)" + self.assertEqual( + optimizer.qualify_columns.qualify_columns( + parse_one(sql, dialect="bigquery"), + schema=MappingSchema(schema=unused_schema, dialect="bigquery"), + ).sql(), + "WITH data AS (SELECT 1 AS id) SELECT FUNC(data.id) AS id FROM data GROUP BY FUNC(data.id)", + ) + def test_optimize_joins(self): self.check_file( "optimize_joins", @@ -779,6 +789,18 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(scopes[2].expression.sql(), f"SELECT a FROM foo CROSS JOIN {udtf}") self.assertEqual(set(scopes[2].sources), {"", "foo"}) + # Check DML statement scopes + sql = ( + "UPDATE customers SET total_spent = (SELECT 1 FROM t1) WHERE EXISTS (SELECT 1 FROM t2)" + ) + self.assertEqual(len(traverse_scope(parse_one(sql))), 3) + + sql = "UPDATE tbl1 SET col = 1 WHERE EXISTS (SELECT 1 FROM tbl2 WHERE tbl1.id = tbl2.id)" + self.assertEqual(len(traverse_scope(parse_one(sql))), 1) + + sql = "UPDATE tbl1 SET col = 0" + self.assertEqual(len(traverse_scope(parse_one(sql))), 0) + @patch("sqlglot.optimizer.scope.logger") def test_scope_warning(self, logger): self.assertEqual(len(traverse_scope(parse_one("WITH q AS (@y) SELECT * FROM q"))), 1) |