From 278f416d08028bd175e1d6433739461f2168f4e2 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 11 Jun 2024 18:34:56 +0200 Subject: Adding upstream version 25.0.3. Signed-off-by: Daniel Baumann --- tests/dialects/test_bigquery.py | 22 +++++++- tests/dialects/test_dialect.py | 38 ++++++++++---- tests/dialects/test_doris.py | 8 +++ tests/dialects/test_duckdb.py | 32 +++++++----- tests/dialects/test_hive.py | 10 ++++ tests/dialects/test_materialize.py | 77 ++++++++++++++++++++++++++++ tests/dialects/test_mysql.py | 3 ++ tests/dialects/test_postgres.py | 3 ++ tests/dialects/test_redshift.py | 1 + tests/dialects/test_risingwave.py | 14 +++++ tests/dialects/test_snowflake.py | 3 ++ tests/fixtures/optimizer/qualify_columns.sql | 4 ++ tests/fixtures/optimizer/simplify.sql | 18 +++++++ tests/test_build.py | 4 ++ tests/test_optimizer.py | 9 ++++ 15 files changed, 223 insertions(+), 23 deletions(-) create mode 100644 tests/dialects/test_materialize.py create mode 100644 tests/dialects/test_risingwave.py (limited to 'tests') diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 728785c..bfaf009 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -1269,7 +1269,7 @@ LANGUAGE js AS "SELECT REGEXP_EXTRACT(abc, 'pattern(group)') FROM table", write={ "bigquery": "SELECT REGEXP_EXTRACT(abc, 'pattern(group)') FROM table", - "duckdb": "SELECT REGEXP_EXTRACT(abc, 'pattern(group)', 1) FROM table", + "duckdb": '''SELECT REGEXP_EXTRACT(abc, 'pattern(group)', 1) FROM "table"''', }, ) self.validate_all( @@ -1524,6 +1524,26 @@ WHERE "SELECT cola FROM (SELECT CAST('1' AS STRING) AS cola UNION ALL SELECT CAST('2' AS STRING) AS cola)", ) + def test_gap_fill(self): + self.validate_identity( + "SELECT * FROM GAP_FILL(TABLE device_data, ts_column => 'time', bucket_width => INTERVAL '1' MINUTE, value_columns => [('signal', 'locf')]) ORDER BY time" + ) + self.validate_identity( + "SELECT a, b, c, d, e FROM GAP_FILL(TABLE foo, ts_column => 'b', partitioning_columns => ['a'], value_columns => [('c', 'bar'), ('d', 'baz'), ('e', 'bla')], bucket_width => INTERVAL '1' DAY)" + ) + self.validate_identity( + "SELECT * FROM GAP_FILL(TABLE device_data, ts_column => 'time', bucket_width => INTERVAL '1' MINUTE, value_columns => [('signal', 'linear')], ignore_null_values => FALSE) ORDER BY time" + ) + self.validate_identity( + "SELECT * FROM GAP_FILL(TABLE device_data, ts_column => 'time', bucket_width => INTERVAL '1' MINUTE) ORDER BY time" + ) + self.validate_identity( + "SELECT * FROM GAP_FILL(TABLE device_data, ts_column => 'time', bucket_width => INTERVAL '1' MINUTE, value_columns => [('signal', 'null')], origin => CAST('2023-11-01 09:30:01' AS DATETIME)) ORDER BY time" + ) + self.validate_identity( + "SELECT * FROM GAP_FILL(TABLE (SELECT * FROM UNNEST(ARRAY>[STRUCT(1, CAST('2023-11-01 09:34:01' AS DATETIME), 74, 'INACTIVE'), STRUCT(2, CAST('2023-11-01 09:36:00' AS DATETIME), 77, 'ACTIVE'), STRUCT(3, CAST('2023-11-01 09:37:00' AS DATETIME), 78, 'ACTIVE'), STRUCT(4, CAST('2023-11-01 09:38:01' AS DATETIME), 80, 'ACTIVE')])), ts_column => 'time', bucket_width => INTERVAL '1' MINUTE, value_columns => [('signal', 'linear')]) ORDER BY time" + ) + def test_models(self): self.validate_identity( "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT label, column1, column2 FROM mydataset.mytable))" diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 9888a5d..aaeb7b0 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -155,6 +155,7 @@ class TestDialect(Validator): "clickhouse": "CAST(a AS String)", "drill": "CAST(a AS VARCHAR)", "duckdb": "CAST(a AS TEXT)", + "materialize": "CAST(a AS TEXT)", "mysql": "CAST(a AS CHAR)", "hive": "CAST(a AS STRING)", "oracle": "CAST(a AS CLOB)", @@ -175,6 +176,7 @@ class TestDialect(Validator): "clickhouse": "CAST(a AS BINARY(4))", "drill": "CAST(a AS VARBINARY(4))", "duckdb": "CAST(a AS BLOB(4))", + "materialize": "CAST(a AS BYTEA(4))", "mysql": "CAST(a AS BINARY(4))", "hive": "CAST(a AS BINARY(4))", "oracle": "CAST(a AS BLOB(4))", @@ -193,6 +195,7 @@ class TestDialect(Validator): "bigquery": "CAST(a AS BYTES)", "clickhouse": "CAST(a AS String)", "duckdb": "CAST(a AS BLOB(4))", + "materialize": "CAST(a AS BYTEA(4))", "mysql": "CAST(a AS VARBINARY(4))", "hive": "CAST(a AS BINARY(4))", "oracle": "CAST(a AS BLOB(4))", @@ -236,6 +239,7 @@ class TestDialect(Validator): "bigquery": "CAST(a AS STRING)", "drill": "CAST(a AS VARCHAR)", "duckdb": "CAST(a AS TEXT)", + "materialize": "CAST(a AS TEXT)", "mysql": "CAST(a AS CHAR)", "hive": "CAST(a AS STRING)", "oracle": "CAST(a AS CLOB)", @@ -255,6 +259,7 @@ class TestDialect(Validator): "bigquery": "CAST(a AS STRING)", "drill": "CAST(a AS VARCHAR)", "duckdb": "CAST(a AS TEXT)", + "materialize": "CAST(a AS VARCHAR)", "mysql": "CAST(a AS CHAR)", "hive": "CAST(a AS STRING)", "oracle": "CAST(a AS VARCHAR2)", @@ -274,6 +279,7 @@ class TestDialect(Validator): "bigquery": "CAST(a AS STRING)", "drill": "CAST(a AS VARCHAR(3))", "duckdb": "CAST(a AS TEXT(3))", + "materialize": "CAST(a AS VARCHAR(3))", "mysql": "CAST(a AS CHAR(3))", "hive": "CAST(a AS VARCHAR(3))", "oracle": "CAST(a AS VARCHAR2(3))", @@ -293,6 +299,7 @@ class TestDialect(Validator): "bigquery": "CAST(a AS INT64)", "drill": "CAST(a AS INTEGER)", "duckdb": "CAST(a AS SMALLINT)", + "materialize": "CAST(a AS SMALLINT)", "mysql": "CAST(a AS SIGNED)", "hive": "CAST(a AS SMALLINT)", "oracle": "CAST(a AS NUMBER)", @@ -328,6 +335,7 @@ class TestDialect(Validator): "clickhouse": "CAST(a AS Float64)", "drill": "CAST(a AS DOUBLE)", "duckdb": "CAST(a AS DOUBLE)", + "materialize": "CAST(a AS DOUBLE PRECISION)", "mysql": "CAST(a AS DOUBLE)", "hive": "CAST(a AS DOUBLE)", "oracle": "CAST(a AS DOUBLE PRECISION)", @@ -599,6 +607,7 @@ class TestDialect(Validator): "drill": "TO_TIMESTAMP(x, 'yy')", "duckdb": "STRPTIME(x, '%y')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)", + "materialize": "TO_TIMESTAMP(x, 'YY')", "presto": "DATE_PARSE(x, '%y')", "oracle": "TO_TIMESTAMP(x, 'YY')", "postgres": "TO_TIMESTAMP(x, 'YY')", @@ -655,6 +664,7 @@ class TestDialect(Validator): "drill": "TO_CHAR(x, 'yyyy-MM-dd')", "duckdb": "STRFTIME(x, '%Y-%m-%d')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", + "materialize": "TO_CHAR(x, 'YYYY-MM-DD')", "oracle": "TO_CHAR(x, 'YYYY-MM-DD')", "postgres": "TO_CHAR(x, 'YYYY-MM-DD')", "presto": "DATE_FORMAT(x, '%Y-%m-%d')", @@ -698,6 +708,7 @@ class TestDialect(Validator): "bigquery": "CAST(x AS DATE)", "duckdb": "CAST(x AS DATE)", "hive": "TO_DATE(x)", + "materialize": "CAST(x AS DATE)", "postgres": "CAST(x AS DATE)", "presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)", "snowflake": "TO_DATE(x)", @@ -730,6 +741,7 @@ class TestDialect(Validator): "duckdb": "TO_TIMESTAMP(x)", "hive": "FROM_UNIXTIME(x)", "oracle": "TO_DATE('1970-01-01', 'YYYY-MM-DD') + (x / 86400)", + "materialize": "TO_TIMESTAMP(x)", "postgres": "TO_TIMESTAMP(x)", "presto": "FROM_UNIXTIME(x)", "starrocks": "FROM_UNIXTIME(x)", @@ -790,6 +802,7 @@ class TestDialect(Validator): "drill": "DATE_ADD(x, INTERVAL 1 DAY)", "duckdb": "x + INTERVAL 1 DAY", "hive": "DATE_ADD(x, 1)", + "materialize": "x + INTERVAL '1 DAY'", "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", "postgres": "x + INTERVAL '1 DAY'", "presto": "DATE_ADD('DAY', 1, x)", @@ -826,6 +839,7 @@ class TestDialect(Validator): "duckdb": "DATE_TRUNC('DAY', x)", "mysql": "DATE(x)", "presto": "DATE_TRUNC('DAY', x)", + "materialize": "DATE_TRUNC('DAY', x)", "postgres": "DATE_TRUNC('DAY', x)", "snowflake": "DATE_TRUNC('DAY', x)", "starrocks": "DATE_TRUNC('DAY', x)", @@ -838,6 +852,7 @@ class TestDialect(Validator): read={ "bigquery": "TIMESTAMP_TRUNC(x, day)", "duckdb": "DATE_TRUNC('day', x)", + "materialize": "DATE_TRUNC('day', x)", "presto": "DATE_TRUNC('day', x)", "postgres": "DATE_TRUNC('day', x)", "snowflake": "DATE_TRUNC('day', x)", @@ -899,6 +914,7 @@ class TestDialect(Validator): }, write={ "bigquery": "DATE_TRUNC(x, YEAR)", + "materialize": "DATE_TRUNC('YEAR', x)", "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", "postgres": "DATE_TRUNC('YEAR', x)", "snowflake": "DATE_TRUNC('YEAR', x)", @@ -911,6 +927,7 @@ class TestDialect(Validator): "TIMESTAMP_TRUNC(x, YEAR)", read={ "bigquery": "TIMESTAMP_TRUNC(x, year)", + "materialize": "DATE_TRUNC('YEAR', x)", "postgres": "DATE_TRUNC(year, x)", "spark": "DATE_TRUNC('year', x)", "snowflake": "DATE_TRUNC(year, x)", @@ -1024,6 +1041,7 @@ class TestDialect(Validator): write={ "": "TIMESTAMP_TRUNC(x, DAY, 'UTC')", "duckdb": "DATE_TRUNC('DAY', x)", + "materialize": "DATE_TRUNC('DAY', x, 'UTC')", "presto": "DATE_TRUNC('DAY', x)", "postgres": "DATE_TRUNC('DAY', x, 'UTC')", "snowflake": "DATE_TRUNC('DAY', x)", @@ -1485,21 +1503,21 @@ class TestDialect(Validator): "snowflake": "x ILIKE '%y'", }, write={ - "bigquery": "LOWER(x) LIKE '%y'", + "bigquery": "LOWER(x) LIKE LOWER('%y')", "clickhouse": "x ILIKE '%y'", "drill": "x `ILIKE` '%y'", "duckdb": "x ILIKE '%y'", - "hive": "LOWER(x) LIKE '%y'", - "mysql": "LOWER(x) LIKE '%y'", - "oracle": "LOWER(x) LIKE '%y'", + "hive": "LOWER(x) LIKE LOWER('%y')", + "mysql": "LOWER(x) LIKE LOWER('%y')", + "oracle": "LOWER(x) LIKE LOWER('%y')", "postgres": "x ILIKE '%y'", - "presto": "LOWER(x) LIKE '%y'", + "presto": "LOWER(x) LIKE LOWER('%y')", "snowflake": "x ILIKE '%y'", "spark": "x ILIKE '%y'", - "sqlite": "LOWER(x) LIKE '%y'", - "starrocks": "LOWER(x) LIKE '%y'", - "trino": "LOWER(x) LIKE '%y'", - "doris": "LOWER(x) LIKE '%y'", + "sqlite": "LOWER(x) LIKE LOWER('%y')", + "starrocks": "LOWER(x) LIKE LOWER('%y')", + "trino": "LOWER(x) LIKE LOWER('%y')", + "doris": "LOWER(x) LIKE LOWER('%y')", }, ) self.validate_all( @@ -2530,7 +2548,7 @@ FROM subquery2""", def test_reserved_keywords(self): order = exp.select("*").from_("order") - for dialect in ("presto", "redshift"): + for dialect in ("duckdb", "presto", "redshift"): dialect = Dialect.get_or_raise(dialect) self.assertEqual( order.sql(dialect=dialect), diff --git a/tests/dialects/test_doris.py b/tests/dialects/test_doris.py index f7fce02..8180d05 100644 --- a/tests/dialects/test_doris.py +++ b/tests/dialects/test_doris.py @@ -48,6 +48,14 @@ class TestDoris(Validator): "postgres": """SELECT JSON_EXTRACT_PATH(CAST('{"key": 1}' AS JSONB), 'key')""", }, ) + self.validate_all( + "SELECT GROUP_CONCAT('aa', ',')", + read={ + "doris": "SELECT GROUP_CONCAT('aa', ',')", + "mysql": "SELECT GROUP_CONCAT('aa' SEPARATOR ',')", + "postgres": "SELECT STRING_AGG('aa', ',')", + }, + ) def test_identity(self): self.validate_identity("COALECSE(a, b, c, d)") diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index bbf665d..cd68ff9 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -1,4 +1,4 @@ -from sqlglot import ErrorLevel, UnsupportedError, exp, parse_one, transpile +from sqlglot import ErrorLevel, ParseError, UnsupportedError, exp, parse_one, transpile from sqlglot.helper import logger as helper_logger from sqlglot.optimizer.annotate_types import annotate_types from tests.dialects.test_dialect import Validator @@ -8,6 +8,9 @@ class TestDuckDB(Validator): dialect = "duckdb" def test_duckdb(self): + with self.assertRaises(ParseError): + parse_one("1 //", read="duckdb") + query = "WITH _data AS (SELECT [{'a': 1, 'b': 2}, {'a': 2, 'b': 3}] AS col) SELECT t.col['b'] FROM _data, UNNEST(_data.col) AS t(col) WHERE t.col['a'] = 1" expr = annotate_types(self.validate_identity(query)) self.assertEqual( @@ -15,6 +18,13 @@ class TestDuckDB(Validator): "WITH _data AS (SELECT [STRUCT(1 AS a, 2 AS b), STRUCT(2 AS a, 3 AS b)] AS col) SELECT col.b FROM _data, UNNEST(_data.col) AS col WHERE col.a = 1", ) + self.validate_all( + "SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMP)", + read={ + "duckdb": "SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMP)", + "snowflake": "SELECT CAST('2020-01-01 12:05:01' AS TIMESTAMPNTZ)", + }, + ) self.validate_all( "SELECT CAST('2020-01-01' AS DATE) + INTERVAL (day_offset) DAY FROM t", read={ @@ -247,7 +257,7 @@ class TestDuckDB(Validator): self.validate_identity("SELECT EPOCH_MS(10) AS t") self.validate_identity("SELECT MAKE_TIMESTAMP(10) AS t") self.validate_identity("SELECT TO_TIMESTAMP(10) AS t") - self.validate_identity("SELECT UNNEST(column, recursive := TRUE) FROM table") + self.validate_identity("SELECT UNNEST(col, recursive := TRUE) FROM t") self.validate_identity("VAR_POP(a)") self.validate_identity("SELECT * FROM foo ASOF LEFT JOIN bar ON a = b") self.validate_identity("PIVOT Cities ON Year USING SUM(Population)") @@ -271,6 +281,10 @@ class TestDuckDB(Validator): self.validate_identity( "SELECT * FROM x LEFT JOIN UNNEST(y)", "SELECT * FROM x LEFT JOIN UNNEST(y) ON TRUE" ) + self.validate_identity( + "SELECT a, LOGICAL_OR(b) FROM foo GROUP BY a", + "SELECT a, BOOL_OR(b) FROM foo GROUP BY a", + ) self.validate_identity( "SELECT JSON_EXTRACT_STRING(c, '$.k1') = 'v1'", "SELECT (c ->> '$.k1') = 'v1'", @@ -424,15 +438,15 @@ class TestDuckDB(Validator): write={"duckdb": 'WITH "x" AS (SELECT 1) SELECT * FROM x'}, ) self.validate_all( - "CREATE TABLE IF NOT EXISTS table (cola INT, colb STRING) USING ICEBERG PARTITIONED BY (colb)", + "CREATE TABLE IF NOT EXISTS t (cola INT, colb STRING) USING ICEBERG PARTITIONED BY (colb)", write={ - "duckdb": "CREATE TABLE IF NOT EXISTS table (cola INT, colb TEXT)", + "duckdb": "CREATE TABLE IF NOT EXISTS t (cola INT, colb TEXT)", }, ) self.validate_all( - "CREATE TABLE IF NOT EXISTS table (cola INT COMMENT 'cola', colb STRING) USING ICEBERG PARTITIONED BY (colb)", + "CREATE TABLE IF NOT EXISTS t (cola INT COMMENT 'cola', colb STRING) USING ICEBERG PARTITIONED BY (colb)", write={ - "duckdb": "CREATE TABLE IF NOT EXISTS table (cola INT, colb TEXT)", + "duckdb": "CREATE TABLE IF NOT EXISTS t (cola INT, colb TEXT)", }, ) self.validate_all( @@ -1086,12 +1100,6 @@ class TestDuckDB(Validator): }, ) - def test_bool_or(self): - self.validate_all( - "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a", - write={"duckdb": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"}, - ) - def test_encode_decode(self): self.validate_all( "ENCODE(x)", diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 3ebaded..0311336 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -744,6 +744,16 @@ class TestHive(Validator): "hive": "SELECT a, SUM(c) FROM t GROUP BY a, DATE_FORMAT(CAST(b AS TIMESTAMP), 'yyyy'), GROUPING SETS ((a, DATE_FORMAT(CAST(b AS TIMESTAMP), 'yyyy')), a)", }, ) + self.validate_all( + "SELECT TRUNC(CAST(ds AS TIMESTAMP), 'MONTH') AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'", + read={ + "hive": "SELECT TRUNC(CAST(ds AS TIMESTAMP), 'MONTH') AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'", + "presto": "SELECT DATE_TRUNC('MONTH', CAST(ds AS TIMESTAMP)) AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'", + }, + write={ + "presto": "SELECT DATE_TRUNC('MONTH', TRY_CAST(ds AS TIMESTAMP)) AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'", + }, + ) def test_escapes(self) -> None: self.validate_identity("'\n'", "'\\n'") diff --git a/tests/dialects/test_materialize.py b/tests/dialects/test_materialize.py new file mode 100644 index 0000000..617a9b5 --- /dev/null +++ b/tests/dialects/test_materialize.py @@ -0,0 +1,77 @@ +from tests.dialects.test_dialect import Validator + + +class TestMaterialize(Validator): + dialect = "materialize" + + def test_materialize(self): + self.validate_all( + "CREATE TABLE example (id INT PRIMARY KEY, name TEXT)", + write={ + "materialize": "CREATE TABLE example (id INT, name TEXT)", + "postgres": "CREATE TABLE example (id INT PRIMARY KEY, name TEXT)", + }, + ) + self.validate_all( + "INSERT INTO example (id, name) VALUES (1, 'Alice') ON CONFLICT(id) DO NOTHING", + write={ + "materialize": "INSERT INTO example (id, name) VALUES (1, 'Alice')", + "postgres": "INSERT INTO example (id, name) VALUES (1, 'Alice') ON CONFLICT(id) DO NOTHING", + }, + ) + self.validate_all( + "CREATE TABLE example (id SERIAL, name TEXT)", + write={ + "materialize": "CREATE TABLE example (id INT NOT NULL, name TEXT)", + "postgres": "CREATE TABLE example (id INT GENERATED BY DEFAULT AS IDENTITY NOT NULL, name TEXT)", + }, + ) + self.validate_all( + "CREATE TABLE example (id INT AUTO_INCREMENT, name TEXT)", + write={ + "materialize": "CREATE TABLE example (id INT NOT NULL, name TEXT)", + "postgres": "CREATE TABLE example (id INT GENERATED BY DEFAULT AS IDENTITY NOT NULL, name TEXT)", + }, + ) + self.validate_all( + 'SELECT JSON_EXTRACT_PATH_TEXT(\'{ "farm": {"barn": { "color": "red", "feed stocked": true }}}\', \'farm\', \'barn\', \'color\')', + write={ + "materialize": 'SELECT JSON_EXTRACT_PATH_TEXT(\'{ "farm": {"barn": { "color": "red", "feed stocked": true }}}\', \'farm\', \'barn\', \'color\')', + "postgres": 'SELECT JSON_EXTRACT_PATH_TEXT(\'{ "farm": {"barn": { "color": "red", "feed stocked": true }}}\', \'farm\', \'barn\', \'color\')', + }, + ) + self.validate_all( + "SELECT MAP['a' => 1]", + write={ + "duckdb": "SELECT MAP {'a': 1}", + "materialize": "SELECT MAP['a' => 1]", + }, + ) + + # Test now functions. + self.validate_identity("CURRENT_TIMESTAMP") + self.validate_identity("NOW()", write_sql="CURRENT_TIMESTAMP") + self.validate_identity("MZ_NOW()") + + # Test custom timestamp type. + self.validate_identity("SELECT CAST(1 AS mz_timestamp)") + + # Test DDL. + self.validate_identity("CREATE TABLE example (id INT, name LIST)") + + # Test list types. + self.validate_identity("SELECT LIST[]") + self.validate_identity("SELECT LIST[1, 2, 3]") + self.validate_identity("SELECT LIST[LIST[1], LIST[2], NULL]") + self.validate_identity("SELECT CAST(LIST[1, 2, 3] AS INT LIST)") + self.validate_identity("SELECT CAST(NULL AS INT LIST)") + self.validate_identity("SELECT CAST(NULL AS INT LIST LIST LIST)") + self.validate_identity("SELECT LIST(SELECT 1)") + + # Test map types. + self.validate_identity("SELECT MAP[]") + self.validate_identity("SELECT MAP['a' => MAP['b' => 'c']]") + self.validate_identity("SELECT CAST(MAP['a' => 1] AS MAP[TEXT => INT])") + self.validate_identity("SELECT CAST(NULL AS MAP[TEXT => INT])") + self.validate_identity("SELECT CAST(NULL AS MAP[TEXT => MAP[TEXT => INT]])") + self.validate_identity("SELECT MAP(SELECT 'a', 1)") diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 591b5dd..fdb7e91 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -222,6 +222,9 @@ class TestMySQL(Validator): self.validate_identity("CHAR(77, 121, 83, 81, '76')") self.validate_identity("CHAR(77, 77.3, '77.3' USING utf8mb4)") self.validate_identity("SELECT * FROM t1 PARTITION(p0)") + self.validate_identity("SELECT @var1 := 1, @var2") + self.validate_identity("SELECT @var1, @var2 := @var1") + self.validate_identity("SELECT @var1 := COUNT(*) FROM t1") def test_types(self): for char_type in MySQL.Generator.CHAR_CAST_MAPPING: diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 38c262f..74753be 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -8,6 +8,9 @@ class TestPostgres(Validator): dialect = "postgres" def test_postgres(self): + self.validate_identity( + 'CREATE TABLE x (a TEXT COLLATE "de_DE")', "CREATE TABLE x (a TEXT COLLATE de_DE)" + ) self.validate_identity("1.x", "1. AS x") self.validate_identity("|/ x", "SQRT(x)") self.validate_identity("||/ x", "CBRT(x)") diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index ccabe2d..844fe46 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -6,6 +6,7 @@ class TestRedshift(Validator): dialect = "redshift" def test_redshift(self): + self.validate_identity("1 div", "1 AS div") self.validate_all( "SELECT SPLIT_TO_ARRAY('12,345,6789')", write={ diff --git a/tests/dialects/test_risingwave.py b/tests/dialects/test_risingwave.py new file mode 100644 index 0000000..7d6d50c --- /dev/null +++ b/tests/dialects/test_risingwave.py @@ -0,0 +1,14 @@ +from tests.dialects.test_dialect import Validator + + +class TestRisingWave(Validator): + dialect = "risingwave" + maxDiff = None + + def test_risingwave(self): + self.validate_all( + "SELECT a FROM tbl", + read={ + "": "SELECT a FROM tbl FOR UPDATE", + }, + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index ba85719..9d9371d 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -1934,6 +1934,9 @@ STORAGE_ALLOWED_LOCATIONS=('s3://mybucket1/path1/', 's3://mybucket2/path2/')""", self.validate_identity( """COPY INTO mytable (col1, col2) FROM 's3://mybucket/data/files' STORAGE_INTEGRATION = "storage" ENCRYPTION = (TYPE='NONE' MASTER_KEY='key') FILES = ('file1', 'file2') PATTERN = 'pattern' FILE_FORMAT = (FORMAT_NAME=my_csv_format NULL_IF=('')) PARSE_HEADER = TRUE""" ) + self.validate_identity( + """COPY INTO @my_stage/result/data FROM (SELECT * FROM orderstiny) FILE_FORMAT = (TYPE='csv')""" + ) self.validate_all( """COPY INTO 's3://example/data.csv' FROM EXTRA.EXAMPLE.TABLE diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 1092bc8..ea96fe5 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -131,6 +131,10 @@ SELECT DATE(x.a) AS _col_0, DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE( SELECT (SELECT MIN(a) FROM UNNEST([1, 2])) AS f FROM x GROUP BY 1; SELECT (SELECT MIN(_q_0.a) AS _col_0 FROM UNNEST(ARRAY(1, 2)) AS _q_0) AS f FROM x AS x GROUP BY 1; +# dialect: bigquery +WITH x AS (select 'a' as a, 1 as b) SELECT x.a AS c, y.a as d, SUM(x.b) AS y, FROM x join x as y on x.a = y.a group by 1, 2; +WITH x AS (SELECT 'a' AS a, 1 AS b) SELECT x.a AS c, y.a AS d, SUM(x.b) AS y FROM x AS x JOIN x AS y ON x.a = y.a GROUP BY x.a, 2; + SELECT SUM(x.a) AS c FROM x JOIN y ON x.b = y.b GROUP BY c; SELECT SUM(x.a) AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 75abc38..87b42d1 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -113,6 +113,9 @@ a AND b; A XOR A; FALSE; +TRUE AND TRUE OR TRUE AND FALSE; +TRUE; + -------------------------------------- -- Absorption -------------------------------------- @@ -158,6 +161,12 @@ A OR C; A AND (B AND C) AND (D AND E); A AND B AND C AND D AND E; +A AND (A OR B) AND (A OR B OR C); +A; + +(A OR B) AND (A OR C) AND (A OR B OR C); +(A OR B) AND (A OR C); + -------------------------------------- -- Elimination -------------------------------------- @@ -194,6 +203,15 @@ NOT A; E OR (A AND B) OR C OR D OR (A AND NOT B); A OR C OR D OR E; +(A AND B) OR (A AND NOT B) OR (A AND NOT B); +A; + +(A AND B) OR (A AND B) OR (A AND NOT B); +A; + +(A AND B) OR (A AND NOT B) OR (A AND B) OR (A AND NOT B); +A; + -------------------------------------- -- Associativity -------------------------------------- diff --git a/tests/test_build.py b/tests/test_build.py index da1677f..150bb42 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -160,6 +160,10 @@ class TestBuild(unittest.TestCase): lambda: select("x", "y", "z", "a").from_("tbl").group_by("x, y", "z").group_by("a"), "SELECT x, y, z, a FROM tbl GROUP BY x, y, z, a", ), + ( + lambda: select(1).from_("tbl").group_by("x with cube"), + "SELECT 1 FROM tbl GROUP BY x WITH CUBE", + ), ( lambda: select("x").distinct("a", "b").from_("tbl"), "SELECT DISTINCT ON (a, b) x FROM tbl", diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 7ec0872..41a5015 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -10,6 +10,7 @@ import sqlglot from sqlglot import exp, optimizer, parse_one from sqlglot.errors import OptimizeError, SchemaError from sqlglot.optimizer.annotate_types import annotate_types +from sqlglot.optimizer.normalize import normalization_distance from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope from sqlglot.schema import MappingSchema from tests.helpers import ( @@ -1214,3 +1215,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') query = parse_one("select a.b:c from d", read="snowflake") qualified = optimizer.qualify.qualify(query) self.assertEqual(qualified.expressions[0].alias, "c") + + def test_normalization_distance(self): + def gen_expr(depth: int) -> exp.Expression: + return parse_one(" OR ".join("a AND b" for _ in range(depth))) + + self.assertEqual(4, normalization_distance(gen_expr(2), max_=100)) + self.assertEqual(18, normalization_distance(gen_expr(3), max_=100)) + self.assertEqual(110, normalization_distance(gen_expr(10), max_=100)) -- cgit v1.2.3