diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-03-08 07:22:15 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-03-08 07:22:15 +0000 |
commit | 5b1ac5070c43c40a2b5bbc991198b0dddf45dc75 (patch) | |
tree | ed329138d5e8e5c9d5164b5c853d6f40a116f4d6 /tests | |
parent | Releasing debian version 11.3.0-1. (diff) | |
download | sqlglot-5b1ac5070c43c40a2b5bbc991198b0dddf45dc75.tar.xz sqlglot-5b1ac5070c43c40a2b5bbc991198b0dddf45dc75.zip |
Merging upstream version 11.3.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests')
-rw-r--r-- | tests/dialects/test_bigquery.py | 17 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 34 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 26 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 6 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 88 | ||||
-rw-r--r-- | tests/dialects/test_sqlite.py | 7 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 21 | ||||
-rw-r--r-- | tests/test_build.py | 6 | ||||
-rw-r--r-- | tests/test_optimizer.py | 5 | ||||
-rw-r--r-- | tests/test_transpile.py | 8 |
13 files changed, 197 insertions, 31 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 7b18a6a..22387da 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -8,6 +8,8 @@ class TestBigQuery(Validator): def test_bigquery(self): self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])") self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))") + + self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"}) self.validate_all( "REGEXP_CONTAINS('foo', '.*')", read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"}, @@ -390,3 +392,18 @@ class TestBigQuery(Validator): "bigquery": "INSERT INTO test (cola, colb) VALUES (CAST(7 AS STRING), CAST(14 AS STRING))", }, ) + + def test_merge(self): + self.validate_all( + """ + MERGE dataset.Inventory T + USING dataset.NewArrivals S ON FALSE + WHEN NOT MATCHED BY TARGET AND product LIKE '%a%' + THEN DELETE + WHEN NOT MATCHED BY SOURCE AND product LIKE '%b%' + THEN DELETE""", + write={ + "bigquery": "MERGE INTO dataset.Inventory AS T USING dataset.NewArrivals AS S ON FALSE WHEN NOT MATCHED AND product LIKE '%a%' THEN DELETE WHEN NOT MATCHED BY SOURCE AND product LIKE '%b%' THEN DELETE", + "snowflake": "MERGE INTO dataset.Inventory AS T USING dataset.NewArrivals AS S ON FALSE WHEN NOT MATCHED AND product LIKE '%a%' THEN DELETE WHEN NOT MATCHED AND product LIKE '%b%' THEN DELETE", + }, + ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 5054d94..69563cb 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -510,6 +510,7 @@ class TestDialect(Validator): "DATE_ADD(x, 1, 'day')", read={ "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", + "snowflake": "DATEADD('day', 1, x)", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", }, write={ diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 0efb7e7..a1a0090 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -75,6 +75,40 @@ class TestDuckDB(Validator): }, ) + def test_sample(self): + self.validate_all( + "SELECT * FROM tbl USING SAMPLE 5", + write={"duckdb": "SELECT * FROM tbl USING SAMPLE (5)"}, + ) + self.validate_all( + "SELECT * FROM tbl USING SAMPLE 10%", + write={"duckdb": "SELECT * FROM tbl USING SAMPLE (10 PERCENT)"}, + ) + self.validate_all( + "SELECT * FROM tbl USING SAMPLE 10 PERCENT (bernoulli)", + write={"duckdb": "SELECT * FROM tbl USING SAMPLE BERNOULLI (10 PERCENT)"}, + ) + self.validate_all( + "SELECT * FROM tbl USING SAMPLE reservoir(50 ROWS) REPEATABLE (100)", + write={"duckdb": "SELECT * FROM tbl USING SAMPLE RESERVOIR (50 ROWS) REPEATABLE (100)"}, + ) + self.validate_all( + "SELECT * FROM tbl USING SAMPLE 10% (system, 377)", + write={"duckdb": "SELECT * FROM tbl USING SAMPLE SYSTEM (10 PERCENT) REPEATABLE (377)"}, + ) + self.validate_all( + "SELECT * FROM tbl TABLESAMPLE RESERVOIR(20%), tbl2 WHERE tbl.i=tbl2.i", + write={ + "duckdb": "SELECT * FROM tbl TABLESAMPLE RESERVOIR (20 PERCENT), tbl2 WHERE tbl.i = tbl2.i" + }, + ) + self.validate_all( + "SELECT * FROM tbl, tbl2 WHERE tbl.i=tbl2.i USING SAMPLE RESERVOIR(20%)", + write={ + "duckdb": "SELECT * FROM tbl, tbl2 WHERE tbl.i = tbl2.i USING SAMPLE RESERVOIR (20 PERCENT)" + }, + ) + def test_duckdb(self): self.validate_identity("SELECT {'a': 1} AS x") self.validate_identity("SELECT {'a': {'b': {'c': 1}}, 'd': {'e': 2}} AS x") diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index a067764..8484805 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -482,9 +482,9 @@ class TestHive(Validator): self.validate_all( "SELECT * FROM x TABLESAMPLE(10) y", write={ - "presto": "SELECT * FROM x AS y TABLESAMPLE(10)", - "hive": "SELECT * FROM x TABLESAMPLE(10) AS y", - "spark": "SELECT * FROM x TABLESAMPLE(10) AS y", + "presto": "SELECT * FROM x AS y TABLESAMPLE (10)", + "hive": "SELECT * FROM x TABLESAMPLE (10) AS y", + "spark": "SELECT * FROM x TABLESAMPLE (10) AS y", }, ) self.validate_all( @@ -626,25 +626,25 @@ class TestHive(Validator): }, ) self.validate_all( - "SELECT * FROM x TABLESAMPLE(1) AS foo", + "SELECT * FROM x TABLESAMPLE (1) AS foo", read={ - "presto": "SELECT * FROM x AS foo TABLESAMPLE(1)", + "presto": "SELECT * FROM x AS foo TABLESAMPLE (1)", }, write={ - "presto": "SELECT * FROM x AS foo TABLESAMPLE(1)", - "hive": "SELECT * FROM x TABLESAMPLE(1) AS foo", - "spark": "SELECT * FROM x TABLESAMPLE(1) AS foo", + "presto": "SELECT * FROM x AS foo TABLESAMPLE (1)", + "hive": "SELECT * FROM x TABLESAMPLE (1) AS foo", + "spark": "SELECT * FROM x TABLESAMPLE (1) AS foo", }, ) self.validate_all( - "SELECT * FROM x TABLESAMPLE(1) AS foo", + "SELECT * FROM x TABLESAMPLE (1) AS foo", read={ - "presto": "SELECT * FROM x AS foo TABLESAMPLE(1)", + "presto": "SELECT * FROM x AS foo TABLESAMPLE (1)", }, write={ - "presto": "SELECT * FROM x AS foo TABLESAMPLE(1)", - "hive": "SELECT * FROM x TABLESAMPLE(1) AS foo", - "spark": "SELECT * FROM x TABLESAMPLE(1) AS foo", + "presto": "SELECT * FROM x AS foo TABLESAMPLE (1)", + "hive": "SELECT * FROM x TABLESAMPLE (1) AS foo", + "spark": "SELECT * FROM x TABLESAMPLE (1) AS foo", }, ) self.validate_all( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 192f9fc..5f8560a 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -64,6 +64,7 @@ class TestMySQL(Validator): self.validate_identity("SET TRANSACTION READ ONLY") self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE") self.validate_identity("SELECT SCHEMA()") + self.validate_identity("SELECT DATABASE()") def test_types(self): self.validate_all( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 5c4a23e..0881a89 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -11,6 +11,14 @@ class TestPostgres(Validator): self.validate_identity("CREATE TABLE test (foo JSONB)") self.validate_identity("CREATE TABLE test (foo VARCHAR(64)[])") + self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING a") + self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING a, b") + self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING *") + self.validate_identity( + "DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid RETURNING a" + ) + self.validate_identity("UPDATE tbl_name SET foo = 123 RETURNING a") + self.validate_all( "CREATE OR REPLACE FUNCTION function_name (input_a character varying DEFAULT NULL::character varying)", write={ diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 640706a..ad83b99 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -85,6 +85,12 @@ class TestRedshift(Validator): "presto": "DATE_DIFF(d, a, b)", }, ) + self.validate_all( + "SELECT TOP 1 x FROM y", + write={ + "redshift": "SELECT x FROM y LIMIT 1", + }, + ) def test_identity(self): self.validate_identity( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 3358227..c28c58d 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -63,6 +63,28 @@ class TestSnowflake(Validator): }, ) self.validate_all( + "ZEROIFNULL(foo)", + write={ + "snowflake": "IFF(foo IS NULL, 0, foo)", + "sqlite": "CASE WHEN foo IS NULL THEN 0 ELSE foo END", + "presto": "IF(foo IS NULL, 0, foo)", + "spark": "IF(foo IS NULL, 0, foo)", + "hive": "IF(foo IS NULL, 0, foo)", + "duckdb": "CASE WHEN foo IS NULL THEN 0 ELSE foo END", + }, + ) + self.validate_all( + "NULLIFZERO(foo)", + write={ + "snowflake": "IFF(foo = 0, NULL, foo)", + "sqlite": "CASE WHEN foo = 0 THEN NULL ELSE foo END", + "presto": "IF(foo = 0, NULL, foo)", + "spark": "IF(foo = 0, NULL, foo)", + "hive": "IF(foo = 0, NULL, foo)", + "duckdb": "CASE WHEN foo = 0 THEN NULL ELSE foo END", + }, + ) + self.validate_all( "CREATE OR REPLACE TEMPORARY TABLE x (y NUMBER IDENTITY(0, 1))", write={ "snowflake": "CREATE OR REPLACE TEMPORARY TABLE x (y DECIMAL AUTOINCREMENT START 0 INCREMENT 1)", @@ -280,12 +302,6 @@ class TestSnowflake(Validator): }, ) self.validate_all( - "SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)", - write={ - "snowflake": "SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)", - }, - ) - self.validate_all( "SELECT a FROM test pivot", write={ "snowflake": "SELECT a FROM test AS pivot", @@ -356,6 +372,51 @@ class TestSnowflake(Validator): }, ) + def test_sample(self): + self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)") + self.validate_identity("SELECT * FROM testtable TABLESAMPLE (100)") + self.validate_identity( + "SELECT i, j FROM table1 AS t1 INNER JOIN table2 AS t2 TABLESAMPLE (50) WHERE t2.j = t1.i" + ) + self.validate_identity( + "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1)" + ) + self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)") + self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)") + + self.validate_all( + "SELECT * FROM testtable SAMPLE (10)", + write={"snowflake": "SELECT * FROM testtable TABLESAMPLE (10)"}, + ) + self.validate_all( + "SELECT * FROM testtable SAMPLE ROW (0)", + write={"snowflake": "SELECT * FROM testtable TABLESAMPLE ROW (0)"}, + ) + self.validate_all( + "SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)", + write={ + "snowflake": "SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)", + }, + ) + self.validate_all( + """ + SELECT i, j + FROM + table1 AS t1 SAMPLE (25) -- 25% of rows in table1 + INNER JOIN + table2 AS t2 SAMPLE (50) -- 50% of rows in table2 + WHERE t2.j = t1.i""", + write={ + "snowflake": "SELECT i, j FROM table1 AS t1 TABLESAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 TABLESAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i", + }, + ) + self.validate_all( + "SELECT * FROM testtable SAMPLE BLOCK (0.012) REPEATABLE (99992)", + write={ + "snowflake": "SELECT * FROM testtable TABLESAMPLE BLOCK (0.012) SEED (99992)", + }, + ) + def test_timestamps(self): self.validate_identity("SELECT EXTRACT(month FROM a)") @@ -415,6 +476,13 @@ class TestSnowflake(Validator): "presto": "SELECT TO_UNIXTIME(CAST(foo AS TIMESTAMP)) * 1000 AS ddate FROM table_name", }, ) + self.validate_all( + "DATEADD(DAY, 5, CAST('2008-12-25' AS DATE))", + write={ + "bigquery": "DATE_ADD(CAST('2008-12-25' AS DATE), INTERVAL 5 DAY)", + "snowflake": "DATEADD(DAY, 5, CAST('2008-12-25' AS DATE))", + }, + ) def test_semi_structured_types(self): self.validate_identity("SELECT CAST(a AS VARIANT)") @@ -655,6 +723,14 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA }, ) + self.validate_all( + """SELECT $1 AS "_1" FROM VALUES ('a'), ('b')""", + write={ + "snowflake": """SELECT $1 AS "_1" FROM (VALUES ('a'), ('b'))""", + "spark": """SELECT @1 AS `_1` FROM VALUES ('a'), ('b')""", + }, + ) + def test_describe_table(self): self.validate_all( "DESCRIBE TABLE db.table", diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py index 19a88f3..a3e4cc9 100644 --- a/tests/dialects/test_sqlite.py +++ b/tests/dialects/test_sqlite.py @@ -81,6 +81,13 @@ class TestSQLite(Validator): "sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", }, ) + self.validate_all("x", read={"snowflake": "LEAST(x)"}) + self.validate_all("MIN(x)", read={"snowflake": "MIN(x)"}, write={"snowflake": "MIN(x)"}) + self.validate_all( + "MIN(x, y, z)", + read={"snowflake": "LEAST(x, y, z)"}, + write={"snowflake": "LEAST(x, y, z)"}, + ) def test_hexadecimal_literal(self): self.validate_all( diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 0677a05..085880c 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -153,6 +153,7 @@ SUM(ROW() OVER (PARTITION BY x AND y)) CASE WHEN (x > 1) THEN 1 ELSE 0 END CASE (1) WHEN 1 THEN 1 ELSE 0 END CASE 1 WHEN 1 THEN 1 ELSE 0 END +CASE 1 WHEN 1 THEN timestamp ELSE date END x AT TIME ZONE 'UTC' CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo' SET x = 1 @@ -191,6 +192,7 @@ SELECT DISTINCT TIMESTAMP_TRUNC(time_field, MONTH) AS time_value FROM "table" SELECT DISTINCT ON (x) x, y FROM z SELECT DISTINCT ON (x, y + 1) * FROM z SELECT DISTINCT ON (x.y) * FROM z +SELECT DISTINCT FROM_SOMETHING SELECT top.x SELECT TIMESTAMP(DATE_TRUNC(DATE(time_field), MONTH)) AS time_value FROM "table" SELECT GREATEST((3 + 1), LEAST(3, 4)) @@ -295,13 +297,13 @@ SELECT CASE CASE x > 1 WHEN TRUE THEN 1 END WHEN 1 THEN 1 ELSE 2 END SELECT a FROM (SELECT a FROM test) AS x SELECT a FROM (SELECT a FROM (SELECT a FROM test) AS y) AS x SELECT a FROM test WHERE a IN (1, 2, 3) OR b BETWEEN 1 AND 4 -SELECT a FROM test AS x TABLESAMPLE(BUCKET 1 OUT OF 5) -SELECT a FROM test TABLESAMPLE(BUCKET 1 OUT OF 5) -SELECT a FROM test TABLESAMPLE(BUCKET 1 OUT OF 5 ON x) -SELECT a FROM test TABLESAMPLE(BUCKET 1 OUT OF 5 ON RAND()) -SELECT a FROM test TABLESAMPLE(0.1 PERCENT) -SELECT a FROM test TABLESAMPLE(100) -SELECT a FROM test TABLESAMPLE(100 ROWS) +SELECT a FROM test AS x TABLESAMPLE (BUCKET 1 OUT OF 5) +SELECT a FROM test TABLESAMPLE (BUCKET 1 OUT OF 5) +SELECT a FROM test TABLESAMPLE (BUCKET 1 OUT OF 5 ON x) +SELECT a FROM test TABLESAMPLE (BUCKET 1 OUT OF 5 ON RAND()) +SELECT a FROM test TABLESAMPLE (0.1 PERCENT) +SELECT a FROM test TABLESAMPLE (100) +SELECT a FROM test TABLESAMPLE (100 ROWS) SELECT a FROM test TABLESAMPLE BERNOULLI (50) SELECT a FROM test TABLESAMPLE SYSTEM (75) SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) @@ -310,7 +312,7 @@ SELECT a FROM test PIVOT(SOMEAGG(x, y, z) FOR q IN (1)) SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) PIVOT(MAX(b) FOR c IN ('d')) SELECT a FROM (SELECT a, b FROM test) PIVOT(SUM(x) FOR y IN ('z', 'q')) SELECT a FROM test UNPIVOT(x FOR y IN (z, q)) AS x -SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE(0.1) +SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) AS x TABLESAMPLE (0.1) SELECT a FROM test PIVOT(SUM(x) FOR y IN ('z', 'q')) UNPIVOT(x FOR y IN (z, q)) AS x SELECT ABS(a) FROM test SELECT AVG(a) FROM test @@ -466,6 +468,7 @@ SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 AND 3) SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND 3) SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING) SELECT AVG(x) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t +SELECT SUM(x) OVER (PARTITION BY a ORDER BY date ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC) SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x) @@ -708,6 +711,7 @@ SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */ SELECT x FROM a.b.c /* x */, e.f.g /* x */ SELECT FOO(x /* c */) /* FOO */, b /* b */ SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM (VALUES (1 /* c4 */, "test" /* c5 */)) /* c6 */ +INSERT INTO foo SELECT * FROM bar /* comment */ SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b' SELECT x AS INTO FROM bla SELECT * INTO newevent FROM event @@ -731,6 +735,7 @@ ALTER TABLE orders DROP PARTITION(dt = '2014-05-14', country = 'IN') ALTER TABLE orders DROP IF EXISTS PARTITION(dt = '2014-05-14', country = 'IN') ALTER TABLE orders DROP PARTITION(dt = '2014-05-14', country = 'IN'), PARTITION(dt = '2014-05-15', country = 'IN') ALTER TABLE mydataset.mytable DELETE WHERE x = 1 +ALTER TABLE table1 MODIFY COLUMN name1 SET TAG foo='bar' SELECT div.a FROM test_table AS div WITH view AS (SELECT 1 AS x) SELECT * FROM view ARRAY<STRUCT<INT, DOUBLE, ARRAY<INT>>> diff --git a/tests/test_build.py b/tests/test_build.py index fbfbb62..718e471 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -504,6 +504,12 @@ class TestBuild(unittest.TestCase): .window("d AS (PARTITION BY g ORDER BY h)"), "SELECT AVG(a) OVER b, MIN(c) OVER d FROM table WINDOW b AS (PARTITION BY e ORDER BY f), d AS (PARTITION BY g ORDER BY h)", ), + ( + lambda: select("*") + .from_("table") + .qualify("row_number() OVER (PARTITION BY a ORDER BY b) = 1"), + "SELECT * FROM table QUALIFY ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) = 1", + ), ]: with self.subTest(sql): self.assertEqual(expression().sql(dialect[0] if dialect else None), sql) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 8ddd95f..e10d05e 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -364,6 +364,11 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') expression = annotate_types(parse_one("ARRAY(1)::ARRAY<INT>")) self.assertEqual(expression.type, parse_one("ARRAY<INT>", into=exp.DataType)) + expression = annotate_types(parse_one("CAST(x AS INTERVAL)")) + self.assertEqual(expression.type.this, exp.DataType.Type.INTERVAL) + self.assertEqual(expression.this.type.this, exp.DataType.Type.UNKNOWN) + self.assertEqual(expression.args["to"].type.this, exp.DataType.Type.INTERVAL) + def test_cache_annotation(self): expression = annotate_types( parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1") diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 0463aed..6355400 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -202,9 +202,9 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", ) def test_types(self): - self.validate("INT x", "CAST(x AS INT)") - self.validate("VARCHAR x y", "CAST(x AS VARCHAR) AS y") - self.validate("STRING x y", "CAST(x AS TEXT) AS y") + self.validate("INT 1", "CAST(1 AS INT)") + self.validate("VARCHAR 'x' y", "CAST('x' AS VARCHAR) AS y") + self.validate("STRING 'x' y", "CAST('x' AS TEXT) AS y") self.validate("x::INT", "CAST(x AS INT)") self.validate("x::INTEGER", "CAST(x AS INT)") self.validate("x::INT y", "CAST(x AS INT) AS y") @@ -221,7 +221,7 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", self.validate("a NOT BETWEEN b AND c", "NOT a BETWEEN b AND c") self.validate("a NOT IN (1, 2)", "NOT a IN (1, 2)") self.validate("a IS NOT NULL", "NOT a IS NULL") - self.validate("a LIKE TEXT y", "a LIKE CAST(y AS TEXT)") + self.validate("a LIKE TEXT 'y'", "a LIKE CAST('y' AS TEXT)") def test_extract(self): self.validate( |