diff options
Diffstat (limited to 'tests/dialects')
-rw-r--r-- | tests/dialects/test_bigquery.py | 21 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_databricks.py | 36 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 41 | ||||
-rw-r--r-- | tests/dialects/test_drill.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 18 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 55 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 89 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 48 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 94 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 83 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 40 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 19 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 35 | ||||
-rw-r--r-- | tests/dialects/test_starrocks.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_teradata.py | 37 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 91 |
17 files changed, 637 insertions, 74 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index e210292..703b7dc 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -14,10 +14,16 @@ class TestBigQuery(Validator): "SELECT * FROM (SELECT * FROM `t`) AS a UNPIVOT((c) FOR c_name IN (v1, v2))" ) + self.validate_all( + "CREATE TEMP TABLE foo AS SELECT 1", + write={"bigquery": "CREATE TEMPORARY TABLE foo AS SELECT 1"}, + ) self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"}) self.validate_all("CAST(x AS CHAR)", write={"bigquery": "CAST(x AS STRING)"}) self.validate_all("CAST(x AS NCHAR)", write={"bigquery": "CAST(x AS STRING)"}) self.validate_all("CAST(x AS NVARCHAR)", write={"bigquery": "CAST(x AS STRING)"}) + self.validate_all("CAST(x AS TIMESTAMP)", write={"bigquery": "CAST(x AS DATETIME)"}) + self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"bigquery": "CAST(x AS TIMESTAMP)"}) self.validate_all( "SELECT ARRAY(SELECT AS STRUCT 1 a, 2 b)", write={ @@ -59,7 +65,7 @@ class TestBigQuery(Validator): "spark": r"'/\*.*\*/'", }, ) - with self.assertRaises(RuntimeError): + with self.assertRaises(ValueError): transpile("'\\'", read="bigquery") self.validate_all( @@ -285,6 +291,7 @@ class TestBigQuery(Validator): "DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY)", write={ "postgres": "CURRENT_DATE - INTERVAL '1' DAY", + "bigquery": "DATE_SUB(CURRENT_DATE, INTERVAL 1 DAY)", }, ) self.validate_all( @@ -359,11 +366,23 @@ class TestBigQuery(Validator): self.validate_identity("BEGIN TRANSACTION") self.validate_identity("COMMIT TRANSACTION") self.validate_identity("ROLLBACK TRANSACTION") + self.validate_identity("CAST(x AS BIGNUMERIC)") + + self.validate_identity("SELECT * FROM UNNEST([1]) WITH ORDINALITY") + self.validate_all( + "SELECT * FROM UNNEST([1]) WITH OFFSET", + write={"bigquery": "SELECT * FROM UNNEST([1]) WITH OFFSET AS offset"}, + ) + self.validate_all( + "SELECT * FROM UNNEST([1]) WITH OFFSET y", + write={"bigquery": "SELECT * FROM UNNEST([1]) WITH OFFSET AS y"}, + ) def test_user_defined_functions(self): self.validate_identity( "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'" ) + self.validate_identity("CREATE TEMPORARY FUNCTION udf(x ANY TYPE) AS (x)") self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)") self.validate_identity( "CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t" diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 40a3a04..9fd2b45 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -5,6 +5,7 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): + self.validate_identity("SELECT INTERVAL t.days day") self.validate_identity("SELECT match('abc', '([a-z]+)')") self.validate_identity("dictGet(x, 'y')") self.validate_identity("SELECT * FROM x FINAL") diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 48ea6d1..4619108 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -5,10 +5,46 @@ class TestDatabricks(Validator): dialect = "databricks" def test_databricks(self): + self.validate_identity("SELECT c1 : price") self.validate_identity("CREATE FUNCTION a.b(x INT) RETURNS INT RETURN x + 1") self.validate_identity("CREATE FUNCTION a AS b") self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1") + # https://docs.databricks.com/sql/language-manual/functions/colonsign.html + def test_json(self): + self.validate_identity("""SELECT c1 : price FROM VALUES ('{ "price": 5 }') AS T(c1)""") + + self.validate_all( + """SELECT c1:['price'] FROM VALUES('{ "price": 5 }') AS T(c1)""", + write={ + "databricks": """SELECT c1 : ARRAY('price') FROM VALUES ('{ "price": 5 }') AS T(c1)""", + }, + ) + self.validate_all( + """SELECT c1:item[1].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + write={ + "databricks": """SELECT c1 : item[1].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + }, + ) + self.validate_all( + """SELECT c1:item[*].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + write={ + "databricks": """SELECT c1 : item[*].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + }, + ) + self.validate_all( + """SELECT from_json(c1:item[*].price, 'ARRAY<DOUBLE>')[0] FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + write={ + "databricks": """SELECT FROM_JSON(c1 : item[*].price, 'ARRAY<DOUBLE>')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + }, + ) + self.validate_all( + """SELECT inline(from_json(c1:item[*], 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + write={ + "databricks": """SELECT INLINE(FROM_JSON(c1 : item[*], 'ARRAY<STRUCT<model STRING, price DOUBLE>>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + }, + ) + def test_datediff(self): self.validate_all( "SELECT DATEDIFF(year, 'start', 'end')", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 3558d62..bcbbfd6 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -435,6 +435,7 @@ class TestDialect(Validator): write={ "duckdb": "EPOCH(CAST('2020-01-01' AS TIMESTAMP))", "hive": "UNIX_TIMESTAMP('2020-01-01')", + "mysql": "UNIX_TIMESTAMP('2020-01-01')", "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %T'))", }, ) @@ -561,25 +562,25 @@ class TestDialect(Validator): }, ) self.validate_all( - "DATE_ADD(x, 1, 'day')", + "DATE_ADD(x, 1, 'DAY')", read={ "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", - "snowflake": "DATEADD('day', 1, x)", + "snowflake": "DATEADD('DAY', 1, x)", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", }, write={ "bigquery": "DATE_ADD(x, INTERVAL 1 DAY)", "drill": "DATE_ADD(x, INTERVAL 1 DAY)", - "duckdb": "x + INTERVAL 1 day", + "duckdb": "x + INTERVAL 1 DAY", "hive": "DATE_ADD(x, 1)", "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", - "postgres": "x + INTERVAL '1' day", - "presto": "DATE_ADD('day', 1, x)", - "snowflake": "DATEADD(day, 1, x)", + "postgres": "x + INTERVAL '1' DAY", + "presto": "DATE_ADD('DAY', 1, x)", + "snowflake": "DATEADD(DAY, 1, x)", "spark": "DATE_ADD(x, 1)", - "sqlite": "DATE(x, '1 day')", + "sqlite": "DATE(x, '1 DAY')", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", - "tsql": "DATEADD(day, 1, x)", + "tsql": "DATEADD(DAY, 1, x)", }, ) self.validate_all( @@ -632,13 +633,13 @@ class TestDialect(Validator): }, ) self.validate_all( + "TIMESTAMP_TRUNC(TRY_CAST(x AS DATE), day)", + read={"postgres": "DATE_TRUNC('day', x::DATE)"}, + ) + self.validate_all( "TIMESTAMP_TRUNC(CAST(x AS DATE), day)", - read={ - "postgres": "DATE_TRUNC('day', x::DATE)", - "starrocks": "DATE_TRUNC('day', x::DATE)", - }, + read={"starrocks": "DATE_TRUNC('day', x::DATE)"}, ) - self.validate_all( "DATE_TRUNC('week', x)", write={ @@ -752,6 +753,20 @@ class TestDialect(Validator): }, ) self.validate_all( + "TS_OR_DS_ADD(x, 1, 'DAY')", + write={ + "presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR(CAST(x AS VARCHAR), 1, 10), '%Y-%m-%d'))", + "hive": "DATE_ADD(x, 1)", + }, + ) + self.validate_all( + "TS_OR_DS_ADD(CURRENT_DATE, 1, 'DAY')", + write={ + "presto": "DATE_ADD('DAY', 1, CURRENT_DATE)", + "hive": "DATE_ADD(CURRENT_DATE, 1)", + }, + ) + self.validate_all( "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", write={ "drill": "DATE_ADD(CAST('2020-01-01' AS DATE), INTERVAL 1 DAY)", diff --git a/tests/dialects/test_drill.py b/tests/dialects/test_drill.py index f035176..a7f609a 100644 --- a/tests/dialects/test_drill.py +++ b/tests/dialects/test_drill.py @@ -14,7 +14,7 @@ class TestDrill(Validator): self.validate_all( "SELECT '2021-01-01' + INTERVAL 1 MONTH", write={ - "mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH", + "mysql": "SELECT '2021-01-01' + INTERVAL '1' MONTH", }, ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 245d82a..9e0040c 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -80,7 +80,7 @@ class TestDuckDB(Validator): "snowflake": "CONVERT_TIMEZONE('America/New_York', CAST(start AS TIMESTAMPTZ))", }, write={ - "bigquery": "TIMESTAMP(DATETIME(CAST(start AS TIMESTAMPTZ), 'America/New_York'))", + "bigquery": "TIMESTAMP(DATETIME(CAST(start AS TIMESTAMP), 'America/New_York'))", "duckdb": "CAST(start AS TIMESTAMPTZ) AT TIME ZONE 'America/New_York'", "snowflake": "CONVERT_TIMEZONE('America/New_York', CAST(start AS TIMESTAMPTZ))", }, @@ -149,6 +149,12 @@ class TestDuckDB(Validator): }, ) self.validate_all( + "CREATE TABLE IF NOT EXISTS table (cola INT COMMENT 'cola', colb STRING) USING ICEBERG PARTITIONED BY (colb)", + write={ + "duckdb": "CREATE TABLE IF NOT EXISTS table (cola INT, colb TEXT)", + }, + ) + self.validate_all( "LIST_VALUE(0, 1, 2)", read={ "spark": "ARRAY(0, 1, 2)", @@ -245,7 +251,7 @@ class TestDuckDB(Validator): }, ) self.validate_all( - "POWER(CAST(2 AS SMALLINT), 3)", + "POWER(TRY_CAST(2 AS SMALLINT), 3)", read={ "hive": "POW(2S, 3)", "spark": "POW(2S, 3)", @@ -339,6 +345,12 @@ class TestDuckDB(Validator): "bigquery": "ARRAY_CONCAT([1, 2], [3, 4])", }, ) + self.validate_all( + "SELECT CAST(CAST(x AS DATE) AS DATE) + INTERVAL 1 DAY", + read={ + "hive": "SELECT DATE_ADD(TO_DATE(x), 1)", + }, + ) with self.assertRaises(UnsupportedError): transpile( @@ -408,7 +420,7 @@ class TestDuckDB(Validator): "CAST(x AS DATE) + INTERVAL (7 * -1) DAY", read={"spark": "DATE_SUB(x, 7)"} ) self.validate_all( - "CAST(1 AS DOUBLE)", + "TRY_CAST(1 AS DOUBLE)", read={ "hive": "1d", "spark": "1d", diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 1a83575..c69368c 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -70,8 +70,8 @@ class TestHive(Validator): self.validate_all( "1s", write={ - "duckdb": "CAST(1 AS SMALLINT)", - "presto": "CAST(1 AS SMALLINT)", + "duckdb": "TRY_CAST(1 AS SMALLINT)", + "presto": "TRY_CAST(1 AS SMALLINT)", "hive": "CAST(1 AS SMALLINT)", "spark": "CAST(1 AS SHORT)", }, @@ -79,8 +79,8 @@ class TestHive(Validator): self.validate_all( "1S", write={ - "duckdb": "CAST(1 AS SMALLINT)", - "presto": "CAST(1 AS SMALLINT)", + "duckdb": "TRY_CAST(1 AS SMALLINT)", + "presto": "TRY_CAST(1 AS SMALLINT)", "hive": "CAST(1 AS SMALLINT)", "spark": "CAST(1 AS SHORT)", }, @@ -88,8 +88,8 @@ class TestHive(Validator): self.validate_all( "1Y", write={ - "duckdb": "CAST(1 AS TINYINT)", - "presto": "CAST(1 AS TINYINT)", + "duckdb": "TRY_CAST(1 AS TINYINT)", + "presto": "TRY_CAST(1 AS TINYINT)", "hive": "CAST(1 AS TINYINT)", "spark": "CAST(1 AS BYTE)", }, @@ -97,8 +97,8 @@ class TestHive(Validator): self.validate_all( "1L", write={ - "duckdb": "CAST(1 AS BIGINT)", - "presto": "CAST(1 AS BIGINT)", + "duckdb": "TRY_CAST(1 AS BIGINT)", + "presto": "TRY_CAST(1 AS BIGINT)", "hive": "CAST(1 AS BIGINT)", "spark": "CAST(1 AS LONG)", }, @@ -106,8 +106,8 @@ class TestHive(Validator): self.validate_all( "1.0bd", write={ - "duckdb": "CAST(1.0 AS DECIMAL)", - "presto": "CAST(1.0 AS DECIMAL)", + "duckdb": "TRY_CAST(1.0 AS DECIMAL)", + "presto": "TRY_CAST(1.0 AS DECIMAL)", "hive": "CAST(1.0 AS DECIMAL)", "spark": "CAST(1.0 AS DECIMAL)", }, @@ -148,6 +148,9 @@ class TestHive(Validator): self.validate_identity( """CREATE EXTERNAL TABLE x (y INT) ROW FORMAT SERDE 'serde' ROW FORMAT DELIMITED FIELDS TERMINATED BY '1' WITH SERDEPROPERTIES ('input.regex'='')""", ) + self.validate_identity( + """CREATE EXTERNAL TABLE `my_table` (`a7` ARRAY<DATE>) ROW FORMAT SERDE 'a' STORED AS INPUTFORMAT 'b' OUTPUTFORMAT 'c' LOCATION 'd' TBLPROPERTIES ('e'='f')""" + ) def test_lateral_view(self): self.validate_all( @@ -318,6 +321,11 @@ class TestHive(Validator): "": "TS_OR_DS_ADD('2020-01-01', 1 * -1, 'DAY')", }, ) + self.validate_all("DATE_ADD('2020-01-01', -1)", read={"": "DATE_SUB('2020-01-01', 1)"}) + self.validate_all("DATE_ADD(a, b * -1)", read={"": "DATE_SUB(a, b)"}) + self.validate_all( + "ADD_MONTHS('2020-01-01', -2)", read={"": "DATE_SUB('2020-01-01', 2, month)"} + ) self.validate_all( "DATEDIFF(TO_DATE(y), x)", write={ @@ -504,11 +512,10 @@ class TestHive(Validator): }, ) self.validate_all( - "SELECT * FROM x TABLESAMPLE(10) y", + "SELECT * FROM x TABLESAMPLE(10 PERCENT) y", write={ - "presto": "SELECT * FROM x AS y TABLESAMPLE (10)", - "hive": "SELECT * FROM x TABLESAMPLE (10) AS y", - "spark": "SELECT * FROM x TABLESAMPLE (10) AS y", + "hive": "SELECT * FROM x TABLESAMPLE (10 PERCENT) AS y", + "spark": "SELECT * FROM x TABLESAMPLE (10 PERCENT) AS y", }, ) self.validate_all( @@ -650,25 +657,13 @@ class TestHive(Validator): }, ) self.validate_all( - "SELECT * FROM x TABLESAMPLE (1) AS foo", - read={ - "presto": "SELECT * FROM x AS foo TABLESAMPLE (1)", - }, - write={ - "presto": "SELECT * FROM x AS foo TABLESAMPLE (1)", - "hive": "SELECT * FROM x TABLESAMPLE (1) AS foo", - "spark": "SELECT * FROM x TABLESAMPLE (1) AS foo", - }, - ) - self.validate_all( - "SELECT * FROM x TABLESAMPLE (1) AS foo", + "SELECT * FROM x TABLESAMPLE (1 PERCENT) AS foo", read={ - "presto": "SELECT * FROM x AS foo TABLESAMPLE (1)", + "presto": "SELECT * FROM x AS foo TABLESAMPLE BERNOULLI (1)", }, write={ - "presto": "SELECT * FROM x AS foo TABLESAMPLE (1)", - "hive": "SELECT * FROM x TABLESAMPLE (1) AS foo", - "spark": "SELECT * FROM x TABLESAMPLE (1) AS foo", + "hive": "SELECT * FROM x TABLESAMPLE (1 PERCENT) AS foo", + "spark": "SELECT * FROM x TABLESAMPLE (1 PERCENT) AS foo", }, ) self.validate_all( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index f618728..524d95e 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -14,8 +14,18 @@ class TestMySQL(Validator): "spark": "CREATE TABLE z (a INT) COMMENT 'x'", }, ) + self.validate_all( + "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP) DEFAULT CHARSET=utf8 ROW_FORMAT=DYNAMIC", + write={ + "mysql": "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC", + }, + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1" + ) def test_identity(self): + self.validate_identity("SELECT CURRENT_TIMESTAMP(6)") self.validate_identity("x ->> '$.name'") self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ')") @@ -186,7 +196,7 @@ class TestMySQL(Validator): self.validate_all( 'SELECT "2021-01-01" + INTERVAL 1 MONTH', write={ - "mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH", + "mysql": "SELECT '2021-01-01' + INTERVAL '1' MONTH", }, ) @@ -239,14 +249,91 @@ class TestMySQL(Validator): write={"mysql": "MATCH(a.b) AGAINST('abc')"}, ) + def test_date_format(self): + self.validate_all( + "SELECT DATE_FORMAT('2017-06-15', '%Y')", + write={ + "mysql": "SELECT DATE_FORMAT('2017-06-15', '%Y')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'yyyy')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2017-06-15', '%m')", + write={ + "mysql": "SELECT DATE_FORMAT('2017-06-15', '%m')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'mm')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2017-06-15', '%d')", + write={ + "mysql": "SELECT DATE_FORMAT('2017-06-15', '%d')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'DD')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2017-06-15', '%Y-%m-%d')", + write={ + "mysql": "SELECT DATE_FORMAT('2017-06-15', '%Y-%m-%d')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'yyyy-mm-DD')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2017-06-15 22:23:34', '%H')", + write={ + "mysql": "SELECT DATE_FORMAT('2017-06-15 22:23:34', '%H')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15 22:23:34' AS TIMESTAMPNTZ), 'hh24')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2017-06-15', '%w')", + write={ + "mysql": "SELECT DATE_FORMAT('2017-06-15', '%w')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'dy')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')", + write={ + "mysql": "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')", + "snowflake": "SELECT TO_CHAR(CAST('2009-10-04 22:23:00' AS TIMESTAMPNTZ), 'DY mmmm yyyy')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2007-10-04 22:23:00', '%H:%i:%s')", + write={ + "mysql": "SELECT DATE_FORMAT('2007-10-04 22:23:00', '%T')", + "snowflake": "SELECT TO_CHAR(CAST('2007-10-04 22:23:00' AS TIMESTAMPNTZ), 'hh24:mi:ss')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %a %d %m %b')", + write={ + "mysql": "SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %W %d %m %b')", + "snowflake": "SELECT TO_CHAR(CAST('1900-10-04 22:23:00' AS TIMESTAMPNTZ), 'DD yy DY DD mm mon')", + }, + ) + + def test_mysql_time(self): + self.validate_identity("FROM_UNIXTIME(a, b)") + self.validate_identity("FROM_UNIXTIME(a, b, c)") + self.validate_identity("TIME_STR_TO_UNIX(x)", "UNIX_TIMESTAMP(x)") + def test_mysql(self): self.validate_all( + "SELECT DATE(DATE_SUB(`dt`, INTERVAL DAYOFMONTH(`dt`) - 1 DAY)) AS __timestamp FROM tableT", + write={ + "mysql": "SELECT DATE(DATE_SUB(`dt`, INTERVAL (DAYOFMONTH(`dt`) - 1) DAY)) AS __timestamp FROM tableT", + }, + ) + self.validate_all( "SELECT a FROM tbl FOR UPDATE", write={ "": "SELECT a FROM tbl", "mysql": "SELECT a FROM tbl FOR UPDATE", "oracle": "SELECT a FROM tbl FOR UPDATE", "postgres": "SELECT a FROM tbl FOR UPDATE", + "redshift": "SELECT a FROM tbl", "tsql": "SELECT a FROM tbl FOR UPDATE", }, ) diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 80fa0f1..dd297d6 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -6,6 +6,27 @@ class TestOracle(Validator): def test_oracle(self): self.validate_identity("SELECT * FROM V$SESSION") + self.validate_identity( + "SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name" + ) + + self.validate_all( + "NVL(NULL, 1)", + write={ + "oracle": "NVL(NULL, 1)", + "": "IFNULL(NULL, 1)", + }, + ) + + self.validate_all( + "DATE '2022-01-01'", + write={ + "": "DATE_STR_TO_DATE('2022-01-01')", + "mysql": "CAST('2022-01-01' AS DATE)", + "oracle": "TO_DATE('2022-01-01', 'YYYY-MM-DD')", + "postgres": "CAST('2022-01-01' AS DATE)", + }, + ) def test_join_marker(self): self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y") @@ -81,7 +102,7 @@ FROM warehouses, XMLTABLE( FROM XMLTABLE( 'ROWSET/ROW' PASSING - dbms_xmlgen.getxmltype ("SELECT table_name, column_name, data_default FROM user_tab_columns") + dbms_xmlgen.GETXMLTYPE('SELECT table_name, column_name, data_default FROM user_tab_columns') COLUMNS table_name VARCHAR2(128) PATH '*[1]', column_name VARCHAR2(128) PATH '*[2]', @@ -90,3 +111,28 @@ FROM XMLTABLE( }, pretty=True, ) + + def test_match_recognize(self): + self.validate_identity( + """SELECT + * +FROM sales_history +MATCH_RECOGNIZE ( + PARTITION BY product + ORDER BY + tstamp + MEASURES + STRT.tstamp AS start_tstamp, + LAST(UP.tstamp) AS peak_tstamp, + LAST(DOWN.tstamp) AS end_tstamp, + MATCH_NUMBER() AS mno + ONE ROW PER MATCH + AFTER MATCH SKIP TO LAST DOWN + PATTERN (STRT UP+ FLAT* DOWN+) + DEFINE + UP AS UP.units_sold > PREV(UP.units_sold), + FLAT AS FLAT.units_sold = PREV(FLAT.units_sold), + DOWN AS DOWN.units_sold < PREV(DOWN.units_sold) +) MR""", + pretty=True, + ) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index a89ae30..e2f9c41 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -1,4 +1,4 @@ -from sqlglot import ParseError, transpile +from sqlglot import ParseError, exp, parse_one, transpile from tests.dialects.test_dialect import Validator @@ -10,11 +10,25 @@ class TestPostgres(Validator): self.validate_identity("CREATE TABLE test (foo HSTORE)") self.validate_identity("CREATE TABLE test (foo JSONB)") self.validate_identity("CREATE TABLE test (foo VARCHAR(64)[])") - self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING a") self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING a, b") self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING *") self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO NOTHING RETURNING *" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO UPDATE SET x.id = 1 RETURNING *" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO UPDATE SET x.id = excluded.id RETURNING *" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT ON CONSTRAINT pkey DO NOTHING RETURNING *" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT ON CONSTRAINT pkey DO UPDATE SET x.id = 1 RETURNING *" + ) + self.validate_identity( "DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid RETURNING a" ) self.validate_identity("UPDATE tbl_name SET foo = 123 RETURNING a") @@ -75,6 +89,7 @@ class TestPostgres(Validator): self.validate_identity("SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]") self.validate_identity("SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]") self.validate_identity("$x") + self.validate_identity("x$") self.validate_identity("SELECT ARRAY[1, 2, 3]") self.validate_identity("SELECT ARRAY(SELECT 1)") self.validate_identity("SELECT ARRAY_LENGTH(ARRAY[1, 2, 3], 1)") @@ -107,6 +122,12 @@ class TestPostgres(Validator): self.validate_identity("COMMENT ON TABLE mytable IS 'this'") self.validate_identity("SELECT e'\\xDEADBEEF'") self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") + self.validate_all( + "e'x'", + write={ + "mysql": "x", + }, + ) self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""") self.validate_identity( "SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)" @@ -118,6 +139,28 @@ class TestPostgres(Validator): self.validate_identity("x ~* 'y'") self.validate_all( + "SELECT DATE_PART('isodow'::varchar(6), current_date)", + write={ + "postgres": "SELECT EXTRACT(CAST('isodow' AS VARCHAR(6)) FROM CURRENT_DATE)", + }, + ) + self.validate_all( + "SELECT DATE_PART('minute', timestamp '2023-01-04 04:05:06.789')", + write={ + "postgres": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", + "redshift": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", + "snowflake": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))", + }, + ) + self.validate_all( + "SELECT DATE_PART('month', date '20220502')", + write={ + "postgres": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + "redshift": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + "snowflake": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + }, + ) + self.validate_all( "SELECT (DATE '2016-01-10', DATE '2016-02-01') OVERLAPS (DATE '2016-01-20', DATE '2016-02-10')", write={ "postgres": "SELECT (CAST('2016-01-10' AS DATE), CAST('2016-02-01' AS DATE)) OVERLAPS (CAST('2016-01-20' AS DATE), CAST('2016-02-10' AS DATE))", @@ -141,17 +184,17 @@ class TestPostgres(Validator): self.validate_all( "GENERATE_SERIES(a, b, ' 2 days ')", write={ - "postgres": "GENERATE_SERIES(a, b, INTERVAL '2' days)", - "presto": "SEQUENCE(a, b, INTERVAL '2' days)", - "trino": "SEQUENCE(a, b, INTERVAL '2' days)", + "postgres": "GENERATE_SERIES(a, b, INTERVAL '2' day)", + "presto": "SEQUENCE(a, b, INTERVAL '2' day)", + "trino": "SEQUENCE(a, b, INTERVAL '2' day)", }, ) self.validate_all( "GENERATE_SERIES('2019-01-01'::TIMESTAMP, NOW(), '1day')", write={ "postgres": "GENERATE_SERIES(CAST('2019-01-01' AS TIMESTAMP), CURRENT_TIMESTAMP, INTERVAL '1' day)", - "presto": "SEQUENCE(CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)", - "trino": "SEQUENCE(CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)", + "presto": "SEQUENCE(TRY_CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)", + "trino": "SEQUENCE(TRY_CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)", }, ) self.validate_all( @@ -296,7 +339,10 @@ class TestPostgres(Validator): ) self.validate_all( """'{"a":1,"b":2}'::json->'b'""", - write={"postgres": """CAST('{"a":1,"b":2}' AS JSON) -> 'b'"""}, + write={ + "postgres": """CAST('{"a":1,"b":2}' AS JSON) -> 'b'""", + "redshift": """CAST('{"a":1,"b":2}' AS JSON)."b\"""", + }, ) self.validate_all( """'{"x": {"y": 1}}'::json->'x'->'y'""", @@ -326,7 +372,7 @@ class TestPostgres(Validator): """SELECT JSON_ARRAY_ELEMENTS((foo->'sections')::JSON) AS sections""", write={ "postgres": """SELECT JSON_ARRAY_ELEMENTS(CAST((foo -> 'sections') AS JSON)) AS sections""", - "presto": """SELECT JSON_ARRAY_ELEMENTS(CAST((JSON_EXTRACT(foo, 'sections')) AS JSON)) AS sections""", + "presto": """SELECT JSON_ARRAY_ELEMENTS(TRY_CAST((JSON_EXTRACT(foo, 'sections')) AS JSON)) AS sections""", }, ) self.validate_all( @@ -389,6 +435,36 @@ class TestPostgres(Validator): "spark": "TRIM(BOTH 'as' FROM 'as string as')", }, ) + self.validate_all( + "merge into x as x using (select id) as y on a = b WHEN matched then update set X.a = y.b", + write={ + "postgres": "MERGE INTO x AS x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", + "snowflake": "MERGE INTO x AS x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET X.a = y.b", + }, + ) + self.validate_all( + "merge into x as z using (select id) as y on a = b WHEN matched then update set X.a = y.b", + write={ + "postgres": "MERGE INTO x AS z USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", + "snowflake": "MERGE INTO x AS z USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET X.a = y.b", + }, + ) + self.validate_all( + "merge into x as z using (select id) as y on a = b WHEN matched then update set Z.a = y.b", + write={ + "postgres": "MERGE INTO x AS z USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", + "snowflake": "MERGE INTO x AS z USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET Z.a = y.b", + }, + ) + self.validate_all( + "merge into x using (select id) as y on a = b WHEN matched then update set x.a = y.b", + write={ + "postgres": "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", + "snowflake": "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b", + }, + ) + + self.assertIsInstance(parse_one("id::UUID", read="postgres"), exp.TryCast) def test_bool_or(self): self.validate_all( diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 1007899..3080476 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -60,7 +60,7 @@ class TestPresto(Validator): self.validate_all( "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", write={ - "bigquery": "CAST(x AS TIMESTAMPTZ)", + "bigquery": "CAST(x AS TIMESTAMP)", "duckdb": "CAST(x AS TIMESTAMPTZ(9))", "presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", "hive": "CAST(x AS TIMESTAMP)", @@ -106,7 +106,33 @@ class TestPresto(Validator): }, ) + def test_interval_plural_to_singular(self): + # Microseconds, weeks and quarters are not supported in Presto/Trino INTERVAL literals + unit_to_expected = { + "SeCoNds": "second", + "minutes": "minute", + "hours": "hour", + "days": "day", + "months": "month", + "years": "year", + } + + for unit, expected in unit_to_expected.items(): + self.validate_all( + f"SELECT INTERVAL '1' {unit}", + write={ + "bigquery": f"SELECT INTERVAL '1' {expected}", + "presto": f"SELECT INTERVAL '1' {expected}", + "trino": f"SELECT INTERVAL '1' {expected}", + }, + ) + def test_time(self): + self.validate_identity("FROM_UNIXTIME(a, b)") + self.validate_identity("FROM_UNIXTIME(a, b, c)") + self.validate_identity("TRIM(a, b)") + self.validate_identity("VAR_POP(a)") + self.validate_all( "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", write={ @@ -158,10 +184,6 @@ class TestPresto(Validator): "spark": "FROM_UNIXTIME(x)", }, ) - self.validate_identity("FROM_UNIXTIME(a, b)") - self.validate_identity("FROM_UNIXTIME(a, b, c)") - self.validate_identity("TRIM(a, b)") - self.validate_identity("VAR_POP(a)") self.validate_all( "TO_UNIXTIME(x)", write={ @@ -243,7 +265,7 @@ class TestPresto(Validator): }, ) self.validate_all( - "CREATE TABLE test STORED = 'PARQUET' AS SELECT 1", + "CREATE TABLE test STORED AS 'PARQUET' AS SELECT 1", write={ "duckdb": "CREATE TABLE test AS SELECT 1", "presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1", @@ -362,6 +384,14 @@ class TestPresto(Validator): }, ) + self.validate_all( + "SELECT a FROM x CROSS JOIN UNNEST(ARRAY(y)) AS t (a) CROSS JOIN b", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(ARRAY[y]) AS t(a) CROSS JOIN b", + "hive": "SELECT a FROM x CROSS JOIN b LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", + }, + ) + def test_presto(self): self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)") self.validate_identity("SELECT * FROM (VALUES (1))") @@ -369,6 +399,9 @@ class TestPresto(Validator): self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ") self.validate_identity("APPROX_PERCENTILE(a, b, c, d)") + self.validate_all("INTERVAL '1 day'", write={"trino": "INTERVAL '1' day"}) + self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' week"}) + self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' WEEKS"}) self.validate_all( "SELECT JSON_OBJECT(KEY 'key1' VALUE 1, KEY 'key2' VALUE TRUE)", write={ @@ -643,3 +676,41 @@ class TestPresto(Validator): "presto": "SELECT CAST(ARRAY[1, 23, 456] AS JSON)", }, ) + + def test_explode_to_unnest(self): + self.validate_all( + "SELECT col FROM tbl CROSS JOIN UNNEST(x) AS _u(col)", + read={"spark": "SELECT EXPLODE(x) FROM tbl"}, + ) + self.validate_all( + "SELECT col_2 FROM _u CROSS JOIN UNNEST(col) AS _u_2(col_2)", + read={"spark": "SELECT EXPLODE(col) FROM _u"}, + ) + self.validate_all( + "SELECT exploded FROM schema.tbl CROSS JOIN UNNEST(col) AS _u(exploded)", + read={"spark": "SELECT EXPLODE(col) AS exploded FROM schema.tbl"}, + ) + self.validate_all( + "SELECT col FROM UNNEST(SEQUENCE(1, 2)) AS _u(col)", + read={"spark": "SELECT EXPLODE(SEQUENCE(1, 2))"}, + ) + self.validate_all( + "SELECT col FROM tbl AS t CROSS JOIN UNNEST(t.c) AS _u(col)", + read={"spark": "SELECT EXPLODE(t.c) FROM tbl t"}, + ) + self.validate_all( + "SELECT pos, col FROM UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u(col, pos)", + read={"spark": "SELECT POSEXPLODE(SEQUENCE(2, 3))"}, + ) + self.validate_all( + "SELECT pos, col FROM tbl CROSS JOIN UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u(col, pos)", + read={"spark": "SELECT POSEXPLODE(SEQUENCE(2, 3)) FROM tbl"}, + ) + self.validate_all( + "SELECT pos, col FROM tbl AS t CROSS JOIN UNNEST(t.c) WITH ORDINALITY AS _u(col, pos)", + read={"spark": "SELECT POSEXPLODE(t.c) FROM tbl t"}, + ) + self.validate_all( + "SELECT col, pos, pos_2, col_2 FROM _u CROSS JOIN UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u_2(col_2, pos_2)", + read={"spark": "SELECT col, pos, POSEXPLODE(SEQUENCE(2, 3)) FROM _u"}, + ) diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 0933051..e5bd0e5 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -5,6 +5,44 @@ class TestRedshift(Validator): dialect = "redshift" def test_redshift(self): + self.validate_identity("SELECT * FROM #x") + self.validate_identity("SELECT INTERVAL '5 day'") + self.validate_identity("foo$") + self.validate_identity("$foo") + + self.validate_all( + "SELECT SNAPSHOT", + write={ + "": "SELECT SNAPSHOT", + "redshift": 'SELECT "SNAPSHOT"', + }, + ) + + self.validate_all( + "SELECT SYSDATE", + write={ + "": "SELECT CURRENT_TIMESTAMP()", + "postgres": "SELECT CURRENT_TIMESTAMP", + "redshift": "SELECT SYSDATE", + }, + ) + self.validate_all( + "SELECT DATE_PART(minute, timestamp '2023-01-04 04:05:06.789')", + write={ + "postgres": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", + "redshift": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", + "snowflake": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))", + }, + ) + self.validate_all( + "SELECT DATE_PART(month, date '20220502')", + write={ + "postgres": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + "redshift": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + "snowflake": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + }, + ) + self.validate_all("SELECT INTERVAL '5 day'", read={"": "SELECT INTERVAL '5' days"}) self.validate_all("CONVERT(INTEGER, x)", write={"redshift": "CAST(x AS INTEGER)"}) self.validate_all( "DATEADD('day', ndays, caldate)", write={"redshift": "DATEADD(day, ndays, caldate)"} @@ -27,7 +65,7 @@ class TestRedshift(Validator): "SELECT ST_AsEWKT(ST_GeomFromEWKT('SRID=4326;POINT(10 20)')::geography)", write={ "redshift": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", - "bigquery": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", + "bigquery": "SELECT ST_ASEWKT(TRY_CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index eb423a5..5c8b096 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -6,12 +6,16 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + self.validate_identity("OBJECT_CONSTRUCT(*)") + self.validate_identity("SELECT TO_DATE('2019-02-28') + INTERVAL '1 day, 1 year'") + self.validate_identity("SELECT CAST('2021-01-01' AS DATE) + INTERVAL '1 DAY'") self.validate_identity("SELECT HLL(*)") self.validate_identity("SELECT HLL(a)") self.validate_identity("SELECT HLL(DISTINCT t.a)") self.validate_identity("SELECT HLL(a, b, c)") self.validate_identity("SELECT HLL(DISTINCT a, b, c)") - self.validate_identity("$x") + self.validate_identity("$x") # parameter + self.validate_identity("a$b") # valid snowflake identifier self.validate_identity("SELECT REGEXP_LIKE(a, b, c)") self.validate_identity("PUT file:///dir/tmp.csv @%table") self.validate_identity("CREATE TABLE foo (bar FLOAT AUTOINCREMENT START 0 INCREMENT 1)") @@ -255,19 +259,18 @@ class TestSnowflake(Validator): write={ "bigquery": "SELECT PARSE_TIMESTAMP('%Y-%m-%d %H:%M:%S', '2013-04-05 01:02:03')", "snowflake": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-mm-dd hh24:mi:ss')", - "spark": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-MM-dd HH:mm:ss')", + "spark": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-MM-d HH:mm:ss')", }, ) self.validate_all( - "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", + "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/DD/yyyy hh24:mi:ss')", read={ "bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')", "duckdb": "SELECT STRPTIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", - "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", }, write={ "bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')", - "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", + "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/DD/yyyy hh24:mi:ss')", "spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')", }, ) @@ -841,11 +844,13 @@ MATCH_RECOGNIZE ( PARTITION BY a, b ORDER BY x DESC - MEASURES y AS b + MEASURES + y AS b {row} {after} PATTERN (^ S1 S2*? ( {{- S3 -}} S4 )+ | PERMUTE(S1, S2){{1,2}} $) - DEFINE x AS y + DEFINE + x AS y )""", pretty=True, ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 0da2931..bfaed53 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -215,6 +215,41 @@ TBLPROPERTIES ( self.validate_identity("SPLIT(str, pattern, lim)") self.validate_all( + "BOOLEAN(x)", + write={ + "": "CAST(x AS BOOLEAN)", + "spark": "CAST(x AS BOOLEAN)", + }, + ) + self.validate_all( + "INT(x)", + write={ + "": "CAST(x AS INT)", + "spark": "CAST(x AS INT)", + }, + ) + self.validate_all( + "STRING(x)", + write={ + "": "CAST(x AS TEXT)", + "spark": "CAST(x AS STRING)", + }, + ) + self.validate_all( + "DATE(x)", + write={ + "": "CAST(x AS DATE)", + "spark": "CAST(x AS DATE)", + }, + ) + self.validate_all( + "TIMESTAMP(x)", + write={ + "": "CAST(x AS TIMESTAMP)", + "spark": "CAST(x AS TIMESTAMP)", + }, + ) + self.validate_all( "CAST(x AS TIMESTAMP)", read={"trino": "CAST(x AS TIMESTAMP(6) WITH TIME ZONE)"} ) self.validate_all( diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py index 35d8b45..b33231c 100644 --- a/tests/dialects/test_starrocks.py +++ b/tests/dialects/test_starrocks.py @@ -6,6 +6,7 @@ class TestMySQL(Validator): def test_identity(self): self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") + self.validate_identity("SELECT APPROX_COUNT_DISTINCT(a) FROM x") def test_time(self): self.validate_identity("TIMESTAMP('2022-01-01')") diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 5d4f7db..dcb513d 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -39,6 +39,31 @@ class TestTeradata(Validator): write={"teradata": "CREATE OR REPLACE VIEW a AS (SELECT b FROM c)"}, ) + self.validate_all( + "CREATE VOLATILE TABLE a", + write={ + "teradata": "CREATE VOLATILE TABLE a", + "bigquery": "CREATE TABLE a", + "clickhouse": "CREATE TABLE a", + "databricks": "CREATE TABLE a", + "drill": "CREATE TABLE a", + "duckdb": "CREATE TABLE a", + "hive": "CREATE TABLE a", + "mysql": "CREATE TABLE a", + "oracle": "CREATE TABLE a", + "postgres": "CREATE TABLE a", + "presto": "CREATE TABLE a", + "redshift": "CREATE TABLE a", + "snowflake": "CREATE TABLE a", + "spark": "CREATE TABLE a", + "sqlite": "CREATE TABLE a", + "starrocks": "CREATE TABLE a", + "tableau": "CREATE TABLE a", + "trino": "CREATE TABLE a", + "tsql": "CREATE TABLE a", + }, + ) + def test_insert(self): self.validate_all( "INS INTO x SELECT * FROM y", write={"teradata": "INSERT INTO x SELECT * FROM y"} @@ -71,3 +96,15 @@ class TestTeradata(Validator): ) self.validate_identity("CREATE TABLE z (a SYSUDTLIB.INT)") + + def test_cast(self): + self.validate_all( + "CAST('1992-01' AS DATE FORMAT 'YYYY-DD')", + write={ + "teradata": "CAST('1992-01' AS DATE FORMAT 'YYYY-DD')", + "databricks": "DATE_FORMAT('1992-01', 'YYYY-DD')", + "mysql": "DATE_FORMAT('1992-01', 'YYYY-DD')", + "spark": "DATE_FORMAT('1992-01', 'YYYY-DD')", + "": "TIME_TO_STR('1992-01', 'YYYY-DD')", + }, + ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index d9ee4ae..b6e893c 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -66,6 +66,54 @@ class TestTSQL(Validator): "postgres": "STRING_AGG(x, '|')", }, ) + self.validate_all( + "SELECT CAST([a].[b] AS SMALLINT) FROM foo", + write={ + "tsql": 'SELECT CAST("a"."b" AS SMALLINT) FROM foo', + "spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo", + }, + ) + self.validate_all( + "HASHBYTES('SHA1', x)", + read={ + "spark": "SHA(x)", + }, + write={ + "tsql": "HASHBYTES('SHA1', x)", + "spark": "SHA(x)", + }, + ) + self.validate_all( + "HASHBYTES('SHA2_256', x)", + read={ + "spark": "SHA2(x, 256)", + }, + write={ + "tsql": "HASHBYTES('SHA2_256', x)", + "spark": "SHA2(x, 256)", + }, + ) + self.validate_all( + "HASHBYTES('SHA2_512', x)", + read={ + "spark": "SHA2(x, 512)", + }, + write={ + "tsql": "HASHBYTES('SHA2_512', x)", + "spark": "SHA2(x, 512)", + }, + ) + self.validate_all( + "HASHBYTES('MD5', 'x')", + read={ + "spark": "MD5('x')", + }, + write={ + "tsql": "HASHBYTES('MD5', 'x')", + "spark": "MD5('x')", + }, + ) + self.validate_identity("HASHBYTES('MD2', 'x')") def test_types(self): self.validate_identity("CAST(x AS XML)") @@ -399,7 +447,7 @@ WHERE self.validate_all( "SELECT CONVERT(VARCHAR(10), testdb.dbo.test.x, 120) y FROM testdb.dbo.test", write={ - "mysql": "SELECT CAST(TIME_TO_STR(testdb.dbo.test.x, '%Y-%m-%d %H:%M:%S') AS VARCHAR(10)) AS y FROM testdb.dbo.test", + "mysql": "SELECT CAST(DATE_FORMAT(testdb.dbo.test.x, '%Y-%m-%d %T') AS VARCHAR(10)) AS y FROM testdb.dbo.test", "spark": "SELECT CAST(DATE_FORMAT(testdb.dbo.test.x, 'yyyy-MM-dd HH:mm:ss') AS VARCHAR(10)) AS y FROM testdb.dbo.test", }, ) @@ -482,6 +530,12 @@ WHERE "spark": "SELECT x.a, x.b, t.v, t.y FROM x LEFT JOIN LATERAL (SELECT v, y FROM t) AS t(v, y)", }, ) + self.validate_all( + "SELECT x.a, x.b, t.v, t.y, s.v, s.y FROM x OUTER APPLY (SELECT v, y FROM t) t(v, y) OUTER APPLY (SELECT v, y FROM t) s(v, y) LEFT JOIN z ON z.id = s.id", + write={ + "spark": "SELECT x.a, x.b, t.v, t.y, s.v, s.y FROM x LEFT JOIN LATERAL (SELECT v, y FROM t) AS t(v, y) LEFT JOIN LATERAL (SELECT v, y FROM t) AS s(v, y) LEFT JOIN z ON z.id = s.id", + }, + ) def test_lateral_table_valued_function(self): self.validate_all( @@ -631,3 +685,38 @@ WHERE "SUSER_SNAME()", write={"spark": "CURRENT_USER()"}, ) + self.validate_all( + "SYSTEM_USER()", + write={"spark": "CURRENT_USER()"}, + ) + self.validate_all( + "SYSTEM_USER", + write={"spark": "CURRENT_USER()"}, + ) + + def test_hints(self): + self.validate_all( + "SELECT x FROM a INNER HASH JOIN b ON b.id = a.id", + write={"spark": "SELECT x FROM a INNER JOIN b ON b.id = a.id"}, + ) + self.validate_all( + "SELECT x FROM a INNER LOOP JOIN b ON b.id = a.id", + write={"spark": "SELECT x FROM a INNER JOIN b ON b.id = a.id"}, + ) + self.validate_all( + "SELECT x FROM a INNER REMOTE JOIN b ON b.id = a.id", + write={"spark": "SELECT x FROM a INNER JOIN b ON b.id = a.id"}, + ) + self.validate_all( + "SELECT x FROM a INNER MERGE JOIN b ON b.id = a.id", + write={"spark": "SELECT x FROM a INNER JOIN b ON b.id = a.id"}, + ) + self.validate_all( + "SELECT x FROM a WITH (NOLOCK)", + write={ + "spark": "SELECT x FROM a", + "tsql": "SELECT x FROM a WITH (NOLOCK)", + "": "SELECT x FROM a WITH (NOLOCK)", + }, + ) + self.validate_identity("SELECT x FROM a INNER LOOP JOIN b ON b.id = a.id") |