diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dialects/test_bigquery.py | 5 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 78 | ||||
-rw-r--r-- | tests/dialects/test_drill.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 5 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 41 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 9 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 35 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 3 | ||||
-rw-r--r-- | tests/dialects/test_sqlite.py | 27 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 6 | ||||
-rw-r--r-- | tests/fixtures/optimizer/canonicalize.sql | 16 | ||||
-rw-r--r-- | tests/fixtures/optimizer/optimizer.sql | 31 | ||||
-rw-r--r-- | tests/fixtures/optimizer/pushdown_projections.sql | 3 | ||||
-rw-r--r-- | tests/fixtures/pretty.sql | 3 | ||||
-rw-r--r-- | tests/test_build.py | 21 | ||||
-rw-r--r-- | tests/test_expressions.py | 29 | ||||
-rw-r--r-- | tests/test_generator.py | 10 | ||||
-rw-r--r-- | tests/test_parser.py | 40 | ||||
-rw-r--r-- | tests/test_tokens.py | 15 | ||||
-rw-r--r-- | tests/test_transforms.py | 35 |
21 files changed, 399 insertions, 23 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 22387da..e731b50 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -8,6 +8,9 @@ class TestBigQuery(Validator): def test_bigquery(self): self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])") self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))") + self.validate_identity( + "SELECT * FROM (SELECT * FROM `t`) AS a UNPIVOT((c) FOR c_name IN (v1, v2))" + ) self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"}) self.validate_all( @@ -280,7 +283,7 @@ class TestBigQuery(Validator): "duckdb": "CURRENT_DATE + INTERVAL 1 DAY", "mysql": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)", "postgres": "CURRENT_DATE + INTERVAL '1' DAY", - "presto": "DATE_ADD(DAY, 1, CURRENT_DATE)", + "presto": "DATE_ADD('DAY', 1, CURRENT_DATE)", "hive": "DATE_ADD(CURRENT_DATE, 1)", "spark": "DATE_ADD(CURRENT_DATE, 1)", }, diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 905e1f4..d206bb1 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -16,6 +16,10 @@ class TestClickhouse(Validator): self.validate_identity("SELECT * FROM foo ANY JOIN bla") self.validate_identity("SELECT quantile(0.5)(a)") self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t") + self.validate_identity("SELECT quantiles(0.1, 0.2, 0.3)(a)") + self.validate_identity("SELECT histogram(5)(a)") + self.validate_identity("SELECT groupUniqArray(2)(a)") + self.validate_identity("SELECT exponentialTimeDecayedAvg(60)(a, b)") self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)") self.validate_identity("position(haystack, needle)") self.validate_identity("position(haystack, needle, position)") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 21f6be6..6214c43 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -519,7 +519,7 @@ class TestDialect(Validator): "duckdb": "x + INTERVAL 1 day", "hive": "DATE_ADD(x, 1)", "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", - "postgres": "x + INTERVAL '1' 'day'", + "postgres": "x + INTERVAL '1' day", "presto": "DATE_ADD('day', 1, x)", "snowflake": "DATEADD(day, 1, x)", "spark": "DATE_ADD(x, 1)", @@ -543,12 +543,49 @@ class TestDialect(Validator): ) self.validate_all( "DATE_TRUNC('day', x)", + read={ + "bigquery": "DATE_TRUNC(x, day)", + "duckdb": "DATE_TRUNC('day', x)", + "spark": "TRUNC(x, 'day')", + }, write={ + "bigquery": "DATE_TRUNC(x, day)", + "duckdb": "DATE_TRUNC('day', x)", "mysql": "DATE(x)", + "presto": "DATE_TRUNC('day', x)", + "postgres": "DATE_TRUNC('day', x)", "snowflake": "DATE_TRUNC('day', x)", + "starrocks": "DATE_TRUNC('day', x)", + "spark": "TRUNC(x, 'day')", + }, + ) + self.validate_all( + "TIMESTAMP_TRUNC(x, day)", + read={ + "bigquery": "TIMESTAMP_TRUNC(x, day)", + "presto": "DATE_TRUNC('day', x)", + "postgres": "DATE_TRUNC('day', x)", + "snowflake": "DATE_TRUNC('day', x)", + "starrocks": "DATE_TRUNC('day', x)", + "spark": "DATE_TRUNC('day', x)", + }, + ) + self.validate_all( + "DATE_TRUNC('day', CAST(x AS DATE))", + read={ + "presto": "DATE_TRUNC('day', x::DATE)", + "snowflake": "DATE_TRUNC('day', x::DATE)", }, ) self.validate_all( + "TIMESTAMP_TRUNC(CAST(x AS DATE), day)", + read={ + "postgres": "DATE_TRUNC('day', x::DATE)", + "starrocks": "DATE_TRUNC('day', x::DATE)", + }, + ) + + self.validate_all( "DATE_TRUNC('week', x)", write={ "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", @@ -582,8 +619,6 @@ class TestDialect(Validator): "DATE_TRUNC('year', x)", read={ "bigquery": "DATE_TRUNC(x, year)", - "snowflake": "DATE_TRUNC(year, x)", - "starrocks": "DATE_TRUNC('year', x)", "spark": "TRUNC(x, 'year')", }, write={ @@ -599,7 +634,10 @@ class TestDialect(Validator): "TIMESTAMP_TRUNC(x, year)", read={ "bigquery": "TIMESTAMP_TRUNC(x, year)", + "postgres": "DATE_TRUNC(year, x)", "spark": "DATE_TRUNC('year', x)", + "snowflake": "DATE_TRUNC(year, x)", + "starrocks": "DATE_TRUNC('year', x)", }, write={ "bigquery": "TIMESTAMP_TRUNC(x, year)", @@ -752,7 +790,6 @@ class TestDialect(Validator): "trino": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", "duckdb": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", "hive": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", - "presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", "spark": "AGGREGATE(x, 0, (acc, x) -> acc + x, acc -> acc)", "presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", }, @@ -1455,3 +1492,36 @@ SELECT "postgres": "SUBSTRING('123456' FROM 2 FOR 3)", }, ) + + def test_count_if(self): + self.validate_identity("COUNT_IF(DISTINCT cond)") + + self.validate_all( + "SELECT COUNT_IF(cond) FILTER", write={"": "SELECT COUNT_IF(cond) AS FILTER"} + ) + self.validate_all( + "SELECT COUNT_IF(col % 2 = 0) FROM foo", + write={ + "": "SELECT COUNT_IF(col % 2 = 0) FROM foo", + "databricks": "SELECT COUNT_IF(col % 2 = 0) FROM foo", + "presto": "SELECT COUNT_IF(col % 2 = 0) FROM foo", + "snowflake": "SELECT COUNT_IF(col % 2 = 0) FROM foo", + "sqlite": "SELECT SUM(CASE WHEN col % 2 = 0 THEN 1 ELSE 0 END) FROM foo", + "tsql": "SELECT COUNT_IF(col % 2 = 0) FROM foo", + }, + ) + self.validate_all( + "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + read={ + "": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + "databricks": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + "tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + }, + write={ + "": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + "databricks": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + "presto": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + "sqlite": "SELECT SUM(CASE WHEN col % 2 = 0 THEN 1 ELSE 0 END) FILTER(WHERE col < 1000) FROM foo", + "tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", + }, + ) diff --git a/tests/dialects/test_drill.py b/tests/dialects/test_drill.py index e41bd34..f035176 100644 --- a/tests/dialects/test_drill.py +++ b/tests/dialects/test_drill.py @@ -4,6 +4,12 @@ from tests.dialects.test_dialect import Validator class TestDrill(Validator): dialect = "drill" + def test_drill(self): + self.validate_all( + "DATE_FORMAT(a, 'yyyy')", + write={"drill": "TO_CHAR(a, 'yyyy')"}, + ) + def test_string_literals(self): self.validate_all( "SELECT '2021-01-01' + INTERVAL 1 MONTH", diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index c314163..1cabade 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -125,6 +125,11 @@ class TestDuckDB(Validator): "SELECT a['x space'] FROM (SELECT {'x space': 1, 'y': 2, 'z': 3} AS a)" ) + self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'}) + self.validate_all( + "WITH 'x' AS (SELECT 1) SELECT * FROM x", + write={"duckdb": 'WITH "x" AS (SELECT 1) SELECT * FROM x'}, + ) self.validate_all( "CREATE TABLE IF NOT EXISTS table (cola INT, colb STRING) USING ICEBERG PARTITIONED BY (colb)", write={ diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index bf22652..0a9111c 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -63,8 +63,8 @@ class TestPresto(Validator): "bigquery": "CAST(x AS TIMESTAMPTZ)", "duckdb": "CAST(x AS TIMESTAMPTZ(9))", "presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", - "hive": "CAST(x AS TIMESTAMPTZ)", - "spark": "CAST(x AS TIMESTAMPTZ)", + "hive": "CAST(x AS TIMESTAMP)", + "spark": "CAST(x AS TIMESTAMP)", }, ) @@ -189,34 +189,38 @@ class TestPresto(Validator): ) self.validate_all( - "DAY_OF_WEEK(timestamp '2012-08-08 01:00')", + "DAY_OF_WEEK(timestamp '2012-08-08 01:00:00')", write={ - "spark": "DAYOFWEEK(CAST('2012-08-08 01:00' AS TIMESTAMP))", - "presto": "DAY_OF_WEEK(CAST('2012-08-08 01:00' AS TIMESTAMP))", + "spark": "DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", + "presto": "DAY_OF_WEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", + "duckdb": "DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", }, ) self.validate_all( - "DAY_OF_MONTH(timestamp '2012-08-08 01:00')", + "DAY_OF_MONTH(timestamp '2012-08-08 01:00:00')", write={ - "spark": "DAYOFMONTH(CAST('2012-08-08 01:00' AS TIMESTAMP))", - "presto": "DAY_OF_MONTH(CAST('2012-08-08 01:00' AS TIMESTAMP))", + "spark": "DAYOFMONTH(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", + "presto": "DAY_OF_MONTH(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", + "duckdb": "DAYOFMONTH(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", }, ) self.validate_all( - "DAY_OF_YEAR(timestamp '2012-08-08 01:00')", + "DAY_OF_YEAR(timestamp '2012-08-08 01:00:00')", write={ - "spark": "DAYOFYEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))", - "presto": "DAY_OF_YEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))", + "spark": "DAYOFYEAR(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", + "presto": "DAY_OF_YEAR(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", + "duckdb": "DAYOFYEAR(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", }, ) self.validate_all( - "WEEK_OF_YEAR(timestamp '2012-08-08 01:00')", + "WEEK_OF_YEAR(timestamp '2012-08-08 01:00:00')", write={ - "spark": "WEEKOFYEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))", - "presto": "WEEK_OF_YEAR(CAST('2012-08-08 01:00' AS TIMESTAMP))", + "spark": "WEEKOFYEAR(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", + "presto": "WEEK_OF_YEAR(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", + "duckdb": "WEEKOFYEAR(CAST('2012-08-08 01:00:00' AS TIMESTAMP))", }, ) @@ -366,6 +370,15 @@ class TestPresto(Validator): self.validate_identity("APPROX_PERCENTILE(a, b, c, d)") self.validate_all( + "ARRAY_AGG(x ORDER BY y DESC)", + write={ + "hive": "COLLECT_LIST(x)", + "presto": "ARRAY_AGG(x ORDER BY y DESC)", + "spark": "COLLECT_LIST(x)", + "trino": "ARRAY_AGG(x ORDER BY y DESC)", + }, + ) + self.validate_all( "SELECT a FROM t GROUP BY a, ROLLUP(b), ROLLUP(c), ROLLUP(d)", write={ "presto": "SELECT a FROM t GROUP BY a, ROLLUP (b, c, d)", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index f5b8a43..ff730f8 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -7,6 +7,9 @@ class TestRedshift(Validator): def test_redshift(self): self.validate_all("CONVERT(INTEGER, x)", write={"redshift": "CAST(x AS INTEGER)"}) self.validate_all( + "DATEADD('day', ndays, caldate)", write={"redshift": "DATEADD(day, ndays, caldate)"} + ) + self.validate_all( 'create table "group" ("col" char(10))', write={ "redshift": 'CREATE TABLE "group" ("col" CHAR(10))', @@ -80,10 +83,10 @@ class TestRedshift(Validator): }, ) self.validate_all( - "DATEDIFF(d, a, b)", + "DATEDIFF('day', a, b)", write={ - "redshift": "DATEDIFF(d, a, b)", - "presto": "DATE_DIFF(d, a, b)", + "redshift": "DATEDIFF(day, a, b)", + "presto": "DATE_DIFF('day', a, b)", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 1ac910c..5f6efce 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -18,6 +18,41 @@ class TestSnowflake(Validator): self.validate_identity("COMMENT IF EXISTS ON TABLE foo IS 'bar'") self.validate_all( + "SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1", + write={ + "": "SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o NULLS LAST) = 1", + "databricks": "SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o NULLS LAST) = 1", + "hive": "SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o NULLS LAST) AS _w FROM qt) AS _t WHERE _w = 1", + "presto": "SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w FROM qt) AS _t WHERE _w = 1", + "snowflake": "SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1", + "spark": "SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o NULLS LAST) AS _w FROM qt) AS _t WHERE _w = 1", + "sqlite": "SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o NULLS LAST) AS _w FROM qt) AS _t WHERE _w = 1", + "trino": "SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w FROM qt) AS _t WHERE _w = 1", + }, + ) + self.validate_all( + "SELECT BOOLOR_AGG(c1), BOOLOR_AGG(c2) FROM test", + write={ + "": "SELECT LOGICAL_OR(c1), LOGICAL_OR(c2) FROM test", + "duckdb": "SELECT BOOL_OR(c1), BOOL_OR(c2) FROM test", + "postgres": "SELECT BOOL_OR(c1), BOOL_OR(c2) FROM test", + "snowflake": "SELECT BOOLOR_AGG(c1), BOOLOR_AGG(c2) FROM test", + "spark": "SELECT BOOL_OR(c1), BOOL_OR(c2) FROM test", + "sqlite": "SELECT MAX(c1), MAX(c2) FROM test", + }, + ) + self.validate_all( + "SELECT BOOLAND_AGG(c1), BOOLAND_AGG(c2) FROM test", + write={ + "": "SELECT LOGICAL_AND(c1), LOGICAL_AND(c2) FROM test", + "duckdb": "SELECT BOOL_AND(c1), BOOL_AND(c2) FROM test", + "postgres": "SELECT BOOL_AND(c1), BOOL_AND(c2) FROM test", + "snowflake": "SELECT BOOLAND_AGG(c1), BOOLAND_AGG(c2) FROM test", + "spark": "SELECT BOOL_AND(c1), BOOL_AND(c2) FROM test", + "sqlite": "SELECT MIN(c1), MIN(c2) FROM test", + }, + ) + self.validate_all( "TO_CHAR(x, y)", read={ "": "TO_CHAR(x, y)", diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 9328eaa..5b21349 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -214,6 +214,9 @@ TBLPROPERTIES ( self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')") self.validate_all( + "CAST(x AS TIMESTAMP)", read={"trino": "CAST(x AS TIMESTAMP(6) WITH TIME ZONE)"} + ) + self.validate_all( "SELECT DATE_ADD(my_date_column, 1)", write={ "spark": "SELECT DATE_ADD(my_date_column, 1)", diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index f889445..98c4a79 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -57,6 +57,33 @@ class TestSQLite(Validator): def test_sqlite(self): self.validate_all( + "CURRENT_DATE", + read={ + "": "CURRENT_DATE", + "snowflake": "CURRENT_DATE()", + }, + ) + self.validate_all( + "CURRENT_TIME", + read={ + "": "CURRENT_TIME", + "snowflake": "CURRENT_TIME()", + }, + ) + self.validate_all( + "CURRENT_TIMESTAMP", + read={ + "": "CURRENT_TIMESTAMP", + "snowflake": "CURRENT_TIMESTAMP()", + }, + ) + self.validate_all( + "SELECT DATE('2020-01-01 16:03:05')", + read={ + "snowflake": "SELECT CAST('2020-01-01 16:03:05' AS DATE)", + }, + ) + self.validate_all( "SELECT CAST([a].[b] AS SMALLINT) FROM foo", write={ "sqlite": 'SELECT CAST("a"."b" AS INTEGER) FROM foo', diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 380d945..3551423 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -132,6 +132,8 @@ INTERVAL '-31' CAST(GETDATE() AS DATE) INTERVAL 2 months INTERVAL (1 + 3) DAYS CAST('45' AS INTERVAL DAYS) +FILTER(a, x -> x.a.b.c.d.e.f.g) +FILTER(a, x -> FOO(x.a.b.c.d.e.f.g) + x.a.b.c.d.e.f.g) TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), 1, DAY) DATETIME_DIFF(CURRENT_DATE, 1, DAY) QUANTILE(x, 0.5) @@ -161,6 +163,10 @@ CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo' SET x = 1 SET -v SET x = ';' +SET variable = value +SET GLOBAL variable = value +SET LOCAL variable = value +SET @user OFF COMMIT USE db USE role x diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql index 8c7cd45..50fee7f 100644 --- a/tests/fixtures/optimizer/canonicalize.sql +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -9,3 +9,19 @@ SELECT CAST(1 AS VARCHAR) AS "a" FROM "w" AS "w"; SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w; SELECT 1 + 3.2 AS "a" FROM "w" AS "w"; + +-------------------------------------- +-- Ensure boolean predicates +-------------------------------------- + +SELECT a FROM x WHERE b; +SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE "x"."b" <> 0; + +SELECT a FROM x GROUP BY a HAVING SUM(b); +SELECT "x"."a" AS "a" FROM "x" AS "x" GROUP BY "x"."a" HAVING SUM("x"."b") <> 0; + +SELECT a FROM x GROUP BY a HAVING SUM(b) AND TRUE; +SELECT "x"."a" AS "a" FROM "x" AS "x" GROUP BY "x"."a" HAVING SUM("x"."b") <> 0 AND TRUE; + +SELECT a FROM x WHERE 1; +SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE 1 <> 0; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index a14e325..0b5504d 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -386,6 +386,29 @@ SELECT "x"."b" + 1 AS "c" FROM "x" AS "x"; +# title: unqualified struct element is selected in the outer query +# execute: false +WITH "cte" AS ( + SELECT + FROM_JSON("value", 'STRUCT<f1: STRUCT<f2: STRUCT<f3: STRUCT<f4: STRING>>>>') AS "struct" + FROM "tbl" +) SELECT "struct"."f1"."f2"."f3"."f4" AS "f4" FROM "cte"; +SELECT + FROM_JSON("tbl"."value", 'STRUCT<f1: STRUCT<f2: STRUCT<f3: STRUCT<f4: STRING>>>>')."f1"."f2"."f3"."f4" AS "f4" +FROM "tbl" AS "tbl"; + +# title: qualified struct element is selected in the outer query +# execute: false +WITH "cte" AS ( + SELECT + FROM_JSON("value", 'STRUCT<f1: STRUCT<f2: INTEGER>, STRUCT<f3: STRING>>') AS "struct" + FROM "tbl" +) SELECT "cte"."struct"."f1"."f2" AS "f2", "cte"."struct"."f1"."f3" AS "f3" FROM "cte"; +SELECT + FROM_JSON("tbl"."value", 'STRUCT<f1: STRUCT<f2: INTEGER>, STRUCT<f3: STRING>>')."f1"."f2" AS "f2", + FROM_JSON("tbl"."value", 'STRUCT<f1: STRUCT<f2: INTEGER>, STRUCT<f3: STRING>>')."f1"."f3" AS "f3" +FROM "tbl" AS "tbl"; + # title: left join doesnt push down predicate to join in merge subqueries # execute: false SELECT @@ -430,3 +453,11 @@ LEFT JOIN "unlocked" AS "unlocked" WHERE CASE WHEN "unlocked"."company_id" IS NULL THEN 0 ELSE 1 END = FALSE AND NOT "company_table_2"."id" IS NULL; + +# title: db.table alias clash +# execute: false +select * from db1.tbl, db2.tbl; +SELECT + * +FROM "db1"."tbl" AS "tbl" +CROSS JOIN "db2"."tbl" AS "tbl_2"; diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql index f3b1a69..6ff9383 100644 --- a/tests/fixtures/optimizer/pushdown_projections.sql +++ b/tests/fixtures/optimizer/pushdown_projections.sql @@ -4,6 +4,9 @@ SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0; SELECT 1 FROM (SELECT * FROM x) WHERE b = 2; SELECT 1 AS "1" FROM (SELECT x.b AS b FROM x AS x) AS _q_0 WHERE _q_0.b = 2; +SELECT a, b, a from x; +SELECT x.a AS a, x.b AS b, x.a AS a FROM x AS x; + SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q; SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS q; diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index a06af88..8de9c85 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -1,3 +1,6 @@ +SET x TO 1; +SET x = 1; + SELECT * FROM test; SELECT * diff --git a/tests/test_build.py b/tests/test_build.py index 718e471..43707b0 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -510,6 +510,27 @@ class TestBuild(unittest.TestCase): .qualify("row_number() OVER (PARTITION BY a ORDER BY b) = 1"), "SELECT * FROM table QUALIFY ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) = 1", ), + (lambda: exp.delete("tbl1", "x = 1").delete("tbl2"), "DELETE FROM tbl2 WHERE x = 1"), + (lambda: exp.delete("tbl").where("x = 1"), "DELETE FROM tbl WHERE x = 1"), + (lambda: exp.delete(exp.table_("tbl")), "DELETE FROM tbl"), + ( + lambda: exp.delete("tbl", "x = 1").where("y = 2"), + "DELETE FROM tbl WHERE x = 1 AND y = 2", + ), + ( + lambda: exp.delete("tbl", "x = 1").where(exp.condition("y = 2").or_("z = 3")), + "DELETE FROM tbl WHERE x = 1 AND (y = 2 OR z = 3)", + ), + ( + lambda: exp.delete("tbl").where("x = 1").returning("*", dialect="postgres"), + "DELETE FROM tbl WHERE x = 1 RETURNING *", + "postgres", + ), + ( + lambda: exp.delete("tbl", where="x = 1", returning="*", dialect="postgres"), + "DELETE FROM tbl WHERE x = 1 RETURNING *", + "postgres", + ), ]: with self.subTest(sql): self.assertEqual(expression().sql(dialect[0] if dialect else None), sql) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index ecbdc24..69e1d14 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -6,6 +6,8 @@ from sqlglot import alias, exp, parse_one class TestExpressions(unittest.TestCase): + maxDiff = None + def test_arg_key(self): self.assertEqual(parse_one("sum(1)").find(exp.Literal).arg_key, "this") @@ -91,6 +93,32 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(column.parent_select, exp.Select) self.assertIsNone(column.find_ancestor(exp.Join)) + def test_to_dot(self): + column = parse_one('a.b.c."d".e.f').find(exp.Column) + dot = column.to_dot() + + self.assertEqual(dot.sql(), 'a.b.c."d".e.f') + + self.assertEqual( + dot, + exp.Dot( + this=exp.Dot( + this=exp.Dot( + this=exp.Dot( + this=exp.Dot( + this=exp.to_identifier("a"), + expression=exp.to_identifier("b"), + ), + expression=exp.to_identifier("c"), + ), + expression=exp.to_identifier("d", quoted=True), + ), + expression=exp.to_identifier("e"), + ), + expression=exp.to_identifier("f"), + ), + ) + def test_root(self): ast = parse_one("select * from (select a from x)") self.assertIs(ast, ast.root()) @@ -480,6 +508,7 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("COMMIT"), exp.Commit) self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback) self.assertIsInstance(parse_one("GENERATE_SERIES(a, b, c)"), exp.GenerateSeries) + self.assertIsInstance(parse_one("COUNT_IF(a > 0)"), exp.CountIf) def test_column(self): column = parse_one("a.b.c.d") diff --git a/tests/test_generator.py b/tests/test_generator.py index d64a818..fce5c81 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,5 +1,6 @@ import unittest +from sqlglot import parse_one from sqlglot.expressions import Func from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer @@ -28,3 +29,12 @@ class TestGenerator(unittest.TestCase): tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x") expression = NewParser().parse(tokens)[0] self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x") + + def test_identify(self): + assert parse_one("x").sql(identify=True) == '"x"' + assert parse_one("x").sql(identify="always") == '"x"' + assert parse_one("X").sql(identify="always") == '"X"' + assert parse_one("x").sql(identify="safe") == '"x"' + assert parse_one("X").sql(identify="safe") == "X" + assert parse_one("x as 1").sql(identify="safe") == '"x" AS "1"' + assert parse_one("X as 1").sql(identify="safe") == 'X AS "1"' diff --git a/tests/test_parser.py b/tests/test_parser.py index dbde437..861e47f 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -102,6 +102,13 @@ class TestParser(unittest.TestCase): self.assertEqual(expressions[1].sql(), "ADD JAR s3://a") self.assertEqual(expressions[2].sql(), "SELECT 1") + def test_lambda_struct(self): + expression = parse_one("FILTER(a.b, x -> x.id = id)") + lambda_expr = expression.expression + + self.assertIsInstance(lambda_expr.this.this, exp.Dot) + self.assertEqual(lambda_expr.sql(), "x -> x.id = id") + def test_transactions(self): expression = parse_one("BEGIN TRANSACTION") self.assertIsNone(expression.this) @@ -280,6 +287,39 @@ class TestParser(unittest.TestCase): self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func) self.assertIsInstance(parse_one("map.x"), exp.Column) + def test_set_expression(self): + set_ = parse_one("SET") + + self.assertEqual(set_.sql(), "SET") + self.assertIsInstance(set_, exp.Set) + + set_session = parse_one("SET SESSION x = 1") + + self.assertEqual(set_session.sql(), "SET SESSION x = 1") + self.assertIsInstance(set_session, exp.Set) + + set_item = set_session.expressions[0] + + self.assertIsInstance(set_item, exp.SetItem) + self.assertIsInstance(set_item.this, exp.EQ) + self.assertIsInstance(set_item.this.this, exp.Identifier) + self.assertIsInstance(set_item.this.expression, exp.Literal) + + self.assertEqual(set_item.args.get("kind"), "SESSION") + + set_to = parse_one("SET x TO 1") + + self.assertEqual(set_to.sql(), "SET x = 1") + self.assertIsInstance(set_to, exp.Set) + + set_as_command = parse_one("SET DEFAULT ROLE ALL TO USER") + + self.assertEqual(set_as_command.sql(), "SET DEFAULT ROLE ALL TO USER") + + self.assertIsInstance(set_as_command, exp.Command) + self.assertEqual(set_as_command.this, "SET") + self.assertEqual(set_as_command.expression, " DEFAULT ROLE ALL TO USER") + def test_pretty_config_override(self): self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x") with patch("sqlglot.pretty", True): diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 0888555..909eb18 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -30,6 +30,21 @@ class TestTokens(unittest.TestCase): self.assertEqual(tokens[-1].line, 6) + def test_command(self): + tokens = Tokenizer().tokenize("SHOW;") + self.assertEqual(tokens[0].token_type, TokenType.SHOW) + self.assertEqual(tokens[1].token_type, TokenType.SEMICOLON) + + tokens = Tokenizer().tokenize("EXECUTE") + self.assertEqual(tokens[0].token_type, TokenType.EXECUTE) + self.assertEqual(len(tokens), 1) + + tokens = Tokenizer().tokenize("FETCH;SHOW;") + self.assertEqual(tokens[0].token_type, TokenType.FETCH) + self.assertEqual(tokens[1].token_type, TokenType.SEMICOLON) + self.assertEqual(tokens[2].token_type, TokenType.SHOW) + self.assertEqual(tokens[3].token_type, TokenType.SEMICOLON) + def test_jinja(self): tokenizer = Tokenizer() diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 76d63b6..1e85b80 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -3,12 +3,13 @@ import unittest from sqlglot import parse_one from sqlglot.transforms import ( eliminate_distinct_on, + eliminate_qualify, remove_precision_parameterized_types, unalias_group, ) -class TestTime(unittest.TestCase): +class TestTransforms(unittest.TestCase): maxDiff = None def validate(self, transform, sql, target): @@ -74,6 +75,38 @@ class TestTime(unittest.TestCase): 'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1', ) + def test_eliminate_qualify(self): + self.validate( + eliminate_qualify, + "SELECT i, a + 1 FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p) = 1", + "SELECT i, _c FROM (SELECT i, a + 1 AS _c, ROW_NUMBER() OVER (PARTITION BY p) AS _w, p FROM qt) AS _t WHERE _w = 1", + ) + self.validate( + eliminate_qualify, + "SELECT i FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1 AND p = 0", + "SELECT i FROM (SELECT i, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w, p, o FROM qt) AS _t WHERE _w = 1 AND p = 0", + ) + self.validate( + eliminate_qualify, + "SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1", + "SELECT i, p, o FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w FROM qt) AS _t WHERE _w = 1", + ) + self.validate( + eliminate_qualify, + "SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS row_num FROM qt QUALIFY row_num = 1", + "SELECT i, p, o, row_num FROM (SELECT i, p, o, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS row_num FROM qt) AS _t WHERE row_num = 1", + ) + self.validate( + eliminate_qualify, + "SELECT * FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1", + "SELECT * FROM (SELECT *, ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) AS _w, p, o FROM qt) AS _t WHERE _w = 1", + ) + self.validate( + eliminate_qualify, + "SELECT c2, SUM(c3) OVER (PARTITION BY c2) AS r FROM t1 WHERE c3 < 4 GROUP BY c2, c3 HAVING SUM(c1) > 3 QUALIFY r IN (SELECT MIN(c1) FROM test GROUP BY c2 HAVING MIN(c1) > 3)", + "SELECT c2, r FROM (SELECT c2, SUM(c3) OVER (PARTITION BY c2) AS r, c1 FROM t1 WHERE c3 < 4 GROUP BY c2, c3 HAVING SUM(c1) > 3) AS _t WHERE r IN (SELECT MIN(c1) FROM test GROUP BY c2 HAVING MIN(c1) > 3)", + ) + def test_remove_precision_parameterized_types(self): self.validate( remove_precision_parameterized_types, |