From beba715b97dd2349e01dde9b077d2535680ebdca Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 10 May 2023 08:44:58 +0200 Subject: Merging upstream version 12.2.0. Signed-off-by: Daniel Baumann --- tests/dataframe/integration/dataframe_validator.py | 6 +- tests/dialects/test_bigquery.py | 38 ++++++- tests/dialects/test_clickhouse.py | 21 ++++ tests/dialects/test_dialect.py | 4 +- tests/dialects/test_duckdb.py | 27 ++++- tests/dialects/test_mysql.py | 126 ++++++++++++++------- tests/dialects/test_oracle.py | 18 ++- tests/dialects/test_postgres.py | 76 ++++++++++--- tests/dialects/test_presto.py | 45 +++++++- tests/dialects/test_redshift.py | 15 +++ tests/dialects/test_snowflake.py | 31 ++--- tests/dialects/test_spark.py | 47 ++++---- tests/dialects/test_starrocks.py | 8 ++ tests/dialects/test_tsql.py | 12 +- tests/fixtures/identity.sql | 6 + tests/fixtures/optimizer/qualify_columns.sql | 7 ++ tests/fixtures/optimizer/qualify_tables.sql | 23 ++++ tests/fixtures/optimizer/tpc-ds/tpc-ds.sql | 16 +++ tests/test_build.py | 19 +++- tests/test_expressions.py | 5 +- tests/test_optimizer.py | 13 +++ tests/test_tokens.py | 11 ++ tests/test_transpile.py | 17 ++- 23 files changed, 466 insertions(+), 125 deletions(-) (limited to 'tests') diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py index 16f8922..c84a342 100644 --- a/tests/dataframe/integration/dataframe_validator.py +++ b/tests/dataframe/integration/dataframe_validator.py @@ -3,17 +3,13 @@ import unittest import warnings import sqlglot -from sqlglot.helper import PYTHON_VERSION from tests.helpers import SKIP_INTEGRATION if t.TYPE_CHECKING: from pyspark.sql import DataFrame as SparkDataFrame -@unittest.skipIf( - SKIP_INTEGRATION or PYTHON_VERSION > (3, 10), - "Skipping Integration Tests since `SKIP_INTEGRATION` is set", -) +@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set") class DataFrameValidator(unittest.TestCase): spark = None sqlglot = None diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 703b7dc..87bba6f 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -6,10 +6,19 @@ class TestBigQuery(Validator): dialect = "bigquery" def test_bigquery(self): + self.validate_identity("DATE_TRUNC(col, WEEK(MONDAY))") + self.validate_identity("SELECT b'abc'") + self.validate_identity("""SELECT * FROM UNNEST(ARRAY>[1, 2])""") self.validate_identity("SELECT AS STRUCT 1 AS a, 2 AS b") + self.validate_identity("SELECT DISTINCT AS STRUCT 1 AS a, 2 AS b") self.validate_identity("SELECT AS VALUE STRUCT(1 AS a, 2 AS b)") self.validate_identity("SELECT STRUCT>(['2023-01-17'])") self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))") + self.validate_identity("""CREATE TABLE x (a STRUCT>)""") + self.validate_identity("""CREATE TABLE x (a STRUCT)""") + self.validate_identity( + """CREATE TABLE x (a STRING OPTIONS (description='x')) OPTIONS (table_expiration_days=1)""" + ) self.validate_identity( "SELECT * FROM (SELECT * FROM `t`) AS a UNPIVOT((c) FOR c_name IN (v1, v2))" ) @@ -97,6 +106,16 @@ class TestBigQuery(Validator): "spark": "CAST(a AS LONG)", }, ) + self.validate_all( + "CAST(a AS BYTES)", + write={ + "bigquery": "CAST(a AS BYTES)", + "duckdb": "CAST(a AS BLOB)", + "presto": "CAST(a AS VARBINARY)", + "hive": "CAST(a AS BINARY)", + "spark": "CAST(a AS BINARY)", + }, + ) self.validate_all( "CAST(a AS NUMERIC)", write={ @@ -173,7 +192,6 @@ class TestBigQuery(Validator): "current_datetime", write={ "bigquery": "CURRENT_DATETIME()", - "duckdb": "CURRENT_DATETIME()", "presto": "CURRENT_DATETIME()", "hive": "CURRENT_DATETIME()", "spark": "CURRENT_DATETIME()", @@ -183,7 +201,7 @@ class TestBigQuery(Validator): "current_time", write={ "bigquery": "CURRENT_TIME()", - "duckdb": "CURRENT_TIME()", + "duckdb": "CURRENT_TIME", "presto": "CURRENT_TIME()", "hive": "CURRENT_TIME()", "spark": "CURRENT_TIME()", @@ -193,7 +211,7 @@ class TestBigQuery(Validator): "current_timestamp", write={ "bigquery": "CURRENT_TIMESTAMP()", - "duckdb": "CURRENT_TIMESTAMP()", + "duckdb": "CURRENT_TIMESTAMP", "postgres": "CURRENT_TIMESTAMP", "presto": "CURRENT_TIMESTAMP", "hive": "CURRENT_TIMESTAMP()", @@ -204,7 +222,7 @@ class TestBigQuery(Validator): "current_timestamp()", write={ "bigquery": "CURRENT_TIMESTAMP()", - "duckdb": "CURRENT_TIMESTAMP()", + "duckdb": "CURRENT_TIMESTAMP", "postgres": "CURRENT_TIMESTAMP", "presto": "CURRENT_TIMESTAMP", "hive": "CURRENT_TIMESTAMP()", @@ -342,6 +360,18 @@ class TestBigQuery(Validator): "snowflake": "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab(cola, colb)", }, ) + self.validate_all( + "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab", + write={ + "bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS _c0, 'test' AS _c1)])", + }, + ) + self.validate_all( + "SELECT cola, colb FROM (VALUES (1, 'test'))", + write={ + "bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS _c0, 'test' AS _c1)])", + }, + ) self.validate_all( "SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)", write={ diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 9fd2b45..1060881 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -65,3 +65,24 @@ class TestClickhouse(Validator): self.validate_identity("WITH SUM(bytes) AS foo SELECT foo FROM system.parts") self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5") self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1") + + def test_signed_and_unsigned_types(self): + data_types = [ + "UInt8", + "UInt16", + "UInt32", + "UInt64", + "UInt128", + "UInt256", + "Int8", + "Int16", + "Int32", + "Int64", + "Int128", + "Int256", + ] + for data_type in data_types: + self.validate_all( + f"pow(2, 32)::{data_type}", + write={"clickhouse": f"CAST(POWER(2, 32) AS {data_type})"}, + ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index bcbbfd6..f12273b 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -95,7 +95,7 @@ class TestDialect(Validator): self.validate_all( "CAST(a AS BINARY(4))", write={ - "bigquery": "CAST(a AS BINARY)", + "bigquery": "CAST(a AS BYTES)", "clickhouse": "CAST(a AS BINARY(4))", "drill": "CAST(a AS VARBINARY(4))", "duckdb": "CAST(a AS BLOB(4))", @@ -114,7 +114,7 @@ class TestDialect(Validator): self.validate_all( "CAST(a AS VARBINARY(4))", write={ - "bigquery": "CAST(a AS VARBINARY)", + "bigquery": "CAST(a AS BYTES)", "clickhouse": "CAST(a AS VARBINARY(4))", "duckdb": "CAST(a AS BLOB(4))", "mysql": "CAST(a AS VARBINARY(4))", diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 9e0040c..8c1b748 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -6,6 +6,9 @@ class TestDuckDB(Validator): dialect = "duckdb" def test_time(self): + self.validate_identity("SELECT CURRENT_DATE") + self.validate_identity("SELECT CURRENT_TIMESTAMP") + self.validate_all( "EPOCH(x)", read={ @@ -24,7 +27,7 @@ class TestDuckDB(Validator): "bigquery": "UNIX_TO_TIME(x / 1000)", "duckdb": "TO_TIMESTAMP(x / 1000)", "presto": "FROM_UNIXTIME(x / 1000)", - "spark": "FROM_UNIXTIME(x / 1000)", + "spark": "CAST(FROM_UNIXTIME(x / 1000) AS TIMESTAMP)", }, ) self.validate_all( @@ -124,18 +127,34 @@ class TestDuckDB(Validator): self.validate_identity("SELECT {'a': 1} AS x") self.validate_identity("SELECT {'a': {'b': {'c': 1}}, 'd': {'e': 2}} AS x") self.validate_identity("SELECT {'x': 1, 'y': 2, 'z': 3}") - self.validate_identity( - "SELECT {'yes': 'duck', 'maybe': 'goose', 'huh': NULL, 'no': 'heron'}" - ) self.validate_identity("SELECT {'key1': 'string', 'key2': 1, 'key3': 12.345}") self.validate_identity("SELECT ROW(x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)") self.validate_identity("SELECT (x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)") self.validate_identity("SELECT a.x FROM (SELECT {'x': 1, 'y': 2, 'z': 3} AS a)") self.validate_identity("ATTACH DATABASE ':memory:' AS new_database") + self.validate_identity( + "SELECT {'yes': 'duck', 'maybe': 'goose', 'huh': NULL, 'no': 'heron'}" + ) self.validate_identity( "SELECT a['x space'] FROM (SELECT {'x space': 1, 'y': 2, 'z': 3} AS a)" ) + self.validate_all("0b1010", write={"": "0 AS b1010"}) + self.validate_all("0x1010", write={"": "0 AS x1010"}) + self.validate_all( + """SELECT DATEDIFF('day', t1."A", t1."B") FROM "table" AS t1""", + write={ + "duckdb": """SELECT DATE_DIFF('day', t1."A", t1."B") FROM "table" AS t1""", + "trino": """SELECT DATE_DIFF('day', t1."A", t1."B") FROM "table" AS t1""", + }, + ) + self.validate_all( + "SELECT DATE_DIFF('day', DATE '2020-01-01', DATE '2020-01-05')", + write={ + "duckdb": "SELECT DATE_DIFF('day', CAST('2020-01-01' AS DATE), CAST('2020-01-05' AS DATE))", + "trino": "SELECT DATE_DIFF('day', CAST('2020-01-01' AS DATE), CAST('2020-01-05' AS DATE))", + }, + ) self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"}) self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'}) self.validate_all( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 524d95e..f31b1b9 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -12,6 +12,7 @@ class TestMySQL(Validator): "duckdb": "CREATE TABLE z (a INT)", "mysql": "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'", "spark": "CREATE TABLE z (a INT) COMMENT 'x'", + "sqlite": "CREATE TABLE z (a INTEGER)", }, ) self.validate_all( @@ -24,6 +25,19 @@ class TestMySQL(Validator): "INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1" ) + self.validate_all( + "CREATE TABLE x (id int not null auto_increment, primary key (id))", + write={ + "sqlite": "CREATE TABLE x (id INTEGER NOT NULL AUTOINCREMENT PRIMARY KEY)", + }, + ) + self.validate_all( + "CREATE TABLE x (id int not null auto_increment)", + write={ + "sqlite": "CREATE TABLE x (id INTEGER NOT NULL)", + }, + ) + def test_identity(self): self.validate_identity("SELECT CURRENT_TIMESTAMP(6)") self.validate_identity("x ->> '$.name'") @@ -150,47 +164,81 @@ class TestMySQL(Validator): ) def test_hexadecimal_literal(self): - self.validate_all( - "SELECT 0xCC", - write={ - "mysql": "SELECT x'CC'", - "sqlite": "SELECT x'CC'", - "spark": "SELECT X'CC'", - "trino": "SELECT X'CC'", - "bigquery": "SELECT 0xCC", - "oracle": "SELECT 204", - }, - ) - self.validate_all( - "SELECT X'1A'", - write={ - "mysql": "SELECT x'1A'", - }, - ) - self.validate_all( - "SELECT 0xz", - write={ - "mysql": "SELECT `0xz`", - }, - ) + write_CC = { + "bigquery": "SELECT 0xCC", + "clickhouse": "SELECT 0xCC", + "databricks": "SELECT 204", + "drill": "SELECT 204", + "duckdb": "SELECT 204", + "hive": "SELECT 204", + "mysql": "SELECT x'CC'", + "oracle": "SELECT 204", + "postgres": "SELECT x'CC'", + "presto": "SELECT 204", + "redshift": "SELECT 204", + "snowflake": "SELECT x'CC'", + "spark": "SELECT X'CC'", + "sqlite": "SELECT x'CC'", + "starrocks": "SELECT x'CC'", + "tableau": "SELECT 204", + "teradata": "SELECT 204", + "trino": "SELECT X'CC'", + "tsql": "SELECT 0xCC", + } + write_CC_with_leading_zeros = { + "bigquery": "SELECT 0x0000CC", + "clickhouse": "SELECT 0x0000CC", + "databricks": "SELECT 204", + "drill": "SELECT 204", + "duckdb": "SELECT 204", + "hive": "SELECT 204", + "mysql": "SELECT x'0000CC'", + "oracle": "SELECT 204", + "postgres": "SELECT x'0000CC'", + "presto": "SELECT 204", + "redshift": "SELECT 204", + "snowflake": "SELECT x'0000CC'", + "spark": "SELECT X'0000CC'", + "sqlite": "SELECT x'0000CC'", + "starrocks": "SELECT x'0000CC'", + "tableau": "SELECT 204", + "teradata": "SELECT 204", + "trino": "SELECT X'0000CC'", + "tsql": "SELECT 0x0000CC", + } + + self.validate_all("SELECT X'1A'", write={"mysql": "SELECT x'1A'"}) + self.validate_all("SELECT 0xz", write={"mysql": "SELECT `0xz`"}) + self.validate_all("SELECT 0xCC", write=write_CC) + self.validate_all("SELECT 0xCC ", write=write_CC) + self.validate_all("SELECT x'CC'", write=write_CC) + self.validate_all("SELECT 0x0000CC", write=write_CC_with_leading_zeros) + self.validate_all("SELECT x'0000CC'", write=write_CC_with_leading_zeros) def test_bits_literal(self): - self.validate_all( - "SELECT 0b1011", - write={ - "mysql": "SELECT b'1011'", - "postgres": "SELECT b'1011'", - "oracle": "SELECT 11", - }, - ) - self.validate_all( - "SELECT B'1011'", - write={ - "mysql": "SELECT b'1011'", - "postgres": "SELECT b'1011'", - "oracle": "SELECT 11", - }, - ) + write_1011 = { + "bigquery": "SELECT 11", + "clickhouse": "SELECT 0b1011", + "databricks": "SELECT 11", + "drill": "SELECT 11", + "hive": "SELECT 11", + "mysql": "SELECT b'1011'", + "oracle": "SELECT 11", + "postgres": "SELECT b'1011'", + "presto": "SELECT 11", + "redshift": "SELECT 11", + "snowflake": "SELECT 11", + "spark": "SELECT 11", + "sqlite": "SELECT 11", + "mysql": "SELECT b'1011'", + "tableau": "SELECT 11", + "teradata": "SELECT 11", + "trino": "SELECT 11", + "tsql": "SELECT 11", + } + + self.validate_all("SELECT 0b1011", write=write_1011) + self.validate_all("SELECT b'1011'", write=write_1011) def test_string_literals(self): self.validate_all( diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index dd297d6..88c79fd 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -5,6 +5,8 @@ class TestOracle(Validator): dialect = "oracle" def test_oracle(self): + self.validate_identity("SELECT * FROM table_name@dblink_name.database_link_domain") + self.validate_identity("SELECT * FROM table_name SAMPLE (25) s") self.validate_identity("SELECT * FROM V$SESSION") self.validate_identity( "SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name" @@ -17,7 +19,6 @@ class TestOracle(Validator): "": "IFNULL(NULL, 1)", }, ) - self.validate_all( "DATE '2022-01-01'", write={ @@ -28,6 +29,21 @@ class TestOracle(Validator): }, ) + self.validate_all( + "x::binary_double", + write={ + "oracle": "CAST(x AS DOUBLE PRECISION)", + "": "CAST(x AS DOUBLE)", + }, + ) + self.validate_all( + "x::binary_float", + write={ + "oracle": "CAST(x AS FLOAT)", + "": "CAST(x AS FLOAT)", + }, + ) + def test_join_marker(self): self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y") self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)") diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index e2f9c41..b535a84 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -98,6 +98,21 @@ class TestPostgres(Validator): self.validate_identity("STRING_AGG(x, ',' ORDER BY y DESC)") self.validate_identity("STRING_AGG(DISTINCT x, ',' ORDER BY y DESC)") self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END") + self.validate_identity("COMMENT ON TABLE mytable IS 'this'") + self.validate_identity("SELECT e'\\xDEADBEEF'") + self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") + self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')") + self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')") + self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") + self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""") + self.validate_identity("x ~ 'y'") + self.validate_identity("x ~* 'y'") + self.validate_identity( + "SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)" + ) + self.validate_identity( + "CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)" + ) self.validate_identity( "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END" ) @@ -107,37 +122,31 @@ class TestPostgres(Validator): self.validate_identity( 'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')' ) - self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')") self.validate_identity( "SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')" ) self.validate_identity( "SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))" ) - self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')") - self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") self.validate_identity( "SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')" ) - self.validate_identity("COMMENT ON TABLE mytable IS 'this'") - self.validate_identity("SELECT e'\\xDEADBEEF'") - self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") + + self.validate_all( + "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY amount)", + write={ + "databricks": "SELECT PERCENTILE_APPROX(amount, 0.5)", + "presto": "SELECT APPROX_PERCENTILE(amount, 0.5)", + "spark": "SELECT PERCENTILE_APPROX(amount, 0.5)", + "trino": "SELECT APPROX_PERCENTILE(amount, 0.5)", + }, + ) self.validate_all( "e'x'", write={ "mysql": "x", }, ) - self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""") - self.validate_identity( - "SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)" - ) - self.validate_identity( - "CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)" - ) - self.validate_identity("x ~ 'y'") - self.validate_identity("x ~* 'y'") - self.validate_all( "SELECT DATE_PART('isodow'::varchar(6), current_date)", write={ @@ -197,6 +206,33 @@ class TestPostgres(Validator): "trino": "SEQUENCE(TRY_CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)", }, ) + self.validate_all( + "GENERATE_SERIES(a, b)", + write={ + "postgres": "GENERATE_SERIES(a, b)", + "presto": "SEQUENCE(a, b)", + "trino": "SEQUENCE(a, b)", + "tsql": "GENERATE_SERIES(a, b)", + }, + ) + self.validate_all( + "GENERATE_SERIES(a, b)", + read={ + "postgres": "GENERATE_SERIES(a, b)", + "presto": "SEQUENCE(a, b)", + "trino": "SEQUENCE(a, b)", + "tsql": "GENERATE_SERIES(a, b)", + }, + ) + self.validate_all( + "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)", + write={ + "postgres": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)", + "presto": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4))", + "trino": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4))", + "tsql": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)", + }, + ) self.validate_all( "END WORK AND NO CHAIN", write={"postgres": "COMMIT AND NO CHAIN"}, @@ -464,6 +500,14 @@ class TestPostgres(Validator): }, ) + self.validate_all( + "x / y ^ z", + write={ + "": "x / POWER(y, z)", + "postgres": "x / y ^ z", + }, + ) + self.assertIsInstance(parse_one("id::UUID", read="postgres"), exp.TryCast) def test_bool_or(self): diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 3080476..15962cc 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -6,6 +6,26 @@ class TestPresto(Validator): dialect = "presto" def test_cast(self): + self.validate_all( + "FROM_BASE64(x)", + read={ + "hive": "UNBASE64(x)", + }, + write={ + "hive": "UNBASE64(x)", + "presto": "FROM_BASE64(x)", + }, + ) + self.validate_all( + "TO_BASE64(x)", + read={ + "hive": "BASE64(x)", + }, + write={ + "hive": "BASE64(x)", + "presto": "TO_BASE64(x)", + }, + ) self.validate_all( "CAST(a AS ARRAY(INT))", write={ @@ -105,6 +125,13 @@ class TestPresto(Validator): "spark": "SIZE(x)", }, ) + self.validate_all( + "ARRAY_JOIN(x, '-', 'a')", + write={ + "hive": "CONCAT_WS('-', x)", + "spark": "ARRAY_JOIN(x, '-', 'a')", + }, + ) def test_interval_plural_to_singular(self): # Microseconds, weeks and quarters are not supported in Presto/Trino INTERVAL literals @@ -133,6 +160,14 @@ class TestPresto(Validator): self.validate_identity("TRIM(a, b)") self.validate_identity("VAR_POP(a)") + self.validate_all( + "SELECT FROM_UNIXTIME(col) FROM tbl", + write={ + "presto": "SELECT FROM_UNIXTIME(col) FROM tbl", + "spark": "SELECT CAST(FROM_UNIXTIME(col) AS TIMESTAMP) FROM tbl", + "trino": "SELECT FROM_UNIXTIME(col) FROM tbl", + }, + ) self.validate_all( "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", write={ @@ -181,7 +216,7 @@ class TestPresto(Validator): "duckdb": "TO_TIMESTAMP(x)", "presto": "FROM_UNIXTIME(x)", "hive": "FROM_UNIXTIME(x)", - "spark": "FROM_UNIXTIME(x)", + "spark": "CAST(FROM_UNIXTIME(x) AS TIMESTAMP)", }, ) self.validate_all( @@ -583,6 +618,14 @@ class TestPresto(Validator): }, ) + self.validate_all( + "JSON_FORMAT(JSON 'x')", + write={ + "presto": "JSON_FORMAT(CAST('x' AS JSON))", + "spark": "TO_JSON('x')", + }, + ) + def test_encode_decode(self): self.validate_all( "TO_UTF8(x)", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index e5bd0e5..f75480e 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -101,7 +101,22 @@ class TestRedshift(Validator): self.validate_all( "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", write={ + "bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", + "databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", + "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", + "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", + "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1", + "oracle": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', + "presto": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', "redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1', + "snowflake": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1', + "spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", + "sqlite": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', + "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1", + "tableau": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', + "teradata": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', + "trino": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', + "tsql": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 5c8b096..57ee235 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -227,7 +227,7 @@ class TestSnowflake(Validator): write={ "bigquery": "SELECT UNIX_TO_TIME(1659981729)", "snowflake": "SELECT TO_TIMESTAMP(1659981729)", - "spark": "SELECT FROM_UNIXTIME(1659981729)", + "spark": "SELECT CAST(FROM_UNIXTIME(1659981729) AS TIMESTAMP)", }, ) self.validate_all( @@ -243,7 +243,7 @@ class TestSnowflake(Validator): write={ "bigquery": "SELECT UNIX_TO_TIME('1659981729')", "snowflake": "SELECT TO_TIMESTAMP('1659981729')", - "spark": "SELECT FROM_UNIXTIME('1659981729')", + "spark": "SELECT CAST(FROM_UNIXTIME('1659981729') AS TIMESTAMP)", }, ) self.validate_all( @@ -401,7 +401,7 @@ class TestSnowflake(Validator): self.validate_all( r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1", write={ - "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" + "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" }, ) self.validate_all( @@ -426,29 +426,18 @@ class TestSnowflake(Validator): def test_sample(self): self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)") self.validate_identity("SELECT * FROM testtable TABLESAMPLE (100)") + self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)") + self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)") + self.validate_identity("SELECT * FROM testtable SAMPLE (10)") + self.validate_identity("SELECT * FROM testtable SAMPLE ROW (0)") + self.validate_identity("SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)") self.validate_identity( "SELECT i, j FROM table1 AS t1 INNER JOIN table2 AS t2 TABLESAMPLE (50) WHERE t2.j = t1.i" ) self.validate_identity( "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1)" ) - self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)") - self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)") - self.validate_all( - "SELECT * FROM testtable SAMPLE (10)", - write={"snowflake": "SELECT * FROM testtable TABLESAMPLE (10)"}, - ) - self.validate_all( - "SELECT * FROM testtable SAMPLE ROW (0)", - write={"snowflake": "SELECT * FROM testtable TABLESAMPLE ROW (0)"}, - ) - self.validate_all( - "SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)", - write={ - "snowflake": "SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)", - }, - ) self.validate_all( """ SELECT i, j @@ -458,13 +447,13 @@ class TestSnowflake(Validator): table2 AS t2 SAMPLE (50) -- 50% of rows in table2 WHERE t2.j = t1.i""", write={ - "snowflake": "SELECT i, j FROM table1 AS t1 TABLESAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 TABLESAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i", + "snowflake": "SELECT i, j FROM table1 AS t1 SAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 SAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i", }, ) self.validate_all( "SELECT * FROM testtable SAMPLE BLOCK (0.012) REPEATABLE (99992)", write={ - "snowflake": "SELECT * FROM testtable TABLESAMPLE BLOCK (0.012) SEED (99992)", + "snowflake": "SELECT * FROM testtable SAMPLE BLOCK (0.012) SEED (99992)", }, ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index bfaed53..be03b4e 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -215,40 +215,45 @@ TBLPROPERTIES ( self.validate_identity("SPLIT(str, pattern, lim)") self.validate_all( - "BOOLEAN(x)", - write={ - "": "CAST(x AS BOOLEAN)", - "spark": "CAST(x AS BOOLEAN)", + "SELECT * FROM produce PIVOT(SUM(produce.sales) FOR quarter IN ('Q1', 'Q2'))", + read={ + "snowflake": "SELECT * FROM produce PIVOT (SUM(produce.sales) FOR produce.quarter IN ('Q1', 'Q2'))", }, ) self.validate_all( - "INT(x)", - write={ - "": "CAST(x AS INT)", - "spark": "CAST(x AS INT)", - }, - ) - self.validate_all( - "STRING(x)", - write={ - "": "CAST(x AS TEXT)", - "spark": "CAST(x AS STRING)", + "SELECT * FROM produce AS p PIVOT(SUM(p.sales) AS sales FOR quarter IN ('Q1' AS Q1, 'Q2' AS Q1))", + read={ + "bigquery": "SELECT * FROM produce AS p PIVOT(SUM(p.sales) AS sales FOR p.quarter IN ('Q1' AS Q1, 'Q2' AS Q1))", }, ) self.validate_all( - "DATE(x)", + "SELECT DATEDIFF(MONTH, '2020-01-01', '2020-03-05')", write={ - "": "CAST(x AS DATE)", - "spark": "CAST(x AS DATE)", + "databricks": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))", + "hive": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))", + "presto": "SELECT DATE_DIFF('MONTH', CAST(SUBSTR(CAST('2020-01-01' AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST('2020-03-05' AS VARCHAR), 1, 10) AS DATE))", + "spark": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))", + "spark2": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))", + "trino": "SELECT DATE_DIFF('MONTH', CAST(SUBSTR(CAST('2020-01-01' AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST('2020-03-05' AS VARCHAR), 1, 10) AS DATE))", }, ) + + for data_type in ("BOOLEAN", "DATE", "DOUBLE", "FLOAT", "INT", "TIMESTAMP"): + self.validate_all( + f"{data_type}(x)", + write={ + "": f"CAST(x AS {data_type})", + "spark": f"CAST(x AS {data_type})", + }, + ) self.validate_all( - "TIMESTAMP(x)", + "STRING(x)", write={ - "": "CAST(x AS TIMESTAMP)", - "spark": "CAST(x AS TIMESTAMP)", + "": "CAST(x AS TEXT)", + "spark": "CAST(x AS STRING)", }, ) + self.validate_all( "CAST(x AS TIMESTAMP)", read={"trino": "CAST(x AS TIMESTAMP(6) WITH TIME ZONE)"} ) diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py index b33231c..96e20da 100644 --- a/tests/dialects/test_starrocks.py +++ b/tests/dialects/test_starrocks.py @@ -10,3 +10,11 @@ class TestMySQL(Validator): def test_time(self): self.validate_identity("TIMESTAMP('2022-01-01')") + + def test_regex(self): + self.validate_all( + "SELECT REGEXP_LIKE(abc, '%foo%')", + write={ + "starrocks": "SELECT REGEXP(abc, '%foo%')", + }, + ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index b6e893c..3a3ac73 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -485,26 +485,30 @@ WHERE def test_date_diff(self): self.validate_identity("SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')") + self.validate_all( "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", write={ "tsql": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", - "spark": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12", + "spark": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", + "spark2": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12", }, ) self.validate_all( "SELECT DATEDIFF(mm, 'start','end')", write={ - "spark": "SELECT MONTHS_BETWEEN('end', 'start')", - "tsql": "SELECT DATEDIFF(month, 'start', 'end')", "databricks": "SELECT DATEDIFF(month, 'start', 'end')", + "spark2": "SELECT MONTHS_BETWEEN('end', 'start')", + "tsql": "SELECT DATEDIFF(month, 'start', 'end')", }, ) self.validate_all( "SELECT DATEDIFF(quarter, 'start', 'end')", write={ - "spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3", "databricks": "SELECT DATEDIFF(quarter, 'start', 'end')", + "spark": "SELECT DATEDIFF(quarter, 'start', 'end')", + "spark2": "SELECT MONTHS_BETWEEN('end', 'start') / 3", + "tsql": "SELECT DATEDIFF(quarter, 'start', 'end')", }, ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index a08a7a8..ea695c9 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -85,6 +85,7 @@ x IS TRUE x IS FALSE x IS TRUE IS TRUE x LIKE y IS TRUE +TRIM('a' || 'b') MAP() GREATEST(x) LEAST(y) @@ -104,6 +105,7 @@ ARRAY(time, foo) ARRAY(foo, time) ARRAY(LENGTH(waiter_name) > 0) ARRAY_CONTAINS(x, 1) +x.EXTRACT(1) EXTRACT(x FROM y) EXTRACT(DATE FROM y) EXTRACT(WEEK(monday) FROM created_at) @@ -215,6 +217,7 @@ SELECT COUNT(DISTINCT a, b) SELECT COUNT(DISTINCT a, b + 1) SELECT SUM(DISTINCT x) SELECT SUM(x IGNORE NULLS) AS x +SELECT COUNT(x RESPECT NULLS) SELECT TRUNCATE(a, b) SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x SELECT ARRAY_AGG(STRUCT(x, x AS y) ORDER BY z DESC) AS x @@ -820,3 +823,6 @@ JSON_OBJECT('x': 1 RETURNING VARBINARY FORMAT JSON ENCODING UTF8) SELECT if.x SELECT NEXT VALUE FOR db.schema.sequence_name SELECT NEXT VALUE FOR db.schema.sequence_name OVER (ORDER BY foo), col +SELECT PERCENTILE_CONT(x, 0.5) OVER () +SELECT PERCENTILE_CONT(x, 0.5 RESPECT NULLS) OVER () +SELECT PERCENTILE_CONT(x, 0.5 IGNORE NULLS) OVER () diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 3013bba..f077647 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -4,6 +4,9 @@ SELECT a FROM x; SELECT x.a AS a FROM x AS x; +SELECT "a" FROM x; +SELECT x."a" AS "a" FROM x AS x; + # execute: false SELECT a FROM zz GROUP BY a ORDER BY a; SELECT zz.a AS a FROM zz AS zz GROUP BY zz.a ORDER BY a; @@ -212,6 +215,10 @@ SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT x.b AS b FROM y AS x); SELECT a FROM x AS i WHERE b IN (SELECT b FROM y AS j WHERE j.b IN (SELECT c FROM y AS k WHERE k.b = j.b)); SELECT i.a AS a FROM x AS i WHERE i.b IN (SELECT j.b AS b FROM y AS j WHERE j.b IN (SELECT k.c AS c FROM y AS k WHERE k.b = j.b)); +# execute: false +SELECT (SELECT n.a FROM n WHERE n.id = m.id) FROM m AS m; +SELECT (SELECT n.a AS a FROM n AS n WHERE n.id = m.id) AS _col_0 FROM m AS m; + -------------------------------------- -- Expand * -------------------------------------- diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql index 2cea85d..0ad155a 100644 --- a/tests/fixtures/optimizer/qualify_tables.sql +++ b/tests/fixtures/optimizer/qualify_tables.sql @@ -15,3 +15,26 @@ WITH a AS (SELECT 1 FROM c.db.z AS z) SELECT 1 FROM a; SELECT (SELECT y.c FROM y AS y) FROM x; SELECT (SELECT y.c FROM c.db.y AS y) FROM c.db.x AS x; + +------------------------- +-- Expand join constructs +------------------------- + +-- This is valid in Trino, so we treat the (tbl AS tbl) as a "join construct" per postgres' terminology. +SELECT * FROM (tbl AS tbl) AS _q_0; +SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0; + +SELECT * FROM ((tbl AS tbl)) AS _q_0; +SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0; + +SELECT * FROM (((tbl AS tbl))) AS _q_0; +SELECT * FROM (SELECT * FROM c.db.tbl AS tbl) AS _q_0; + +SELECT * FROM (tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3) AS _q_0; +SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN c.db.tbl2 AS tbl2 ON id1 = id2 JOIN c.db.tbl3 AS tbl3 ON id1 = id3) AS _q_0; + +SELECT * FROM ((tbl1 AS tbl1 JOIN tbl2 AS tbl2 ON id1 = id2 JOIN tbl3 AS tbl3 ON id1 = id3)) AS _q_0; +SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN c.db.tbl2 AS tbl2 ON id1 = id2 JOIN c.db.tbl3 AS tbl3 ON id1 = id3) AS _q_0; + +SELECT * FROM (tbl1 AS tbl1 JOIN (tbl2 AS tbl2 JOIN tbl3 AS tbl3 ON id2 = id3) AS _q_0 ON id1 = id3) AS _q_1; +SELECT * FROM (SELECT * FROM c.db.tbl1 AS tbl1 JOIN (SELECT * FROM c.db.tbl2 AS tbl2 JOIN c.db.tbl3 AS tbl3 ON id2 = id3) AS _q_0 ON id1 = id3) AS _q_1; diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index 9168508..9908756 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -6385,6 +6385,14 @@ WITH "tmp1" AS ( "item"."i_brand" IN ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1') OR "item"."i_class" IN ('personal', 'portable', 'reference', 'self-help') ) + AND ( + "item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9') + OR "item"."i_category" IN ('Women', 'Music', 'Men') + ) + AND ( + "item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9') + OR "item"."i_class" IN ('accessories', 'classical', 'fragrances', 'pants') + ) AND ( "item"."i_category" IN ('Books', 'Children', 'Electronics') OR "item"."i_category" IN ('Women', 'Music', 'Men') @@ -7589,6 +7597,14 @@ WITH "tmp1" AS ( "item"."i_brand" IN ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1') OR "item"."i_class" IN ('personal', 'portable', 'reference', 'self-help') ) + AND ( + "item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9') + OR "item"."i_category" IN ('Women', 'Music', 'Men') + ) + AND ( + "item"."i_brand" IN ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9') + OR "item"."i_class" IN ('accessories', 'classical', 'fragrances', 'pants') + ) AND ( "item"."i_category" IN ('Books', 'Children', 'Electronics') OR "item"."i_category" IN ('Women', 'Music', 'Men') diff --git a/tests/test_build.py b/tests/test_build.py index c4b97ce..509b857 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -19,6 +19,11 @@ from sqlglot import ( class TestBuild(unittest.TestCase): def test_build(self): x = condition("x") + x_plus_one = x + 1 + + # Make sure we're not mutating x by changing its parent to be x_plus_one + self.assertIsNone(x.parent) + self.assertNotEqual(id(x_plus_one.this), id(x)) for expression, sql, *dialect in [ (lambda: x + 1, "x + 1"), @@ -51,6 +56,7 @@ class TestBuild(unittest.TestCase): (lambda: x.neq(1), "x <> 1"), (lambda: x.isin(1, "2"), "x IN (1, '2')"), (lambda: x.isin(query="select 1"), "x IN (SELECT 1)"), + (lambda: x.between(1, 2), "x BETWEEN 1 AND 2"), (lambda: 1 + x + 2 + 3, "1 + x + 2 + 3"), (lambda: 1 + x * 2 + 3, "1 + (x * 2) + 3"), (lambda: x * 1 * 2 + 3, "(x * 1 * 2) + 3"), @@ -137,10 +143,14 @@ class TestBuild(unittest.TestCase): "SELECT x, y, z, a FROM tbl GROUP BY x, y, z, a", ), ( - lambda: select("x").distinct(True).from_("tbl"), + lambda: select("x").distinct("a", "b").from_("tbl"), + "SELECT DISTINCT ON (a, b) x FROM tbl", + ), + ( + lambda: select("x").distinct(distinct=True).from_("tbl"), "SELECT DISTINCT x FROM tbl", ), - (lambda: select("x").distinct(False).from_("tbl"), "SELECT x FROM tbl"), + (lambda: select("x").distinct(distinct=False).from_("tbl"), "SELECT x FROM tbl"), ( lambda: select("x").lateral("OUTER explode(y) tbl2 AS z").from_("tbl"), "SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z", @@ -583,6 +593,11 @@ class TestBuild(unittest.TestCase): "DELETE FROM tbl WHERE x = 1 RETURNING *", "postgres", ), + ( + lambda: exp.convert((exp.column("x"), exp.column("y"))).isin((1, 2), (3, 4)), + "(x, y) IN ((1, 2), (3, 4))", + "postgres", + ), ]: with self.subTest(sql): self.assertEqual(expression().sql(dialect[0] if dialect else None), sql) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index eb0cf56..e7588b5 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -297,6 +297,9 @@ class TestExpressions(unittest.TestCase): expression = parse_one("SELECT a, b FROM x") self.assertEqual([s.sql() for s in expression.selects], ["a", "b"]) + expression = parse_one("(SELECT a, b FROM x)") + self.assertEqual([s.sql() for s in expression.selects], ["a", "b"]) + def test_alias_column_names(self): expression = parse_one("SELECT * FROM (SELECT * FROM x) AS y") subquery = expression.find(exp.Subquery) @@ -761,7 +764,7 @@ FROM foo""", "t", {"a": exp.DataType.build("TEXT"), "b": exp.DataType.build("TEXT")}, ).sql(), - "(VALUES (CAST(1 AS TEXT), CAST(2 AS TEXT)), (3, 4)) AS t(a, b)", + "(VALUES (1, 2), (3, 4)) AS t(a, b)", ) with self.assertRaises(ValueError): exp.values([(1, 2), (3, 4)], columns=["a"]) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index d077570..423cb84 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -47,6 +47,7 @@ class TestOptimizer(unittest.TestCase): @classmethod def setUpClass(cls): + sqlglot.schema = MappingSchema() cls.conn = duckdb.connect() cls.conn.execute( """ @@ -221,6 +222,12 @@ class TestOptimizer(unittest.TestCase): self.check_file("pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates) def test_expand_laterals(self): + # check order of lateral expansion with no schema + self.assertEqual( + optimizer.optimize("SELECT a + 1 AS d, d + 1 AS e FROM x " "").sql(), + 'SELECT "x"."a" + 1 AS "d", "x"."a" + 2 AS "e" FROM "x" AS "x"', + ) + self.check_file( "expand_laterals", optimizer.expand_laterals.expand_laterals, @@ -612,6 +619,12 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') expression = annotate_types(parse_one("CONCAT('A', 'B')")) self.assertEqual(expression.type.this, exp.DataType.Type.VARCHAR) + def test_root_subquery_annotation(self): + expression = annotate_types(parse_one("(SELECT 1, 2 FROM x) LIMIT 0")) + self.assertIsInstance(expression, exp.Subquery) + self.assertEqual(exp.DataType.Type.INT, expression.selects[0].type.this) + self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this) + def test_recursive_cte(self): query = parse_one( """ diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 987c60b..f70d70e 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -102,3 +102,14 @@ x""" (TokenType.SEMICOLON, ";"), ], ) + + tokens = tokenizer.tokenize("""'{{ var('x') }}'""") + tokens = [(token.token_type, token.text) for token in tokens] + self.assertEqual( + tokens, + [ + (TokenType.STRING, "{{ var("), + (TokenType.VAR, "x"), + (TokenType.STRING, ") }}"), + ], + ) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index d68f6f8..ad8ec72 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -263,6 +263,18 @@ FROM v""", "(/* 1 */ 1 ) /* 2 */", "(1) /* 1 */ /* 2 */", ) + self.validate( + "select * from t where not a in (23) /*test*/ and b in (14)", + "SELECT * FROM t WHERE NOT a IN (23) /* test */ AND b IN (14)", + ) + self.validate( + "select * from t where a in (23) /*test*/ and b in (14)", + "SELECT * FROM t WHERE a IN (23) /* test */ AND b IN (14)", + ) + self.validate( + "select * from t where ((condition = 1)/*test*/)", + "SELECT * FROM t WHERE ((condition = 1) /* test */)", + ) def test_types(self): self.validate("INT 1", "CAST(1 AS INT)") @@ -324,9 +336,6 @@ FROM v""", ) self.validate("SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo") - def test_ignore_nulls(self): - self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)") - def test_with(self): self.validate( "WITH a AS (SELECT 1) WITH b AS (SELECT 2) SELECT *", @@ -482,7 +491,7 @@ FROM v""", self.validate("UNIX_TO_STR(123, 'y')", "FROM_UNIXTIME(123, 'y')", write="spark") self.validate( "UNIX_TO_TIME(123)", - "FROM_UNIXTIME(123)", + "CAST(FROM_UNIXTIME(123) AS TIMESTAMP)", write="spark", ) self.validate( -- cgit v1.2.3