From 3d48060515ba25b4c49d975a520ee0682327d1b7 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 16 Feb 2024 06:45:52 +0100 Subject: Merging upstream version 21.1.1. Signed-off-by: Daniel Baumann --- tests/dialects/test_bigquery.py | 266 ++++++++++++-------------- tests/dialects/test_clickhouse.py | 1 + tests/dialects/test_dialect.py | 4 +- tests/dialects/test_hive.py | 2 +- tests/dialects/test_postgres.py | 2 + tests/dialects/test_presto.py | 2 +- tests/dialects/test_redshift.py | 96 +++++----- tests/dialects/test_snowflake.py | 18 +- tests/dialects/test_spark.py | 26 ++- tests/dialects/test_tableau.py | 15 ++ tests/dialects/test_tsql.py | 22 +-- tests/fixtures/identity.sql | 3 + tests/fixtures/optimizer/canonicalize.sql | 12 ++ tests/fixtures/optimizer/merge_subqueries.sql | 6 +- tests/fixtures/optimizer/qualify_columns.sql | 32 ++-- tests/fixtures/optimizer/tpc-h/tpc-h.sql | 38 ++-- tests/test_expressions.py | 9 + tests/test_lineage.py | 33 +++- tests/test_optimizer.py | 6 + tests/test_parser.py | 13 ++ tests/test_transpile.py | 3 +- 21 files changed, 344 insertions(+), 265 deletions(-) (limited to 'tests') diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 5cc5480..f231179 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -18,78 +18,6 @@ class TestBigQuery(Validator): maxDiff = None def test_bigquery(self): - self.validate_identity("ARRAY_AGG(x IGNORE NULLS LIMIT 1)") - self.validate_identity("ARRAY_AGG(x IGNORE NULLS ORDER BY x LIMIT 1)") - self.validate_identity("ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY x LIMIT 1)") - self.validate_identity("ARRAY_AGG(x IGNORE NULLS)") - self.validate_identity("ARRAY_AGG(DISTINCT x IGNORE NULLS HAVING MAX x ORDER BY x LIMIT 1)") - - self.validate_all( - "SELECT SUM(x IGNORE NULLS) AS x", - read={ - "bigquery": "SELECT SUM(x IGNORE NULLS) AS x", - "duckdb": "SELECT SUM(x IGNORE NULLS) AS x", - "postgres": "SELECT SUM(x) IGNORE NULLS AS x", - "spark": "SELECT SUM(x) IGNORE NULLS AS x", - "snowflake": "SELECT SUM(x) IGNORE NULLS AS x", - }, - write={ - "bigquery": "SELECT SUM(x IGNORE NULLS) AS x", - "duckdb": "SELECT SUM(x IGNORE NULLS) AS x", - "postgres": "SELECT SUM(x) IGNORE NULLS AS x", - "spark": "SELECT SUM(x) IGNORE NULLS AS x", - "snowflake": "SELECT SUM(x) IGNORE NULLS AS x", - }, - ) - self.validate_all( - "SELECT SUM(x RESPECT NULLS) AS x", - read={ - "bigquery": "SELECT SUM(x RESPECT NULLS) AS x", - "duckdb": "SELECT SUM(x RESPECT NULLS) AS x", - "postgres": "SELECT SUM(x) RESPECT NULLS AS x", - "spark": "SELECT SUM(x) RESPECT NULLS AS x", - "snowflake": "SELECT SUM(x) RESPECT NULLS AS x", - }, - write={ - "bigquery": "SELECT SUM(x RESPECT NULLS) AS x", - "duckdb": "SELECT SUM(x RESPECT NULLS) AS x", - "postgres": "SELECT SUM(x) RESPECT NULLS AS x", - "spark": "SELECT SUM(x) RESPECT NULLS AS x", - "snowflake": "SELECT SUM(x) RESPECT NULLS AS x", - }, - ) - self.validate_all( - "SELECT PERCENTILE_CONT(x, 0.5 RESPECT NULLS) OVER ()", - write={ - "bigquery": "SELECT PERCENTILE_CONT(x, 0.5 RESPECT NULLS) OVER ()", - "duckdb": "SELECT QUANTILE_CONT(x, 0.5 RESPECT NULLS) OVER ()", - "spark": "SELECT PERCENTILE_CONT(x, 0.5) RESPECT NULLS OVER ()", - }, - ) - self.validate_all( - "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x", - write={ - "bigquery": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x", - "duckdb": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a NULLS FIRST, b DESC LIMIT 10) AS x", - "spark": "SELECT COLLECT_LIST(DISTINCT x ORDER BY a, b DESC LIMIT 10) IGNORE NULLS AS x", - }, - ) - self.validate_all( - "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 1, 10) AS x", - write={ - "bigquery": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 1, 10) AS x", - "duckdb": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a NULLS FIRST, b DESC LIMIT 1, 10) AS x", - "spark": "SELECT COLLECT_LIST(DISTINCT x ORDER BY a, b DESC LIMIT 1, 10) IGNORE NULLS AS x", - }, - ) - self.validate_identity("SELECT COUNT(x RESPECT NULLS)") - self.validate_identity("SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x") - - self.validate_identity( - "create or replace view test (tenant_id OPTIONS(description='Test description on table creation')) select 1 as tenant_id, 1 as customer_id;", - "CREATE OR REPLACE VIEW test (tenant_id OPTIONS (description='Test description on table creation')) AS SELECT 1 AS tenant_id, 1 AS customer_id", - ) - with self.assertLogs(helper_logger) as cm: statements = parse( """ @@ -131,19 +59,12 @@ class TestBigQuery(Validator): self.validate_all( "a[0]", read={ + "bigquery": "a[0]", "duckdb": "a[1]", "presto": "a[1]", }, ) - self.validate_identity( - "select array_contains([1, 2, 3], 1)", - "SELECT EXISTS(SELECT 1 FROM UNNEST([1, 2, 3]) AS _col WHERE _col = 1)", - ) - self.validate_identity("CREATE SCHEMA x DEFAULT COLLATE 'en'") - self.validate_identity("CREATE TABLE x (y INT64) DEFAULT COLLATE 'en'") - self.validate_identity("PARSE_JSON('{}', wide_number_mode => 'exact')") - with self.assertRaises(TokenError): transpile("'\\'", read="bigquery") @@ -179,6 +100,16 @@ class TestBigQuery(Validator): ) assert "'END FOR'" in cm.output[0] + self.validate_identity("CREATE SCHEMA x DEFAULT COLLATE 'en'") + self.validate_identity("CREATE TABLE x (y INT64) DEFAULT COLLATE 'en'") + self.validate_identity("PARSE_JSON('{}', wide_number_mode => 'exact')") + self.validate_identity("FOO(values)") + self.validate_identity("STRUCT(values AS value)") + self.validate_identity("ARRAY_AGG(x IGNORE NULLS LIMIT 1)") + self.validate_identity("ARRAY_AGG(x IGNORE NULLS ORDER BY x LIMIT 1)") + self.validate_identity("ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY x LIMIT 1)") + self.validate_identity("ARRAY_AGG(x IGNORE NULLS)") + self.validate_identity("ARRAY_AGG(DISTINCT x IGNORE NULLS HAVING MAX x ORDER BY x LIMIT 1)") self.validate_identity("SELECT * FROM dataset.my_table TABLESAMPLE SYSTEM (10 PERCENT)") self.validate_identity("TIME('2008-12-25 15:30:00+08')") self.validate_identity("TIME('2008-12-25 15:30:00+08', 'America/Los_Angeles')") @@ -237,6 +168,13 @@ class TestBigQuery(Validator): self.validate_identity("SELECT TIMESTAMP_SECONDS(2) AS t") self.validate_identity("SELECT TIMESTAMP_MILLIS(2) AS t") self.validate_identity("""SELECT JSON_EXTRACT_SCALAR('{"a": 5}', '$.a')""") + self.validate_identity("UPDATE x SET y = NULL") + self.validate_identity("LOG(n, b)") + self.validate_identity("SELECT COUNT(x RESPECT NULLS)") + self.validate_identity("SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x") + self.validate_identity( + "SELECT * FROM test QUALIFY a IS DISTINCT FROM b WINDOW c AS (PARTITION BY d)" + ) self.validate_identity( "FOR record IN (SELECT word, word_count FROM bigquery-public-data.samples.shakespeare LIMIT 5) DO SELECT record.word, record.word_count" ) @@ -264,6 +202,14 @@ class TestBigQuery(Validator): self.validate_identity( """SELECT JSON_EXTRACT_SCALAR('5')""", """SELECT JSON_EXTRACT_SCALAR('5', '$')""" ) + self.validate_identity( + "select array_contains([1, 2, 3], 1)", + "SELECT EXISTS(SELECT 1 FROM UNNEST([1, 2, 3]) AS _col WHERE _col = 1)", + ) + self.validate_identity( + "create or replace view test (tenant_id OPTIONS(description='Test description on table creation')) select 1 as tenant_id, 1 as customer_id;", + "CREATE OR REPLACE VIEW test (tenant_id OPTIONS (description='Test description on table creation')) AS SELECT 1 AS tenant_id, 1 AS customer_id", + ) self.validate_identity( "SELECT SPLIT(foo)", "SELECT SPLIT(foo, ',')", @@ -312,7 +258,81 @@ class TestBigQuery(Validator): "SELECT * FROM UNNEST(x) WITH OFFSET EXCEPT DISTINCT SELECT * FROM UNNEST(y) WITH OFFSET", "SELECT * FROM UNNEST(x) WITH OFFSET AS offset EXCEPT DISTINCT SELECT * FROM UNNEST(y) WITH OFFSET AS offset", ) + self.validate_identity( + "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))", + "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))", + ) + self.validate_identity( + r"REGEXP_EXTRACT(svc_plugin_output, r'\\\((.*)')", + r"REGEXP_EXTRACT(svc_plugin_output, '\\\\\\((.*)')", + ) + self.validate_all( + "TIMESTAMP(x)", + write={ + "bigquery": "TIMESTAMP(x)", + "duckdb": "CAST(x AS TIMESTAMPTZ)", + "presto": "CAST(x AS TIMESTAMP WITH TIME ZONE)", + }, + ) + self.validate_all( + "SELECT SUM(x IGNORE NULLS) AS x", + read={ + "bigquery": "SELECT SUM(x IGNORE NULLS) AS x", + "duckdb": "SELECT SUM(x IGNORE NULLS) AS x", + "postgres": "SELECT SUM(x) IGNORE NULLS AS x", + "spark": "SELECT SUM(x) IGNORE NULLS AS x", + "snowflake": "SELECT SUM(x) IGNORE NULLS AS x", + }, + write={ + "bigquery": "SELECT SUM(x IGNORE NULLS) AS x", + "duckdb": "SELECT SUM(x IGNORE NULLS) AS x", + "postgres": "SELECT SUM(x) IGNORE NULLS AS x", + "spark": "SELECT SUM(x) IGNORE NULLS AS x", + "snowflake": "SELECT SUM(x) IGNORE NULLS AS x", + }, + ) + self.validate_all( + "SELECT SUM(x RESPECT NULLS) AS x", + read={ + "bigquery": "SELECT SUM(x RESPECT NULLS) AS x", + "duckdb": "SELECT SUM(x RESPECT NULLS) AS x", + "postgres": "SELECT SUM(x) RESPECT NULLS AS x", + "spark": "SELECT SUM(x) RESPECT NULLS AS x", + "snowflake": "SELECT SUM(x) RESPECT NULLS AS x", + }, + write={ + "bigquery": "SELECT SUM(x RESPECT NULLS) AS x", + "duckdb": "SELECT SUM(x RESPECT NULLS) AS x", + "postgres": "SELECT SUM(x) RESPECT NULLS AS x", + "spark": "SELECT SUM(x) RESPECT NULLS AS x", + "snowflake": "SELECT SUM(x) RESPECT NULLS AS x", + }, + ) + self.validate_all( + "SELECT PERCENTILE_CONT(x, 0.5 RESPECT NULLS) OVER ()", + write={ + "bigquery": "SELECT PERCENTILE_CONT(x, 0.5 RESPECT NULLS) OVER ()", + "duckdb": "SELECT QUANTILE_CONT(x, 0.5 RESPECT NULLS) OVER ()", + "spark": "SELECT PERCENTILE_CONT(x, 0.5) RESPECT NULLS OVER ()", + }, + ) + self.validate_all( + "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x", + write={ + "bigquery": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x", + "duckdb": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a NULLS FIRST, b DESC LIMIT 10) AS x", + "spark": "SELECT COLLECT_LIST(DISTINCT x ORDER BY a, b DESC LIMIT 10) IGNORE NULLS AS x", + }, + ) + self.validate_all( + "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 1, 10) AS x", + write={ + "bigquery": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 1, 10) AS x", + "duckdb": "SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a NULLS FIRST, b DESC LIMIT 1, 10) AS x", + "spark": "SELECT COLLECT_LIST(DISTINCT x ORDER BY a, b DESC LIMIT 1, 10) IGNORE NULLS AS x", + }, + ) self.validate_all( "SELECT * FROM Produce UNPIVOT((first_half_sales, second_half_sales) FOR semesters IN ((Q1, Q2) AS 'semester_1', (Q3, Q4) AS 'semester_2'))", read={ @@ -464,7 +484,6 @@ class TestBigQuery(Validator): "duckdb": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS _t(x) WHERE x > 1)", }, ) - self.validate_identity("UPDATE x SET y = NULL") self.validate_all( "NULL", read={ @@ -620,6 +639,14 @@ class TestBigQuery(Validator): "spark": "WITH cte AS (SELECT ARRAY(1, 2, 3) AS arr) SELECT EXPLODE(arr) FROM cte" }, ) + self.validate_all( + "SELECT IF(pos = pos_2, col, NULL) AS col FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [[][SAFE_ORDINAL(0)]], []))) - 1)) AS pos CROSS JOIN UNNEST(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [[][SAFE_ORDINAL(0)]], [])) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [[][SAFE_ORDINAL(0)]], [])) - 1) AND pos_2 = (ARRAY_LENGTH(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [[][SAFE_ORDINAL(0)]], [])) - 1))", + read={"spark": "select explode_outer([])"}, + ) + self.validate_all( + "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_2, pos_2, NULL) AS pos_2 FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [[][SAFE_ORDINAL(0)]], []))) - 1)) AS pos CROSS JOIN UNNEST(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [[][SAFE_ORDINAL(0)]], [])) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [[][SAFE_ORDINAL(0)]], [])) - 1) AND pos_2 = (ARRAY_LENGTH(IF(ARRAY_LENGTH(COALESCE([], [])) = 0, [[][SAFE_ORDINAL(0)]], [])) - 1))", + read={"spark": "select posexplode_outer([])"}, + ) self.validate_all( "SELECT AS STRUCT ARRAY(SELECT AS STRUCT b FROM x) AS y FROM z", write={ @@ -660,10 +687,6 @@ class TestBigQuery(Validator): "bigquery": "SELECT ARRAY(SELECT AS STRUCT 1 AS a, 2 AS b)", }, ) - self.validate_identity( - r"REGEXP_EXTRACT(svc_plugin_output, r'\\\((.*)')", - r"REGEXP_EXTRACT(svc_plugin_output, '\\\\\\((.*)')", - ) self.validate_all( "REGEXP_CONTAINS('foo', '.*')", read={ @@ -986,9 +1009,6 @@ class TestBigQuery(Validator): "postgres": "CURRENT_DATE AT TIME ZONE 'UTC'", }, ) - self.validate_identity( - "SELECT * FROM test QUALIFY a IS DISTINCT FROM b WINDOW c AS (PARTITION BY d)" - ) self.validate_all( "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10", write={ @@ -997,45 +1017,20 @@ class TestBigQuery(Validator): }, ) self.validate_all( - "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", - write={ - "spark": "SELECT cola, colb FROM VALUES (1, 'test') AS tab(cola, colb)", + "SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)])", + read={ "bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb)])", "snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", - }, - ) - self.validate_all( - "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab", - write={ - "bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS _c0, 'test' AS _c1)])", - }, - ) - self.validate_all( - "SELECT cola, colb FROM (VALUES (1, 'test'))", - write={ - "bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS _c0, 'test' AS _c1)])", + "spark": "SELECT cola, colb FROM VALUES (1, 'test') AS tab(cola, colb)", }, ) self.validate_all( "SELECT * FROM UNNEST([STRUCT(1 AS id)]) CROSS JOIN UNNEST([STRUCT(1 AS id)])", read={ + "bigquery": "SELECT * FROM UNNEST([STRUCT(1 AS id)]) CROSS JOIN UNNEST([STRUCT(1 AS id)])", "postgres": "SELECT * FROM (VALUES (1)) AS t1(id) CROSS JOIN (VALUES (1)) AS t2(id)", }, ) - self.validate_all( - "SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)", - write={ - "spark": "SELECT cola, colb, colc FROM VALUES (1, 'test', NULL) AS tab(cola, colb, colc)", - "bigquery": "SELECT cola, colb, colc FROM UNNEST([STRUCT(1 AS cola, 'test' AS colb, NULL AS colc)])", - "snowflake": "SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)", - }, - ) - self.validate_all( - "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) d, COUNT(*) e FOR c IN ('x', 'y'))", - write={ - "bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))", - }, - ) self.validate_all( "SELECT REGEXP_EXTRACT(abc, 'pattern(group)') FROM table", write={ @@ -1091,8 +1086,6 @@ WHERE pretty=True, ) - self.validate_identity("LOG(n, b)") - def test_user_defined_functions(self): self.validate_identity( "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'" @@ -1114,35 +1107,22 @@ WHERE ) def test_remove_precision_parameterized_types(self): - self.validate_all( - "SELECT CAST(1 AS NUMERIC(10, 2))", - write={ - "bigquery": "SELECT CAST(1 AS NUMERIC)", - }, + self.validate_identity("CREATE TABLE test (a NUMERIC(10, 2))") + self.validate_identity( + "INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING(10)), CAST(14 AS STRING(10)))", + "INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING), CAST(14 AS STRING))", ) - self.validate_all( - "CREATE TABLE test (a NUMERIC(10, 2))", - write={ - "bigquery": "CREATE TABLE test (a NUMERIC(10, 2))", - }, + self.validate_identity( + "SELECT CAST(1 AS NUMERIC(10, 2))", + "SELECT CAST(1 AS NUMERIC)", ) - self.validate_all( + self.validate_identity( "SELECT CAST('1' AS STRING(10)) UNION ALL SELECT CAST('2' AS STRING(10))", - write={ - "bigquery": "SELECT CAST('1' AS STRING) UNION ALL SELECT CAST('2' AS STRING)", - }, + "SELECT CAST('1' AS STRING) UNION ALL SELECT CAST('2' AS STRING)", ) - self.validate_all( + self.validate_identity( "SELECT cola FROM (SELECT CAST('1' AS STRING(10)) AS cola UNION ALL SELECT CAST('2' AS STRING(10)) AS cola)", - write={ - "bigquery": "SELECT cola FROM (SELECT CAST('1' AS STRING) AS cola UNION ALL SELECT CAST('2' AS STRING) AS cola)", - }, - ) - self.validate_all( - "INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING(10)), CAST(14 AS STRING(10)))", - write={ - "bigquery": "INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING), CAST(14 AS STRING))", - }, + "SELECT cola FROM (SELECT CAST('1' AS STRING) AS cola UNION ALL SELECT CAST('2' AS STRING) AS cola)", ) def test_models(self): diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 7351f6a..0148812 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -369,6 +369,7 @@ class TestClickhouse(Validator): "SELECT STARTS_WITH('a', 'b'), STARTSWITH('a', 'b')", write={"clickhouse": "SELECT startsWith('a', 'b'), startsWith('a', 'b')"}, ) + self.validate_identity("SYSTEM STOP MERGES foo.bar", check_command_warning=True) def test_cte(self): self.validate_identity("WITH 'x' AS foo SELECT foo") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index fd9dbdb..4b1e2a7 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -2101,7 +2101,7 @@ SELECT "databricks": "SELECT COUNT_IF(col % 2 = 0) FROM foo", "presto": "SELECT COUNT_IF(col % 2 = 0) FROM foo", "snowflake": "SELECT COUNT_IF(col % 2 = 0) FROM foo", - "sqlite": "SELECT SUM(CASE WHEN col % 2 = 0 THEN 1 ELSE 0 END) FROM foo", + "sqlite": "SELECT SUM(IIF(col % 2 = 0, 1, 0)) FROM foo", "tsql": "SELECT COUNT_IF(col % 2 = 0) FROM foo", }, ) @@ -2116,7 +2116,7 @@ SELECT "": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", "databricks": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", "presto": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", - "sqlite": "SELECT SUM(CASE WHEN col % 2 = 0 THEN 1 ELSE 0 END) FILTER(WHERE col < 1000) FROM foo", + "sqlite": "SELECT SUM(IIF(col % 2 = 0, 1, 0)) FILTER(WHERE col < 1000) FROM foo", "tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", }, ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index d1b7589..ea28f29 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -152,7 +152,7 @@ class TestHive(Validator): "duckdb": "CREATE TABLE x (w TEXT)", # Partition columns should exist in table "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])", "hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", - "spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + "spark": "CREATE TABLE x (w STRING, y INT, z INT) PARTITIONED BY (y, z)", }, ) self.validate_all( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 61421e5..e77fa8a 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -8,6 +8,8 @@ class TestPostgres(Validator): dialect = "postgres" def test_postgres(self): + self.validate_identity("|/ x", "SQRT(x)") + self.validate_identity("||/ x", "CBRT(x)") expr = parse_one( "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", read="postgres" ) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 36006d2..d3d1a76 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -424,7 +424,7 @@ class TestPresto(Validator): "duckdb": "CREATE TABLE x (w TEXT, y INT, z INT)", "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY=ARRAY['y', 'z'])", "hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", - "spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + "spark": "CREATE TABLE x (w STRING, y INT, z INT) PARTITIONED BY (y, z)", }, ) self.validate_all( diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index b6b6ccc..33cfa0c 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -1,4 +1,4 @@ -from sqlglot import transpile +from sqlglot import exp, parse_one, transpile from tests.dialects.test_dialect import Validator @@ -381,8 +381,6 @@ class TestRedshift(Validator): "SELECT DATEADD(DAY, 1, DATE('2023-01-01'))", ) - self.validate_identity("SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l") - self.validate_identity( """SELECT c_name, @@ -408,8 +406,9 @@ ORDER BY union_query = f"SELECT * FROM ({' UNION ALL '.join('SELECT ' + v for v in values)})" self.assertEqual(transpile(values_query, write="redshift")[0], union_query) - self.validate_identity( - "SELECT * FROM (VALUES (1), (2))", + values_sql = transpile("SELECT * FROM (VALUES (1), (2))", write="redshift", pretty=True)[0] + self.assertEqual( + values_sql, """SELECT * FROM ( @@ -419,69 +418,51 @@ FROM ( SELECT 2 )""", - pretty=True, ) + self.validate_identity("INSERT INTO t (a) VALUES (1), (2), (3)") + self.validate_identity("INSERT INTO t (a, b) VALUES (1, 2), (3, 4)") + self.validate_all( - "SELECT * FROM (VALUES (1, 2)) AS t", - write={ - "redshift": "SELECT * FROM (SELECT 1, 2) AS t", - "mysql": "SELECT * FROM (SELECT 1, 2) AS t", - "presto": "SELECT * FROM (VALUES (1, 2)) AS t", - }, - ) - self.validate_all( - "SELECT * FROM (VALUES (1)) AS t1(id) CROSS JOIN (VALUES (1)) AS t2(id)", - write={ - "redshift": "SELECT * FROM (SELECT 1 AS id) AS t1 CROSS JOIN (SELECT 1 AS id) AS t2", + "SELECT * FROM (SELECT 1, 2) AS t", + read={ + "": "SELECT * FROM (VALUES (1, 2)) AS t", }, - ) - self.validate_all( - "SELECT a, b FROM (VALUES (1, 2)) AS t (a, b)", write={ - "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b) AS t", + "mysql": "SELECT * FROM (SELECT 1, 2) AS t", + "presto": "SELECT * FROM (SELECT 1, 2) AS t", }, ) self.validate_all( - 'SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS "t" (a, b)', - write={ - "redshift": 'SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS "t"', + "SELECT * FROM (SELECT 1 AS id) AS t1 CROSS JOIN (SELECT 1 AS id) AS t2", + read={ + "": "SELECT * FROM (VALUES (1)) AS t1(id) CROSS JOIN (VALUES (1)) AS t2(id)", }, ) self.validate_all( - "SELECT a, b FROM (VALUES (1, 2), (3, 4), (5, 6), (7, 8)) AS t (a, b)", - write={ - "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4 UNION ALL SELECT 5, 6 UNION ALL SELECT 7, 8) AS t", + "SELECT a, b FROM (SELECT 1 AS a, 2 AS b) AS t", + read={ + "": "SELECT a, b FROM (VALUES (1, 2)) AS t (a, b)", }, ) self.validate_all( - "INSERT INTO t(a) VALUES (1), (2), (3)", - write={ - "redshift": "INSERT INTO t (a) VALUES (1), (2), (3)", + 'SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS "t"', + read={ + "": 'SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS "t" (a, b)', }, ) self.validate_all( - "INSERT INTO t(a, b) SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)", - write={ - "redshift": "INSERT INTO t (a, b) SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t", + "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4 UNION ALL SELECT 5, 6 UNION ALL SELECT 7, 8) AS t", + read={ + "": "SELECT a, b FROM (VALUES (1, 2), (3, 4), (5, 6), (7, 8)) AS t (a, b)", }, ) self.validate_all( - "INSERT INTO t(a, b) VALUES (1, 2), (3, 4)", - write={ - "redshift": "INSERT INTO t (a, b) VALUES (1, 2), (3, 4)", + "INSERT INTO t (a, b) SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t", + read={ + "": "INSERT INTO t(a, b) SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)", }, ) - self.validate_identity( - 'SELECT * FROM (VALUES (1)) AS "t"(a)', - '''SELECT - * -FROM ( - SELECT - 1 AS a -) AS "t"''', - pretty=True, - ) def test_create_table_like(self): self.validate_identity( @@ -532,3 +513,26 @@ FROM ( "redshift": "CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING", }, ) + + def test_column_unnesting(self): + ast = parse_one("SELECT * FROM t.t JOIN t.c1 ON c1.c2 = t.c3", read="redshift") + ast.args["from"].this.assert_is(exp.Table) + ast.args["joins"][0].this.assert_is(exp.Table) + self.assertEqual(ast.sql("redshift"), "SELECT * FROM t.t JOIN t.c1 ON c1.c2 = t.c3") + + ast = parse_one("SELECT * FROM t AS t CROSS JOIN t.c1", read="redshift") + ast.args["from"].this.assert_is(exp.Table) + ast.args["joins"][0].this.assert_is(exp.Column) + self.assertEqual(ast.sql("redshift"), "SELECT * FROM t AS t CROSS JOIN t.c1") + + ast = parse_one( + "SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l", read="redshift" + ) + joins = ast.args["joins"] + ast.args["from"].this.assert_is(exp.Table) + joins[0].this.this.assert_is(exp.Column) + joins[1].this.this.assert_is(exp.Column) + joins[2].this.this.assert_is(exp.Dot) + self.assertEqual( + ast.sql("redshift"), "SELECT * FROM x AS a, a.b AS c, c.d.e AS f, f.g.h.i.j.k AS l" + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 7a821f6..321dd73 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -84,6 +84,10 @@ WHERE self.validate_identity( "SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE (0.1)" ) + self.validate_identity( + "value:values::string", + "CAST(GET_PATH(value, 'values') AS TEXT)", + ) self.validate_identity( """SELECT GET_PATH(PARSE_JSON('{"y": [{"z": 1}]}'), 'y[0]:z')""", """SELECT GET_PATH(PARSE_JSON('{"y": [{"z": 1}]}'), 'y[0].z')""", @@ -462,7 +466,7 @@ WHERE "DIV0(foo, bar)", write={ "snowflake": "IFF(bar = 0, 0, foo / bar)", - "sqlite": "CASE WHEN bar = 0 THEN 0 ELSE CAST(foo AS REAL) / bar END", + "sqlite": "IIF(bar = 0, 0, CAST(foo AS REAL) / bar)", "presto": "IF(bar = 0, 0, CAST(foo AS DOUBLE) / bar)", "spark": "IF(bar = 0, 0, foo / bar)", "hive": "IF(bar = 0, 0, foo / bar)", @@ -473,7 +477,7 @@ WHERE "ZEROIFNULL(foo)", write={ "snowflake": "IFF(foo IS NULL, 0, foo)", - "sqlite": "CASE WHEN foo IS NULL THEN 0 ELSE foo END", + "sqlite": "IIF(foo IS NULL, 0, foo)", "presto": "IF(foo IS NULL, 0, foo)", "spark": "IF(foo IS NULL, 0, foo)", "hive": "IF(foo IS NULL, 0, foo)", @@ -484,7 +488,7 @@ WHERE "NULLIFZERO(foo)", write={ "snowflake": "IFF(foo = 0, NULL, foo)", - "sqlite": "CASE WHEN foo = 0 THEN NULL ELSE foo END", + "sqlite": "IIF(foo = 0, NULL, foo)", "presto": "IF(foo = 0, NULL, foo)", "spark": "IF(foo = 0, NULL, foo)", "hive": "IF(foo = 0, NULL, foo)", @@ -1513,6 +1517,10 @@ MATCH_RECOGNIZE ( self.validate_identity("SHOW COLUMNS IN VIEW") self.validate_identity("SHOW COLUMNS LIKE '_foo%' IN VIEW dt_test") + self.validate_identity("SHOW USERS") + self.validate_identity("SHOW TERSE USERS") + self.validate_identity("SHOW USERS LIKE '_foo%' STARTS WITH 'bar' LIMIT 5 FROM 'baz'") + ast = parse_one("SHOW COLUMNS LIKE '_testing%' IN dt_test", read="snowflake") table = ast.find(exp.Table) literal = ast.find(exp.Literal) @@ -1536,6 +1544,10 @@ MATCH_RECOGNIZE ( table = ast.find(exp.Table) self.assertEqual(table.sql(dialect="snowflake"), "db1.schema1") + users_exp = self.validate_identity("SHOW USERS") + self.assertTrue(isinstance(users_exp, exp.Show)) + self.assertEqual(users_exp.this, "USERS") + def test_swap(self): ast = parse_one("ALTER TABLE a SWAP WITH b", read="snowflake") assert isinstance(ast, exp.AlterTable) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 75bb91a..196735b 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -93,11 +93,12 @@ TBLPROPERTIES ( 'x'='1' )""", "spark": """CREATE TABLE blah ( - col_a INT + col_a INT, + date STRING ) COMMENT 'Test comment: blah' PARTITIONED BY ( - date STRING + date ) USING ICEBERG TBLPROPERTIES ( @@ -125,13 +126,6 @@ TBLPROPERTIES ( "spark": "ALTER TABLE StudentInfo DROP COLUMNS (LastName, DOB)", }, ) - self.validate_all( - "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", - identify=True, - write={ - "spark": "CREATE TABLE `x` USING ICEBERG PARTITIONED BY (MONTHS(`y`)) LOCATION 's3://z'", - }, - ) def test_to_date(self): self.validate_all( @@ -256,6 +250,14 @@ TBLPROPERTIES ( self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')") self.validate_identity("SPLIT(str, pattern, lim)") + self.validate_identity( + "SELECT CAST('2023-01-01' AS TIMESTAMP) + INTERVAL 23 HOUR + 59 MINUTE + 59 SECONDS", + "SELECT CAST('2023-01-01' AS TIMESTAMP) + INTERVAL '23' HOUR + INTERVAL '59' MINUTE + INTERVAL '59' SECONDS", + ) + self.validate_identity( + "SELECT CAST('2023-01-01' AS TIMESTAMP) + INTERVAL '23' HOUR + '59' MINUTE + '59' SECONDS", + "SELECT CAST('2023-01-01' AS TIMESTAMP) + INTERVAL '23' HOUR + INTERVAL '59' MINUTE + INTERVAL '59' SECONDS", + ) self.validate_identity( "SELECT INTERVAL '5' HOURS '30' MINUTES '5' SECONDS '6' MILLISECONDS '7' MICROSECONDS", "SELECT INTERVAL '5' HOURS + INTERVAL '30' MINUTES + INTERVAL '5' SECONDS + INTERVAL '6' MILLISECONDS + INTERVAL '7' MICROSECONDS", @@ -616,12 +618,6 @@ TBLPROPERTIES ( }, ) - def test_iif(self): - self.validate_all( - "SELECT IIF(cond, 'True', 'False')", - write={"spark": "SELECT IF(cond, 'True', 'False')"}, - ) - def test_bool_or(self): self.validate_all( "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a", diff --git a/tests/dialects/test_tableau.py b/tests/dialects/test_tableau.py index 0f612dd..fe605b1 100644 --- a/tests/dialects/test_tableau.py +++ b/tests/dialects/test_tableau.py @@ -5,6 +5,21 @@ class TestTableau(Validator): dialect = "tableau" def test_tableau(self): + self.validate_all( + "[x]", + write={ + "hive": "`x`", + "tableau": "[x]", + }, + ) + self.validate_all( + '"x"', + write={ + "hive": "'x'", + "tableau": "'x'", + }, + ) + self.validate_all( "IF x = 'a' THEN y ELSE NULL END", read={ diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index c8c0d82..e2ec15b 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -28,6 +28,14 @@ class TestTSQL(Validator): self.validate_identity("1 AND true", "1 <> 0 AND (1 = 1)") self.validate_identity("CAST(x AS int) OR y", "CAST(x AS INTEGER) <> 0 OR y <> 0") + self.validate_all( + "SELECT IIF(cond <> 0, 'True', 'False')", + read={ + "spark": "SELECT IF(cond, 'True', 'False')", + "sqlite": "SELECT IIF(cond, 'True', 'False')", + "tsql": "SELECT IIF(cond <> 0, 'True', 'False')", + }, + ) self.validate_all( "SELECT TRIM(BOTH 'a' FROM a)", read={ @@ -1302,20 +1310,6 @@ WHERE }, ) - def test_iif(self): - self.validate_identity( - "SELECT IF(cond, 'True', 'False')", "SELECT IIF(cond <> 0, 'True', 'False')" - ) - self.validate_identity( - "SELECT IIF(cond, 'True', 'False')", "SELECT IIF(cond <> 0, 'True', 'False')" - ) - self.validate_all( - "SELECT IIF(cond, 'True', 'False');", - write={ - "spark": "SELECT IF(cond, 'True', 'False')", - }, - ) - def test_lateral_subquery(self): self.validate_all( "SELECT x.a, x.b, t.v, t.y FROM x CROSS APPLY (SELECT v, y FROM t) t(v, y)", diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 366b79e..d9efc57 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -847,3 +847,6 @@ WITH use(use) AS (SELECT 1) SELECT use FROM use SELECT recursive FROM t SELECT (ROW_NUMBER() OVER (PARTITION BY user ORDER BY date ASC) - ROW_NUMBER() OVER (PARTITION BY user, segment ORDER BY date ASC)) AS group_id FROM example_table CAST(foo AS BPCHAR) +values +SELECT values +SELECT values AS values FROM t WHERE values + 1 > 3 diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql index 4db3764..98b2f07 100644 --- a/tests/fixtures/optimizer/canonicalize.sql +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -96,3 +96,15 @@ DATE_TRUNC('DAY', CAST('2023-01-01' AS DATE)); DATEDIFF('2023-01-01', '2023-01-02', DAY); DATEDIFF(CAST('2023-01-01' AS DATETIME), CAST('2023-01-02' AS DATETIME), DAY); + +-------------------------------------- +-- Remove redundant casts +-------------------------------------- +CAST(CAST('2023-01-01' AS DATE) AS DATE); +CAST('2023-01-01' AS DATE); + +CAST(DATE_TRUNC('YEAR', CAST('2023-01-01' AS DATE)) AS DATE); +DATE_TRUNC('YEAR', CAST('2023-01-01' AS DATE)); + +DATE(DATE_TRUNC('YEAR', CAST("x" AS DATE))); +DATE_TRUNC('YEAR', CAST("x" AS DATE)); diff --git a/tests/fixtures/optimizer/merge_subqueries.sql b/tests/fixtures/optimizer/merge_subqueries.sql index 7bc45a7..0f22925 100644 --- a/tests/fixtures/optimizer/merge_subqueries.sql +++ b/tests/fixtures/optimizer/merge_subqueries.sql @@ -218,6 +218,7 @@ with t1 as ( ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) as row_num FROM x + ORDER BY x.a, x.b, row_num ) SELECT t1.a, @@ -226,7 +227,7 @@ FROM t1 WHERE row_num = 1; -WITH t1 AS (SELECT x.a AS a, x.b AS b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x AS x) SELECT t1.a AS a, t1.b AS b FROM t1 AS t1 WHERE t1.row_num = 1; +WITH t1 AS (SELECT x.a AS a, x.b AS b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x AS x ORDER BY x.a, x.b, row_num) SELECT t1.a AS a, t1.b AS b FROM t1 AS t1 WHERE t1.row_num = 1; # title: Test preventing merge of window expressions join clause with t1 as ( @@ -301,6 +302,7 @@ with t1 as ( ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) as row_num FROM x + ORDER BY x.a, x.b, row_num ) SELECT t1.a, @@ -308,7 +310,7 @@ SELECT t1.row_num FROM t1; -SELECT x.a AS a, x.b AS b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x AS x; +SELECT x.a AS a, x.b AS b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) AS row_num FROM x AS x ORDER BY x.a, x.b, row_num; # title: Don't merge window functions, inner table is aliased in outer query with t1 as ( diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index ad197db..4fdf33b 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -208,14 +208,14 @@ SELECT x.a AS a, y.c AS c FROM x AS x, y AS y; -------------------------------------- -- Unions -------------------------------------- -SELECT a FROM x UNION SELECT a FROM x; -SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x; +SELECT a FROM x UNION SELECT a FROM x ORDER BY a; +SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x ORDER BY a; -SELECT a FROM x UNION SELECT a FROM x UNION SELECT a FROM x; -SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x; +SELECT a FROM x UNION SELECT a FROM x UNION SELECT a FROM x ORDER BY a; +SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x ORDER BY a; -SELECT a FROM (SELECT a FROM x UNION SELECT a FROM x); -SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x) AS _q_0; +SELECT a FROM (SELECT a FROM x UNION SELECT a FROM x) ORDER BY a; +SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x) AS _q_0 ORDER BY a; -------------------------------------- -- Subqueries @@ -318,8 +318,8 @@ WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z AS z; WITH z AS (SELECT a FROM x), q AS (SELECT * FROM z) SELECT * FROM q; WITH z AS (SELECT x.a AS a FROM x AS x), q AS (SELECT z.a AS a FROM z AS z) SELECT q.a AS a FROM q AS q; -WITH z AS (SELECT * FROM x) SELECT * FROM z UNION SELECT * FROM z; -WITH z AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT z.a AS a, z.b AS b FROM z AS z UNION SELECT z.a AS a, z.b AS b FROM z AS z; +WITH z AS (SELECT * FROM x) SELECT * FROM z UNION SELECT * FROM z ORDER BY a, b; +WITH z AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT z.a AS a, z.b AS b FROM z AS z UNION SELECT z.a AS a, z.b AS b FROM z AS z ORDER BY a, b; WITH z AS (SELECT * FROM x), q AS (SELECT b FROM z) SELECT b FROM q; WITH z AS (SELECT x.a AS a, x.b AS b FROM x AS x), q AS (SELECT z.b AS b FROM z AS z) SELECT q.b AS b FROM q AS q; @@ -359,8 +359,8 @@ SELECT x.b AS b FROM x AS x; SELECT * EXCEPT (a, b) FROM x; SELECT * EXCEPT (x.a, x.b) FROM x AS x; -SELECT COALESCE(t1.a, '') AS a, t2.* EXCEPT (a) FROM x AS t1, x AS t2; -SELECT COALESCE(t1.a, '') AS a, t2.b AS b FROM x AS t1, x AS t2; +SELECT COALESCE(CAST(t1.a AS VARCHAR), '') AS a, t2.* EXCEPT (a) FROM x AS t1, x AS t2; +SELECT COALESCE(CAST(t1.a AS VARCHAR), '') AS a, t2.b AS b FROM x AS t1, x AS t2; -------------------------------------- -- Using @@ -468,8 +468,8 @@ select * from unnest ([1, 2]) as x with offset as y; SELECT x AS x, y AS y FROM UNNEST([1, 2]) AS x WITH OFFSET AS y; # dialect: presto -SELECT x.a, i.b FROM x CROSS JOIN UNNEST(SPLIT(b, ',')) AS i(b); -SELECT x.a AS a, i.b AS b FROM x AS x CROSS JOIN UNNEST(SPLIT(x.b, ',')) AS i(b); +SELECT x.a, i.b FROM x CROSS JOIN UNNEST(SPLIT(CAST(b AS VARCHAR), ',')) AS i(b); +SELECT x.a AS a, i.b AS b FROM x AS x CROSS JOIN UNNEST(SPLIT(CAST(x.b AS VARCHAR), ',')) AS i(b); # execute: false SELECT c FROM (SELECT 1 a) AS x LATERAL VIEW EXPLODE(a) AS c; @@ -487,16 +487,16 @@ SELECT t.c1 AS c1, t.c3 AS c3 FROM FOO(bar) AS t(c1, c2, c3); -- Window functions -------------------------------------- # title: ORDER BY in window function -SELECT a + 1 AS a, ROW_NUMBER() OVER (PARTITION BY b ORDER BY a) AS row_num FROM x; -SELECT x.a + 1 AS a, ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a) AS row_num FROM x AS x; +SELECT a + 1 AS a, ROW_NUMBER() OVER (PARTITION BY b ORDER BY a) AS row_num FROM x ORDER BY a, row_num; +SELECT x.a + 1 AS a, ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a) AS row_num FROM x AS x ORDER BY a, row_num; # dialect: bigquery SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) AS row_num FROM x QUALIFY row_num = 1; SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x QUALIFY ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) = 1; # dialect: bigquery -SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1; -SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1; +SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1 ORDER BY x.b, x.a; +SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1 ORDER BY x.b, x.a; SELECT * FROM x QUALIFY COUNT(a) OVER (PARTITION BY b) > 1; SELECT x.a AS a, x.b AS b FROM x AS x QUALIFY COUNT(x.a) OVER (PARTITION BY x.b) > 1; diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index bf624da..a99abcd 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -15,7 +15,7 @@ select from lineitem where - l_shipdate <= date '1998-12-01' - interval '90' day + CAST(l_shipdate AS DATE) <= date '1998-12-01' - interval '90' day group by l_returnflag, l_linestatus @@ -218,8 +218,8 @@ select from orders where - o_orderdate >= date '1993-07-01' - and o_orderdate < date '1993-07-01' + interval '3' month + CAST(o_orderdate AS DATE) >= date '1993-07-01' + and CAST(o_orderdate AS DATE) < date '1993-07-01' + interval '3' month and exists ( select * @@ -278,8 +278,8 @@ where and s_nationkey = n_nationkey and n_regionkey = r_regionkey and r_name = 'ASIA' - and o_orderdate >= date '1994-01-01' - and o_orderdate < date '1994-01-01' + interval '1' year + and CAST(o_orderdate AS DATE) >= date '1994-01-01' + and CAST(o_orderdate AS DATE) < date '1994-01-01' + interval '1' year group by n_name order by @@ -316,8 +316,8 @@ select from lineitem where - l_shipdate >= date '1994-01-01' - and l_shipdate < date '1994-01-01' + interval '1' year + CAST(l_shipdate AS DATE) >= date '1994-01-01' + and CAST(l_shipdate AS DATE) < date '1994-01-01' + interval '1' year and l_discount between 0.06 - 0.01 and 0.06 + 0.01 and l_quantity < 24; SELECT @@ -362,7 +362,7 @@ from (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') ) - and l_shipdate between date '1995-01-01' and date '1996-12-31' + and CAST(l_shipdate AS DATE) between date '1995-01-01' and date '1996-12-31' ) as shipping group by supp_nation, @@ -446,7 +446,7 @@ from and n1.n_regionkey = r_regionkey and r_name = 'AMERICA' and s_nationkey = n2.n_nationkey - and o_orderdate between date '1995-01-01' and date '1996-12-31' + and CAST(o_orderdate AS DATE) between date '1995-01-01' and date '1996-12-31' and p_type = 'ECONOMY ANODIZED STEEL' ) as all_nations group by @@ -574,8 +574,8 @@ from where c_custkey = o_custkey and l_orderkey = o_orderkey - and o_orderdate >= date '1993-10-01' - and o_orderdate < date '1993-10-01' + interval '3' month + and CAST(o_orderdate AS DATE) >= date '1993-10-01' + and CAST(o_orderdate AS DATE) < date '1993-10-01' + interval '3' month and l_returnflag = 'R' and c_nationkey = n_nationkey group by @@ -714,8 +714,8 @@ where and l_shipmode in ('MAIL', 'SHIP') and l_commitdate < l_receiptdate and l_shipdate < l_commitdate - and l_receiptdate >= date '1994-01-01' - and l_receiptdate < date '1994-01-01' + interval '1' year + and CAST(l_receiptdate AS DATE) >= date '1994-01-01' + and CAST(l_receiptdate AS DATE) < date '1994-01-01' + interval '1' year group by l_shipmode order by @@ -813,8 +813,8 @@ from part where l_partkey = p_partkey - and l_shipdate >= date '1995-09-01' - and l_shipdate < date '1995-09-01' + interval '1' month; + and CAST(l_shipdate AS DATE) >= date '1995-09-01' + and CAST(l_shipdate AS DATE) < date '1995-09-01' + interval '1' month; SELECT 100.00 * SUM( CASE @@ -844,8 +844,8 @@ with revenue (supplier_no, total_revenue) as ( from lineitem where - l_shipdate >= date '1996-01-01' - and l_shipdate < date '1996-01-01' + interval '3' month + CAST(l_shipdate AS DATE) >= date '1996-01-01' + and CAST(l_shipdate AS DATE) < date '1996-01-01' + interval '3' month group by l_suppkey) select @@ -1223,8 +1223,8 @@ where where l_partkey = ps_partkey and l_suppkey = ps_suppkey - and l_shipdate >= date '1994-01-01' - and l_shipdate < date '1994-01-01' + interval '1' year + and CAST(l_shipdate AS DATE) >= date '1994-01-01' + and CAST(l_shipdate AS DATE) < date '1994-01-01' + interval '1' year ) ) and s_nationkey = n_nationkey diff --git a/tests/test_expressions.py b/tests/test_expressions.py index f415ff6..d42eeca 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -243,6 +243,15 @@ class TestExpressions(unittest.TestCase): 'SELECT * FROM a1 AS a /* a */, b.a /* b */, c.a2 /* c.a */, d2 /* d.a */ CROSS JOIN e.a CROSS JOIN "F" /* f-F.A */ CROSS JOIN g1.a /* g */', ) + self.assertEqual( + exp.replace_tables( + parse_one("select * from example.table", dialect="bigquery"), + {"example.table": "`my-project.example.table`"}, + dialect="bigquery", + ).sql(), + 'SELECT * FROM "my-project".example.table /* example.table */', + ) + def test_expand(self): self.assertEqual( exp.expand( diff --git a/tests/test_lineage.py b/tests/test_lineage.py index 2f3456d..922edcb 100644 --- a/tests/test_lineage.py +++ b/tests/test_lineage.py @@ -95,7 +95,7 @@ class TestLineage(unittest.TestCase): downstream.source.sql(), "SELECT x.a AS a FROM x AS x", ) - self.assertEqual(downstream.alias, "") + self.assertEqual(downstream.alias, "z") def test_lineage_source_with_star(self) -> None: node = lineage( @@ -153,7 +153,7 @@ class TestLineage(unittest.TestCase): downstream = downstream.downstream[0] self.assertEqual(downstream.source.sql(), "(VALUES (1), (2)) AS t(a)") self.assertEqual(downstream.expression.sql(), "a") - self.assertEqual(downstream.alias, "") + self.assertEqual(downstream.alias, "y") def test_lineage_cte_name_appears_in_schema(self) -> None: schema = {"a": {"b": {"t1": {"c1": "int"}, "t2": {"c2": "int"}}}} @@ -284,6 +284,35 @@ class TestLineage(unittest.TestCase): self.assertEqual(downstream_b.name, "0") self.assertEqual(downstream_b.source.sql(), "SELECT * FROM catalog.db.table_b AS table_b") + def test_lineage_source_union(self) -> None: + query = "SELECT x, created_at FROM dataset;" + node = lineage( + "x", + query, + sources={ + "dataset": """ + SELECT * + FROM catalog.db.table_a + + UNION + + SELECT * + FROM catalog.db.table_b + """ + }, + ) + + self.assertEqual(node.name, "x") + + downstream_a = node.downstream[0] + self.assertEqual(downstream_a.name, "0") + self.assertEqual(downstream_a.alias, "dataset") + self.assertEqual(downstream_a.source.sql(), "SELECT * FROM catalog.db.table_a AS table_a") + downstream_b = node.downstream[1] + self.assertEqual(downstream_b.name, "0") + self.assertEqual(downstream_b.alias, "dataset") + self.assertEqual(downstream_b.source.sql(), "SELECT * FROM catalog.db.table_b AS table_b") + def test_select_star(self) -> None: node = lineage("x", "SELECT x from (SELECT * from table_a)") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 0e8a803..d4f2edb 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -874,6 +874,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(case_ifs_expr.type.this, exp.DataType.Type.VARCHAR) self.assertEqual(case_ifs_expr.args["true"].type.this, exp.DataType.Type.VARCHAR) + timestamp = annotate_types(parse_one("TIMESTAMP(x)")) + self.assertEqual(timestamp.type.this, exp.DataType.Type.TIMESTAMP) + + timestamptz = annotate_types(parse_one("TIMESTAMP(x)", read="bigquery")) + self.assertEqual(timestamptz.type.this, exp.DataType.Type.TIMESTAMPTZ) + def test_unknown_annotation(self): schema = {"x": {"cola": "VARCHAR"}} sql = "SELECT x.cola + SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x" diff --git a/tests/test_parser.py b/tests/test_parser.py index c7e1dbe..035b5de 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -839,3 +839,16 @@ class TestParser(unittest.TestCase): ), ) self.assertEqual(ast.sql(dialect=dialect), "CREATE SCHEMA catalog.schema") + + def test_values_as_identifier(self): + sql = "SELECT values FROM t WHERE values + 1 > x" + for dialect in ( + "bigquery", + "clickhouse", + "duckdb", + "postgres", + "redshift", + "snowflake", + ): + with self.subTest(dialect): + self.assertEqual(parse_one(sql, dialect=dialect).sql(dialect=dialect), sql) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index fdbf2e0..99b3fac 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -554,11 +554,12 @@ FROM base""", self.validate( "WITH A(filter) AS (VALUES 1, 2, 3) SELECT * FROM A WHERE filter >= 2", "WITH A(filter) AS (VALUES (1), (2), (3)) SELECT * FROM A WHERE filter >= 2", + read="presto", ) self.validate( "SELECT BOOL_OR(a > 10) FROM (VALUES 1, 2, 15) AS T(a)", "SELECT BOOL_OR(a > 10) FROM (VALUES (1), (2), (15)) AS T(a)", - write="presto", + read="presto", ) def test_alter(self): -- cgit v1.2.3