diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dialects/test_bigquery.py | 38 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 21 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 27 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 126 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 18 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 76 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 45 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 15 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 31 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 47 | ||||
-rw-r--r-- | tests/dialects/test_starrocks.py | 8 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 12 |
13 files changed, 355 insertions, 113 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 703b7dc..87bba6f 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -6,10 +6,19 @@ class TestBigQuery(Validator): dialect = "bigquery" def test_bigquery(self): + self.validate_identity("DATE_TRUNC(col, WEEK(MONDAY))") + self.validate_identity("SELECT b'abc'") + self.validate_identity("""SELECT * FROM UNNEST(ARRAY<STRUCT<x INT64>>[1, 2])""") self.validate_identity("SELECT AS STRUCT 1 AS a, 2 AS b") + self.validate_identity("SELECT DISTINCT AS STRUCT 1 AS a, 2 AS b") self.validate_identity("SELECT AS VALUE STRUCT(1 AS a, 2 AS b)") self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])") self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))") + self.validate_identity("""CREATE TABLE x (a STRUCT<values ARRAY<INT64>>)""") + self.validate_identity("""CREATE TABLE x (a STRUCT<b STRING OPTIONS (description='b')>)""") + self.validate_identity( + """CREATE TABLE x (a STRING OPTIONS (description='x')) OPTIONS (table_expiration_days=1)""" + ) self.validate_identity( "SELECT * FROM (SELECT * FROM `t`) AS a UNPIVOT((c) FOR c_name IN (v1, v2))" ) @@ -98,6 +107,16 @@ class TestBigQuery(Validator): }, ) self.validate_all( + "CAST(a AS BYTES)", + write={ + "bigquery": "CAST(a AS BYTES)", + "duckdb": "CAST(a AS BLOB)", + "presto": "CAST(a AS VARBINARY)", + "hive": "CAST(a AS BINARY)", + "spark": "CAST(a AS BINARY)", + }, + ) + self.validate_all( "CAST(a AS NUMERIC)", write={ "bigquery": "CAST(a AS NUMERIC)", @@ -173,7 +192,6 @@ class TestBigQuery(Validator): "current_datetime", write={ "bigquery": "CURRENT_DATETIME()", - "duckdb": "CURRENT_DATETIME()", "presto": "CURRENT_DATETIME()", "hive": "CURRENT_DATETIME()", "spark": "CURRENT_DATETIME()", @@ -183,7 +201,7 @@ class TestBigQuery(Validator): "current_time", write={ "bigquery": "CURRENT_TIME()", - "duckdb": "CURRENT_TIME()", + "duckdb": "CURRENT_TIME", "presto": "CURRENT_TIME()", "hive": "CURRENT_TIME()", "spark": "CURRENT_TIME()", @@ -193,7 +211,7 @@ class TestBigQuery(Validator): "current_timestamp", write={ "bigquery": "CURRENT_TIMESTAMP()", - "duckdb": "CURRENT_TIMESTAMP()", + "duckdb": "CURRENT_TIMESTAMP", "postgres": "CURRENT_TIMESTAMP", "presto": "CURRENT_TIMESTAMP", "hive": "CURRENT_TIMESTAMP()", @@ -204,7 +222,7 @@ class TestBigQuery(Validator): "current_timestamp()", write={ "bigquery": "CURRENT_TIMESTAMP()", - "duckdb": "CURRENT_TIMESTAMP()", + "duckdb": "CURRENT_TIMESTAMP", "postgres": "CURRENT_TIMESTAMP", "presto": "CURRENT_TIMESTAMP", "hive": "CURRENT_TIMESTAMP()", @@ -343,6 +361,18 @@ class TestBigQuery(Validator): }, ) self.validate_all( + "SELECT cola, colb FROM (VALUES (1, 'test')) AS tab", + write={ + "bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS _c0, 'test' AS _c1)])", + }, + ) + self.validate_all( + "SELECT cola, colb FROM (VALUES (1, 'test'))", + write={ + "bigquery": "SELECT cola, colb FROM UNNEST([STRUCT(1 AS _c0, 'test' AS _c1)])", + }, + ) + self.validate_all( "SELECT cola, colb, colc FROM (VALUES (1, 'test', NULL)) AS tab(cola, colb, colc)", write={ "spark": "SELECT cola, colb, colc FROM VALUES (1, 'test', NULL) AS tab(cola, colb, colc)", diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 9fd2b45..1060881 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -65,3 +65,24 @@ class TestClickhouse(Validator): self.validate_identity("WITH SUM(bytes) AS foo SELECT foo FROM system.parts") self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5") self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1") + + def test_signed_and_unsigned_types(self): + data_types = [ + "UInt8", + "UInt16", + "UInt32", + "UInt64", + "UInt128", + "UInt256", + "Int8", + "Int16", + "Int32", + "Int64", + "Int128", + "Int256", + ] + for data_type in data_types: + self.validate_all( + f"pow(2, 32)::{data_type}", + write={"clickhouse": f"CAST(POWER(2, 32) AS {data_type})"}, + ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index bcbbfd6..f12273b 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -95,7 +95,7 @@ class TestDialect(Validator): self.validate_all( "CAST(a AS BINARY(4))", write={ - "bigquery": "CAST(a AS BINARY)", + "bigquery": "CAST(a AS BYTES)", "clickhouse": "CAST(a AS BINARY(4))", "drill": "CAST(a AS VARBINARY(4))", "duckdb": "CAST(a AS BLOB(4))", @@ -114,7 +114,7 @@ class TestDialect(Validator): self.validate_all( "CAST(a AS VARBINARY(4))", write={ - "bigquery": "CAST(a AS VARBINARY)", + "bigquery": "CAST(a AS BYTES)", "clickhouse": "CAST(a AS VARBINARY(4))", "duckdb": "CAST(a AS BLOB(4))", "mysql": "CAST(a AS VARBINARY(4))", diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 9e0040c..8c1b748 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -6,6 +6,9 @@ class TestDuckDB(Validator): dialect = "duckdb" def test_time(self): + self.validate_identity("SELECT CURRENT_DATE") + self.validate_identity("SELECT CURRENT_TIMESTAMP") + self.validate_all( "EPOCH(x)", read={ @@ -24,7 +27,7 @@ class TestDuckDB(Validator): "bigquery": "UNIX_TO_TIME(x / 1000)", "duckdb": "TO_TIMESTAMP(x / 1000)", "presto": "FROM_UNIXTIME(x / 1000)", - "spark": "FROM_UNIXTIME(x / 1000)", + "spark": "CAST(FROM_UNIXTIME(x / 1000) AS TIMESTAMP)", }, ) self.validate_all( @@ -124,18 +127,34 @@ class TestDuckDB(Validator): self.validate_identity("SELECT {'a': 1} AS x") self.validate_identity("SELECT {'a': {'b': {'c': 1}}, 'd': {'e': 2}} AS x") self.validate_identity("SELECT {'x': 1, 'y': 2, 'z': 3}") - self.validate_identity( - "SELECT {'yes': 'duck', 'maybe': 'goose', 'huh': NULL, 'no': 'heron'}" - ) self.validate_identity("SELECT {'key1': 'string', 'key2': 1, 'key3': 12.345}") self.validate_identity("SELECT ROW(x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)") self.validate_identity("SELECT (x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)") self.validate_identity("SELECT a.x FROM (SELECT {'x': 1, 'y': 2, 'z': 3} AS a)") self.validate_identity("ATTACH DATABASE ':memory:' AS new_database") self.validate_identity( + "SELECT {'yes': 'duck', 'maybe': 'goose', 'huh': NULL, 'no': 'heron'}" + ) + self.validate_identity( "SELECT a['x space'] FROM (SELECT {'x space': 1, 'y': 2, 'z': 3} AS a)" ) + self.validate_all("0b1010", write={"": "0 AS b1010"}) + self.validate_all("0x1010", write={"": "0 AS x1010"}) + self.validate_all( + """SELECT DATEDIFF('day', t1."A", t1."B") FROM "table" AS t1""", + write={ + "duckdb": """SELECT DATE_DIFF('day', t1."A", t1."B") FROM "table" AS t1""", + "trino": """SELECT DATE_DIFF('day', t1."A", t1."B") FROM "table" AS t1""", + }, + ) + self.validate_all( + "SELECT DATE_DIFF('day', DATE '2020-01-01', DATE '2020-01-05')", + write={ + "duckdb": "SELECT DATE_DIFF('day', CAST('2020-01-01' AS DATE), CAST('2020-01-05' AS DATE))", + "trino": "SELECT DATE_DIFF('day', CAST('2020-01-01' AS DATE), CAST('2020-01-05' AS DATE))", + }, + ) self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"}) self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'}) self.validate_all( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 524d95e..f31b1b9 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -12,6 +12,7 @@ class TestMySQL(Validator): "duckdb": "CREATE TABLE z (a INT)", "mysql": "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'", "spark": "CREATE TABLE z (a INT) COMMENT 'x'", + "sqlite": "CREATE TABLE z (a INTEGER)", }, ) self.validate_all( @@ -24,6 +25,19 @@ class TestMySQL(Validator): "INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1" ) + self.validate_all( + "CREATE TABLE x (id int not null auto_increment, primary key (id))", + write={ + "sqlite": "CREATE TABLE x (id INTEGER NOT NULL AUTOINCREMENT PRIMARY KEY)", + }, + ) + self.validate_all( + "CREATE TABLE x (id int not null auto_increment)", + write={ + "sqlite": "CREATE TABLE x (id INTEGER NOT NULL)", + }, + ) + def test_identity(self): self.validate_identity("SELECT CURRENT_TIMESTAMP(6)") self.validate_identity("x ->> '$.name'") @@ -150,47 +164,81 @@ class TestMySQL(Validator): ) def test_hexadecimal_literal(self): - self.validate_all( - "SELECT 0xCC", - write={ - "mysql": "SELECT x'CC'", - "sqlite": "SELECT x'CC'", - "spark": "SELECT X'CC'", - "trino": "SELECT X'CC'", - "bigquery": "SELECT 0xCC", - "oracle": "SELECT 204", - }, - ) - self.validate_all( - "SELECT X'1A'", - write={ - "mysql": "SELECT x'1A'", - }, - ) - self.validate_all( - "SELECT 0xz", - write={ - "mysql": "SELECT `0xz`", - }, - ) + write_CC = { + "bigquery": "SELECT 0xCC", + "clickhouse": "SELECT 0xCC", + "databricks": "SELECT 204", + "drill": "SELECT 204", + "duckdb": "SELECT 204", + "hive": "SELECT 204", + "mysql": "SELECT x'CC'", + "oracle": "SELECT 204", + "postgres": "SELECT x'CC'", + "presto": "SELECT 204", + "redshift": "SELECT 204", + "snowflake": "SELECT x'CC'", + "spark": "SELECT X'CC'", + "sqlite": "SELECT x'CC'", + "starrocks": "SELECT x'CC'", + "tableau": "SELECT 204", + "teradata": "SELECT 204", + "trino": "SELECT X'CC'", + "tsql": "SELECT 0xCC", + } + write_CC_with_leading_zeros = { + "bigquery": "SELECT 0x0000CC", + "clickhouse": "SELECT 0x0000CC", + "databricks": "SELECT 204", + "drill": "SELECT 204", + "duckdb": "SELECT 204", + "hive": "SELECT 204", + "mysql": "SELECT x'0000CC'", + "oracle": "SELECT 204", + "postgres": "SELECT x'0000CC'", + "presto": "SELECT 204", + "redshift": "SELECT 204", + "snowflake": "SELECT x'0000CC'", + "spark": "SELECT X'0000CC'", + "sqlite": "SELECT x'0000CC'", + "starrocks": "SELECT x'0000CC'", + "tableau": "SELECT 204", + "teradata": "SELECT 204", + "trino": "SELECT X'0000CC'", + "tsql": "SELECT 0x0000CC", + } + + self.validate_all("SELECT X'1A'", write={"mysql": "SELECT x'1A'"}) + self.validate_all("SELECT 0xz", write={"mysql": "SELECT `0xz`"}) + self.validate_all("SELECT 0xCC", write=write_CC) + self.validate_all("SELECT 0xCC ", write=write_CC) + self.validate_all("SELECT x'CC'", write=write_CC) + self.validate_all("SELECT 0x0000CC", write=write_CC_with_leading_zeros) + self.validate_all("SELECT x'0000CC'", write=write_CC_with_leading_zeros) def test_bits_literal(self): - self.validate_all( - "SELECT 0b1011", - write={ - "mysql": "SELECT b'1011'", - "postgres": "SELECT b'1011'", - "oracle": "SELECT 11", - }, - ) - self.validate_all( - "SELECT B'1011'", - write={ - "mysql": "SELECT b'1011'", - "postgres": "SELECT b'1011'", - "oracle": "SELECT 11", - }, - ) + write_1011 = { + "bigquery": "SELECT 11", + "clickhouse": "SELECT 0b1011", + "databricks": "SELECT 11", + "drill": "SELECT 11", + "hive": "SELECT 11", + "mysql": "SELECT b'1011'", + "oracle": "SELECT 11", + "postgres": "SELECT b'1011'", + "presto": "SELECT 11", + "redshift": "SELECT 11", + "snowflake": "SELECT 11", + "spark": "SELECT 11", + "sqlite": "SELECT 11", + "mysql": "SELECT b'1011'", + "tableau": "SELECT 11", + "teradata": "SELECT 11", + "trino": "SELECT 11", + "tsql": "SELECT 11", + } + + self.validate_all("SELECT 0b1011", write=write_1011) + self.validate_all("SELECT b'1011'", write=write_1011) def test_string_literals(self): self.validate_all( diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index dd297d6..88c79fd 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -5,6 +5,8 @@ class TestOracle(Validator): dialect = "oracle" def test_oracle(self): + self.validate_identity("SELECT * FROM table_name@dblink_name.database_link_domain") + self.validate_identity("SELECT * FROM table_name SAMPLE (25) s") self.validate_identity("SELECT * FROM V$SESSION") self.validate_identity( "SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name" @@ -17,7 +19,6 @@ class TestOracle(Validator): "": "IFNULL(NULL, 1)", }, ) - self.validate_all( "DATE '2022-01-01'", write={ @@ -28,6 +29,21 @@ class TestOracle(Validator): }, ) + self.validate_all( + "x::binary_double", + write={ + "oracle": "CAST(x AS DOUBLE PRECISION)", + "": "CAST(x AS DOUBLE)", + }, + ) + self.validate_all( + "x::binary_float", + write={ + "oracle": "CAST(x AS FLOAT)", + "": "CAST(x AS FLOAT)", + }, + ) + def test_join_marker(self): self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y") self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y = e2.y (+)") diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index e2f9c41..b535a84 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -98,6 +98,21 @@ class TestPostgres(Validator): self.validate_identity("STRING_AGG(x, ',' ORDER BY y DESC)") self.validate_identity("STRING_AGG(DISTINCT x, ',' ORDER BY y DESC)") self.validate_identity("SELECT CASE WHEN SUBSTRING('abcdefg') IN ('ab') THEN 1 ELSE 0 END") + self.validate_identity("COMMENT ON TABLE mytable IS 'this'") + self.validate_identity("SELECT e'\\xDEADBEEF'") + self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") + self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')") + self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')") + self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") + self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""") + self.validate_identity("x ~ 'y'") + self.validate_identity("x ~* 'y'") + self.validate_identity( + "SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)" + ) + self.validate_identity( + "CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)" + ) self.validate_identity( "SELECT CASE WHEN SUBSTRING('abcdefg' FROM 1) IN ('ab') THEN 1 ELSE 0 END" ) @@ -107,37 +122,31 @@ class TestPostgres(Validator): self.validate_identity( 'SELECT * FROM "x" WHERE SUBSTRING("x"."foo" FROM 1 FOR 2) IN (\'mas\')' ) - self.validate_identity("SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '...$') IN ('mas')") self.validate_identity( "SELECT * FROM x WHERE SUBSTRING('Thomas' FROM '%#\"o_a#\"_' FOR '#') IN ('mas')" ) self.validate_identity( "SELECT SUBSTRING('bla' + 'foo' || 'bar' FROM 3 - 1 + 5 FOR 4 + SOME_FUNC(arg1, arg2))" ) - self.validate_identity("SELECT TRIM(' X' FROM ' XXX ')") - self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ' COLLATE utf8_bin)") self.validate_identity( "SELECT TO_TIMESTAMP(1284352323.5), TO_TIMESTAMP('05 Dec 2000', 'DD Mon YYYY')" ) - self.validate_identity("COMMENT ON TABLE mytable IS 'this'") - self.validate_identity("SELECT e'\\xDEADBEEF'") - self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") + + self.validate_all( + "SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY amount)", + write={ + "databricks": "SELECT PERCENTILE_APPROX(amount, 0.5)", + "presto": "SELECT APPROX_PERCENTILE(amount, 0.5)", + "spark": "SELECT PERCENTILE_APPROX(amount, 0.5)", + "trino": "SELECT APPROX_PERCENTILE(amount, 0.5)", + }, + ) self.validate_all( "e'x'", write={ "mysql": "x", }, ) - self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""") - self.validate_identity( - "SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)" - ) - self.validate_identity( - "CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)" - ) - self.validate_identity("x ~ 'y'") - self.validate_identity("x ~* 'y'") - self.validate_all( "SELECT DATE_PART('isodow'::varchar(6), current_date)", write={ @@ -198,6 +207,33 @@ class TestPostgres(Validator): }, ) self.validate_all( + "GENERATE_SERIES(a, b)", + write={ + "postgres": "GENERATE_SERIES(a, b)", + "presto": "SEQUENCE(a, b)", + "trino": "SEQUENCE(a, b)", + "tsql": "GENERATE_SERIES(a, b)", + }, + ) + self.validate_all( + "GENERATE_SERIES(a, b)", + read={ + "postgres": "GENERATE_SERIES(a, b)", + "presto": "SEQUENCE(a, b)", + "trino": "SEQUENCE(a, b)", + "tsql": "GENERATE_SERIES(a, b)", + }, + ) + self.validate_all( + "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)", + write={ + "postgres": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)", + "presto": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4))", + "trino": "SELECT * FROM t CROSS JOIN UNNEST(SEQUENCE(2, 4))", + "tsql": "SELECT * FROM t CROSS JOIN GENERATE_SERIES(2, 4)", + }, + ) + self.validate_all( "END WORK AND NO CHAIN", write={"postgres": "COMMIT AND NO CHAIN"}, ) @@ -464,6 +500,14 @@ class TestPostgres(Validator): }, ) + self.validate_all( + "x / y ^ z", + write={ + "": "x / POWER(y, z)", + "postgres": "x / y ^ z", + }, + ) + self.assertIsInstance(parse_one("id::UUID", read="postgres"), exp.TryCast) def test_bool_or(self): diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 3080476..15962cc 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -7,6 +7,26 @@ class TestPresto(Validator): def test_cast(self): self.validate_all( + "FROM_BASE64(x)", + read={ + "hive": "UNBASE64(x)", + }, + write={ + "hive": "UNBASE64(x)", + "presto": "FROM_BASE64(x)", + }, + ) + self.validate_all( + "TO_BASE64(x)", + read={ + "hive": "BASE64(x)", + }, + write={ + "hive": "BASE64(x)", + "presto": "TO_BASE64(x)", + }, + ) + self.validate_all( "CAST(a AS ARRAY(INT))", write={ "bigquery": "CAST(a AS ARRAY<INT64>)", @@ -105,6 +125,13 @@ class TestPresto(Validator): "spark": "SIZE(x)", }, ) + self.validate_all( + "ARRAY_JOIN(x, '-', 'a')", + write={ + "hive": "CONCAT_WS('-', x)", + "spark": "ARRAY_JOIN(x, '-', 'a')", + }, + ) def test_interval_plural_to_singular(self): # Microseconds, weeks and quarters are not supported in Presto/Trino INTERVAL literals @@ -134,6 +161,14 @@ class TestPresto(Validator): self.validate_identity("VAR_POP(a)") self.validate_all( + "SELECT FROM_UNIXTIME(col) FROM tbl", + write={ + "presto": "SELECT FROM_UNIXTIME(col) FROM tbl", + "spark": "SELECT CAST(FROM_UNIXTIME(col) AS TIMESTAMP) FROM tbl", + "trino": "SELECT FROM_UNIXTIME(col) FROM tbl", + }, + ) + self.validate_all( "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", write={ "duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", @@ -181,7 +216,7 @@ class TestPresto(Validator): "duckdb": "TO_TIMESTAMP(x)", "presto": "FROM_UNIXTIME(x)", "hive": "FROM_UNIXTIME(x)", - "spark": "FROM_UNIXTIME(x)", + "spark": "CAST(FROM_UNIXTIME(x) AS TIMESTAMP)", }, ) self.validate_all( @@ -583,6 +618,14 @@ class TestPresto(Validator): }, ) + self.validate_all( + "JSON_FORMAT(JSON 'x')", + write={ + "presto": "JSON_FORMAT(CAST('x' AS JSON))", + "spark": "TO_JSON('x')", + }, + ) + def test_encode_decode(self): self.validate_all( "TO_UTF8(x)", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index e5bd0e5..f75480e 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -101,7 +101,22 @@ class TestRedshift(Validator): self.validate_all( "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", write={ + "bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", + "databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", + "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", + "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", + "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1", + "oracle": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', + "presto": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', "redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1', + "snowflake": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1', + "spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", + "sqlite": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', + "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1", + "tableau": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', + "teradata": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', + "trino": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', + "tsql": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 5c8b096..57ee235 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -227,7 +227,7 @@ class TestSnowflake(Validator): write={ "bigquery": "SELECT UNIX_TO_TIME(1659981729)", "snowflake": "SELECT TO_TIMESTAMP(1659981729)", - "spark": "SELECT FROM_UNIXTIME(1659981729)", + "spark": "SELECT CAST(FROM_UNIXTIME(1659981729) AS TIMESTAMP)", }, ) self.validate_all( @@ -243,7 +243,7 @@ class TestSnowflake(Validator): write={ "bigquery": "SELECT UNIX_TO_TIME('1659981729')", "snowflake": "SELECT TO_TIMESTAMP('1659981729')", - "spark": "SELECT FROM_UNIXTIME('1659981729')", + "spark": "SELECT CAST(FROM_UNIXTIME('1659981729') AS TIMESTAMP)", }, ) self.validate_all( @@ -401,7 +401,7 @@ class TestSnowflake(Validator): self.validate_all( r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1", write={ - "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" + "snowflake": r"SELECT FIRST_VALUE(TABLE1.COLUMN1 RESPECT NULLS) OVER (PARTITION BY RANDOM_COLUMN1, RANDOM_COLUMN2 ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS MY_ALIAS FROM TABLE1" }, ) self.validate_all( @@ -426,30 +426,19 @@ class TestSnowflake(Validator): def test_sample(self): self.validate_identity("SELECT * FROM testtable TABLESAMPLE BERNOULLI (20.3)") self.validate_identity("SELECT * FROM testtable TABLESAMPLE (100)") + self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)") + self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)") + self.validate_identity("SELECT * FROM testtable SAMPLE (10)") + self.validate_identity("SELECT * FROM testtable SAMPLE ROW (0)") + self.validate_identity("SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)") self.validate_identity( "SELECT i, j FROM table1 AS t1 INNER JOIN table2 AS t2 TABLESAMPLE (50) WHERE t2.j = t1.i" ) self.validate_identity( "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) TABLESAMPLE (1)" ) - self.validate_identity("SELECT * FROM testtable TABLESAMPLE SYSTEM (3) SEED (82)") - self.validate_identity("SELECT * FROM testtable TABLESAMPLE (10 ROWS)") self.validate_all( - "SELECT * FROM testtable SAMPLE (10)", - write={"snowflake": "SELECT * FROM testtable TABLESAMPLE (10)"}, - ) - self.validate_all( - "SELECT * FROM testtable SAMPLE ROW (0)", - write={"snowflake": "SELECT * FROM testtable TABLESAMPLE ROW (0)"}, - ) - self.validate_all( - "SELECT a FROM test SAMPLE BLOCK (0.5) SEED (42)", - write={ - "snowflake": "SELECT a FROM test TABLESAMPLE BLOCK (0.5) SEED (42)", - }, - ) - self.validate_all( """ SELECT i, j FROM @@ -458,13 +447,13 @@ class TestSnowflake(Validator): table2 AS t2 SAMPLE (50) -- 50% of rows in table2 WHERE t2.j = t1.i""", write={ - "snowflake": "SELECT i, j FROM table1 AS t1 TABLESAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 TABLESAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i", + "snowflake": "SELECT i, j FROM table1 AS t1 SAMPLE (25) /* 25% of rows in table1 */ INNER JOIN table2 AS t2 SAMPLE (50) /* 50% of rows in table2 */ WHERE t2.j = t1.i", }, ) self.validate_all( "SELECT * FROM testtable SAMPLE BLOCK (0.012) REPEATABLE (99992)", write={ - "snowflake": "SELECT * FROM testtable TABLESAMPLE BLOCK (0.012) SEED (99992)", + "snowflake": "SELECT * FROM testtable SAMPLE BLOCK (0.012) SEED (99992)", }, ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index bfaed53..be03b4e 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -215,40 +215,45 @@ TBLPROPERTIES ( self.validate_identity("SPLIT(str, pattern, lim)") self.validate_all( - "BOOLEAN(x)", - write={ - "": "CAST(x AS BOOLEAN)", - "spark": "CAST(x AS BOOLEAN)", + "SELECT * FROM produce PIVOT(SUM(produce.sales) FOR quarter IN ('Q1', 'Q2'))", + read={ + "snowflake": "SELECT * FROM produce PIVOT (SUM(produce.sales) FOR produce.quarter IN ('Q1', 'Q2'))", }, ) self.validate_all( - "INT(x)", - write={ - "": "CAST(x AS INT)", - "spark": "CAST(x AS INT)", - }, - ) - self.validate_all( - "STRING(x)", - write={ - "": "CAST(x AS TEXT)", - "spark": "CAST(x AS STRING)", + "SELECT * FROM produce AS p PIVOT(SUM(p.sales) AS sales FOR quarter IN ('Q1' AS Q1, 'Q2' AS Q1))", + read={ + "bigquery": "SELECT * FROM produce AS p PIVOT(SUM(p.sales) AS sales FOR p.quarter IN ('Q1' AS Q1, 'Q2' AS Q1))", }, ) self.validate_all( - "DATE(x)", + "SELECT DATEDIFF(MONTH, '2020-01-01', '2020-03-05')", write={ - "": "CAST(x AS DATE)", - "spark": "CAST(x AS DATE)", + "databricks": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))", + "hive": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))", + "presto": "SELECT DATE_DIFF('MONTH', CAST(SUBSTR(CAST('2020-01-01' AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST('2020-03-05' AS VARCHAR), 1, 10) AS DATE))", + "spark": "SELECT DATEDIFF(MONTH, TO_DATE('2020-01-01'), TO_DATE('2020-03-05'))", + "spark2": "SELECT MONTHS_BETWEEN(TO_DATE('2020-03-05'), TO_DATE('2020-01-01'))", + "trino": "SELECT DATE_DIFF('MONTH', CAST(SUBSTR(CAST('2020-01-01' AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST('2020-03-05' AS VARCHAR), 1, 10) AS DATE))", }, ) + + for data_type in ("BOOLEAN", "DATE", "DOUBLE", "FLOAT", "INT", "TIMESTAMP"): + self.validate_all( + f"{data_type}(x)", + write={ + "": f"CAST(x AS {data_type})", + "spark": f"CAST(x AS {data_type})", + }, + ) self.validate_all( - "TIMESTAMP(x)", + "STRING(x)", write={ - "": "CAST(x AS TIMESTAMP)", - "spark": "CAST(x AS TIMESTAMP)", + "": "CAST(x AS TEXT)", + "spark": "CAST(x AS STRING)", }, ) + self.validate_all( "CAST(x AS TIMESTAMP)", read={"trino": "CAST(x AS TIMESTAMP(6) WITH TIME ZONE)"} ) diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py index b33231c..96e20da 100644 --- a/tests/dialects/test_starrocks.py +++ b/tests/dialects/test_starrocks.py @@ -10,3 +10,11 @@ class TestMySQL(Validator): def test_time(self): self.validate_identity("TIMESTAMP('2022-01-01')") + + def test_regex(self): + self.validate_all( + "SELECT REGEXP_LIKE(abc, '%foo%')", + write={ + "starrocks": "SELECT REGEXP(abc, '%foo%')", + }, + ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index b6e893c..3a3ac73 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -485,26 +485,30 @@ WHERE def test_date_diff(self): self.validate_identity("SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')") + self.validate_all( "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", write={ "tsql": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", - "spark": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12", + "spark": "SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')", + "spark2": "SELECT MONTHS_BETWEEN('2021/01/01', '2020/01/01') / 12", }, ) self.validate_all( "SELECT DATEDIFF(mm, 'start','end')", write={ - "spark": "SELECT MONTHS_BETWEEN('end', 'start')", - "tsql": "SELECT DATEDIFF(month, 'start', 'end')", "databricks": "SELECT DATEDIFF(month, 'start', 'end')", + "spark2": "SELECT MONTHS_BETWEEN('end', 'start')", + "tsql": "SELECT DATEDIFF(month, 'start', 'end')", }, ) self.validate_all( "SELECT DATEDIFF(quarter, 'start', 'end')", write={ - "spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3", "databricks": "SELECT DATEDIFF(quarter, 'start', 'end')", + "spark": "SELECT DATEDIFF(quarter, 'start', 'end')", + "spark2": "SELECT MONTHS_BETWEEN('end', 'start') / 3", + "tsql": "SELECT DATEDIFF(quarter, 'start', 'end')", }, ) |