diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dialects/test_bigquery.py | 192 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 11 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 20 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 7 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 4 | ||||
-rw-r--r-- | tests/fixtures/optimizer/merge_subqueries.sql | 2 | ||||
-rw-r--r-- | tests/fixtures/optimizer/optimizer.sql | 38 | ||||
-rw-r--r-- | tests/fixtures/optimizer/tpc-ds/tpc-ds.sql | 34 | ||||
-rw-r--r-- | tests/test_executor.py | 58 | ||||
-rw-r--r-- | tests/test_optimizer.py | 32 | ||||
-rw-r--r-- | tests/test_schema.py | 4 |
13 files changed, 306 insertions, 103 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index e05fca0..e95ff3e 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -1,3 +1,5 @@ +from unittest import mock + from sqlglot import ErrorLevel, ParseError, UnsupportedError, transpile from tests.dialects.test_dialect import Validator @@ -6,6 +8,35 @@ class TestBigQuery(Validator): dialect = "bigquery" def test_bigquery(self): + with self.assertRaises(ValueError): + transpile("'\\'", read="bigquery") + + # Reference: https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#set_operators + with self.assertRaises(UnsupportedError): + transpile( + "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + write="bigquery", + unsupported_level=ErrorLevel.RAISE, + ) + + with self.assertRaises(UnsupportedError): + transpile( + "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + write="bigquery", + unsupported_level=ErrorLevel.RAISE, + ) + + with self.assertRaises(ParseError): + transpile("SELECT * FROM UNNEST(x) AS x(y)", read="bigquery") + + self.validate_identity("SELECT `project-id`.udfs.func(call.dir)") + self.validate_identity("SELECT CAST(CURRENT_DATE AS STRING FORMAT 'DAY') AS current_day") + self.validate_identity("SAFE_CAST(encrypted_value AS STRING FORMAT 'BASE64')") + self.validate_identity("CAST(encrypted_value AS STRING FORMAT 'BASE64')") + self.validate_identity("STRING_AGG(a)") + self.validate_identity("STRING_AGG(a, ' & ')") + self.validate_identity("STRING_AGG(DISTINCT a, ' & ')") + self.validate_identity("STRING_AGG(a, ' & ' ORDER BY LENGTH(a))") self.validate_identity("DATE(2016, 12, 25)") self.validate_identity("DATE(CAST('2016-12-25 23:59:59' AS DATETIME))") self.validate_identity("SELECT foo IN UNNEST(bar) AS bla") @@ -21,16 +52,8 @@ class TestBigQuery(Validator): self.validate_identity("x <> ''") self.validate_identity("DATE_TRUNC(col, WEEK(MONDAY))") self.validate_identity("SELECT b'abc'") - self.validate_identity("""SELECT * FROM UNNEST(ARRAY<STRUCT<x INT64>>[1, 2])""") + self.validate_identity("""SELECT * FROM UNNEST(ARRAY<STRUCT<x INT64>>[])""") self.validate_identity("SELECT AS STRUCT 1 AS a, 2 AS b") - self.validate_all( - "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z", - write={ - "": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z", - "bigquery": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z", - "duckdb": "SELECT {'y': ARRAY(SELECT {'b': b} FROM x)} FROM z", - }, - ) self.validate_identity("SELECT DISTINCT AS STRUCT 1 AS a, 2 AS b") self.validate_identity("SELECT AS VALUE STRUCT(1 AS a, 2 AS b)") self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])") @@ -38,6 +61,13 @@ class TestBigQuery(Validator): self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))") self.validate_identity("""CREATE TABLE x (a STRUCT<values ARRAY<INT64>>)""") self.validate_identity("""CREATE TABLE x (a STRUCT<b STRING OPTIONS (description='b')>)""") + self.validate_identity("CAST(x AS TIMESTAMP)") + self.validate_identity("REGEXP_EXTRACT(`foo`, 'bar: (.+?)', 1, 1)") + self.validate_identity("BEGIN A B C D E F") + self.validate_identity("BEGIN TRANSACTION") + self.validate_identity("COMMIT TRANSACTION") + self.validate_identity("ROLLBACK TRANSACTION") + self.validate_identity("CAST(x AS BIGNUMERIC)") self.validate_identity( "DATE(CAST('2016-12-25 05:30:00+07' AS DATETIME), 'America/Los_Angeles')" ) @@ -50,8 +80,55 @@ class TestBigQuery(Validator): self.validate_identity( "CREATE TABLE IF NOT EXISTS foo AS SELECT * FROM bla EXCEPT DISTINCT (SELECT * FROM bar) LIMIT 0" ) + self.validate_identity( + "SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)" + ) + self.validate_identity( + "SELECT item, purchases, LAST_VALUE(item) OVER (item_window ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS most_popular FROM Produce WINDOW item_window AS (ORDER BY purchases)" + ) + self.validate_identity( + "SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)", + ) self.validate_all("SELECT SPLIT(foo)", write={"bigquery": "SELECT SPLIT(foo, ',')"}) + self.validate_all("SELECT 1 AS hash", write={"bigquery": "SELECT 1 AS `hash`"}) + self.validate_all("SELECT 1 AS at", write={"bigquery": "SELECT 1 AS `at`"}) + self.validate_all('x <> ""', write={"bigquery": "x <> ''"}) + self.validate_all('x <> """"""', write={"bigquery": "x <> ''"}) + self.validate_all("x <> ''''''", write={"bigquery": "x <> ''"}) + self.validate_all("CAST(x AS DATETIME)", read={"": "x::timestamp"}) + self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"}) + self.validate_all("CAST(x AS CHAR)", write={"bigquery": "CAST(x AS STRING)"}) + self.validate_all("CAST(x AS NCHAR)", write={"bigquery": "CAST(x AS STRING)"}) + self.validate_all("CAST(x AS NVARCHAR)", write={"bigquery": "CAST(x AS STRING)"}) + self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"bigquery": "CAST(x AS TIMESTAMP)"}) + self.validate_all("CAST(x AS RECORD)", write={"bigquery": "CAST(x AS STRUCT)"}) + self.validate_all( + "SELECT CAST(TIMESTAMP '2008-12-25 00:00:00+00:00' AS STRING FORMAT 'YYYY-MM-DD HH24:MI:SS TZH:TZM') AS date_time_to_string", + write={ + "bigquery": "SELECT CAST(CAST('2008-12-25 00:00:00+00:00' AS TIMESTAMP) AS STRING FORMAT 'YYYY-MM-DD HH24:MI:SS TZH:TZM') AS date_time_to_string", + }, + ) + self.validate_all( + "SELECT CAST(TIMESTAMP '2008-12-25 00:00:00+00:00' AS STRING FORMAT 'YYYY-MM-DD HH24:MI:SS TZH:TZM' AT TIME ZONE 'Asia/Kolkata') AS date_time_to_string", + write={ + "bigquery": "SELECT CAST(CAST('2008-12-25 00:00:00+00:00' AS TIMESTAMP) AS STRING FORMAT 'YYYY-MM-DD HH24:MI:SS TZH:TZM' AT TIME ZONE 'Asia/Kolkata') AS date_time_to_string", + }, + ) + self.validate_all( + "WITH cte AS (SELECT [1, 2, 3] AS arr) SELECT col FROM cte CROSS JOIN UNNEST(arr) AS col", + read={ + "spark": "WITH cte AS (SELECT ARRAY(1, 2, 3) AS arr) SELECT EXPLODE(arr) FROM cte" + }, + ) + self.validate_all( + "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z", + write={ + "": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z", + "bigquery": "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z", + "duckdb": "SELECT {'y': ARRAY(SELECT {'b': b} FROM x)} FROM z", + }, + ) self.validate_all( "cast(x as date format 'MM/DD/YYYY')", write={ @@ -64,10 +141,6 @@ class TestBigQuery(Validator): "bigquery": "PARSE_TIMESTAMP('%Y.%m.%d %I:%M:%S%z', x)", }, ) - self.validate_all("SELECT 1 AS hash", write={"bigquery": "SELECT 1 AS `hash`"}) - self.validate_all('x <> ""', write={"bigquery": "x <> ''"}) - self.validate_all('x <> """"""', write={"bigquery": "x <> ''"}) - self.validate_all("x <> ''''''", write={"bigquery": "x <> ''"}) self.validate_all( "CREATE TEMP TABLE foo AS SELECT 1", write={"bigquery": "CREATE TEMPORARY TABLE foo AS SELECT 1"}, @@ -82,14 +155,6 @@ class TestBigQuery(Validator): "SELECT * FROM `my-project.my-dataset.my-table`", write={"bigquery": "SELECT * FROM `my-project`.`my-dataset`.`my-table`"}, ) - self.validate_all("CAST(x AS DATETIME)", read={"": "x::timestamp"}) - self.validate_identity("CAST(x AS TIMESTAMP)") - self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"}) - self.validate_all("CAST(x AS CHAR)", write={"bigquery": "CAST(x AS STRING)"}) - self.validate_all("CAST(x AS NCHAR)", write={"bigquery": "CAST(x AS STRING)"}) - self.validate_all("CAST(x AS NVARCHAR)", write={"bigquery": "CAST(x AS STRING)"}) - self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"bigquery": "CAST(x AS TIMESTAMP)"}) - self.validate_all("CAST(x AS RECORD)", write={"bigquery": "CAST(x AS STRUCT)"}) self.validate_all( "SELECT ARRAY(SELECT AS STRUCT 1 a, 2 b)", write={ @@ -121,9 +186,6 @@ class TestBigQuery(Validator): "spark": "'x\\''", }, ) - with self.assertRaises(ValueError): - transpile("'\\'", read="bigquery") - self.validate_all( "r'x\\''", write={ @@ -301,7 +363,6 @@ class TestBigQuery(Validator): "spark": "CURRENT_TIMESTAMP()", }, ) - self.validate_all( "DIV(x, y)", write={ @@ -309,19 +370,6 @@ class TestBigQuery(Validator): "duckdb": "x // y", }, ) - - self.validate_identity( - "SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)" - ) - - self.validate_identity( - "SELECT item, purchases, LAST_VALUE(item) OVER (item_window ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) AS most_popular FROM Produce WINDOW item_window AS (ORDER BY purchases)" - ) - - self.validate_identity( - "SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)", - ) - self.validate_all( "CREATE TABLE db.example_table (col_a struct<struct_col_a:int, struct_col_b:string>)", write={ @@ -358,25 +406,6 @@ class TestBigQuery(Validator): "spark": "SELECT * FROM a WHERE b IN (SELECT UNNEST(ARRAY(1, 2, 3)))", }, ) - - # Reference: https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#set_operators - with self.assertRaises(UnsupportedError): - transpile( - "SELECT * FROM a INTERSECT ALL SELECT * FROM b", - write="bigquery", - unsupported_level=ErrorLevel.RAISE, - ) - - with self.assertRaises(UnsupportedError): - transpile( - "SELECT * FROM a EXCEPT ALL SELECT * FROM b", - write="bigquery", - unsupported_level=ErrorLevel.RAISE, - ) - - with self.assertRaises(ParseError): - transpile("SELECT * FROM UNNEST(x) AS x(y)", read="bigquery") - self.validate_all( "DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY)", write={ @@ -465,14 +494,6 @@ class TestBigQuery(Validator): "duckdb": "SELECT REGEXP_EXTRACT(abc, 'pattern(group)', 1) FROM table", }, ) - self.validate_identity("REGEXP_EXTRACT(`foo`, 'bar: (.+?)', 1, 1)") - self.validate_identity("BEGIN A B C D E F") - self.validate_identity("BEGIN TRANSACTION") - self.validate_identity("COMMIT TRANSACTION") - self.validate_identity("ROLLBACK TRANSACTION") - self.validate_identity("CAST(x AS BIGNUMERIC)") - - self.validate_identity("SELECT * FROM UNNEST([1]) WITH ORDINALITY") self.validate_all( "SELECT * FROM UNNEST([1]) WITH OFFSET", write={"bigquery": "SELECT * FROM UNNEST([1]) WITH OFFSET AS offset"}, @@ -497,6 +518,16 @@ class TestBigQuery(Validator): }, ) + self.validate_identity( + "SELECT y + 1 z FROM x GROUP BY y + 1 ORDER BY z", + "SELECT y + 1 AS z FROM x GROUP BY z ORDER BY z", + ) + self.validate_identity( + "SELECT y + 1 z FROM x GROUP BY y + 1", + "SELECT y + 1 AS z FROM x GROUP BY y + 1", + ) + self.validate_identity("SELECT y + 1 FROM x GROUP BY y + 1 ORDER BY 1") + def test_user_defined_functions(self): self.validate_identity( "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'" @@ -568,3 +599,34 @@ class TestBigQuery(Validator): "bigquery": "ALTER TABLE db.t1 RENAME TO t2", }, ) + + @mock.patch("sqlglot.dialects.bigquery.logger") + def test_pushdown_cte_column_names(self, mock_logger): + with self.assertRaises(UnsupportedError): + transpile( + "WITH cte(foo) AS (SELECT * FROM tbl) SELECT foo FROM cte", + read="spark", + write="bigquery", + unsupported_level=ErrorLevel.RAISE, + ) + + self.validate_all( + "WITH cte AS (SELECT 1 AS foo) SELECT foo FROM cte", + read={"spark": "WITH cte(foo) AS (SELECT 1) SELECT foo FROM cte"}, + ) + self.validate_all( + "WITH cte AS (SELECT 1 AS foo) SELECT foo FROM cte", + read={"spark": "WITH cte(foo) AS (SELECT 1 AS bar) SELECT foo FROM cte"}, + ) + self.validate_all( + "WITH cte AS (SELECT 1 AS bar) SELECT bar FROM cte", + read={"spark": "WITH cte AS (SELECT 1 AS bar) SELECT bar FROM cte"}, + ) + self.validate_all( + "WITH cte AS (SELECT 1 AS foo, 2) SELECT foo FROM cte", + read={"postgres": "WITH cte(foo) AS (SELECT 1, 2) SELECT foo FROM cte"}, + ) + self.validate_all( + "WITH cte AS (SELECT 1 AS foo UNION ALL SELECT 2) SELECT foo FROM cte", + read={"postgres": "WITH cte(foo) AS (SELECT 1 UNION ALL SELECT 2) SELECT foo FROM cte"}, + ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 3ac05cf..78f87ff 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1413,7 +1413,8 @@ class TestDialect(Validator): "presto": "SELECT a AS b FROM x GROUP BY 1", "hive": "SELECT a AS b FROM x GROUP BY 1", "oracle": "SELECT a AS b FROM x GROUP BY 1", - "spark": "SELECT a AS b FROM x GROUP BY 1", + "spark": "SELECT a AS b FROM x GROUP BY b", + "spark2": "SELECT a AS b FROM x GROUP BY 1", }, ) self.validate_all( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index b8f7af0..ca2f921 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -6,6 +6,7 @@ class TestMySQL(Validator): dialect = "mysql" def test_ddl(self): + self.validate_identity("CREATE TABLE foo (id BIGINT)") self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10") self.validate_identity("DELETE FROM t WHERE a <= 10 LIMIT 10") self.validate_identity( @@ -397,6 +398,16 @@ class TestMySQL(Validator): self.validate_identity("TIME_STR_TO_UNIX(x)", "UNIX_TIMESTAMP(x)") def test_mysql(self): + self.validate_all("CAST(x AS SIGNED)", write={"mysql": "CAST(x AS SIGNED)"}) + self.validate_all("CAST(x AS SIGNED INTEGER)", write={"mysql": "CAST(x AS SIGNED)"}) + self.validate_all("CAST(x AS UNSIGNED)", write={"mysql": "CAST(x AS UNSIGNED)"}) + self.validate_all("CAST(x AS UNSIGNED INTEGER)", write={"mysql": "CAST(x AS UNSIGNED)"}) + self.validate_all( + "SELECT DATE_ADD('2023-06-23 12:00:00', INTERVAL 2 * 2 MONTH) FROM foo", + write={ + "mysql": "SELECT DATE_ADD('2023-06-23 12:00:00', INTERVAL (2 * 2) MONTH) FROM foo", + }, + ) self.validate_all( "SELECT * FROM t LOCK IN SHARE MODE", write={"mysql": "SELECT * FROM t FOR SHARE"} ) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 852b494..49139f9 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -1,3 +1,5 @@ +from unittest import mock + from sqlglot import UnsupportedError from tests.dialects.test_dialect import Validator @@ -439,7 +441,8 @@ class TestPresto(Validator): }, ) - def test_presto(self): + @mock.patch("sqlglot.helper.logger") + def test_presto(self, mock_logger): self.validate_identity("SELECT * FROM x OFFSET 1 LIMIT 1") self.validate_identity("SELECT * FROM x OFFSET 1 FETCH FIRST 1 ROWS ONLY") self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)") @@ -453,6 +456,21 @@ class TestPresto(Validator): self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' week"}) self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' WEEKS"}) self.validate_all( + "SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table", + write={ + "postgres": UnsupportedError, + "presto": "SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table", + }, + ) + self.validate_all( + "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 4)", + write={ + "": "SELECT ARRAY(1, 2, 3)[3]", + "postgres": "SELECT (ARRAY[1, 2, 3])[4]", + "presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 4)", + }, + ) + self.validate_all( "SELECT SUBSTRING(a, 1, 3), SUBSTRING(a, LENGTH(a) - (3 - 1))", read={ "redshift": "SELECT LEFT(a, 3), RIGHT(a, 3)", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 88168ab..620aae2 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -82,7 +82,7 @@ class TestRedshift(Validator): }, ) self.validate_all("SELECT INTERVAL '5 days'", read={"": "SELECT INTERVAL '5' days"}) - self.validate_all("CONVERT(INTEGER, x)", write={"redshift": "CAST(x AS INTEGER)"}) + self.validate_all("CONVERT(INT, x)", write={"redshift": "CAST(x AS INTEGER)"}) self.validate_all( "DATEADD('day', ndays, caldate)", write={"redshift": "DATEADD(day, ndays, caldate)"} ) @@ -104,7 +104,7 @@ class TestRedshift(Validator): "SELECT ST_AsEWKT(ST_GeomFromEWKT('SRID=4326;POINT(10 20)')::geography)", write={ "redshift": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", - "bigquery": "SELECT ST_ASEWKT(SAFE_CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", + "bigquery": "SELECT ST_AsEWKT(SAFE_CAST(ST_GeomFromEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 48bb2f7..4d2c392 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -6,6 +6,8 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT foo FROM IDENTIFIER('x')") + self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT IDENTIFIER('foo') FROM x") self.validate_identity("INITCAP('iqamqinterestedqinqthisqtopic', 'q')") self.validate_identity("CAST(x AS GEOMETRY)") self.validate_identity("OBJECT_CONSTRUCT(*)") @@ -23,6 +25,9 @@ class TestSnowflake(Validator): self.validate_identity("CREATE TABLE foo (bar FLOAT AUTOINCREMENT START 0 INCREMENT 1)") self.validate_identity("ALTER TABLE IF EXISTS foo SET TAG a = 'a', b = 'b', c = 'c'") self.validate_identity("ALTER TABLE foo UNSET TAG a, b, c") + self.validate_identity("ALTER TABLE foo SET COMMENT = 'bar'") + self.validate_identity("ALTER TABLE foo SET CHANGE_TRACKING = FALSE") + self.validate_identity("ALTER TABLE foo UNSET DATA_RETENTION_TIME_IN_DAYS, CHANGE_TRACKING") self.validate_identity("COMMENT IF EXISTS ON TABLE foo IS 'bar'") self.validate_identity("SELECT CONVERT_TIMEZONE('UTC', 'America/Los_Angeles', col)") self.validate_identity( @@ -582,6 +587,8 @@ class TestSnowflake(Validator): self.validate_identity("CREATE DATABASE mytestdb_clone CLONE mytestdb") self.validate_identity("CREATE SCHEMA mytestschema_clone CLONE testschema") self.validate_identity("CREATE TABLE orders_clone CLONE orders") + self.validate_identity("CREATE TABLE IDENTIFIER('foo') (COLUMN1 VARCHAR, COLUMN2 VARCHAR)") + self.validate_identity("CREATE TABLE IDENTIFIER($foo) (col1 VARCHAR, col2 VARCHAR)") self.validate_identity( "CREATE TABLE orders_clone_restore CLONE orders AT (TIMESTAMP => TO_TIMESTAMP_TZ('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss'))" ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 54c39e7..8acc48e 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -7,6 +7,10 @@ class TestSpark(Validator): def test_ddl(self): self.validate_identity("CREATE TABLE foo (col VARCHAR(50))") self.validate_identity("CREATE TABLE foo (col STRUCT<struct_col_a: VARCHAR((50))>)") + self.validate_identity("CREATE TABLE foo (col STRING) CLUSTERED BY (col) INTO 10 BUCKETS") + self.validate_identity( + "CREATE TABLE foo (col STRING) CLUSTERED BY (col) SORTED BY (col) INTO 10 BUCKETS" + ) self.validate_all( "CREATE TABLE db.example_table (col_a struct<struct_col_a:int, struct_col_b:string>)", diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql index 1124a79..bd56e07 100644 --- a/tests/fixtures/optimizer/merge_subqueries.sql +++ b/tests/fixtures/optimizer/merge_subqueries.sql @@ -252,7 +252,7 @@ FROM t1 GROUP BY t1.row_num ORDER BY t1.row_num; -WITH t1 AS (SELECT x.a AS a, x.b AS b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x AS x) SELECT t1.row_num AS row_num, SUM(t1.a) AS total FROM t1 GROUP BY t1.row_num ORDER BY t1.row_num; +WITH t1 AS (SELECT x.a AS a, x.b AS b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x AS x) SELECT t1.row_num AS row_num, SUM(t1.a) AS total FROM t1 GROUP BY t1.row_num ORDER BY row_num; # title: Test prevent merging of window if in order by func with t1 as ( diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index 214535a..f71ddde 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -577,10 +577,10 @@ FROM `u_cte` AS `u_cte` PIVOT(SUM(`u_cte`.`f`) AS `sum` FOR `u_cte`.`h` IN ('x', # dialect: snowflake SELECT * FROM u PIVOT (SUM(f) FOR h IN ('x', 'y')); SELECT - "_q_0"."G" AS "G", - "_q_0"."'x'" AS "'x'", - "_q_0"."'y'" AS "'y'" -FROM "U" AS "U" PIVOT(SUM("U"."F") FOR "U"."H" IN ('x', 'y')) AS "_q_0" + "_Q_0"."G" AS "G", + "_Q_0"."'x'" AS "'x'", + "_Q_0"."'y'" AS "'y'" +FROM "U" AS "U" PIVOT(SUM("U"."F") FOR "U"."H" IN ('x', 'y')) AS "_Q_0" ; # title: selecting all columns from a pivoted source and generating spark @@ -668,16 +668,28 @@ WHERE GROUP BY `dAy`, `top_term`, rank ORDER BY `DaY` DESC; SELECT - `TOp_TeRmS`.`refresh_date` AS `day`, - `TOp_TeRmS`.`term` AS `top_term`, - `TOp_TeRmS`.`rank` AS `rank` -FROM `bigquery-public-data`.`GooGle_tReNDs`.`TOp_TeRmS` AS `TOp_TeRmS` + `top_terms`.`refresh_date` AS `day`, + `top_terms`.`term` AS `top_term`, + `top_terms`.`rank` AS `rank` +FROM `bigquery-public-data`.`GooGle_tReNDs`.`TOp_TeRmS` AS `top_terms` WHERE - `TOp_TeRmS`.`rank` = 1 - AND CAST(`TOp_TeRmS`.`refresh_date` AS DATE) >= DATE_SUB(CURRENT_DATE, INTERVAL 2 WEEK) + `top_terms`.`rank` = 1 + AND CAST(`top_terms`.`refresh_date` AS DATE) >= DATE_SUB(CURRENT_DATE, INTERVAL 2 WEEK) GROUP BY - `TOp_TeRmS`.`refresh_date`, - `TOp_TeRmS`.`term`, - `TOp_TeRmS`.`rank` + `day`, + `top_term`, + `rank` ORDER BY `day` DESC; + + +# title: group by keys cannot be simplified +SELECT a + 1 + 1 + 1 + 1 AS b, 2 + 1 AS c FROM x GROUP BY a + 1 + 1 HAVING a + 1 + 1 + 1 + 1 > 1; +SELECT + "x"."a" + 1 + 1 + 1 + 1 AS "b", + 3 AS "c" +FROM "x" AS "x" +GROUP BY + "x"."a" + 1 + 1 +HAVING + "x"."a" + 1 + 1 + 1 + 1 > 1; diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index 7ef7a6d..bbfd47f 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -254,7 +254,7 @@ GROUP BY "item"."i_brand", "item"."i_brand_id" ORDER BY - "dt"."d_year", + "d_year", "sum_agg" DESC, "brand_id" LIMIT 100; @@ -2767,8 +2767,8 @@ GROUP BY "item"."i_manufact" ORDER BY "ext_price" DESC, - "item"."i_brand", - "item"."i_brand_id", + "brand", + "brand_id", "i_manufact_id", "i_manufact" LIMIT 100; @@ -5112,10 +5112,10 @@ GROUP BY "item"."i_category_id", "item"."i_category" ORDER BY - SUM("store_sales"."ss_ext_sales_price") DESC, - "dt"."d_year", - "item"."i_category_id", - "item"."i_category" + "_col_3" DESC, + "d_year", + "i_category_id", + "i_category" LIMIT 100; -------------------------------------- @@ -6353,7 +6353,7 @@ GROUP BY "item"."i_brand", "item"."i_brand_id" ORDER BY - "dt"."d_year", + "d_year", "ext_price" DESC, "brand_id" LIMIT 100; @@ -6648,7 +6648,7 @@ GROUP BY "item"."i_brand_id" ORDER BY "ext_price" DESC, - "item"."i_brand_id" + "brand_id" LIMIT 100; -------------------------------------- @@ -7770,7 +7770,7 @@ GROUP BY "ship_mode"."sm_type", "web_site"."web_name" ORDER BY - SUBSTR("warehouse"."w_warehouse_name", 1, 20), + "_col_0", "sm_type", "web_name" LIMIT 100; @@ -9668,7 +9668,7 @@ GROUP BY "time_dim"."t_minute" ORDER BY "ext_price" DESC, - "item"."i_brand_id"; + "brand_id"; -------------------------------------- -- TPC-DS 72 @@ -11692,10 +11692,10 @@ JOIN "customer_demographics" AS "cd1" GROUP BY "reason"."r_reason_desc" ORDER BY - SUBSTR("reason"."r_reason_desc", 1, 20), - AVG("web_sales"."ws_quantity"), - AVG("web_returns"."wr_refunded_cash"), - AVG("web_returns"."wr_fee") + "_col_0", + "_col_1", + "_col_2", + "_col_3" LIMIT 100; -------------------------------------- @@ -12364,7 +12364,7 @@ GROUP BY "customer_demographics"."cd_marital_status", "customer_demographics"."cd_education_status" ORDER BY - SUM("catalog_returns"."cr_net_loss") DESC; + "returns_loss" DESC; -------------------------------------- -- TPC-DS 92 @@ -12940,7 +12940,7 @@ GROUP BY "ship_mode"."sm_type", "call_center"."cc_name" ORDER BY - SUBSTR("warehouse"."w_warehouse_name", 1, 20), + "_col_0", "sm_type", "cc_name" LIMIT 100; diff --git a/tests/test_executor.py b/tests/test_executor.py index 3a37cd4..6dd530f 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -254,6 +254,11 @@ class TestExecutor(unittest.TestCase): [("a",)], ), ( + "(SELECT a FROM x) EXCEPT (SELECT a FROM y)", + ["a"], + [("a",)], + ), + ( "SELECT a FROM x INTERSECT SELECT a FROM y", ["a"], [("b",), ("c",)], @@ -646,3 +651,56 @@ class TestExecutor(unittest.TestCase): self.assertEqual(result.columns, ("id", "price")) self.assertEqual(result.rows, [(1, 1.0), (2, 2.0), (3, 3.0)]) + + def test_group_by(self): + tables = { + "x": [ + {"a": 1, "b": 10}, + {"a": 2, "b": 20}, + {"a": 3, "b": 28}, + {"a": 2, "b": 25}, + {"a": 1, "b": 40}, + ], + } + + for sql, expected, columns in ( + ( + "SELECT a, AVG(b) FROM x GROUP BY a ORDER BY AVG(b)", + [(2, 22.5), (1, 25.0), (3, 28.0)], + ("a", "_col_1"), + ), + ( + "SELECT a, AVG(b) FROM x GROUP BY a having avg(b) > 23", + [(1, 25.0), (3, 28.0)], + ("a", "_col_1"), + ), + ( + "SELECT a, AVG(b) FROM x GROUP BY a having avg(b + 1) > 23", + [(1, 25.0), (2, 22.5), (3, 28.0)], + ("a", "_col_1"), + ), + ( + "SELECT a, AVG(b) FROM x GROUP BY a having sum(b) + 5 > 50", + [(1, 25.0)], + ("a", "_col_1"), + ), + ( + "SELECT a + 1 AS a, AVG(b + 1) FROM x GROUP BY a + 1 having AVG(b + 1) > 26", + [(4, 29.0)], + ("a", "_col_1"), + ), + ( + "SELECT a, avg(b) FROM x GROUP BY a HAVING a = 1", + [(1, 25.0)], + ("a", "_col_1"), + ), + ( + "SELECT a + 1, avg(b) FROM x GROUP BY a + 1 HAVING a + 1 = 2", + [(2, 25.0)], + ("_col_0", "_col_1"), + ), + ): + with self.subTest(sql): + result = execute(sql, tables=tables) + self.assertEqual(result.columns, columns) + self.assertEqual(result.rows, expected) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 94bd0ba..b7425af 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -216,6 +216,17 @@ class TestOptimizer(unittest.TestCase): "SELECT y AS y FROM x", ) + self.assertEqual( + optimizer.qualify.qualify( + parse_one( + "WITH X AS (SELECT Y.A FROM DB.Y CROSS JOIN a.b.INFORMATION_SCHEMA.COLUMNS) SELECT `A` FROM X", + read="bigquery", + ), + dialect="bigquery", + ).sql(), + 'WITH "x" AS (SELECT "y"."a" AS "a" FROM "DB"."Y" AS "y" CROSS JOIN "a"."b"."INFORMATION_SCHEMA"."COLUMNS" AS "columns") SELECT "x"."a" AS "a" FROM "x"', + ) + self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema) def test_qualify_columns__with_invisible(self): @@ -262,7 +273,7 @@ class TestOptimizer(unittest.TestCase): # check order of lateral expansion with no schema self.assertEqual( optimizer.optimize("SELECT a + 1 AS d, d + 1 AS e FROM x WHERE e > 1 GROUP BY e").sql(), - 'SELECT "x"."a" + 1 AS "d", "x"."a" + 2 AS "e" FROM "x" AS "x" WHERE "x"."a" + 2 > 1 GROUP BY "x"."a" + 2', + 'SELECT "x"."a" + 1 AS "d", "x"."a" + 1 + 1 AS "e" FROM "x" AS "x" WHERE "x"."a" + 2 > 1 GROUP BY "x"."a" + 1 + 1', ) self.assertEqual( @@ -724,6 +735,23 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') ).sql(pretty=True, dialect="snowflake") for func in (optimizer.qualify.qualify, optimizer.optimize): - source_query = parse_one('SELECT * FROM example."source"', read="snowflake") + source_query = parse_one('SELECT * FROM example."source" AS "source"', read="snowflake") transformed = func(source_query, dialect="snowflake", schema=schema) self.assertEqual(transformed.sql(pretty=True, dialect="snowflake"), expected) + + def test_no_pseudocolumn_expansion(self): + schema = { + "a": { + "a": "text", + "b": "text", + "_PARTITIONDATE": "date", + "_PARTITIONTIME": "timestamp", + } + } + + self.assertEqual( + optimizer.optimize( + parse_one("SELECT * FROM a"), schema=MappingSchema(schema, dialect="bigquery") + ), + parse_one('SELECT "a"."a" AS "a", "a"."b" AS "b" FROM "a" AS "a"'), + ) diff --git a/tests/test_schema.py b/tests/test_schema.py index 23690b9..b89754f 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -227,12 +227,14 @@ class TestSchema(unittest.TestCase): self.assertEqual(schema.column_names(exp.Table(this="x")), ["foo"]) # Check that the correct dialect is used when calling schema methods + # Note: T-SQL is case-insensitive by default, so `fo` in clickhouse will match the normalized table name schema = MappingSchema(schema={"[Fo]": {"x": "int"}}, dialect="tsql") self.assertEqual( - schema.column_names("[Fo]"), schema.column_names("`Fo`", dialect="clickhouse") + schema.column_names("[Fo]"), schema.column_names("`fo`", dialect="clickhouse") ) # Check that all column identifiers are normalized to lowercase for BigQuery, even quoted # ones. Also, ensure that tables aren't normalized, since they're case-sensitive by default. schema = MappingSchema(schema={"Foo": {"`BaR`": "int"}}, dialect="bigquery") self.assertEqual(schema.column_names("Foo"), ["bar"]) + self.assertEqual(schema.column_names("foo"), []) |