diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dialects/test_bigquery.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 7 | ||||
-rw-r--r-- | tests/dialects/test_databricks.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 46 | ||||
-rw-r--r-- | tests/dialects/test_doris.py | 20 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 24 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 35 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 20 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 1 | ||||
-rw-r--r-- | tests/fixtures/optimizer/optimizer.sql | 6 | ||||
-rw-r--r-- | tests/fixtures/optimizer/qualify_columns.sql | 8 | ||||
-rw-r--r-- | tests/fixtures/optimizer/simplify.sql | 65 | ||||
-rw-r--r-- | tests/fixtures/optimizer/tpc-ds/tpc-ds.sql | 26 | ||||
-rw-r--r-- | tests/test_executor.py | 21 | ||||
-rw-r--r-- | tests/test_optimizer.py | 14 | ||||
-rw-r--r-- | tests/test_tokens.py | 1 | ||||
-rw-r--r-- | tests/test_transpile.py | 56 |
19 files changed, 343 insertions, 26 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index b5f91cf..52f86bd 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -112,6 +112,14 @@ class TestBigQuery(Validator): self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"bigquery": "CAST(x AS TIMESTAMP)"}) self.validate_all("CAST(x AS RECORD)", write={"bigquery": "CAST(x AS STRUCT)"}) self.validate_all( + 'SELECT TIMESTAMP_ADD(TIMESTAMP "2008-12-25 15:30:00+00", INTERVAL 10 MINUTE)', + write={ + "bigquery": "SELECT TIMESTAMP_ADD(CAST('2008-12-25 15:30:00+00' AS TIMESTAMP), INTERVAL 10 MINUTE)", + "databricks": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", + "spark": "SELECT DATEADD(MINUTE, 10, CAST('2008-12-25 15:30:00+00' AS TIMESTAMP))", + }, + ) + self.validate_all( "MD5(x)", write={ "": "MD5_DIGEST(x)", diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 16c10fe..583be3e 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -10,6 +10,13 @@ class TestClickhouse(Validator): self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)") self.assertIsNone(expr._meta) + self.validate_identity("CAST(x AS Nested(ID UInt32, Serial UInt32, EventTime DATETIME))") + self.validate_identity("CAST(x AS Enum('hello' = 1, 'world' = 2))") + self.validate_identity("CAST(x AS Enum('hello', 'world'))") + self.validate_identity("CAST(x AS Enum('hello' = 1, 'world'))") + self.validate_identity("CAST(x AS Enum8('hello' = -123, 'world'))") + self.validate_identity("CAST(x AS FixedString(1))") + self.validate_identity("CAST(x AS LowCardinality(FixedString))") self.validate_identity("SELECT isNaN(1.0)") self.validate_identity("SELECT startsWith('Spider-Man', 'Spi')") self.validate_identity("SELECT xor(TRUE, FALSE)") diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 14f7cd0..38a7952 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -11,6 +11,12 @@ class TestDatabricks(Validator): self.validate_identity("CREATE FUNCTION a AS b") self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1") self.validate_identity("CREATE TABLE foo (x DATE GENERATED ALWAYS AS (CAST(y AS DATE)))") + self.validate_identity( + "SELECT * FROM sales UNPIVOT INCLUDE NULLS (sales FOR quarter IN (q1 AS `Jan-Mar`))" + ) + self.validate_identity( + "SELECT * FROM sales UNPIVOT EXCLUDE NULLS (sales FOR quarter IN (q1 AS `Jan-Mar`))" + ) self.validate_all( "CREATE TABLE foo (x INT GENERATED ALWAYS AS (YEAR(y)))", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index aaaffab..63f789f 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -90,6 +90,7 @@ class TestDialect(Validator): "snowflake": "CAST(a AS TEXT)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)", + "doris": "CAST(a AS STRING)", }, ) self.validate_all( @@ -169,6 +170,7 @@ class TestDialect(Validator): "snowflake": "CAST(a AS TEXT)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS STRING)", + "doris": "CAST(a AS STRING)", }, ) self.validate_all( @@ -186,6 +188,7 @@ class TestDialect(Validator): "snowflake": "CAST(a AS VARCHAR)", "spark": "CAST(a AS STRING)", "starrocks": "CAST(a AS VARCHAR)", + "doris": "CAST(a AS VARCHAR)", }, ) self.validate_all( @@ -203,6 +206,7 @@ class TestDialect(Validator): "snowflake": "CAST(a AS VARCHAR(3))", "spark": "CAST(a AS VARCHAR(3))", "starrocks": "CAST(a AS VARCHAR(3))", + "doris": "CAST(a AS VARCHAR(3))", }, ) self.validate_all( @@ -221,6 +225,7 @@ class TestDialect(Validator): "spark": "CAST(a AS SMALLINT)", "sqlite": "CAST(a AS INTEGER)", "starrocks": "CAST(a AS SMALLINT)", + "doris": "CAST(a AS SMALLINT)", }, ) self.validate_all( @@ -234,6 +239,7 @@ class TestDialect(Validator): "drill": "CAST(a AS DOUBLE)", "postgres": "CAST(a AS DOUBLE PRECISION)", "redshift": "CAST(a AS DOUBLE PRECISION)", + "doris": "CAST(a AS DOUBLE)", }, ) @@ -267,13 +273,15 @@ class TestDialect(Validator): write={ "starrocks": "CAST(a AS DATETIME)", "redshift": "CAST(a AS TIMESTAMP)", + "doris": "CAST(a AS DATETIME)", }, ) self.validate_all( "CAST(a AS TIMESTAMPTZ)", write={ "starrocks": "CAST(a AS DATETIME)", - "redshift": "CAST(a AS TIMESTAMPTZ)", + "redshift": "CAST(a AS TIMESTAMP WITH TIME ZONE)", + "doris": "CAST(a AS DATETIME)", }, ) self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"}) @@ -402,12 +410,13 @@ class TestDialect(Validator): }, ) self.validate_all( - "STR_TO_UNIX('2020-01-01', '%Y-%M-%d')", + "STR_TO_UNIX('2020-01-01', '%Y-%m-%d')", write={ - "duckdb": "EPOCH(STRPTIME('2020-01-01', '%Y-%M-%d'))", - "hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-mm-dd')", - "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%i-%d'))", - "starrocks": "UNIX_TIMESTAMP('2020-01-01', '%Y-%i-%d')", + "duckdb": "EPOCH(STRPTIME('2020-01-01', '%Y-%m-%d'))", + "hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')", + "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d'))", + "starrocks": "UNIX_TIMESTAMP('2020-01-01', '%Y-%m-%d')", + "doris": "UNIX_TIMESTAMP('2020-01-01', '%Y-%m-%d')", }, ) self.validate_all( @@ -418,6 +427,7 @@ class TestDialect(Validator): "hive": "TO_DATE('2020-01-01')", "presto": "CAST('2020-01-01' AS TIMESTAMP)", "starrocks": "TO_DATE('2020-01-01')", + "doris": "TO_DATE('2020-01-01')", }, ) self.validate_all( @@ -428,6 +438,7 @@ class TestDialect(Validator): "hive": "CAST('2020-01-01' AS TIMESTAMP)", "presto": "CAST('2020-01-01' AS TIMESTAMP)", "sqlite": "'2020-01-01'", + "doris": "CAST('2020-01-01' AS DATETIME)", }, ) self.validate_all( @@ -437,6 +448,7 @@ class TestDialect(Validator): "hive": "UNIX_TIMESTAMP('2020-01-01')", "mysql": "UNIX_TIMESTAMP('2020-01-01')", "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %T'))", + "doris": "UNIX_TIMESTAMP('2020-01-01')", }, ) self.validate_all( @@ -449,6 +461,7 @@ class TestDialect(Validator): "postgres": "TO_CHAR(x, 'YYYY-MM-DD')", "presto": "DATE_FORMAT(x, '%Y-%m-%d')", "redshift": "TO_CHAR(x, 'YYYY-MM-DD')", + "doris": "DATE_FORMAT(x, '%Y-%m-%d')", }, ) self.validate_all( @@ -459,6 +472,7 @@ class TestDialect(Validator): "hive": "CAST(x AS STRING)", "presto": "CAST(x AS VARCHAR)", "redshift": "CAST(x AS VARCHAR(MAX))", + "doris": "CAST(x AS STRING)", }, ) self.validate_all( @@ -468,6 +482,7 @@ class TestDialect(Validator): "duckdb": "EPOCH(x)", "hive": "UNIX_TIMESTAMP(x)", "presto": "TO_UNIXTIME(x)", + "doris": "UNIX_TIMESTAMP(x)", }, ) self.validate_all( @@ -476,6 +491,7 @@ class TestDialect(Validator): "duckdb": "SUBSTRING(CAST(x AS TEXT), 1, 10)", "hive": "SUBSTRING(CAST(x AS STRING), 1, 10)", "presto": "SUBSTRING(CAST(x AS VARCHAR), 1, 10)", + "doris": "SUBSTRING(CAST(x AS STRING), 1, 10)", }, ) self.validate_all( @@ -487,6 +503,7 @@ class TestDialect(Validator): "postgres": "CAST(x AS DATE)", "presto": "CAST(CAST(x AS TIMESTAMP) AS DATE)", "snowflake": "CAST(x AS DATE)", + "doris": "TO_DATE(x)", }, ) self.validate_all( @@ -505,6 +522,7 @@ class TestDialect(Validator): "hive": "FROM_UNIXTIME(x, y)", "presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)", "starrocks": "FROM_UNIXTIME(x, y)", + "doris": "FROM_UNIXTIME(x, y)", }, ) self.validate_all( @@ -516,6 +534,7 @@ class TestDialect(Validator): "postgres": "TO_TIMESTAMP(x)", "presto": "FROM_UNIXTIME(x)", "starrocks": "FROM_UNIXTIME(x)", + "doris": "FROM_UNIXTIME(x)", }, ) self.validate_all( @@ -582,6 +601,7 @@ class TestDialect(Validator): "sqlite": "DATE(x, '1 DAY')", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", "tsql": "DATEADD(DAY, 1, x)", + "doris": "DATE_ADD(x, INTERVAL 1 DAY)", }, ) self.validate_all( @@ -595,6 +615,7 @@ class TestDialect(Validator): "presto": "DATE_ADD('day', 1, x)", "spark": "DATE_ADD(x, 1)", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", + "doris": "DATE_ADD(x, INTERVAL 1 DAY)", }, ) self.validate_all( @@ -612,6 +633,7 @@ class TestDialect(Validator): "snowflake": "DATE_TRUNC('day', x)", "starrocks": "DATE_TRUNC('day', x)", "spark": "TRUNC(x, 'day')", + "doris": "DATE_TRUNC(x, 'day')", }, ) self.validate_all( @@ -624,6 +646,7 @@ class TestDialect(Validator): "snowflake": "DATE_TRUNC('day', x)", "starrocks": "DATE_TRUNC('day', x)", "spark": "DATE_TRUNC('day', x)", + "doris": "DATE_TRUNC('day', x)", }, ) self.validate_all( @@ -684,6 +707,7 @@ class TestDialect(Validator): "snowflake": "DATE_TRUNC('year', x)", "starrocks": "DATE_TRUNC('year', x)", "spark": "TRUNC(x, 'year')", + "doris": "DATE_TRUNC(x, 'year')", }, ) self.validate_all( @@ -698,6 +722,7 @@ class TestDialect(Validator): write={ "bigquery": "TIMESTAMP_TRUNC(x, year)", "spark": "DATE_TRUNC('year', x)", + "doris": "DATE_TRUNC(x, 'year')", }, ) self.validate_all( @@ -719,6 +744,7 @@ class TestDialect(Validator): "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)", "presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%T') AS DATE)", "spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')", + "doris": "STR_TO_DATE(x, '%Y-%m-%dT%T')", }, ) self.validate_all( @@ -730,6 +756,7 @@ class TestDialect(Validator): "hive": "CAST(x AS DATE)", "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", "spark": "TO_DATE(x)", + "doris": "STR_TO_DATE(x, '%Y-%m-%d')", }, ) self.validate_all( @@ -784,6 +811,7 @@ class TestDialect(Validator): "mysql": "CAST('2022-01-01' AS TIMESTAMP)", "starrocks": "CAST('2022-01-01' AS DATETIME)", "hive": "CAST('2022-01-01' AS TIMESTAMP)", + "doris": "CAST('2022-01-01' AS DATETIME)", }, ) self.validate_all( @@ -792,6 +820,7 @@ class TestDialect(Validator): "mysql": "TIMESTAMP('2022-01-01')", "starrocks": "TIMESTAMP('2022-01-01')", "hive": "TIMESTAMP('2022-01-01')", + "doris": "TIMESTAMP('2022-01-01')", }, ) @@ -807,6 +836,7 @@ class TestDialect(Validator): "mysql", "presto", "starrocks", + "doris", ) }, write={ @@ -820,6 +850,7 @@ class TestDialect(Validator): "hive", "spark", "starrocks", + "doris", ) }, ) @@ -886,6 +917,7 @@ class TestDialect(Validator): "postgres": "x->'y'", "presto": "JSON_EXTRACT(x, 'y')", "starrocks": "x -> 'y'", + "doris": "x -> 'y'", }, write={ "mysql": "JSON_EXTRACT(x, 'y')", @@ -893,6 +925,7 @@ class TestDialect(Validator): "postgres": "x -> 'y'", "presto": "JSON_EXTRACT(x, 'y')", "starrocks": "x -> 'y'", + "doris": "x -> 'y'", }, ) self.validate_all( @@ -1115,6 +1148,7 @@ class TestDialect(Validator): "sqlite": "LOWER(x) LIKE '%y'", "starrocks": "LOWER(x) LIKE '%y'", "trino": "LOWER(x) LIKE '%y'", + "doris": "LOWER(x) LIKE '%y'", }, ) self.validate_all( diff --git a/tests/dialects/test_doris.py b/tests/dialects/test_doris.py new file mode 100644 index 0000000..63325a6 --- /dev/null +++ b/tests/dialects/test_doris.py @@ -0,0 +1,20 @@ +from tests.dialects.test_dialect import Validator + + +class TestDoris(Validator): + dialect = "doris" + + def test_identity(self): + self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") + self.validate_identity("SELECT APPROX_COUNT_DISTINCT(a) FROM x") + + def test_time(self): + self.validate_identity("TIMESTAMP('2022-01-01')") + + def test_regex(self): + self.validate_all( + "SELECT REGEXP_LIKE(abc, '%foo%')", + write={ + "doris": "SELECT REGEXP(abc, '%foo%')", + }, + ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 5c35d8f..c33c899 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -337,6 +337,8 @@ class TestDuckDB(Validator): unsupported_level=ErrorLevel.IMMEDIATE, ) + self.validate_identity("SELECT ISNAN(x)") + def test_time(self): self.validate_identity("SELECT CURRENT_DATE") self.validate_identity("SELECT CURRENT_TIMESTAMP") @@ -399,7 +401,7 @@ class TestDuckDB(Validator): "bigquery": "TIME_TO_STR(x, '%y-%-m-%S')", "duckdb": "STRFTIME(x, '%y-%-m-%S')", "postgres": "TO_CHAR(x, 'YY-FMMM-SS')", - "presto": "DATE_FORMAT(x, '%y-%c-%S')", + "presto": "DATE_FORMAT(x, '%y-%c-%s')", "spark": "DATE_FORMAT(x, 'yy-M-ss')", }, ) @@ -497,8 +499,12 @@ class TestDuckDB(Validator): self.validate_identity("CAST(x AS USMALLINT)") self.validate_identity("CAST(x AS UTINYINT)") self.validate_identity("CAST(x AS TEXT)") + self.validate_identity("CAST(x AS INT128)") + self.validate_identity("CAST(x AS DOUBLE)") + self.validate_identity("CAST(x AS DECIMAL(15, 4))") - self.validate_all("CAST(x AS NUMERIC)", write={"duckdb": "CAST(x AS DOUBLE)"}) + self.validate_all("CAST(x AS NUMERIC(1, 2))", write={"duckdb": "CAST(x AS DECIMAL(1, 2))"}) + self.validate_all("CAST(x AS HUGEINT)", write={"duckdb": "CAST(x AS INT128)"}) self.validate_all("CAST(x AS CHAR)", write={"duckdb": "CAST(x AS TEXT)"}) self.validate_all("CAST(x AS BPCHAR)", write={"duckdb": "CAST(x AS TEXT)"}) self.validate_all("CAST(x AS STRING)", write={"duckdb": "CAST(x AS TEXT)"}) @@ -514,6 +520,20 @@ class TestDuckDB(Validator): self.validate_all("CAST(x AS VARBINARY)", write={"duckdb": "CAST(x AS BLOB)"}) self.validate_all("CAST(x AS LOGICAL)", write={"duckdb": "CAST(x AS BOOLEAN)"}) self.validate_all( + "CAST(x AS NUMERIC)", + write={ + "duckdb": "CAST(x AS DECIMAL(18, 3))", + "postgres": "CAST(x AS DECIMAL(18, 3))", + }, + ) + self.validate_all( + "CAST(x AS DECIMAL)", + write={ + "duckdb": "CAST(x AS DECIMAL(18, 3))", + "postgres": "CAST(x AS DECIMAL(18, 3))", + }, + ) + self.validate_all( "CAST(x AS BIT)", read={ "duckdb": "CAST(x AS BITSTRING)", diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index d021d62..d60f09d 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -103,6 +103,7 @@ class TestMySQL(Validator): self.validate_identity("@@GLOBAL.max_connections") self.validate_identity("CREATE TABLE A LIKE B") self.validate_identity("SELECT * FROM t1, t2 FOR SHARE OF t1, t2 SKIP LOCKED") + self.validate_identity("SELECT a || b", "SELECT a OR b") self.validate_identity( """SELECT * FROM foo WHERE 3 MEMBER OF(JSON_EXTRACT(info, '$.value'))""" ) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index be34d8c..a7719a9 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -9,6 +9,10 @@ class TestPostgres(Validator): dialect = "postgres" def test_ddl(self): + self.validate_identity( + "CREATE TABLE test (x TIMESTAMP WITHOUT TIME ZONE[][])", + "CREATE TABLE test (x TIMESTAMP[][])", + ) self.validate_identity("CREATE TABLE test (elems JSONB[])") self.validate_identity("CREATE TABLE public.y (x TSTZRANGE NOT NULL)") self.validate_identity("CREATE TABLE test (foo HSTORE)") diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index ec1ad30..5091540 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -1,6 +1,6 @@ from unittest import mock -from sqlglot import UnsupportedError +from sqlglot import UnsupportedError, exp, parse_one from tests.dialects.test_dialect import Validator @@ -8,6 +8,23 @@ class TestPresto(Validator): dialect = "presto" def test_cast(self): + self.validate_identity("CAST(x AS IPADDRESS)") + self.validate_identity("CAST(x AS IPPREFIX)") + + self.validate_all( + "CAST(x AS INTERVAL YEAR TO MONTH)", + write={ + "oracle": "CAST(x AS INTERVAL YEAR TO MONTH)", + "presto": "CAST(x AS INTERVAL YEAR TO MONTH)", + }, + ) + self.validate_all( + "CAST(x AS INTERVAL DAY TO SECOND)", + write={ + "oracle": "CAST(x AS INTERVAL DAY TO SECOND)", + "presto": "CAST(x AS INTERVAL DAY TO SECOND)", + }, + ) self.validate_all( "SELECT CAST('10C' AS INTEGER)", read={ @@ -100,17 +117,24 @@ class TestPresto(Validator): }, ) self.validate_all( + "CAST(x AS TIME(5) WITH TIME ZONE)", + write={ + "duckdb": "CAST(x AS TIMETZ)", + "postgres": "CAST(x AS TIMETZ(5))", + "presto": "CAST(x AS TIME(5) WITH TIME ZONE)", + "redshift": "CAST(x AS TIME(5) WITH TIME ZONE)", + }, + ) + self.validate_all( "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", write={ "bigquery": "CAST(x AS TIMESTAMP)", - "duckdb": "CAST(x AS TIMESTAMPTZ(9))", + "duckdb": "CAST(x AS TIMESTAMPTZ)", "presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", "hive": "CAST(x AS TIMESTAMP)", "spark": "CAST(x AS TIMESTAMP)", }, ) - self.validate_identity("CAST(x AS IPADDRESS)") - self.validate_identity("CAST(x AS IPPREFIX)") def test_regex(self): self.validate_all( @@ -179,6 +203,9 @@ class TestPresto(Validator): ) def test_time(self): + expr = parse_one("TIME(7) WITH TIME ZONE", into=exp.DataType, read="presto") + self.assertEqual(expr.this, exp.DataType.Type.TIMETZ) + self.validate_identity("FROM_UNIXTIME(a, b)") self.validate_identity("FROM_UNIXTIME(a, b, c)") self.validate_identity("TRIM(a, b)") diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 96e9e20..3af27d4 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -6,6 +6,26 @@ class TestRedshift(Validator): def test_redshift(self): self.validate_all( + "SELECT CAST('01:03:05.124' AS TIME(2) WITH TIME ZONE)", + read={ + "postgres": "SELECT CAST('01:03:05.124' AS TIMETZ(2))", + }, + write={ + "postgres": "SELECT CAST('01:03:05.124' AS TIMETZ(2))", + "redshift": "SELECT CAST('01:03:05.124' AS TIME(2) WITH TIME ZONE)", + }, + ) + self.validate_all( + "SELECT CAST('2020-02-02 01:03:05.124' AS TIMESTAMP(2) WITH TIME ZONE)", + read={ + "postgres": "SELECT CAST('2020-02-02 01:03:05.124' AS TIMESTAMPTZ(2))", + }, + write={ + "postgres": "SELECT CAST('2020-02-02 01:03:05.124' AS TIMESTAMPTZ(2))", + "redshift": "SELECT CAST('2020-02-02 01:03:05.124' AS TIMESTAMP(2) WITH TIME ZONE)", + }, + ) + self.validate_all( "SELECT INTERVAL '5 days'", read={ "": "SELECT INTERVAL '5' days", diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 0690421..b21d65d 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -853,4 +853,5 @@ SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1) /* comment1 */ INSERT INTO x /* comment2 */ VALUES (1, 2, 3) /* comment1 */ UPDATE tbl /* comment2 */ SET x = 2 WHERE x < 2 /* comment1 */ DELETE FROM x /* comment2 */ WHERE y > 1 +/* comment */ CREATE TABLE foo AS SELECT 1 SELECT next, transform, if diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index 74572d2..b318a92 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -944,3 +944,9 @@ SELECT FROM "m" JOIN "n" AS "foo"("a") ON "m"."a" = "foo"."a"; + +# title: reduction of string concatenation that uses CONCAT(..), || and + +# execute: false +SELECT CONCAT('a', 'b') || CONCAT(CONCAT('c', 'd'), CONCAT('e', 'f')) + ('g' || 'h' || 'i'); +SELECT + 'abcdefghi' AS "_col_0"; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 7ba8e54..3224a83 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -431,6 +431,14 @@ SELECT x.a AS a, i.b AS b FROM x AS x CROSS JOIN UNNEST(SPLIT(x.b, ',')) AS i(b) SELECT c FROM (SELECT 1 a) AS x LATERAL VIEW EXPLODE(a) AS c; SELECT _q_0.c AS c FROM (SELECT 1 AS a) AS x LATERAL VIEW EXPLODE(x.a) _q_0 AS c; +# execute: false +SELECT * FROM foo(bar) AS t(c1, c2, c3); +SELECT t.c1 AS c1, t.c2 AS c2, t.c3 AS c3 FROM FOO(bar) AS t(c1, c2, c3); + +# execute: false +SELECT c1, c3 FROM foo(bar) AS t(c1, c2, c3); +SELECT t.c1 AS c1, t.c3 AS c3 FROM FOO(bar) AS t(c1, c2, c3); + -------------------------------------- -- Window functions -------------------------------------- diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index f821575..3ed02cd 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -240,9 +240,18 @@ A AND B AND C; SELECT x WHERE TRUE; SELECT x; -SELECT x FROM y LEFT JOIN z ON TRUE; +SELECT x FROM y JOIN z ON TRUE; +SELECT x FROM y CROSS JOIN z; + +SELECT x FROM y RIGHT JOIN z ON TRUE; SELECT x FROM y CROSS JOIN z; +SELECT x FROM y LEFT JOIN z ON TRUE; +SELECT x FROM y LEFT JOIN z ON TRUE; + +SELECT x FROM y FULL OUTER JOIN z ON TRUE; +SELECT x FROM y FULL OUTER JOIN z ON TRUE; + SELECT x FROM y JOIN z USING (x); SELECT x FROM y JOIN z USING (x); @@ -602,3 +611,57 @@ TRUE; x = 2018 OR x <> 2018; x <> 2018 OR x = 2018; + +-------------------------------------- +-- Coalesce +-------------------------------------- +COALESCE(x); +x; + +COALESCE(x, 1) = 2; +x = 2 AND NOT x IS NULL; + +2 = COALESCE(x, 1); +2 = x AND NOT x IS NULL; + +COALESCE(x, 1, 1) = 1 + 1; +x = 2 AND NOT x IS NULL; + +COALESCE(x, 1, 2) = 2; +x = 2 AND NOT x IS NULL; + +COALESCE(x, 3) <= 2; +x <= 2 AND NOT x IS NULL; + +COALESCE(x, 1) <> 2; +x <> 2 OR x IS NULL; + +COALESCE(x, 1) <= 2; +x <= 2 OR x IS NULL; + +COALESCE(x, 1) = 1; +x = 1 OR x IS NULL; + +COALESCE(x, 1) IS NULL; +FALSE; + +-------------------------------------- +-- CONCAT +-------------------------------------- +CONCAT(x, y); +CONCAT(x, y); + +CONCAT(x); +x; + +CONCAT('a', 'b', 'c'); +'abc'; + +CONCAT('a', x, y, 'b', 'c'); +CONCAT('a', x, y, 'bc'); + +'a' || 'b'; +'ab'; + +'a' || 'b' || x; +CONCAT('ab', x); diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index 1205c33..f50cf0b 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -857,7 +857,7 @@ WITH "salesreturns" AS ( ), "cte_10" AS ( SELECT 'catalog channel' AS "channel", - 'catalog_page' || "csr"."cp_catalog_page_id" AS "id", + CONCAT('catalog_page', "csr"."cp_catalog_page_id") AS "id", "csr"."sales" AS "sales", "csr"."returns1" AS "returns1", "csr"."profit" - "csr"."profit_loss" AS "profit" @@ -865,7 +865,7 @@ WITH "salesreturns" AS ( UNION ALL SELECT 'web channel' AS "channel", - 'web_site' || "wsr"."web_site_id" AS "id", + CONCAT('web_site', "wsr"."web_site_id") AS "id", "wsr"."sales" AS "sales", "wsr"."returns1" AS "returns1", "wsr"."profit" - "wsr"."profit_loss" AS "profit" @@ -873,7 +873,7 @@ WITH "salesreturns" AS ( ), "x" AS ( SELECT 'store channel' AS "channel", - 'store' || "ssr"."s_store_id" AS "id", + CONCAT('store', "ssr"."s_store_id") AS "id", "ssr"."sales" AS "sales", "ssr"."returns1" AS "returns1", "ssr"."profit" - "ssr"."profit_loss" AS "profit" @@ -8611,7 +8611,7 @@ WITH "date_dim_2" AS ( "warehouse"."w_county" AS "w_county", "warehouse"."w_state" AS "w_state", "warehouse"."w_country" AS "w_country", - 'ZOUROS' || ',' || 'ZHOU' AS "ship_carriers", + 'ZOUROS,ZHOU' AS "ship_carriers", "date_dim"."d_year" AS "year1", SUM( CASE @@ -8806,7 +8806,7 @@ WITH "date_dim_2" AS ( "warehouse"."w_county" AS "w_county", "warehouse"."w_state" AS "w_state", "warehouse"."w_country" AS "w_country", - 'ZOUROS' || ',' || 'ZHOU' AS "ship_carriers", + 'ZOUROS,ZHOU' AS "ship_carriers", "date_dim"."d_year" AS "year1", SUM( CASE @@ -10833,9 +10833,11 @@ LEFT JOIN "ws" AND "ws"."ws_item_sk" = "ss"."ss_item_sk" AND "ws"."ws_sold_year" = "ss"."ss_sold_year" WHERE - "ss"."ss_sold_year" = 1999 - AND COALESCE("cs"."cs_qty", 0) > 0 - AND COALESCE("ws"."ws_qty", 0) > 0 + "cs"."cs_qty" > 0 + AND "ss"."ss_sold_year" = 1999 + AND "ws"."ws_qty" > 0 + AND NOT "cs"."cs_qty" IS NULL + AND NOT "ws"."ws_qty" IS NULL ORDER BY "ss_item_sk", "ss"."ss_qty" DESC, @@ -11121,7 +11123,7 @@ WITH "date_dim_2" AS ( ), "cte_4" AS ( SELECT 'catalog channel' AS "channel", - 'catalog_page' || "csr"."catalog_page_id" AS "id", + CONCAT('catalog_page', "csr"."catalog_page_id") AS "id", "csr"."sales" AS "sales", "csr"."returns1" AS "returns1", "csr"."profit" AS "profit" @@ -11129,7 +11131,7 @@ WITH "date_dim_2" AS ( UNION ALL SELECT 'web channel' AS "channel", - 'web_site' || "wsr"."web_site_id" AS "id", + CONCAT('web_site', "wsr"."web_site_id") AS "id", "wsr"."sales" AS "sales", "wsr"."returns1" AS "returns1", "wsr"."profit" AS "profit" @@ -11137,7 +11139,7 @@ WITH "date_dim_2" AS ( ), "x" AS ( SELECT 'store channel' AS "channel", - 'store' || "ssr"."store_id" AS "id", + CONCAT('store', "ssr"."store_id") AS "id", "ssr"."sales" AS "sales", "ssr"."returns1" AS "returns1", "ssr"."profit" AS "profit" @@ -11569,7 +11571,7 @@ ORDER BY c_customer_id LIMIT 100; SELECT "customer"."c_customer_id" AS "customer_id", - "customer"."c_last_name" || ', ' || "customer"."c_first_name" AS "customername" + CONCAT("customer"."c_last_name", ', ', "customer"."c_first_name") AS "customername" FROM "customer" AS "customer" JOIN "customer_address" AS "customer_address" ON "customer"."c_current_addr_sk" = "customer_address"."ca_address_sk" diff --git a/tests/test_executor.py b/tests/test_executor.py index 9dacbbf..ffe0229 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -723,3 +723,24 @@ class TestExecutor(unittest.TestCase): result = execute(sql, tables=tables) self.assertEqual(result.columns, columns) self.assertEqual(result.rows, expected) + + def test_dict_values(self): + tables = { + "foo": [{"raw": {"name": "Hello, World"}}], + } + result = execute("SELECT raw:name AS name FROM foo", read="snowflake", tables=tables) + + self.assertEqual(result.columns, ("NAME",)) + self.assertEqual(result.rows, [("Hello, World",)]) + + tables = { + '"ITEM"': [ + {"id": 1, "attributes": {"flavor": "cherry", "taste": "sweet"}}, + {"id": 2, "attributes": {"flavor": "lime", "taste": "sour"}}, + {"id": 3, "attributes": {"flavor": "apple", "taste": None}}, + ] + } + result = execute("SELECT i.attributes.flavor FROM `ITEM` i", read="bigquery", tables=tables) + + self.assertEqual(result.columns, ("flavor",)) + self.assertEqual(result.rows, [("cherry",), ("lime",), ("apple",)]) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 3fe53e4..a1bd309 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -277,6 +277,20 @@ class TestOptimizer(unittest.TestCase): self.assertEqual(exp.true(), optimizer.simplify.simplify(expression)) self.assertEqual(exp.true(), optimizer.simplify.simplify(expression.this)) + # CONCAT in (e.g.) Presto is parsed as Concat instead of SafeConcat which is the default type + # This test checks that simplify_concat preserves the corresponding expression types. + concat = parse_one("CONCAT('a', x, 'b', 'c')", read="presto") + simplified_concat = optimizer.simplify.simplify(concat) + + safe_concat = parse_one("CONCAT('a', x, 'b', 'c')") + simplified_safe_concat = optimizer.simplify.simplify(safe_concat) + + self.assertIs(type(simplified_concat), exp.Concat) + self.assertIs(type(simplified_safe_concat), exp.SafeConcat) + + self.assertEqual("CONCAT('a', x, 'bc')", simplified_concat.sql(dialect="presto")) + self.assertEqual("CONCAT('a', x, 'bc')", simplified_safe_concat.sql()) + def test_unnest_subqueries(self): self.check_file( "unnest_subqueries", diff --git a/tests/test_tokens.py b/tests/test_tokens.py index e6e984d..f3343e7 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -17,6 +17,7 @@ class TestTokens(unittest.TestCase): ("foo /*comment 1*/ /*comment 2*/", ["comment 1", "comment 2"]), ("foo\n-- comment", [" comment"]), ("1 /*/2 */", ["/2 "]), + ("1\n/*comment*/;", ["comment"]), ] for sql, comment in sql_comment: diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 3f284c9..e58ed86 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -90,6 +90,19 @@ class TestTranspile(unittest.TestCase): self.validate("SELECT 3>=3", "SELECT 3 >= 3") def test_comments(self): + self.validate("SELECT\n foo\n/* comments */\n;", "SELECT foo /* comments */") + self.validate( + "SELECT * FROM a INNER /* comments */ JOIN b", + "SELECT * FROM a /* comments */ INNER JOIN b", + ) + self.validate( + "SELECT * FROM a LEFT /* comment 1 */ OUTER /* comment 2 */ JOIN b", + "SELECT * FROM a /* comment 1 */ /* comment 2 */ LEFT OUTER JOIN b", + ) + self.validate( + "SELECT CASE /* test */ WHEN a THEN b ELSE c END", + "SELECT CASE WHEN a THEN b ELSE c END /* test */", + ) self.validate("SELECT 1 /*/2 */", "SELECT 1 /* /2 */") self.validate("SELECT */*comment*/", "SELECT * /* comment */") self.validate( @@ -308,6 +321,7 @@ DROP TABLE IF EXISTS db.tba""", ) self.validate( """ + -- comment4 CREATE TABLE db.tba AS SELECT a, b, c FROM tb_01 @@ -316,8 +330,10 @@ DROP TABLE IF EXISTS db.tba""", a = 1 AND b = 2 --comment6 -- and c = 1 -- comment7 + ; """, - """CREATE TABLE db.tba AS + """/* comment4 */ +CREATE TABLE db.tba AS SELECT a, b, @@ -329,6 +345,44 @@ WHERE /* comment7 */""", pretty=True, ) + self.validate( + """ + SELECT + -- This is testing comments + col, + -- 2nd testing comments + CASE WHEN a THEN b ELSE c END as d + FROM t + """, + """SELECT + col, /* This is testing comments */ + CASE WHEN a THEN b ELSE c END /* 2nd testing comments */ AS d +FROM t""", + pretty=True, + ) + self.validate( + """ + SELECT * FROM a + -- comments + INNER JOIN b + """, + """SELECT + * +FROM a +/* comments */ +INNER JOIN b""", + pretty=True, + ) + self.validate( + "SELECT * FROM a LEFT /* comment 1 */ OUTER /* comment 2 */ JOIN b", + """SELECT + * +FROM a +/* comment 1 */ +/* comment 2 */ +LEFT OUTER JOIN b""", + pretty=True, + ) def test_types(self): self.validate("INT 1", "CAST(1 AS INT)") |