diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dialects/test_clickhouse.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 54 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 23 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 35 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 11 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 15 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 10 | ||||
-rw-r--r-- | tests/dialects/test_sqlite.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 14 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 3 | ||||
-rw-r--r-- | tests/fixtures/optimizer/optimizer.sql | 1 | ||||
-rw-r--r-- | tests/fixtures/optimizer/qualify_columns.sql | 22 | ||||
-rw-r--r-- | tests/fixtures/optimizer/simplify.sql | 16 | ||||
-rw-r--r-- | tests/fixtures/optimizer/tpc-ds/tpc-ds.sql | 16 | ||||
-rw-r--r-- | tests/test_executor.py | 23 | ||||
-rw-r--r-- | tests/test_expressions.py | 40 | ||||
-rw-r--r-- | tests/test_optimizer.py | 9 | ||||
-rw-r--r-- | tests/test_tokens.py | 1 |
20 files changed, 255 insertions, 59 deletions
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index d206bb1..40a3a04 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -5,6 +5,7 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): + self.validate_identity("SELECT match('abc', '([a-z]+)')") self.validate_identity("dictGet(x, 'y')") self.validate_identity("SELECT * FROM x FINAL") self.validate_identity("SELECT * FROM x AS y FINAL") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 0805e9c..3558d62 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -290,6 +290,60 @@ class TestDialect(Validator): read={"postgres": "INET '127.0.0.1/32'"}, ) + def test_decode(self): + self.validate_identity("DECODE(bin, charset)") + + self.validate_all( + "SELECT DECODE(a, 1, 'one')", + write={ + "": "SELECT CASE WHEN a = 1 THEN 'one' END", + "oracle": "SELECT CASE WHEN a = 1 THEN 'one' END", + "redshift": "SELECT CASE WHEN a = 1 THEN 'one' END", + "snowflake": "SELECT CASE WHEN a = 1 THEN 'one' END", + "spark": "SELECT CASE WHEN a = 1 THEN 'one' END", + }, + ) + self.validate_all( + "SELECT DECODE(a, 1, 'one', 'default')", + write={ + "": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END", + "oracle": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END", + "redshift": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END", + "snowflake": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END", + "spark": "SELECT CASE WHEN a = 1 THEN 'one' ELSE 'default' END", + }, + ) + self.validate_all( + "SELECT DECODE(a, NULL, 'null')", + write={ + "": "SELECT CASE WHEN a IS NULL THEN 'null' END", + "oracle": "SELECT CASE WHEN a IS NULL THEN 'null' END", + "redshift": "SELECT CASE WHEN a IS NULL THEN 'null' END", + "snowflake": "SELECT CASE WHEN a IS NULL THEN 'null' END", + "spark": "SELECT CASE WHEN a IS NULL THEN 'null' END", + }, + ) + self.validate_all( + "SELECT DECODE(a, b, c)", + write={ + "": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END", + "oracle": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END", + "redshift": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END", + "snowflake": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END", + "spark": "SELECT CASE WHEN a = b OR (a IS NULL AND b IS NULL) THEN c END", + }, + ) + self.validate_all( + "SELECT DECODE(tbl.col, 'some_string', 'foo')", + write={ + "": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END", + "oracle": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END", + "redshift": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END", + "snowflake": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END", + "spark": "SELECT CASE WHEN tbl.col = 'some_string' THEN 'foo' END", + }, + ) + def test_if_null(self): self.validate_all( "SELECT IFNULL(1, NULL) FROM foo", diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index a15e6b4..245d82a 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -74,6 +74,17 @@ class TestDuckDB(Validator): "spark": "TO_TIMESTAMP(x, 'M/d/yy h:mm a')", }, ) + self.validate_all( + "CAST(start AS TIMESTAMPTZ) AT TIME ZONE 'America/New_York'", + read={ + "snowflake": "CONVERT_TIMEZONE('America/New_York', CAST(start AS TIMESTAMPTZ))", + }, + write={ + "bigquery": "TIMESTAMP(DATETIME(CAST(start AS TIMESTAMPTZ), 'America/New_York'))", + "duckdb": "CAST(start AS TIMESTAMPTZ) AT TIME ZONE 'America/New_York'", + "snowflake": "CONVERT_TIMEZONE('America/New_York', CAST(start AS TIMESTAMPTZ))", + }, + ) def test_sample(self): self.validate_all( @@ -421,6 +432,18 @@ class TestDuckDB(Validator): "snowflake": "CAST(COL AS ARRAY)", }, ) + self.validate_all( + "CAST([STRUCT_PACK(a := 1)] AS STRUCT(a BIGINT)[])", + write={ + "duckdb": "CAST(LIST_VALUE({'a': 1}) AS STRUCT(a BIGINT)[])", + }, + ) + self.validate_all( + "CAST([[STRUCT_PACK(a := 1)]] AS STRUCT(a BIGINT)[][])", + write={ + "duckdb": "CAST(LIST_VALUE(LIST_VALUE({'a': 1})) AS STRUCT(a BIGINT)[][])", + }, + ) def test_bool_or(self): self.validate_all( diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 0161f1e..1a83575 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -689,3 +689,11 @@ class TestHive(Validator): self.validate_identity("'\\z'") self.validate_identity("'\\\z'") self.validate_identity("'\\\\z'") + + def test_data_type(self): + self.validate_all( + "CAST(a AS BIT)", + write={ + "hive": "CAST(a AS BOOLEAN)", + }, + ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 5059d05..f618728 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -204,6 +204,41 @@ class TestMySQL(Validator): }, ) + def test_match_against(self): + self.validate_all( + "MATCH(col1, col2, col3) AGAINST('abc')", + read={ + "": "MATCH(col1, col2, col3) AGAINST('abc')", + "mysql": "MATCH(col1, col2, col3) AGAINST('abc')", + }, + write={ + "": "MATCH(col1, col2, col3) AGAINST('abc')", + "mysql": "MATCH(col1, col2, col3) AGAINST('abc')", + }, + ) + self.validate_all( + "MATCH(col1, col2) AGAINST('abc' IN NATURAL LANGUAGE MODE)", + write={"mysql": "MATCH(col1, col2) AGAINST('abc' IN NATURAL LANGUAGE MODE)"}, + ) + self.validate_all( + "MATCH(col1, col2) AGAINST('abc' IN NATURAL LANGUAGE MODE WITH QUERY EXPANSION)", + write={ + "mysql": "MATCH(col1, col2) AGAINST('abc' IN NATURAL LANGUAGE MODE WITH QUERY EXPANSION)" + }, + ) + self.validate_all( + "MATCH(col1, col2) AGAINST('abc' IN BOOLEAN MODE)", + write={"mysql": "MATCH(col1, col2) AGAINST('abc' IN BOOLEAN MODE)"}, + ) + self.validate_all( + "MATCH(col1, col2) AGAINST('abc' WITH QUERY EXPANSION)", + write={"mysql": "MATCH(col1, col2) AGAINST('abc' WITH QUERY EXPANSION)"}, + ) + self.validate_all( + "MATCH(a.b) AGAINST('abc')", + write={"mysql": "MATCH(a.b) AGAINST('abc')"}, + ) + def test_mysql(self): self.validate_all( "SELECT a FROM tbl FOR UPDATE", diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 1762e7a..1007899 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -539,6 +539,16 @@ class TestPresto(Validator): "hive": "SELECT a, b, c, d, SUM(y) FROM z GROUP BY d, GROUPING SETS ((b, c)), CUBE (a), ROLLUP (a)", }, ) + self.validate_all( + "JSON_FORMAT(x)", + read={ + "spark": "TO_JSON(x)", + }, + write={ + "presto": "JSON_FORMAT(x)", + "spark": "TO_JSON(x)", + }, + ) def test_encode_decode(self): self.validate_all( diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index ff730f8..0933051 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -69,9 +69,11 @@ class TestRedshift(Validator): self.validate_all( "DECODE(x, a, b, c, d)", write={ - "": "MATCHES(x, a, b, c, d)", - "oracle": "DECODE(x, a, b, c, d)", - "snowflake": "DECODE(x, a, b, c, d)", + "": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END", + "oracle": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END", + "redshift": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END", + "snowflake": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END", + "spark": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d END", }, ) self.validate_all( @@ -97,9 +99,6 @@ class TestRedshift(Validator): ) def test_identity(self): - self.validate_identity( - "SELECT DECODE(COL1, 'replace_this', 'with_this', 'replace_that', 'with_that')" - ) self.validate_identity("CAST('bla' AS SUPER)") self.validate_identity("CREATE TABLE real1 (realcol REAL)") self.validate_identity("CAST('foo' AS HLLSKETCH)") diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 940fa50..eb423a5 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -6,6 +6,11 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + self.validate_identity("SELECT HLL(*)") + self.validate_identity("SELECT HLL(a)") + self.validate_identity("SELECT HLL(DISTINCT t.a)") + self.validate_identity("SELECT HLL(a, b, c)") + self.validate_identity("SELECT HLL(DISTINCT a, b, c)") self.validate_identity("$x") self.validate_identity("SELECT REGEXP_LIKE(a, b, c)") self.validate_identity("PUT file:///dir/tmp.csv @%table") @@ -376,14 +381,10 @@ class TestSnowflake(Validator): }, ) self.validate_all( - "DECODE(x, a, b, c, d)", - read={ - "": "MATCHES(x, a, b, c, d)", - }, + "DECODE(x, a, b, c, d, e)", write={ - "": "MATCHES(x, a, b, c, d)", - "oracle": "DECODE(x, a, b, c, d)", - "snowflake": "DECODE(x, a, b, c, d)", + "": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d ELSE e END", + "snowflake": "CASE WHEN x = a OR (x IS NULL AND a IS NULL) THEN b WHEN x = c OR (x IS NULL AND c IS NULL) THEN d ELSE e END", }, ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index b12f272..0da2931 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -345,3 +345,13 @@ TBLPROPERTIES ( "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a", write={"spark": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"}, ) + + def test_current_user(self): + self.validate_all( + "CURRENT_USER", + write={"spark": "CURRENT_USER()"}, + ) + self.validate_all( + "CURRENT_USER()", + write={"spark": "CURRENT_USER()"}, + ) diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index fd9e52b..583d5be 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -12,7 +12,7 @@ class TestSQLite(Validator): self.validate_identity("INSERT OR ROLLBACK INTO foo (x, y) VALUES (1, 2)") self.validate_all( - "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)", + "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)", write={"sqlite": "CREATE TABLE foo (id INTEGER PRIMARY KEY ASC)"}, ) self.validate_all( diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 60867be..d9ee4ae 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -547,11 +547,11 @@ WHERE def test_string(self): self.validate_all( "SELECT N'test'", - write={"spark": "SELECT N'test'"}, + write={"spark": "SELECT 'test'"}, ) self.validate_all( "SELECT n'test'", - write={"spark": "SELECT N'test'"}, + write={"spark": "SELECT 'test'"}, ) self.validate_all( "SELECT '''test'''", @@ -621,3 +621,13 @@ WHERE "tsql": """SELECT "x" FROM "a"."b" FOR SYSTEM_TIME ALL AS alias""", }, ) + + def test_current_user(self): + self.validate_all( + "SUSER_NAME()", + write={"spark": "CURRENT_USER()"}, + ) + self.validate_all( + "SUSER_SNAME()", + write={"spark": "CURRENT_USER()"}, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 5fab65b..54e5583 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -666,6 +666,7 @@ INSERT INTO x VALUES (1, 'a', 2.0) INSERT INTO x VALUES (1, 'a', 2.0), (1, 'a', 3.0), (X(), y[1], z.x) INSERT INTO y (a, b, c) SELECT a, b, c FROM x INSERT INTO y (SELECT 1) UNION (SELECT 2) +INSERT INTO result_table (WITH test AS (SELECT * FROM source_table) SELECT * FROM test) INSERT OVERWRITE TABLE x IF EXISTS SELECT * FROM y INSERT OVERWRITE TABLE a.b IF EXISTS SELECT * FROM y INSERT OVERWRITE DIRECTORY 'x' SELECT 1 @@ -728,6 +729,8 @@ SELECT * INTO newevent FROM event SELECT * INTO TEMPORARY newevent FROM event SELECT * INTO UNLOGGED newevent FROM event ALTER TABLE integers ADD COLUMN k INT +ALTER TABLE integers ADD COLUMN k INT FIRST +ALTER TABLE integers ADD COLUMN k INT AFTER m ALTER TABLE integers ADD COLUMN IF NOT EXISTS k INT ALTER TABLE IF EXISTS integers ADD COLUMN k INT ALTER TABLE integers ADD COLUMN l INT DEFAULT 10 diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index c5112b2..9e7880c 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -462,6 +462,7 @@ SELECT FROM "db1"."tbl" AS "tbl" CROSS JOIN "db2"."tbl" AS "tbl_2"; +# execute: false SELECT *, IFF( diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index df65e65..74e2d0a 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -35,6 +35,9 @@ SELECT 1 AS "1", 2 + 3 AS _col_1 FROM x AS x; SELECT a + b FROM x; SELECT x.a + x.b AS _col_0 FROM x AS x; +SELECT l.a FROM x l WHERE a IN (select a FROM x ORDER by a); +SELECT l.a AS a FROM x AS l WHERE l.a IN (SELECT x.a AS a FROM x AS x ORDER BY a); + # execute: false SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a; SELECT x.a AS a, SUM(x.b) AS _col_1 FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a; @@ -46,15 +49,14 @@ SELECT SUM(a) AS a FROM x HAVING SUM(a) > 3; SELECT SUM(x.a) AS a FROM x AS x HAVING SUM(x.a) > 3; SELECT SUM(a) AS c FROM x HAVING c > 3; -SELECT SUM(x.a) AS c FROM x AS x HAVING c > 3; +SELECT SUM(x.a) AS c FROM x AS x HAVING SUM(x.a) > 3; # execute: false SELECT SUM(a) AS a FROM x HAVING a > 3; -SELECT SUM(x.a) AS a FROM x AS x HAVING a > 3; +SELECT SUM(x.a) AS a FROM x AS x HAVING SUM(x.a) > 3; -# execute: false -SELECT SUM(a) AS c FROM x HAVING SUM(c) > 3; -SELECT SUM(x.a) AS c FROM x AS x HAVING SUM(c) > 3; +SELECT SUM(a) AS c FROM x HAVING SUM(b) > 3; +SELECT SUM(x.a) AS c FROM x AS x HAVING SUM(x.b) > 3; SELECT a AS j, b FROM x ORDER BY j; SELECT x.a AS j, x.b AS b FROM x AS x ORDER BY j; @@ -95,6 +97,7 @@ SELECT COALESCE(x.a) AS d FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY COALESCE SELECT a + 1 AS d FROM x WHERE d > 1; SELECT x.a + 1 AS d FROM x AS x WHERE x.a + 1 > 1; +# execute: false SELECT a + 1 AS d, d + 2 FROM x; SELECT x.a + 1 AS d, x.a + 1 + 2 AS _col_1 FROM x AS x; @@ -124,6 +127,10 @@ SELECT DATE_TRUNC('week', x.a) AS a FROM x AS x; SELECT DATE_TRUNC(a, MONTH) AS a FROM x; SELECT DATE_TRUNC(x.a, MONTH) AS a FROM x AS x; +# execute: false +SELECT x FROM READ_PARQUET('path.parquet', hive_partition=1); +SELECT _q_0.x AS x FROM READ_PARQUET('path.parquet', hive_partition = 1) AS _q_0; + -------------------------------------- -- Derived tables -------------------------------------- @@ -262,11 +269,9 @@ SELECT x.a AS d, x.b AS b FROM x AS x; SELECT * EXCEPT(b) REPLACE(a AS d) FROM x; SELECT x.a AS d FROM x AS x; -# execute: false SELECT x.* EXCEPT(a), y.* FROM x, y; SELECT x.b AS b, y.b AS b, y.c AS c FROM x AS x, y AS y; -# execute: false SELECT * EXCEPT(a) FROM x; SELECT x.b AS b FROM x AS x; @@ -338,12 +343,11 @@ SELECT t.c AS c FROM x AS x LATERAL VIEW EXPLODE(x.a) t AS c; SELECT aa FROM x, UNNEST(a) AS t(aa); SELECT t.aa AS aa FROM x AS x, UNNEST(x.a) AS t(aa); -# execute: false # dialect: bigquery +# execute: false SELECT aa FROM x, UNNEST(a) AS aa; SELECT aa AS aa FROM x AS x, UNNEST(x.a) AS aa; -# execute: false # dialect: presto SELECT x.a, i.b FROM x CROSS JOIN UNNEST(SPLIT(b, ',')) AS i(b); SELECT x.a AS a, i.b AS b FROM x AS x CROSS JOIN UNNEST(SPLIT(x.b, ',')) AS i(b); diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 08e8700..54ec64b 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -282,6 +282,9 @@ x * (1 - y); -1 + 3; 2; +1 - 2 - 4; +-5; + -(-1); 1; @@ -307,19 +310,22 @@ x * (1 - y); 0.0219; 1 / 3; -0; +1 / 3; + +1 / 3.0; +0.3333333333333333333333333333; 20.0 / 6; 3.333333333333333333333333333; 10 / 5; -2; +10 / 5; (1.0 * 3) * 4 - 2 * (5 / 2); -8.0; +12.0 - 2 * (5 / 2); 6 - 2 + 4 * 2 + a; -a + 12; +12 + a; a + 1 + 1 + 2; a + 4; @@ -376,7 +382,7 @@ interval '1' year + date '1998-01-01'; CAST('1999-01-01' AS DATE); interval '1' year + date '1998-01-01' + 3 * 7 * 4; -84 + CAST('1999-01-01' AS DATE); +CAST('1999-01-01' AS DATE) + 84; date '1998-12-01' - interval '90' foo; CAST('1998-12-01' AS DATE) - INTERVAL '90' foo; diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index b92ad37..d9a06cc 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -6145,7 +6145,7 @@ WITH web_v1 AS ( SELECT ws_item_sk item_sk, d_date, - sum(Sum(ws_sales_price)) OVER (partition BY ws_item_sk ORDER BY d_date rows BETWEEN UNBOUNDED PRECEDING AND CURRENT row) cume_sales + sum(Sum(ws_sales_price)) OVER (partition BY ws_item_sk ORDER BY d_date rows BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) cume_sales FROM web_sales , date_dim WHERE ws_sold_date_sk=d_date_sk @@ -6156,7 +6156,7 @@ WITH web_v1 AS ( SELECT ss_item_sk item_sk, d_date, - sum(sum(ss_sales_price)) OVER (partition BY ss_item_sk ORDER BY d_date rows BETWEEN UNBOUNDED PRECEDING AND CURRENT row) cume_sales + sum(sum(ss_sales_price)) OVER (partition BY ss_item_sk ORDER BY d_date rows BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) cume_sales FROM store_sales , date_dim WHERE ss_sold_date_sk=d_date_sk @@ -6171,8 +6171,8 @@ FROM ( d_date , web_sales , store_sales , - max(web_sales) OVER (partition BY item_sk ORDER BY d_date rows BETWEEN UNBOUNDED PRECEDING AND CURRENT row) web_cumulative , - max(store_sales) OVER (partition BY item_sk ORDER BY d_date rows BETWEEN UNBOUNDED PRECEDING AND CURRENT row) store_cumulative + max(web_sales) OVER (partition BY item_sk ORDER BY d_date rows BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) web_cumulative , + max(store_sales) OVER (partition BY item_sk ORDER BY d_date rows BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) store_cumulative FROM ( SELECT CASE @@ -6206,7 +6206,7 @@ WITH "date_dim_2" AS ( SELECT "web_sales"."ws_item_sk" AS "item_sk", "date_dim"."d_date" AS "d_date", - SUM(SUM("web_sales"."ws_sales_price")) OVER (PARTITION BY "web_sales"."ws_item_sk" ORDER BY "date_dim"."d_date" rows BETWEEN UNBOUNDED PRECEDING AND CURRENT row) AS "cume_sales" + SUM(SUM("web_sales"."ws_sales_price")) OVER (PARTITION BY "web_sales"."ws_item_sk" ORDER BY "date_dim"."d_date" rows BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS "cume_sales" FROM "web_sales" AS "web_sales" JOIN "date_dim_2" AS "date_dim" ON "web_sales"."ws_sold_date_sk" = "date_dim"."d_date_sk" @@ -6219,7 +6219,7 @@ WITH "date_dim_2" AS ( SELECT "store_sales"."ss_item_sk" AS "item_sk", "date_dim"."d_date" AS "d_date", - SUM(SUM("store_sales"."ss_sales_price")) OVER (PARTITION BY "store_sales"."ss_item_sk" ORDER BY "date_dim"."d_date" rows BETWEEN UNBOUNDED PRECEDING AND CURRENT row) AS "cume_sales" + SUM(SUM("store_sales"."ss_sales_price")) OVER (PARTITION BY "store_sales"."ss_item_sk" ORDER BY "date_dim"."d_date" rows BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS "cume_sales" FROM "store_sales" AS "store_sales" JOIN "date_dim_2" AS "date_dim" ON "store_sales"."ss_sold_date_sk" = "date_dim"."d_date_sk" @@ -6242,12 +6242,12 @@ WITH "date_dim_2" AS ( WHEN NOT "web"."item_sk" IS NULL THEN "web"."item_sk" ELSE "store"."item_sk" - END ORDER BY CASE WHEN NOT "web"."d_date" IS NULL THEN "web"."d_date" ELSE "store"."d_date" END rows BETWEEN UNBOUNDED PRECEDING AND CURRENT row) AS "web_cumulative", + END ORDER BY CASE WHEN NOT "web"."d_date" IS NULL THEN "web"."d_date" ELSE "store"."d_date" END rows BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS "web_cumulative", MAX("store"."cume_sales") OVER (PARTITION BY CASE WHEN NOT "web"."item_sk" IS NULL THEN "web"."item_sk" ELSE "store"."item_sk" - END ORDER BY CASE WHEN NOT "web"."d_date" IS NULL THEN "web"."d_date" ELSE "store"."d_date" END rows BETWEEN UNBOUNDED PRECEDING AND CURRENT row) AS "store_cumulative" + END ORDER BY CASE WHEN NOT "web"."d_date" IS NULL THEN "web"."d_date" ELSE "store"."d_date" END rows BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS "store_cumulative" FROM "web_v1" AS "web" FULL JOIN "store_v1" AS "store" ON "web"."d_date" = "store"."d_date" AND "web"."item_sk" = "store"."item_sk" diff --git a/tests/test_executor.py b/tests/test_executor.py index c455f3a..6bf7d6a 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -1,3 +1,4 @@ +import datetime import unittest from datetime import date @@ -513,6 +514,8 @@ class TestExecutor(unittest.TestCase): self.assertEqual(result.rows, [(3,)]) def test_scalar_functions(self): + now = datetime.datetime.now() + for sql, expected in [ ("CONCAT('a', 'b')", "ab"), ("CONCAT('a', NULL)", None), @@ -569,6 +572,26 @@ class TestExecutor(unittest.TestCase): ("NULL IS NOT NULL", False), ("NULL = NULL", None), ("NULL <> NULL", None), + ("YEAR(CURRENT_TIMESTAMP)", now.year), + ("MONTH(CURRENT_TIME)", now.month), + ("DAY(CURRENT_DATETIME())", now.day), + ("YEAR(CURRENT_DATE())", now.year), + ("MONTH(CURRENT_DATE())", now.month), + ("DAY(CURRENT_DATE())", now.day), + ("YEAR(CURRENT_TIMESTAMP) + 1", now.year + 1), + ( + "YEAR(CURRENT_TIMESTAMP) IN (YEAR(CURRENT_TIMESTAMP) + 1, YEAR(CURRENT_TIMESTAMP) * 10)", + False, + ), + ("YEAR(CURRENT_TIMESTAMP) = (YEAR(CURRENT_TIMESTAMP))", True), + ("YEAR(CURRENT_TIMESTAMP) <> (YEAR(CURRENT_TIMESTAMP))", False), + ("YEAR(CURRENT_DATE()) + 1", now.year + 1), + ( + "YEAR(CURRENT_DATE()) IN (YEAR(CURRENT_DATE()) + 1, YEAR(CURRENT_DATE()) * 10)", + False, + ), + ("YEAR(CURRENT_DATE()) = (YEAR(CURRENT_DATE()))", True), + ("YEAR(CURRENT_DATE()) <> (YEAR(CURRENT_DATE()))", False), ]: with self.subTest(sql): result = execute(f"SELECT {sql}") diff --git a/tests/test_expressions.py b/tests/test_expressions.py index c22f13e..b09b2ab 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -194,30 +194,31 @@ class TestExpressions(unittest.TestCase): def test_replace_placeholders(self): self.assertEqual( exp.replace_placeholders( - parse_one("select * from :tbl1 JOIN :tbl2 ON :col1 = :col2 WHERE :col3 > 100"), - tbl1="foo", - tbl2="bar", - col1="a", - col2="b", - col3="c", + parse_one("select * from :tbl1 JOIN :tbl2 ON :col1 = :str1 WHERE :col2 > :int1"), + tbl1=exp.to_identifier("foo"), + tbl2=exp.to_identifier("bar"), + col1=exp.to_identifier("a"), + col2=exp.to_identifier("c"), + str1="b", + int1=100, ).sql(), - "SELECT * FROM foo JOIN bar ON a = b WHERE c > 100", + "SELECT * FROM foo JOIN bar ON a = 'b' WHERE c > 100", ) self.assertEqual( exp.replace_placeholders( - parse_one("select * from ? JOIN ? ON ? = ? WHERE ? > 100"), - "foo", - "bar", - "a", + parse_one("select * from ? JOIN ? ON ? = ? WHERE ? = 'bla'"), + exp.to_identifier("foo"), + exp.to_identifier("bar"), + exp.to_identifier("a"), "b", - "c", + "bla", ).sql(), - "SELECT * FROM foo JOIN bar ON a = b WHERE c > 100", + "SELECT * FROM foo JOIN bar ON a = 'b' WHERE 'bla' = 'bla'", ) self.assertEqual( exp.replace_placeholders( parse_one("select * from ? WHERE ? > 100"), - "foo", + exp.to_identifier("foo"), ).sql(), "SELECT * FROM foo WHERE ? > 100", ) @@ -229,12 +230,12 @@ class TestExpressions(unittest.TestCase): ) self.assertEqual( exp.replace_placeholders( - parse_one("select * from (SELECT :col1 FROM ?) WHERE :col2 > 100"), - "tbl1", - "tbl2", + parse_one("select * from (SELECT :col1 FROM ?) WHERE :col2 > ?"), + exp.to_identifier("tbl1"), + 100, "tbl3", - col1="a", - col2="b", + col1=exp.to_identifier("a"), + col2=exp.to_identifier("b"), col3="c", ).sql(), "SELECT * FROM (SELECT a FROM tbl1) WHERE b > 100", @@ -526,6 +527,7 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance) self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop) self.assertIsInstance(parse_one("YEAR(a)"), exp.Year) + self.assertIsInstance(parse_one("HLL(a)"), exp.Hll) def test_column(self): column = parse_one("a.b.c.d") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 597fa6f..d077570 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -106,7 +106,6 @@ class TestOptimizer(unittest.TestCase): ): title = meta.get("title") or f"{i}, {sql}" dialect = meta.get("dialect") - execute = execute if meta.get("execute") is None else False leave_tables_isolated = meta.get("leave_tables_isolated") func_kwargs = {**kwargs} @@ -114,7 +113,13 @@ class TestOptimizer(unittest.TestCase): func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated) future = pool.submit(parse_and_optimize, func, sql, dialect, **func_kwargs) - results[future] = (sql, title, expected, dialect, execute) + results[future] = ( + sql, + title, + expected, + dialect, + execute if meta.get("execute") is None else False, + ) for future in as_completed(results): optimized = future.result() diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 909eb18..8481f4d 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -13,6 +13,7 @@ class TestTokens(unittest.TestCase): ("foo --comment", ["comment"]), ("foo", []), ("foo /*comment 1*/ /*comment 2*/", ["comment 1", "comment 2"]), + ("foo\n-- comment", [" comment"]), ] for sql, comment in sql_comment: |