diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-10-09 06:28:48 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-10-09 06:28:48 +0000 |
commit | 33802ae744af096b1be30c5e2d02e03c8fce4c77 (patch) | |
tree | 13be65e148a9772441401d092259912c630a2adc /tests | |
parent | Adding upstream version 25.24.0. (diff) | |
download | sqlglot-33802ae744af096b1be30c5e2d02e03c8fce4c77.tar.xz sqlglot-33802ae744af096b1be30c5e2d02e03c8fce4c77.zip |
Adding upstream version 25.24.5.upstream/25.24.5upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dialects/test_athena.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_bigquery.py | 14 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 32 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 20 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 17 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 12 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 27 | ||||
-rw-r--r-- | tests/dialects/test_sqlite.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_trino.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 9 | ||||
-rw-r--r-- | tests/fixtures/optimizer/annotate_functions.sql | 189 | ||||
-rw-r--r-- | tests/fixtures/pretty.sql | 31 | ||||
-rw-r--r-- | tests/test_build.py | 30 | ||||
-rw-r--r-- | tests/test_diff.py | 42 | ||||
-rw-r--r-- | tests/test_expressions.py | 1 | ||||
-rw-r--r-- | tests/test_optimizer.py | 57 | ||||
-rw-r--r-- | tests/test_transpile.py | 31 |
18 files changed, 477 insertions, 54 deletions
diff --git a/tests/dialects/test_athena.py b/tests/dialects/test_athena.py index ca91d4a..ef96938 100644 --- a/tests/dialects/test_athena.py +++ b/tests/dialects/test_athena.py @@ -62,8 +62,12 @@ class TestAthena(Validator): # CTAS goes to the Trino engine, where the table properties cant be encased in single quotes like they can for Hive # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties + # They're also case sensitive and need to be lowercase, otherwise you get eg "Table properties [FORMAT] are not supported." self.validate_identity( - "CREATE TABLE foo WITH (table_type='ICEBERG', external_location='s3://foo/') AS SELECT * FROM a" + "CREATE TABLE foo WITH (table_type='ICEBERG', location='s3://foo/', format='orc', partitioning=ARRAY['bucket(id, 5)']) AS SELECT * FROM a" + ) + self.validate_identity( + "CREATE TABLE foo WITH (table_type='HIVE', external_location='s3://foo/', format='parquet', partitioned_by=ARRAY['ds']) AS SELECT * FROM a" ) self.validate_identity( "CREATE TABLE foo AS WITH foo AS (SELECT a, b FROM bar) SELECT * FROM foo" diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index e2adfea..d854165 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -1985,3 +1985,17 @@ OPTIONS ( self.validate_identity( "SELECT RANGE(CAST('2022-10-01 14:53:27 America/Los_Angeles' AS TIMESTAMP), CAST('2022-10-01 16:00:00 America/Los_Angeles' AS TIMESTAMP))" ) + + def test_null_ordering(self): + # Aggregate functions allow "NULLS FIRST" only with ascending order and + # "NULLS LAST" only with descending + for sort_order, null_order in (("ASC", "NULLS LAST"), ("DESC", "NULLS FIRST")): + self.validate_all( + f"SELECT color, ARRAY_AGG(id ORDER BY id {sort_order}) AS ids FROM colors GROUP BY 1", + read={ + "": f"SELECT color, ARRAY_AGG(id ORDER BY id {sort_order} {null_order}) AS ids FROM colors GROUP BY 1" + }, + write={ + "bigquery": f"SELECT color, ARRAY_AGG(id ORDER BY id {sort_order}) AS ids FROM colors GROUP BY 1", + }, + ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index e4788ec..6b58934 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -858,6 +858,28 @@ class TestDuckDB(Validator): self.validate_identity( "SELECT COALESCE(*COLUMNS(['a', 'b', 'c'])) AS result FROM (SELECT NULL AS a, 42 AS b, TRUE AS c)" ) + self.validate_all( + "SELECT UNNEST(foo) AS x", + write={ + "redshift": UnsupportedError, + }, + ) + self.validate_identity("a ^ b", "POWER(a, b)") + self.validate_identity("a ** b", "POWER(a, b)") + self.validate_identity("a ~~~ b", "a GLOB b") + self.validate_identity("a ~~ b", "a LIKE b") + self.validate_identity("a @> b") + self.validate_identity("a <@ b", "b @> a") + self.validate_identity("a && b").assert_is(exp.ArrayOverlaps) + self.validate_identity("a ^@ b", "STARTS_WITH(a, b)") + self.validate_identity( + "a !~~ b", + "NOT a LIKE b", + ) + self.validate_identity( + "a !~~* b", + "NOT a ILIKE b", + ) def test_array_index(self): with self.assertLogs(helper_logger) as cm: @@ -967,6 +989,15 @@ class TestDuckDB(Validator): "spark": "DATE_FORMAT(x, 'yy-M-ss')", }, ) + + self.validate_all( + "SHA1(x)", + write={ + "duckdb": "SHA1(x)", + "": "SHA(x)", + }, + ) + self.validate_all( "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", write={ @@ -1086,6 +1117,7 @@ class TestDuckDB(Validator): self.validate_identity("CAST(x AS INT16)", "CAST(x AS SMALLINT)") self.validate_identity("CAST(x AS NUMERIC(1, 2))", "CAST(x AS DECIMAL(1, 2))") self.validate_identity("CAST(x AS HUGEINT)", "CAST(x AS INT128)") + self.validate_identity("CAST(x AS UHUGEINT)", "CAST(x AS UINT128)") self.validate_identity("CAST(x AS CHAR)", "CAST(x AS TEXT)") self.validate_identity("CAST(x AS BPCHAR)", "CAST(x AS TEXT)") self.validate_identity("CAST(x AS STRING)", "CAST(x AS TEXT)") diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 835ee7c..0e593ef 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -747,16 +747,28 @@ class TestMySQL(Validator): }, ) self.validate_all( - "SELECT * FROM x LEFT JOIN y ON x.id = y.id UNION SELECT * FROM x RIGHT JOIN y ON x.id = y.id LIMIT 0", + "SELECT * FROM x LEFT JOIN y ON x.id = y.id UNION ALL SELECT * FROM x RIGHT JOIN y ON x.id = y.id WHERE NOT EXISTS(SELECT 1 FROM x WHERE x.id = y.id) ORDER BY 1 LIMIT 0", read={ - "postgres": "SELECT * FROM x FULL JOIN y ON x.id = y.id LIMIT 0", + "postgres": "SELECT * FROM x FULL JOIN y ON x.id = y.id ORDER BY 1 LIMIT 0", }, ) self.validate_all( # MySQL doesn't support FULL OUTER joins - "WITH t1 AS (SELECT 1) SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.x = t2.x UNION SELECT * FROM t1 RIGHT OUTER JOIN t2 ON t1.x = t2.x", + "SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.x = t2.x UNION ALL SELECT * FROM t1 RIGHT OUTER JOIN t2 ON t1.x = t2.x WHERE NOT EXISTS(SELECT 1 FROM t1 WHERE t1.x = t2.x)", read={ - "postgres": "WITH t1 AS (SELECT 1) SELECT * FROM t1 FULL OUTER JOIN t2 ON t1.x = t2.x", + "postgres": "SELECT * FROM t1 FULL OUTER JOIN t2 ON t1.x = t2.x", + }, + ) + self.validate_all( + "SELECT * FROM t1 LEFT OUTER JOIN t2 USING (x) UNION ALL SELECT * FROM t1 RIGHT OUTER JOIN t2 USING (x) WHERE NOT EXISTS(SELECT 1 FROM t1 WHERE t1.x = t2.x)", + read={ + "postgres": "SELECT * FROM t1 FULL OUTER JOIN t2 USING (x) ", + }, + ) + self.validate_all( + "SELECT * FROM t1 LEFT OUTER JOIN t2 USING (x, y) UNION ALL SELECT * FROM t1 RIGHT OUTER JOIN t2 USING (x, y) WHERE NOT EXISTS(SELECT 1 FROM t1 WHERE t1.x = t2.x AND t1.y = t2.y)", + read={ + "postgres": "SELECT * FROM t1 FULL OUTER JOIN t2 USING (x, y) ", }, ) self.validate_all( diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 8675086..d2bbedc 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -67,6 +67,15 @@ class TestOracle(Validator): "SELECT COUNT(1) INTO V_Temp FROM TABLE(CAST(somelist AS data_list)) WHERE col LIKE '%contact'" ) self.validate_identity( + "SELECT department_id INTO v_department_id FROM departments FETCH FIRST 1 ROWS ONLY" + ) + self.validate_identity( + "SELECT department_id BULK COLLECT INTO v_department_ids FROM departments" + ) + self.validate_identity( + "SELECT department_id, department_name BULK COLLECT INTO v_department_ids, v_department_names FROM departments" + ) + self.validate_identity( "SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name" ) self.validate_identity( @@ -103,6 +112,14 @@ class TestOracle(Validator): ) self.validate_all( + "SELECT department_id, department_name INTO v_department_id, v_department_name FROM departments FETCH FIRST 1 ROWS ONLY", + write={ + "oracle": "SELECT department_id, department_name INTO v_department_id, v_department_name FROM departments FETCH FIRST 1 ROWS ONLY", + "postgres": UnsupportedError, + "tsql": UnsupportedError, + }, + ) + self.validate_all( "TRUNC(SYSDATE, 'YEAR')", write={ "clickhouse": "DATE_TRUNC('YEAR', CURRENT_TIMESTAMP())", diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 63266a5..62ae247 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -354,10 +354,10 @@ class TestPostgres(Validator): self.validate_all( "SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]", read={ - "duckdb": "SELECT ARRAY_HAS_ALL([1, 2, 3], [1, 2])", + "duckdb": "SELECT [1, 2, 3] @> [1, 2]", }, write={ - "duckdb": "SELECT ARRAY_HAS_ALL([1, 2, 3], [1, 2])", + "duckdb": "SELECT [1, 2, 3] @> [1, 2]", "mysql": UnsupportedError, "postgres": "SELECT ARRAY[1, 2, 3] @> ARRAY[1, 2]", }, @@ -399,13 +399,6 @@ class TestPostgres(Validator): }, ) self.validate_all( - "SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]", - write={ - "": "SELECT ARRAY_OVERLAPS(ARRAY(1, 2, 3), ARRAY(1, 2))", - "postgres": "SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]", - }, - ) - self.validate_all( "SELECT JSON_EXTRACT_PATH_TEXT(x, k1, k2, k3) FROM t", read={ "clickhouse": "SELECT JSONExtractString(x, k1, k2, k3) FROM t", @@ -802,6 +795,7 @@ class TestPostgres(Validator): ) self.validate_identity("SELECT OVERLAY(a PLACING b FROM 1)") self.validate_identity("SELECT OVERLAY(a PLACING b FROM 1 FOR 1)") + self.validate_identity("ARRAY[1, 2, 3] && ARRAY[1, 2]").assert_is(exp.ArrayOverlaps) def test_ddl(self): # Checks that user-defined types are parsed into DataType instead of Identifier diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 6f561da..01c7f78 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -214,6 +214,12 @@ class TestRedshift(Validator): }, ) self.validate_all( + "CREATE TABLE a (b BINARY VARYING(10))", + write={ + "redshift": "CREATE TABLE a (b VARBYTE(10))", + }, + ) + self.validate_all( "SELECT 'abc'::CHARACTER", write={ "redshift": "SELECT CAST('abc' AS CHAR)", diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 4fed68c..01859c6 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -2,7 +2,6 @@ from unittest import mock from sqlglot import exp, parse_one from sqlglot.dialects.dialect import Dialects -from sqlglot.helper import logger as helper_logger from tests.dialects.test_dialect import Validator @@ -294,19 +293,19 @@ TBLPROPERTIES ( "SELECT STR_TO_MAP('a:1,b:2,c:3')", "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", ) - - with self.assertLogs(helper_logger): - self.validate_all( - "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", - read={ - "databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", - }, - write={ - "databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", - "duckdb": "SELECT ([1, 2, 3])[3]", - "spark": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", - }, - ) + self.validate_all( + "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + read={ + "databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + "presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 2)", + }, + write={ + "databricks": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + "spark": "SELECT TRY_ELEMENT_AT(ARRAY(1, 2, 3), 2)", + "duckdb": "SELECT ([1, 2, 3])[2]", + "presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 2)", + }, + ) self.validate_all( "SELECT ARRAY_AGG(x) FILTER (WHERE x = 5) FROM (SELECT 1 UNION ALL SELECT NULL) AS t(x)", diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index f2c9802..230c0e8 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -26,6 +26,7 @@ class TestSQLite(Validator): """SELECT item AS "item", some AS "some" FROM data WHERE (item = 'value_1' COLLATE NOCASE) AND (some = 't' COLLATE NOCASE) ORDER BY item ASC LIMIT 1 OFFSET 0""" ) self.validate_identity("SELECT * FROM GENERATE_SERIES(1, 5)") + self.validate_identity("SELECT INSTR(haystack, needle)") self.validate_all("SELECT LIKE(y, x)", write={"sqlite": "SELECT x LIKE y"}) self.validate_all("SELECT GLOB('*y*', 'xyz')", write={"sqlite": "SELECT 'xyz' GLOB '*y*'"}) diff --git a/tests/dialects/test_trino.py b/tests/dialects/test_trino.py index 0ebe749..8c73ec1 100644 --- a/tests/dialects/test_trino.py +++ b/tests/dialects/test_trino.py @@ -4,6 +4,12 @@ from tests.dialects.test_dialect import Validator class TestTrino(Validator): dialect = "trino" + def test_trino(self): + self.validate_identity("JSON_EXTRACT(content, json_path)") + self.validate_identity("JSON_QUERY(content, 'lax $.HY.*')") + self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITH UNCONDITIONAL WRAPPER)") + self.validate_identity("JSON_QUERY(content, 'strict $.HY.*' WITHOUT CONDITIONAL WRAPPER)") + def test_trim(self): self.validate_identity("SELECT TRIM('!' FROM '!foo!')") self.validate_identity("SELECT TRIM(BOTH '$' FROM '$var$')") diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 453cd5a..9be6fcd 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -8,6 +8,11 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): + self.validate_identity( + "with x as (select 1) select * from x union select * from x order by 1 limit 0", + "WITH x AS (SELECT 1 AS [1]) SELECT TOP 0 * FROM (SELECT * FROM x UNION SELECT * FROM x) AS _l_0 ORDER BY 1", + ) + # https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN # tsql allows .. which means use the default schema self.validate_identity("SELECT * FROM a..b") @@ -46,6 +51,10 @@ class TestTSQL(Validator): self.validate_identity( "COPY INTO test_1 FROM 'path' WITH (FORMAT_NAME = test, FILE_TYPE = 'CSV', CREDENTIAL = (IDENTITY='Shared Access Signature', SECRET='token'), FIELDTERMINATOR = ';', ROWTERMINATOR = '0X0A', ENCODING = 'UTF8', DATEFORMAT = 'ymd', MAXERRORS = 10, ERRORFILE = 'errorsfolder', IDENTITY_INSERT = 'ON')" ) + self.validate_identity( + 'SELECT 1 AS "[x]"', + "SELECT 1 AS [[x]]]", + ) self.assertEqual( annotate_types(self.validate_identity("SELECT 1 WHERE EXISTS(SELECT 1)")).sql("tsql"), "SELECT 1 WHERE EXISTS(SELECT 1)", diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql new file mode 100644 index 0000000..8aa77d4 --- /dev/null +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -0,0 +1,189 @@ +-------------------------------------- +-- Dialect +-------------------------------------- +ABS(1); +INT; + +ABS(1.5); +DOUBLE; + +GREATEST(1, 2, 3); +INT; + +GREATEST(1, 2.5, 3); +DOUBLE; + +LEAST(1, 2, 3); +INT; + +LEAST(1, 2.5, 3); +DOUBLE; + +-------------------------------------- +-- Spark2 / Spark3 / Databricks +-------------------------------------- + +# dialect: spark2, spark, databricks +SUBSTRING(tbl.str_col, 0, 0); +STRING; + +# dialect: spark2, spark, databricks +SUBSTRING(tbl.bin_col, 0, 0); +BINARY; + +# dialect: spark2, spark, databricks +CONCAT(tbl.bin_col, tbl.bin_col); +BINARY; + +# dialect: spark2, spark, databricks +CONCAT(tbl.bin_col, tbl.str_col); +STRING; + +# dialect: spark2, spark, databricks +CONCAT(tbl.str_col, tbl.bin_col); +STRING; + +# dialect: spark2, spark, databricks +CONCAT(tbl.str_col, tbl.str_col); +STRING; + +# dialect: spark2, spark, databricks +CONCAT(tbl.str_col, unknown); +STRING; + +# dialect: spark2, spark, databricks +CONCAT(tbl.bin_col, unknown); +UNKNOWN; + +# dialect: spark2, spark, databricks +CONCAT(unknown, unknown); +UNKNOWN; + +# dialect: spark2, spark, databricks +LPAD(tbl.bin_col, 1, tbl.bin_col); +BINARY; + +# dialect: spark2, spark, databricks +RPAD(tbl.bin_col, 1, tbl.bin_col); +BINARY; + +# dialect: spark2, spark, databricks +LPAD(tbl.bin_col, 1, tbl.str_col); +STRING; + +# dialect: spark2, spark, databricks +RPAD(tbl.bin_col, 1, tbl.str_col); +STRING; + +# dialect: spark2, spark, databricks +LPAD(tbl.str_col, 1, tbl.bin_col); +STRING; + +# dialect: spark2, spark, databricks +RPAD(tbl.str_col, 1, tbl.bin_col); +STRING; + +# dialect: spark2, spark, databricks +LPAD(tbl.str_col, 1, tbl.str_col); +STRING; + +# dialect: spark2, spark, databricks +RPAD(tbl.str_col, 1, tbl.str_col); +STRING; + + +-------------------------------------- +-- BigQuery +-------------------------------------- + +# dialect: bigquery +SIGN(1); +INT; + +# dialect: bigquery +SIGN(1.5); +DOUBLE; + +# dialect: bigquery +CEIL(1); +DOUBLE; + +# dialect: bigquery +CEIL(5.5); +DOUBLE; + +# dialect: bigquery +CEIL(tbl.bignum_col); +BIGDECIMAL; + +# dialect: bigquery +FLOOR(1); +DOUBLE; + +# dialect: bigquery +FLOOR(5.5); +DOUBLE; + +# dialect: bigquery +FLOOR(tbl.bignum_col); +BIGDECIMAL; + +# dialect: bigquery +SQRT(1); +DOUBLE; + +# dialect: bigquery +SQRT(5.5); +DOUBLE; + +# dialect: bigquery +SQRT(tbl.bignum_col); +BIGDECIMAL; + +# dialect: bigquery +LN(1); +DOUBLE; + +# dialect: bigquery +LN(5.5); +DOUBLE; + +# dialect: bigquery +LN(tbl.bignum_col); +BIGDECIMAL; + +# dialect: bigquery +LOG(1); +DOUBLE; + +# dialect: bigquery +LOG(5.5); +DOUBLE; + +# dialect: bigquery +LOG(tbl.bignum_col); +BIGDECIMAL; + +# dialect: bigquery +ROUND(1); +DOUBLE; + +# dialect: bigquery +ROUND(5.5); +DOUBLE; + +# dialect: bigquery +ROUND(tbl.bignum_col); +BIGDECIMAL; + +# dialect: bigquery +EXP(1); +DOUBLE; + +# dialect: bigquery +EXP(5.5); +DOUBLE; + +# dialect: bigquery +EXP(tbl.bignum_col); +BIGDECIMAL;
\ No newline at end of file diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index 3e5619a..ca5e4a9 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -418,3 +418,34 @@ INSERT FIRST SELECT salary FROM employees; + +SELECT * +FROM foo +wHERE 1=1 + AND + -- my comment + EXISTS ( + SELECT 1 + FROM bar + ); +SELECT + * +FROM foo +WHERE + 1 = 1 AND EXISTS( + SELECT + 1 + FROM bar + ) /* my comment */; + +SELECT 1 +FROM foo +WHERE 1=1 +AND -- first comment + -- second comment + foo.a = 1; +SELECT + 1 +FROM foo +WHERE + 1 = 1 AND /* first comment */ foo.a /* second comment */ = 1; diff --git a/tests/test_build.py b/tests/test_build.py index 7518b72..5d383ad 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -578,6 +578,36 @@ class TestBuild(unittest.TestCase): "UPDATE tbl SET x = 1 FROM tbl2 CROSS JOIN tbl3", ), ( + lambda: exp.update( + "my_table", + {"x": 1}, + from_="baz", + where="my_table.id = baz.id", + with_={"baz": "SELECT id FROM foo UNION SELECT id FROM bar"}, + ), + "WITH baz AS (SELECT id FROM foo UNION SELECT id FROM bar) UPDATE my_table SET x = 1 FROM baz WHERE my_table.id = baz.id", + ), + ( + lambda: exp.update("my_table").set_("x = 1"), + "UPDATE my_table SET x = 1", + ), + ( + lambda: exp.update("my_table").set_("x = 1").where("y = 2"), + "UPDATE my_table SET x = 1 WHERE y = 2", + ), + ( + lambda: exp.update("my_table").set_("a = 1").set_("b = 2"), + "UPDATE my_table SET a = 1, b = 2", + ), + ( + lambda: exp.update("my_table") + .set_("x = 1") + .where("my_table.id = baz.id") + .from_("baz") + .with_("baz", "SELECT id FROM foo"), + "WITH baz AS (SELECT id FROM foo) UPDATE my_table SET x = 1 FROM baz WHERE my_table.id = baz.id", + ), + ( lambda: union("SELECT * FROM foo", "SELECT * FROM bla"), "SELECT * FROM foo UNION SELECT * FROM bla", ), diff --git a/tests/test_diff.py b/tests/test_diff.py index f83c805..edd3b26 100644 --- a/tests/test_diff.py +++ b/tests/test_diff.py @@ -157,11 +157,20 @@ class TestDiff(unittest.TestCase): self._validate_delta_only( diff_delta_only(expr_src, expr_tgt), [ - Remove(parse_one("ROW_NUMBER()")), # the Anonymous node - Insert(parse_one("RANK()")), # the Anonymous node + Remove(parse_one("ROW_NUMBER()")), + Insert(parse_one("RANK()")), + Update(source=expr_src.selects[0], target=expr_tgt.selects[0]), ], ) + expr_src = parse_one("SELECT MAX(x) OVER (ORDER BY y) FROM z", "oracle") + expr_tgt = parse_one("SELECT MAX(x) KEEP (DENSE_RANK LAST ORDER BY y) FROM z", "oracle") + + self._validate_delta_only( + diff_delta_only(expr_src, expr_tgt), + [Update(source=expr_src.selects[0], target=expr_tgt.selects[0])], + ) + def test_pre_matchings(self): expr_src = parse_one("SELECT 1") expr_tgt = parse_one("SELECT 1, 2, 3, 4") @@ -202,5 +211,34 @@ class TestDiff(unittest.TestCase): ], ) + expr_src = parse_one("SELECT 1 AS c1, 2 AS c2") + expr_tgt = parse_one("SELECT 2 AS c1, 3 AS c2") + + self._validate_delta_only( + diff_delta_only(expr_src, expr_tgt), + [ + Remove(expression=exp.alias_(1, "c1")), + Remove(expression=exp.Literal.number(1)), + Insert(expression=exp.alias_(3, "c2")), + Insert(expression=exp.Literal.number(3)), + Update(source=exp.alias_(2, "c2"), target=exp.alias_(2, "c1")), + ], + ) + + def test_dialect_aware_diff(self): + from sqlglot.generator import logger + + with self.assertLogs(logger) as cm: + # We want to assert there are no warnings, but the 'assertLogs' method does not support that. + # Therefore, we are adding a dummy warning, and then we will assert it is the only warning. + logger.warning("Dummy warning") + + expression = parse_one("SELECT foo FROM bar FOR UPDATE", dialect="oracle") + self._validate_delta_only( + diff_delta_only(expression, expression.copy(), dialect="oracle"), [] + ) + + self.assertEqual(["WARNING:sqlglot:Dummy warning"], cm.output) + def _validate_delta_only(self, actual_delta, expected_delta): self.assertEqual(set(actual_delta), set(expected_delta)) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index e88740b..1c88952 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -350,6 +350,7 @@ class TestExpressions(unittest.TestCase): ) self.assertIsInstance(exp.func("instr", "x", "b", dialect="mysql"), exp.StrPosition) + self.assertIsInstance(exp.func("instr", "x", "b", dialect="sqlite"), exp.StrPosition) self.assertIsInstance(exp.func("bla", 1, "foo"), exp.Anonymous) self.assertIsInstance( exp.func("cast", this=exp.Literal.number(5), to=exp.DataType.build("DOUBLE")), diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 857ba1a..2c2015b 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -54,6 +54,18 @@ def simplify(expression, **kwargs): return optimizer.simplify.simplify(expression, constant_propagation=True, **kwargs) +def annotate_functions(expression, **kwargs): + from sqlglot.dialects import Dialect + + dialect = kwargs.get("dialect") + schema = kwargs.get("schema") + + annotators = Dialect.get_or_raise(dialect).ANNOTATORS + annotated = annotate_types(expression, annotators=annotators, schema=schema) + + return annotated.expressions[0] + + class TestOptimizer(unittest.TestCase): maxDiff = None @@ -787,6 +799,28 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') with self.subTest(title): self.assertEqual(result.type.sql(), exp.DataType.build(expected).sql()) + def test_annotate_funcs(self): + test_schema = { + "tbl": {"bin_col": "BINARY", "str_col": "STRING", "bignum_col": "BIGNUMERIC"} + } + + for i, (meta, sql, expected) in enumerate( + load_sql_fixture_pairs("optimizer/annotate_functions.sql"), start=1 + ): + title = meta.get("title") or f"{i}, {sql}" + dialect = meta.get("dialect") or "" + sql = f"SELECT {sql} FROM tbl" + + for dialect in dialect.split(", "): + result = parse_and_optimize( + annotate_functions, sql, dialect, schema=test_schema, dialect=dialect + ) + + with self.subTest(title): + self.assertEqual( + result.type.sql(dialect), exp.DataType.build(expected).sql(dialect) + ) + def test_cast_type_annotation(self): expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")) self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ) @@ -1377,26 +1411,3 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(4, normalization_distance(gen_expr(2), max_=100)) self.assertEqual(18, normalization_distance(gen_expr(3), max_=100)) self.assertEqual(110, normalization_distance(gen_expr(10), max_=100)) - - def test_custom_annotators(self): - # In Spark hierarchy, SUBSTRING result type is dependent on input expr type - for dialect in ("spark2", "spark", "databricks"): - for expr_type_pair in ( - ("col", "STRING"), - ("col", "BINARY"), - ("'str_literal'", "STRING"), - ("CAST('str_literal' AS BINARY)", "BINARY"), - ): - with self.subTest( - f"Testing {dialect}'s SUBSTRING() result type for {expr_type_pair}" - ): - expr, type = expr_type_pair - ast = parse_one(f"SELECT substring({expr}, 2, 3) AS x FROM tbl", read=dialect) - - subst_type = ( - optimizer.optimize(ast, schema={"tbl": {"col": type}}, dialect=dialect) - .expressions[0] - .type - ) - - self.assertEqual(subst_type.sql(dialect), exp.DataType.build(type).sql(dialect)) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 07a915d..e7f1665 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -563,7 +563,36 @@ FROM x""", ) self.validate( """with a as /* comment */ ( select * from b) select * from a""", - """WITH a AS (SELECT * FROM b) /* comment */ SELECT * FROM a""", + """WITH a /* comment */ AS (SELECT * FROM b) SELECT * FROM a""", + ) + self.validate( + """ + -- comment at the top +WITH +-- comment for tbl1 +tbl1 AS (SELECT 1) +-- comment for tbl2 +, tbl2 AS (SELECT 2) +-- comment for tbl3 +, tbl3 AS (SELECT 3) +-- comment for final select +SELECT * FROM tbl1""", + """/* comment at the top */ +WITH tbl1 /* comment for tbl1 */ AS ( + SELECT + 1 +), tbl2 /* comment for tbl2 */ AS ( + SELECT + 2 +), tbl3 /* comment for tbl3 */ AS ( + SELECT + 3 +) +/* comment for final select */ +SELECT + * +FROM tbl1""", + pretty=True, ) def test_types(self): |