From c03ba18c491e52cc85d8aae1825dd9e0b4f75e32 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 26 Oct 2023 19:21:54 +0200 Subject: Merging upstream version 18.17.0. Signed-off-by: Daniel Baumann --- tests/dialects/test_bigquery.py | 23 +++++++-- tests/dialects/test_clickhouse.py | 49 +++++++++++++++++++ tests/dialects/test_databricks.py | 4 +- tests/dialects/test_dialect.py | 4 ++ tests/dialects/test_duckdb.py | 1 + tests/dialects/test_mysql.py | 2 + tests/dialects/test_postgres.py | 5 ++ tests/dialects/test_presto.py | 62 ++++++++++++++++++++++++ tests/dialects/test_redshift.py | 26 ++++++---- tests/dialects/test_snowflake.py | 42 +++++++++++++++- tests/dialects/test_spark.py | 7 +-- tests/dialects/test_teradata.py | 1 + tests/dialects/test_tsql.py | 14 +++--- tests/fixtures/identity.sql | 1 + tests/fixtures/optimizer/canonicalize.sql | 1 - tests/fixtures/optimizer/qualify_tables.sql | 8 ++++ tests/fixtures/optimizer/simplify.sql | 67 +++++++++++++++++++++++++- tests/fixtures/optimizer/tpc-ds/tpc-ds.sql | 4 +- tests/test_expressions.py | 9 ++++ tests/test_lineage.py | 74 +++++++++++++++++++++++++++++ tests/test_optimizer.py | 41 ++++++++++++++++ tests/test_parser.py | 3 ++ tests/test_transpile.py | 57 +++++++++++++++++++++- 23 files changed, 475 insertions(+), 30 deletions(-) (limited to 'tests') diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 3cf95a7..3601e47 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -9,6 +9,10 @@ class TestBigQuery(Validator): maxDiff = None def test_bigquery(self): + self.validate_identity("CREATE SCHEMA x DEFAULT COLLATE 'en'") + self.validate_identity("CREATE TABLE x (y INT64) DEFAULT COLLATE 'en'") + self.validate_identity("PARSE_JSON('{}', wide_number_mode => 'exact')") + with self.assertRaises(TokenError): transpile("'\\'", read="bigquery") @@ -138,6 +142,20 @@ class TestBigQuery(Validator): self.validate_all('x <> """"""', write={"bigquery": "x <> ''"}) self.validate_all("x <> ''''''", write={"bigquery": "x <> ''"}) self.validate_all("CAST(x AS DATETIME)", read={"": "x::timestamp"}) + self.validate_all( + "SELECT * FROM t WHERE EXISTS(SELECT * FROM unnest(nums) AS x WHERE x > 1)", + write={ + "bigquery": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS x WHERE x > 1)", + "duckdb": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS _t(x) WHERE x > 1)", + }, + ) + self.validate_all( + "NULL", + read={ + "duckdb": "NULL = a", + "postgres": "a = NULL", + }, + ) self.validate_all( "SELECT '\\n'", read={ @@ -465,9 +483,8 @@ class TestBigQuery(Validator): }, write={ "bigquery": "SELECT * FROM UNNEST(['7', '14']) AS x", - "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS (x)", - "hive": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)", - "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)", + "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS _t(x)", + "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS _t(x)", }, ) self.validate_all( diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 948c00e..93d1ced 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -6,6 +6,22 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): + self.validate_identity("x <> y") + + self.validate_all( + "has([1], x)", + read={ + "postgres": "x = any(array[1])", + }, + ) + self.validate_all( + "NOT has([1], x)", + read={ + "postgres": "any(array[1]) <> x", + }, + ) + self.validate_identity("x = y") + string_types = [ "BLOB", "LONGBLOB", @@ -85,6 +101,39 @@ class TestClickhouse(Validator): "CREATE MATERIALIZED VIEW test_view (id UInt8) TO db.table1 AS SELECT * FROM test_data" ) + self.validate_all( + "SELECT CAST('2020-01-01' AS TIMESTAMP) + INTERVAL '500' microsecond", + read={ + "duckdb": "SELECT TIMESTAMP '2020-01-01' + INTERVAL '500 us'", + "postgres": "SELECT TIMESTAMP '2020-01-01' + INTERVAL '500 us'", + }, + ) + self.validate_all( + "SELECT CURRENT_DATE()", + read={ + "clickhouse": "SELECT CURRENT_DATE()", + "postgres": "SELECT CURRENT_DATE", + }, + ) + self.validate_all( + "SELECT CURRENT_TIMESTAMP()", + read={ + "clickhouse": "SELECT CURRENT_TIMESTAMP()", + "postgres": "SELECT CURRENT_TIMESTAMP", + }, + ) + self.validate_all( + "SELECT match('ThOmAs', CONCAT('(?i)', 'thomas'))", + read={ + "postgres": "SELECT 'ThOmAs' ~* 'thomas'", + }, + ) + self.validate_all( + "SELECT match('ThOmAs', CONCAT('(?i)', x)) FROM t", + read={ + "postgres": "SELECT 'ThOmAs' ~* x FROM t", + }, + ) self.validate_all( "SELECT '\\0'", read={ diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 7c03c83..8bb88b3 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -6,8 +6,8 @@ class TestDatabricks(Validator): def test_databricks(self): self.validate_identity("CREATE TABLE t (c STRUCT)") - self.validate_identity("CREATE TABLE my_table () TBLPROPERTIES (a.b=15)") - self.validate_identity("CREATE TABLE my_table () TBLPROPERTIES ('a.b'=15)") + self.validate_identity("CREATE TABLE my_table TBLPROPERTIES (a.b=15)") + self.validate_identity("CREATE TABLE my_table TBLPROPERTIES ('a.b'=15)") self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO HOUR)") self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO MINUTE)") self.validate_identity("SELECT CAST('11 23:4:0' AS INTERVAL DAY TO SECOND)") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 91eba17..0d43b2a 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -99,6 +99,7 @@ class TestDialect(Validator): "snowflake": "CAST(a AS TEXT)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)", + "tsql": "CAST(a AS VARCHAR(MAX))", "doris": "CAST(a AS STRING)", }, ) @@ -179,6 +180,7 @@ class TestDialect(Validator): "snowflake": "CAST(a AS TEXT)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)", + "tsql": "CAST(a AS VARCHAR(MAX))", "doris": "CAST(a AS STRING)", }, ) @@ -197,6 +199,7 @@ class TestDialect(Validator): "snowflake": "CAST(a AS VARCHAR)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS VARCHAR)", + "tsql": "CAST(a AS VARCHAR)", "doris": "CAST(a AS VARCHAR)", }, ) @@ -215,6 +218,7 @@ class TestDialect(Validator): "snowflake": "CAST(a AS VARCHAR(3))", "spark": "CAST(a AS VARCHAR(3))", "starrocks": "CAST(a AS VARCHAR(3))", + "tsql": "CAST(a AS VARCHAR(3))", "doris": "CAST(a AS VARCHAR(3))", }, ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 54553b3..f9de953 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -249,6 +249,7 @@ class TestDuckDB(Validator): "SELECT ARRAY_LENGTH([0], 1) AS x", write={"duckdb": "SELECT ARRAY_LENGTH([0], 1) AS x"}, ) + self.validate_identity("REGEXP_REPLACE(this, pattern, replacement, modifiers)") self.validate_all( "REGEXP_MATCHES(x, y)", write={ diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index b9d1d26..dce2b9d 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -586,6 +586,8 @@ class TestMySQL(Validator): write={ "mysql": "SELECT * FROM test LIMIT 1 OFFSET 1", "postgres": "SELECT * FROM test LIMIT 0 + 1 OFFSET 0 + 1", + "presto": "SELECT * FROM test OFFSET 1 LIMIT 1", + "trino": "SELECT * FROM test OFFSET 1 LIMIT 1", }, ) self.validate_all( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 22bede4..3121cb0 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -732,3 +732,8 @@ class TestPostgres(Validator): self.validate_all( "VAR_POP(x)", read={"": "VARIANCE_POP(x)"}, write={"postgres": "VAR_POP(x)"} ) + + def test_regexp_binary(self): + """See https://github.com/tobymao/sqlglot/pull/2404 for details.""" + self.assertIsInstance(parse_one("'thomas' ~ '.*thomas.*'", read="postgres"), exp.Binary) + self.assertIsInstance(parse_one("'thomas' ~* '.*thomas.*'", read="postgres"), exp.Binary) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index fd297d7..ed734b6 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -367,6 +367,21 @@ class TestPresto(Validator): "CAST(x AS TIMESTAMP)", read={"mysql": "TIMESTAMP(x)"}, ) + self.validate_all( + "TIMESTAMP(x, 'America/Los_Angeles')", + write={ + "duckdb": "CAST(x AS TIMESTAMP) AT TIME ZONE 'America/Los_Angeles'", + "presto": "CAST(x AS TIMESTAMP) AT TIME ZONE 'America/Los_Angeles'", + }, + ) + # this case isn't really correct, but it's a fall back for mysql's version + self.validate_all( + "TIMESTAMP(x, '12:00:00')", + write={ + "duckdb": "TIMESTAMP(x, '12:00:00')", + "presto": "TIMESTAMP(x, '12:00:00')", + }, + ) def test_ddl(self): self.validate_all( @@ -441,6 +456,22 @@ class TestPresto(Validator): }, ) + self.validate_all( + "CREATE OR REPLACE VIEW x (cola) SELECT 1 as cola", + write={ + "spark": "CREATE OR REPLACE VIEW x (cola) AS SELECT 1 AS cola", + "presto": "CREATE OR REPLACE VIEW x AS SELECT 1 AS cola", + }, + ) + + self.validate_all( + 'CREATE TABLE IF NOT EXISTS x ("cola" INTEGER, "ds" TEXT) WITH (PARTITIONED BY=("ds"))', + write={ + "spark": "CREATE TABLE IF NOT EXISTS x (`cola` INT, `ds` STRING) PARTITIONED BY (`ds`)", + "presto": """CREATE TABLE IF NOT EXISTS x ("cola" INTEGER, "ds" VARCHAR) WITH (PARTITIONED_BY=ARRAY['ds'])""", + }, + ) + def test_quotes(self): self.validate_all( "''''", @@ -527,6 +558,37 @@ class TestPresto(Validator): "SELECT SPLIT_TO_MAP('a:1;b:2;a:3', ';', ':', (k, v1, v2) -> CONCAT(v1, v2))" ) + self.validate_all( + "SELECT MAX_BY(a.id, a.timestamp) FROM a", + read={ + "bigquery": "SELECT MAX_BY(a.id, a.timestamp) FROM a", + "clickhouse": "SELECT argMax(a.id, a.timestamp) FROM a", + "duckdb": "SELECT MAX_BY(a.id, a.timestamp) FROM a", + "snowflake": "SELECT MAX_BY(a.id, a.timestamp) FROM a", + "spark": "SELECT MAX_BY(a.id, a.timestamp) FROM a", + "teradata": "SELECT MAX_BY(a.id, a.timestamp) FROM a", + }, + write={ + "bigquery": "SELECT MAX_BY(a.id, a.timestamp) FROM a", + "clickhouse": "SELECT argMax(a.id, a.timestamp) FROM a", + "duckdb": "SELECT ARG_MAX(a.id, a.timestamp) FROM a", + "presto": "SELECT MAX_BY(a.id, a.timestamp) FROM a", + "snowflake": "SELECT MAX_BY(a.id, a.timestamp) FROM a", + "spark": "SELECT MAX_BY(a.id, a.timestamp) FROM a", + "teradata": "SELECT MAX_BY(a.id, a.timestamp) FROM a", + }, + ) + self.validate_all( + "SELECT MIN_BY(a.id, a.timestamp, 3) FROM a", + write={ + "clickhouse": "SELECT argMin(a.id, a.timestamp) FROM a", + "duckdb": "SELECT ARG_MIN(a.id, a.timestamp) FROM a", + "presto": "SELECT MIN_BY(a.id, a.timestamp, 3) FROM a", + "snowflake": "SELECT MIN_BY(a.id, a.timestamp, 3) FROM a", + "spark": "SELECT MIN_BY(a.id, a.timestamp) FROM a", + "teradata": "SELECT MIN_BY(a.id, a.timestamp, 3) FROM a", + }, + ) self.validate_all( """JSON '"foo"'""", write={ diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index f182feb..c848010 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -6,6 +6,10 @@ class TestRedshift(Validator): dialect = "redshift" def test_redshift(self): + self.validate_identity( + "SELECT * FROM x WHERE y = DATEADD('month', -1, DATE_TRUNC('month', (SELECT y FROM #temp_table)))", + "SELECT * FROM x WHERE y = DATEADD(month, -1, CAST(DATE_TRUNC('month', (SELECT y FROM #temp_table)) AS DATE))", + ) self.validate_all( "SELECT APPROXIMATE COUNT(DISTINCT y)", read={ @@ -16,13 +20,6 @@ class TestRedshift(Validator): "spark": "SELECT APPROX_COUNT_DISTINCT(y)", }, ) - self.validate_identity("SELECT APPROXIMATE AS y") - - self.validate_identity( - "SELECT 'a''b'", - "SELECT 'a\\'b'", - ) - self.validate_all( "x ~* 'pat'", write={ @@ -30,7 +27,6 @@ class TestRedshift(Validator): "snowflake": "REGEXP_LIKE(x, 'pat', 'i')", }, ) - self.validate_all( "SELECT CAST('01:03:05.124' AS TIME(2) WITH TIME ZONE)", read={ @@ -248,6 +244,19 @@ class TestRedshift(Validator): self.validate_identity("CAST('foo' AS HLLSKETCH)") self.validate_identity("'abc' SIMILAR TO '(b|c)%'") self.validate_identity("CREATE TABLE datetable (start_date DATE, end_date DATE)") + self.validate_identity("SELECT APPROXIMATE AS y") + self.validate_identity("CREATE TABLE t (c BIGINT IDENTITY(0, 1))") + self.validate_identity( + "SELECT 'a''b'", + "SELECT 'a\\'b'", + ) + self.validate_identity( + "CREATE TABLE t (c BIGINT GENERATED BY DEFAULT AS IDENTITY (0, 1))", + "CREATE TABLE t (c BIGINT IDENTITY(0, 1))", + ) + self.validate_identity( + "CREATE OR REPLACE VIEW v1 AS SELECT id, AVG(average_metric1) AS m1, AVG(average_metric2) AS m2 FROM t GROUP BY id WITH NO SCHEMA BINDING" + ) self.validate_identity( "SELECT caldate + INTERVAL '1 second' AS dateplus FROM date WHERE caldate = '12-31-2008'" ) @@ -301,6 +310,7 @@ ORDER BY self.validate_identity( "SELECT attr AS attr, JSON_TYPEOF(val) AS value_type FROM customer_orders_lineitem AS c, UNPIVOT c.c_orders AS val AT attr WHERE c_custkey = 9451" ) + self.validate_identity("SELECT JSON_PARSE('[]')") def test_values(self): # Test crazy-sized VALUES clause to UNION ALL conversion to ensure we don't get RecursionError diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 7c36bea..65b77ea 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -9,6 +9,12 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + expr = parse_one("SELECT APPROX_TOP_K(C4, 3, 5) FROM t") + expr.selects[0].assert_is(exp.AggFunc) + self.assertEqual(expr.sql(dialect="snowflake"), "SELECT APPROX_TOP_K(C4, 3, 5) FROM t") + + self.validate_identity("SELECT DAYOFMONTH(CURRENT_TIMESTAMP())") + self.validate_identity("SELECT DAYOFYEAR(CURRENT_TIMESTAMP())") self.validate_identity("LISTAGG(data['some_field'], ',')") self.validate_identity("WEEKOFYEAR(tstamp)") self.validate_identity("SELECT SUM(amount) FROM mytable GROUP BY ALL") @@ -36,6 +42,7 @@ class TestSnowflake(Validator): self.validate_identity("COMMENT IF EXISTS ON TABLE foo IS 'bar'") self.validate_identity("SELECT CONVERT_TIMEZONE('UTC', 'America/Los_Angeles', col)") self.validate_identity("REGEXP_REPLACE('target', 'pattern', '\n')") + self.validate_identity("ALTER TABLE a SWAP WITH b") self.validate_identity( 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage' ) @@ -58,6 +65,18 @@ class TestSnowflake(Validator): "SELECT {'test': 'best'}::VARIANT", "SELECT CAST(OBJECT_CONSTRUCT('test', 'best') AS VARIANT)", ) + self.validate_identity( + "SELECT {fn DAYNAME('2022-5-13')}", + "SELECT DAYNAME('2022-5-13')", + ) + self.validate_identity( + "SELECT {fn LOG(5)}", + "SELECT LN(5)", + ) + self.validate_identity( + "SELECT {fn CEILING(5.3)}", + "SELECT CEIL(5.3)", + ) self.validate_all("CAST(x AS BYTEINT)", write={"snowflake": "CAST(x AS INT)"}) self.validate_all("CAST(x AS CHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) @@ -911,7 +930,23 @@ FROM cs.telescope.dag_report, TABLE(FLATTEN(input => SPLIT(operators, ','))) AS f.value AS "Contact", f1.value['type'] AS "Type", f1.value['content'] AS "Details" -FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERAL FLATTEN(input => f.value['business']) AS f1""", +FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f(SEQ, KEY, PATH, INDEX, VALUE, THIS), LATERAL FLATTEN(input => f.value['business']) AS f1(SEQ, KEY, PATH, INDEX, VALUE, THIS)""", + }, + pretty=True, + ) + + self.validate_all( + """ + SELECT id as "ID", + value AS "Contact" + FROM persons p, + lateral flatten(input => p.c, path => 'contact') + """, + write={ + "snowflake": """SELECT + id AS "ID", + value AS "Contact" +FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS _flattened(SEQ, KEY, PATH, INDEX, VALUE, THIS)""", }, pretty=True, ) @@ -1134,3 +1169,8 @@ MATCH_RECOGNIZE ( self.assertIsNotNone(table) self.assertEqual(table.sql(dialect="snowflake"), '"TEST"."PUBLIC"."customers"') + + def test_swap(self): + ast = parse_one("ALTER TABLE a SWAP WITH b", read="snowflake") + assert isinstance(ast, exp.AlterTable) + assert isinstance(ast.args["actions"][0], exp.SwapTable) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 9bb9d79..e08915b 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -230,6 +230,7 @@ TBLPROPERTIES ( self.assertIsInstance(expr.args.get("ignore_nulls"), exp.Boolean) self.assertEqual(expr.sql(dialect="spark"), "ANY_VALUE(col, TRUE)") + self.validate_identity("SELECT CASE WHEN a = NULL THEN 1 ELSE 2 END") self.validate_identity("SELECT * FROM t1 SEMI JOIN t2 ON t1.x = t2.x") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)") @@ -295,7 +296,7 @@ TBLPROPERTIES ( }, write={ "spark": "SELECT DATEDIFF(month, TO_DATE(CAST('1996-10-30' AS TIMESTAMP)), TO_DATE(CAST('1997-02-28 10:30:00' AS TIMESTAMP)))", - "spark2": "SELECT MONTHS_BETWEEN(TO_DATE(CAST('1997-02-28 10:30:00' AS TIMESTAMP)), TO_DATE(CAST('1996-10-30' AS TIMESTAMP)))", + "spark2": "SELECT CAST(MONTHS_BETWEEN(TO_DATE(CAST('1997-02-28 10:30:00' AS TIMESTAMP)), TO_DATE(CAST('1996-10-30' AS TIMESTAMP))) AS INT)", }, ) self.validate_all( @@ -403,10 +404,10 @@ TBLPROPERTIES ( "SELECT DATEDIFF(MONTH, '2020-01-01', '2020-03-05')", write={ "databricks": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))", - "hive": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))", + "hive": "SELECT CAST(MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01')) AS INT)", "presto": "SELECT DATE_DIFF('MONTH', CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE), CAST(CAST('2020-03-05' AS TIMESTAMP) AS DATE))", "spark": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))", - "spark2": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))", + "spark2": "SELECT CAST(MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01')) AS INT)", "trino": "SELECT DATE_DIFF('MONTH', CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE), CAST(CAST('2020-03-05' AS TIMESTAMP) AS DATE))", }, ) diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 9dbac8c..b5c0fe8 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -5,6 +5,7 @@ class TestTeradata(Validator): dialect = "teradata" def test_teradata(self): + self.validate_identity("SELECT TOP 10 * FROM tbl") self.validate_identity("SELECT * FROM tbl SAMPLE 5") self.validate_identity( "SELECT * FROM tbl SAMPLE 0.33, .25, .1", diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index f9a720a..4775020 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1058,18 +1058,18 @@ WHERE }, ) self.validate_all( - "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", + "SELECT DATEDIFF(year, '2020-01-01', '2021-01-01')", write={ - "tsql": "SELECT DATEDIFF(year, CAST('2020/01/01' AS DATETIME2), CAST('2021/01/01' AS DATETIME2))", - "spark": "SELECT DATEDIFF(year, CAST('2020/01/01' AS TIMESTAMP), CAST('2021/01/01' AS TIMESTAMP))", - "spark2": "SELECT MONTHS_BETWEEN(CAST('2021/01/01' AS TIMESTAMP), CAST('2020/01/01' AS TIMESTAMP)) / 12", + "tsql": "SELECT DATEDIFF(year, CAST('2020-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", + "spark": "SELECT DATEDIFF(year, CAST('2020-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('2021-01-01' AS TIMESTAMP), CAST('2020-01-01' AS TIMESTAMP)) AS INT) / 12", }, ) self.validate_all( "SELECT DATEDIFF(mm, 'start', 'end')", write={ "databricks": "SELECT DATEDIFF(month, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", - "spark2": "SELECT MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))", + "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) AS INT)", "tsql": "SELECT DATEDIFF(month, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))", }, ) @@ -1078,7 +1078,7 @@ WHERE write={ "databricks": "SELECT DATEDIFF(quarter, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", "spark": "SELECT DATEDIFF(quarter, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", - "spark2": "SELECT MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) / 3", + "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) AS INT) / 3", "tsql": "SELECT DATEDIFF(quarter, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))", }, ) @@ -1374,7 +1374,7 @@ FROM OPENJSON(@json) WITH ( Date DATETIME2 '$.Order.Date', Customer VARCHAR(200) '$.AccountNumber', Quantity INTEGER '$.Item.Quantity', - "Order" TEXT AS JSON + "Order" VARCHAR(MAX) AS JSON )""" }, pretty=True, diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 2738707..6e0a3e5 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -866,3 +866,4 @@ KILL CONNECTION 123 KILL QUERY '123' CHR(97) SELECT * FROM UNNEST(x) WITH ORDINALITY UNION ALL SELECT * FROM UNNEST(y) WITH ORDINALITY +WITH use(use) AS (SELECT 1) SELECT use FROM use diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql index 2ba762d..954b1c1 100644 --- a/tests/fixtures/optimizer/canonicalize.sql +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -16,7 +16,6 @@ SELECT CAST('2022-01-01' AS DATE) + INTERVAL '1' day AS "_col_0"; -------------------------------------- -- Ensure boolean predicates -------------------------------------- - SELECT a FROM x WHERE b; SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE "x"."b" <> 0; diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql index f43ac01..3717cd4 100644 --- a/tests/fixtures/optimizer/qualify_tables.sql +++ b/tests/fixtures/optimizer/qualify_tables.sql @@ -109,3 +109,11 @@ SELECT * FROM ((SELECT * FROM c.db.t AS t) AS _q_0); # title: wrapped subquery without alias joined with a table SELECT * FROM ((SELECT * FROM t1) INNER JOIN t2 ON a = b); SELECT * FROM ((SELECT * FROM c.db.t1 AS t1) AS _q_0 INNER JOIN c.db.t2 AS t2 ON a = b); + +# title: lateral unnest with alias +SELECT x FROM t, LATERAL UNNEST(t.xs) AS x; +SELECT x FROM c.db.t AS t, LATERAL UNNEST(t.xs) AS x; + +# title: lateral unnest without alias +SELECT x FROM t, LATERAL UNNEST(t.xs); +SELECT x FROM c.db.t AS t, LATERAL UNNEST(t.xs) AS _q_0; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index e54170c..c53a972 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -911,13 +911,76 @@ t1.a = 39 AND t2.b = t1.a AND t3.c = t2.b; t1.a = 39 AND t2.b = 39 AND t3.c = 39; x = 1 AND CASE WHEN x = 5 THEN FALSE ELSE TRUE END; -x = 1 AND CASE WHEN FALSE THEN FALSE ELSE TRUE END; +x = 1; x = 1 AND IF(x = 5, FALSE, TRUE); -x = 1 AND CASE WHEN FALSE THEN FALSE ELSE TRUE END; +x = 1; + +x = 1 AND CASE x WHEN 5 THEN FALSE ELSE TRUE END; +x = 1; x = y AND CASE WHEN x = 5 THEN FALSE ELSE TRUE END; x = y AND CASE WHEN x = 5 THEN FALSE ELSE TRUE END; x = 1 AND CASE WHEN y = 5 THEN x = z END; x = 1 AND CASE WHEN y = 5 THEN 1 = z END; + +-------------------------------------- +-- Simplify Conditionals +-------------------------------------- +IF(TRUE, x, y); +x; + +IF(FALSE, x, y); +y; + +IF(FALSE, x); +NULL; + +IF(NULL, x, y); +y; + +IF(cond, x, y); +CASE WHEN cond THEN x ELSE y END; + +CASE WHEN TRUE THEN x ELSE y END; +x; + +CASE WHEN FALSE THEN x ELSE y END; +y; + +CASE WHEN FALSE THEN x WHEN FALSE THEN y WHEN TRUE THEN z END; +z; + +CASE NULL WHEN NULL THEN x ELSE y END; +y; + +CASE 4 WHEN 1 THEN x WHEN 2 THEN y WHEN 3 THEN z ELSE w END; +w; + +CASE 4 WHEN 1 THEN x WHEN 2 THEN y WHEN 3 THEN z WHEN 4 THEN w END; +w; + +CASE WHEN value = 1 THEN x ELSE y END; +CASE WHEN value = 1 THEN x ELSE y END; + +CASE WHEN FALSE THEN x END; +NULL; + +CASE 1 WHEN 1 + 1 THEN x END; +NULL; + +CASE WHEN cond THEN x ELSE y END; +CASE WHEN cond THEN x ELSE y END; + +CASE WHEN cond THEN x END; +CASE WHEN cond THEN x END; + +CASE x WHEN y THEN z ELSE w END; +CASE WHEN x = y THEN z ELSE w END; + +CASE x WHEN y THEN z END; +CASE WHEN x = y THEN z END; + +CASE x1 + x2 WHEN x3 THEN x4 WHEN x5 + x6 THEN x7 ELSE x8 END; +CASE WHEN (x1 + x2) = x3 THEN x4 WHEN (x1 + x2) = (x5 + x6) THEN x7 ELSE x8 END; diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index 91b553e..52ee12c 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -4808,10 +4808,10 @@ WITH "foo" AS ( "foo"."i_item_sk" AS "i_item_sk", "foo"."d_moy" AS "d_moy", "foo"."mean" AS "mean", - CASE "foo"."mean" WHEN FALSE THEN NULL ELSE "foo"."stdev" / "foo"."mean" END AS "cov" + CASE WHEN "foo"."mean" = 0 THEN NULL ELSE "foo"."stdev" / "foo"."mean" END AS "cov" FROM "foo" AS "foo" WHERE - CASE "foo"."mean" WHEN FALSE THEN 0 ELSE "foo"."stdev" / "foo"."mean" END > 1 + CASE WHEN "foo"."mean" = 0 THEN 0 ELSE "foo"."stdev" / "foo"."mean" END > 1 ) SELECT "inv1"."w_warehouse_sk" AS "w_warehouse_sk", diff --git a/tests/test_expressions.py b/tests/test_expressions.py index f8c8bcc..6c48943 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -632,6 +632,11 @@ class TestExpressions(unittest.TestCase): week = unit.find(exp.Week) self.assertEqual(week.this, exp.var("thursday")) + for abbreviated_unit, unnabreviated_unit in exp.TimeUnit.UNABBREVIATED_UNIT_NAME.items(): + interval = parse_one(f"interval '500 {abbreviated_unit}'") + self.assertIsInstance(interval.unit, exp.Var) + self.assertEqual(interval.unit.name, unnabreviated_unit) + def test_identifier(self): self.assertTrue(exp.to_identifier('"x"').quoted) self.assertFalse(exp.to_identifier("x").quoted) @@ -861,6 +866,10 @@ FROM foo""", self.assertEqual(exp.DataType.build("ARRAY").sql(), "ARRAY") self.assertEqual(exp.DataType.build("ARRAY").sql(), "ARRAY") + self.assertEqual(exp.DataType.build("varchar(100) collate 'en-ci'").sql(), "VARCHAR(100)") + + with self.assertRaises(ParseError): + exp.DataType.build("varchar(") def test_rename_table(self): self.assertEqual( diff --git a/tests/test_lineage.py b/tests/test_lineage.py index 0fd9da8..25329e2 100644 --- a/tests/test_lineage.py +++ b/tests/test_lineage.py @@ -199,3 +199,77 @@ class TestLineage(unittest.TestCase): "SELECT x FROM (SELECT ax AS x FROM a UNION SELECT bx FROM b UNION SELECT cx FROM c)", ) assert len(node.downstream) == 3 + + def test_lineage_lateral_flatten(self) -> None: + node = lineage( + "VALUE", + "SELECT FLATTENED.VALUE FROM TEST_TABLE, LATERAL FLATTEN(INPUT => RESULT, OUTER => TRUE) FLATTENED", + dialect="snowflake", + ) + self.assertEqual(node.name, "VALUE") + + downstream = node.downstream[0] + self.assertEqual(downstream.name, "FLATTENED.VALUE") + self.assertEqual( + downstream.source.sql(dialect="snowflake"), + "LATERAL FLATTEN(INPUT => TEST_TABLE.RESULT, OUTER => TRUE) AS FLATTENED(SEQ, KEY, PATH, INDEX, VALUE, THIS)", + ) + self.assertEqual( + downstream.expression.sql(dialect="snowflake"), + "VALUE", + ) + self.assertEqual(len(downstream.downstream), 1) + + downstream = downstream.downstream[0] + self.assertEqual(downstream.name, "TEST_TABLE.RESULT") + self.assertEqual(downstream.source.sql(dialect="snowflake"), "TEST_TABLE AS TEST_TABLE") + + def test_subquery(self) -> None: + node = lineage( + "output", + "SELECT (SELECT max(t3.my_column) my_column FROM foo t3) AS output FROM table3", + ) + self.assertEqual(node.name, "SUBQUERY") + node = node.downstream[0] + self.assertEqual(node.name, "my_column") + node = node.downstream[0] + self.assertEqual(node.name, "t3.my_column") + self.assertEqual(node.source.sql(), "foo AS t3") + + def test_lineage_cte_union(self) -> None: + query = """ + WITH dataset AS ( + SELECT * + FROM catalog.db.table_a + + UNION + + SELECT * + FROM catalog.db.table_b + ) + + SELECT x, created_at FROM dataset; + """ + node = lineage("x", query) + + self.assertEqual(node.name, "x") + + downstream_a = node.downstream[0] + self.assertEqual(downstream_a.name, "0") + self.assertEqual(downstream_a.source.sql(), "SELECT * FROM catalog.db.table_a AS table_a") + downstream_b = node.downstream[1] + self.assertEqual(downstream_b.name, "0") + self.assertEqual(downstream_b.source.sql(), "SELECT * FROM catalog.db.table_b AS table_b") + + def test_select_star(self) -> None: + node = lineage("x", "SELECT x from (SELECT * from table_a)") + + self.assertEqual(node.name, "x") + + downstream = node.downstream[0] + self.assertEqual(downstream.name, "_q_0.x") + self.assertEqual(downstream.source.sql(), "SELECT * FROM table_a AS table_a") + + downstream = downstream.downstream[0] + self.assertEqual(downstream.name, "*") + self.assertEqual(downstream.source.sql(), "table_a AS table_a") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index c43a84e..8f5dd08 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -550,6 +550,47 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT) self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT) + def test_bracket_annotation(self): + expression = annotate_types(parse_one("SELECT A[:]")).expressions[0] + + self.assertEqual(expression.type.this, exp.DataType.Type.UNKNOWN) + self.assertEqual(expression.expressions[0].type.this, exp.DataType.Type.UNKNOWN) + + expression = annotate_types(parse_one("SELECT ARRAY[1, 2, 3][1]")).expressions[0] + self.assertEqual(expression.this.type.sql(), "ARRAY") + self.assertEqual(expression.type.this, exp.DataType.Type.INT) + + expression = annotate_types(parse_one("SELECT ARRAY[1, 2, 3][1 : 2]")).expressions[0] + self.assertEqual(expression.this.type.sql(), "ARRAY") + self.assertEqual(expression.type.sql(), "ARRAY") + + expression = annotate_types( + parse_one("SELECT ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]][1][2]") + ).expressions[0] + self.assertEqual(expression.this.this.type.sql(), "ARRAY>") + self.assertEqual(expression.this.type.sql(), "ARRAY") + self.assertEqual(expression.type.this, exp.DataType.Type.INT) + + expression = annotate_types( + parse_one("SELECT ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]][1:2]") + ).expressions[0] + self.assertEqual(expression.type.sql(), "ARRAY>") + + expression = annotate_types(parse_one("MAP(1.0, 2, '2', 3.0)['2']", read="spark")) + self.assertEqual(expression.type.this, exp.DataType.Type.DOUBLE) + + expression = annotate_types(parse_one("MAP(1.0, 2, x, 3.0)[2]", read="spark")) + self.assertEqual(expression.type.this, exp.DataType.Type.UNKNOWN) + + expression = annotate_types(parse_one("MAP(ARRAY(1.0, x), ARRAY(2, 3.0))[x]")) + self.assertEqual(expression.type.this, exp.DataType.Type.DOUBLE) + + expression = annotate_types( + parse_one("SELECT MAP(1.0, 2, 2, t.y)[2] FROM t", read="spark"), + schema={"t": {"y": "int"}}, + ).expressions[0] + self.assertEqual(expression.type.this, exp.DataType.Type.INT) + def test_interval_math_annotation(self): schema = { "x": { diff --git a/tests/test_parser.py b/tests/test_parser.py index 53e1a85..f3e663e 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -234,6 +234,9 @@ class TestParser(unittest.TestCase): "CREATE TABLE t (i UInt8) ENGINE=AggregatingMergeTree() ORDER BY tuple()", ) + with self.assertRaises(ParseError): + parse_one("SELECT A[:") + def test_space(self): self.assertEqual( parse_one("SELECT ROW() OVER(PARTITION BY x) FROM x GROUP BY y").sql(), diff --git a/tests/test_transpile.py b/tests/test_transpile.py index d588f07..c16b1f6 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -19,6 +19,9 @@ class TestTranspile(unittest.TestCase): def validate(self, sql, target, **kwargs): self.assertEqual(transpile(sql, **kwargs)[0], target) + def test_weird_chars(self): + self.assertEqual(transpile("0Êß")[0], "0 AS Êß") + def test_alias(self): self.assertEqual(transpile("SELECT SUM(y) KEEP")[0], "SELECT SUM(y) AS KEEP") self.assertEqual(transpile("SELECT 1 overwrite")[0], "SELECT 1 AS overwrite") @@ -87,7 +90,18 @@ class TestTranspile(unittest.TestCase): self.validate("SELECT 3>=3", "SELECT 3 >= 3") def test_comments(self): - self.validate("SELECT\n foo\n/* comments */\n;", "SELECT foo /* comments */") + self.validate( + "SELECT * FROM t1\n/*x*/\nUNION ALL SELECT * FROM t2", + "SELECT * FROM t1 /* x */ UNION ALL SELECT * FROM t2", + ) + self.validate( + "SELECT * FROM t1\n/*x*/\nINTERSECT ALL SELECT * FROM t2", + "SELECT * FROM t1 /* x */ INTERSECT ALL SELECT * FROM t2", + ) + self.validate( + "SELECT\n foo\n/* comments */\n;", + "SELECT foo /* comments */", + ) self.validate( "SELECT * FROM a INNER /* comments */ JOIN b", "SELECT * FROM a /* comments */ INNER JOIN b", @@ -379,6 +393,47 @@ LEFT OUTER JOIN b""", FROM tbl""", pretty=True, ) + self.validate( + """ +SELECT + 'hotel1' AS hotel, + * +FROM dw_1_dw_1_1.exactonline_1.transactionlines +/* + UNION ALL + SELECT + 'Thon Partner Hotel Jølster' AS hotel, + name, + date, + CAST(identifier AS VARCHAR) AS identifier, + value + FROM d2o_889_oupjr_1348.public.accountvalues_forecast +*/ +UNION ALL +SELECT + 'hotel2' AS hotel, + * +FROM dw_1_dw_1_1.exactonline_2.transactionlines""", + """SELECT + 'hotel1' AS hotel, + * +FROM dw_1_dw_1_1.exactonline_1.transactionlines /* + UNION ALL + SELECT + 'Thon Partner Hotel Jølster' AS hotel, + name, + date, + CAST(identifier AS VARCHAR) AS identifier, + value + FROM d2o_889_oupjr_1348.public.accountvalues_forecast +*/ +UNION ALL +SELECT + 'hotel2' AS hotel, + * +FROM dw_1_dw_1_1.exactonline_2.transactionlines""", + pretty=True, + ) def test_types(self): self.validate("INT 1", "CAST(1 AS INT)") -- cgit v1.2.3