diff options
author | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
---|---|---|
committer | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
commit | 8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch) | |
tree | 6e2ebbf565b0351fd0f003f488a8339e771ad90c /tests/dialects | |
parent | Releasing debian version 19.0.1-1. (diff) | |
download | sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip |
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'tests/dialects')
-rw-r--r-- | tests/dialects/test_bigquery.py | 59 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 68 | ||||
-rw-r--r-- | tests/dialects/test_databricks.py | 2 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 219 | ||||
-rw-r--r-- | tests/dialects/test_duckdb.py | 124 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 34 | ||||
-rw-r--r-- | tests/dialects/test_mysql.py | 56 | ||||
-rw-r--r-- | tests/dialects/test_oracle.py | 18 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 86 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 62 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 103 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 122 | ||||
-rw-r--r-- | tests/dialects/test_spark.py | 45 | ||||
-rw-r--r-- | tests/dialects/test_teradata.py | 14 | ||||
-rw-r--r-- | tests/dialects/test_tsql.py | 191 |
15 files changed, 967 insertions, 236 deletions
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 3601e47..420803a 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -1,6 +1,14 @@ from unittest import mock -from sqlglot import ErrorLevel, ParseError, TokenError, UnsupportedError, transpile +from sqlglot import ( + ErrorLevel, + ParseError, + TokenError, + UnsupportedError, + parse, + transpile, +) +from sqlglot.helper import logger as helper_logger from tests.dialects.test_dialect import Validator @@ -9,6 +17,28 @@ class TestBigQuery(Validator): maxDiff = None def test_bigquery(self): + with self.assertLogs(helper_logger) as cm: + self.validate_all( + "SELECT a[1], b[OFFSET(1)], c[ORDINAL(1)], d[SAFE_OFFSET(1)], e[SAFE_ORDINAL(1)]", + write={ + "duckdb": "SELECT a[2], b[2], c[1], d[2], e[1]", + "bigquery": "SELECT a[1], b[OFFSET(1)], c[ORDINAL(1)], d[SAFE_OFFSET(1)], e[SAFE_ORDINAL(1)]", + "presto": "SELECT a[2], b[2], c[1], ELEMENT_AT(d, 2), ELEMENT_AT(e, 1)", + }, + ) + + self.validate_all( + "a[0]", + read={ + "duckdb": "a[1]", + "presto": "a[1]", + }, + ) + + self.validate_identity( + "select array_contains([1, 2, 3], 1)", + "SELECT EXISTS(SELECT 1 FROM UNNEST([1, 2, 3]) AS _col WHERE _col = 1)", + ) self.validate_identity("CREATE SCHEMA x DEFAULT COLLATE 'en'") self.validate_identity("CREATE TABLE x (y INT64) DEFAULT COLLATE 'en'") self.validate_identity("PARSE_JSON('{}', wide_number_mode => 'exact')") @@ -37,6 +67,15 @@ class TestBigQuery(Validator): with self.assertRaises(ParseError): transpile("DATE_ADD(x, day)", read="bigquery") + for_in_stmts = parse( + "FOR record IN (SELECT word FROM shakespeare) DO SELECT record.word; END FOR;", + read="bigquery", + ) + self.assertEqual( + [s.sql(dialect="bigquery") for s in for_in_stmts], + ["FOR record IN (SELECT word FROM shakespeare) DO SELECT record.word", "END FOR"], + ) + self.validate_identity("SELECT test.Unknown FROM test") self.validate_identity(r"SELECT '\n\r\a\v\f\t'") self.validate_identity("SELECT * FROM tbl FOR SYSTEM_TIME AS OF z") @@ -89,6 +128,11 @@ class TestBigQuery(Validator): self.validate_identity("ROLLBACK TRANSACTION") self.validate_identity("CAST(x AS BIGNUMERIC)") self.validate_identity("SELECT y + 1 FROM x GROUP BY y + 1 ORDER BY 1") + self.validate_identity("SELECT TIMESTAMP_SECONDS(2) AS t") + self.validate_identity("SELECT TIMESTAMP_MILLIS(2) AS t") + self.validate_identity( + "FOR record IN (SELECT word, word_count FROM bigquery-public-data.samples.shakespeare LIMIT 5) DO SELECT record.word, record.word_count" + ) self.validate_identity( "DATE(CAST('2016-12-25 05:30:00+07' AS DATETIME), 'America/Los_Angeles')" ) @@ -143,6 +187,19 @@ class TestBigQuery(Validator): self.validate_all("x <> ''''''", write={"bigquery": "x <> ''"}) self.validate_all("CAST(x AS DATETIME)", read={"": "x::timestamp"}) self.validate_all( + "SELECT TIMESTAMP_MICROS(x)", + read={ + "duckdb": "SELECT MAKE_TIMESTAMP(x)", + "spark": "SELECT TIMESTAMP_MICROS(x)", + }, + write={ + "bigquery": "SELECT TIMESTAMP_MICROS(x)", + "duckdb": "SELECT MAKE_TIMESTAMP(x)", + "snowflake": "SELECT TO_TIMESTAMP(x / 1000, 3)", + "spark": "SELECT TIMESTAMP_MICROS(x)", + }, + ) + self.validate_all( "SELECT * FROM t WHERE EXISTS(SELECT * FROM unnest(nums) AS x WHERE x > 1)", write={ "bigquery": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS x WHERE x > 1)", diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 93d1ced..86ddb00 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -6,22 +6,6 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): - self.validate_identity("x <> y") - - self.validate_all( - "has([1], x)", - read={ - "postgres": "x = any(array[1])", - }, - ) - self.validate_all( - "NOT has([1], x)", - read={ - "postgres": "any(array[1]) <> x", - }, - ) - self.validate_identity("x = y") - string_types = [ "BLOB", "LONGBLOB", @@ -40,6 +24,8 @@ class TestClickhouse(Validator): self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)") self.assertIsNone(expr._meta) + self.validate_identity("x = y") + self.validate_identity("x <> y") self.validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 0.01)") self.validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 1 / 10 OFFSET 1 / 2)") self.validate_identity("SELECT sum(foo * bar) FROM bla SAMPLE 10000000") @@ -81,7 +67,17 @@ class TestClickhouse(Validator): self.validate_identity("position(haystack, needle, position)") self.validate_identity("CAST(x AS DATETIME)") self.validate_identity("CAST(x as MEDIUMINT)", "CAST(x AS Int32)") - + self.validate_identity("SELECT arrayJoin([1, 2, 3] AS src) AS dst, 'Hello', src") + self.validate_identity( + "SELECT SUM(1) AS impressions, arrayJoin(cities) AS city, arrayJoin(browsers) AS browser FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities, ['Firefox', 'Chrome', 'Chrome'] AS browsers) GROUP BY 2, 3" + ) + self.validate_identity( + "SELECT sum(1) AS impressions, (arrayJoin(arrayZip(cities, browsers)) AS t).1 AS city, t.2 AS browser FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities, ['Firefox', 'Chrome', 'Chrome'] AS browsers) GROUP BY 2, 3" + ) + self.validate_identity( + "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ['Istanbul', 'Berlin']", + "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ('Istanbul', 'Berlin')", + ) self.validate_identity( 'SELECT CAST(tuple(1 AS "a", 2 AS "b", 3.0 AS "c").2 AS Nullable(String))' ) @@ -102,6 +98,25 @@ class TestClickhouse(Validator): ) self.validate_all( + "SELECT arrayJoin([1,2,3])", + write={ + "clickhouse": "SELECT arrayJoin([1, 2, 3])", + "postgres": "SELECT UNNEST(ARRAY[1, 2, 3])", + }, + ) + self.validate_all( + "has([1], x)", + read={ + "postgres": "x = any(array[1])", + }, + ) + self.validate_all( + "NOT has([1], x)", + read={ + "postgres": "any(array[1]) <> x", + }, + ) + self.validate_all( "SELECT CAST('2020-01-01' AS TIMESTAMP) + INTERVAL '500' microsecond", read={ "duckdb": "SELECT TIMESTAMP '2020-01-01' + INTERVAL '500 us'", @@ -197,12 +212,15 @@ class TestClickhouse(Validator): }, ) self.validate_all( - "CONCAT(CASE WHEN COALESCE(CAST(a AS String), '') IS NULL THEN COALESCE(CAST(a AS String), '') ELSE CAST(COALESCE(CAST(a AS String), '') AS String) END, CASE WHEN COALESCE(CAST(b AS String), '') IS NULL THEN COALESCE(CAST(b AS String), '') ELSE CAST(COALESCE(CAST(b AS String), '') AS String) END)", - read={"postgres": "CONCAT(a, b)"}, - ) - self.validate_all( - "CONCAT(CASE WHEN a IS NULL THEN a ELSE CAST(a AS String) END, CASE WHEN b IS NULL THEN b ELSE CAST(b AS String) END)", - read={"mysql": "CONCAT(a, b)"}, + "CONCAT(a, b)", + read={ + "clickhouse": "CONCAT(a, b)", + "mysql": "CONCAT(a, b)", + }, + write={ + "mysql": "CONCAT(a, b)", + "postgres": "CONCAT(a, b)", + }, ) self.validate_all( r"'Enum8(\'Sunday\' = 0)'", write={"clickhouse": "'Enum8(''Sunday'' = 0)'"} @@ -320,6 +338,10 @@ class TestClickhouse(Validator): self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5") self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1") + query = parse_one("""WITH (SELECT 1) AS y SELECT * FROM y""", read="clickhouse") + self.assertIsInstance(query.args["with"].expressions[0].this, exp.Subquery) + self.assertEqual(query.args["with"].expressions[0].alias, "y") + def test_ternary(self): self.validate_all("x ? 1 : 2", write={"clickhouse": "CASE WHEN x THEN 1 ELSE 2 END"}) self.validate_all( diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 8bb88b3..7c13e79 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -131,7 +131,7 @@ class TestDatabricks(Validator): "SELECT DATEDIFF(week, 'start', 'end')", write={ "databricks": "SELECT DATEDIFF(week, 'start', 'end')", - "postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 48 + EXTRACT(month FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 4 + EXTRACT(day FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) / 7 AS BIGINT)", + "postgres": "SELECT CAST(EXTRACT(days FROM (CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP))) / 7 AS BIGINT)", }, ) self.validate_all( diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 2546c98..49afc62 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -7,9 +7,10 @@ from sqlglot import ( ParseError, TokenError, UnsupportedError, + exp, parse_one, ) -from sqlglot.dialects import Hive +from sqlglot.dialects import BigQuery, Hive, Snowflake class Validator(unittest.TestCase): @@ -78,9 +79,56 @@ class TestDialect(Validator): self.assertIsNotNone(Dialect[dialect.value]) def test_get_or_raise(self): - self.assertEqual(Dialect.get_or_raise(Hive), Hive) - self.assertEqual(Dialect.get_or_raise(Hive()), Hive) - self.assertEqual(Dialect.get_or_raise("hive"), Hive) + self.assertIsInstance(Dialect.get_or_raise(Hive), Hive) + self.assertIsInstance(Dialect.get_or_raise(Hive()), Hive) + self.assertIsInstance(Dialect.get_or_raise("hive"), Hive) + + with self.assertRaises(ValueError): + Dialect.get_or_raise(1) + + default_mysql = Dialect.get_or_raise("mysql") + self.assertEqual(default_mysql.normalization_strategy, "CASE_SENSITIVE") + + lowercase_mysql = Dialect.get_or_raise("mysql,normalization_strategy=lowercase") + self.assertEqual(lowercase_mysql.normalization_strategy, "LOWERCASE") + + lowercase_mysql = Dialect.get_or_raise("mysql, normalization_strategy = lowercase") + self.assertEqual(lowercase_mysql.normalization_strategy.value, "LOWERCASE") + + with self.assertRaises(ValueError) as cm: + Dialect.get_or_raise("mysql, normalization_strategy") + + self.assertEqual( + str(cm.exception), + "Invalid dialect format: 'mysql, normalization_strategy'. " + "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'.", + ) + + def test_compare_dialects(self): + bigquery_class = Dialect["bigquery"] + bigquery_object = BigQuery() + bigquery_string = "bigquery" + + snowflake_class = Dialect["snowflake"] + snowflake_object = Snowflake() + snowflake_string = "snowflake" + + self.assertEqual(snowflake_class, snowflake_class) + self.assertEqual(snowflake_class, snowflake_object) + self.assertEqual(snowflake_class, snowflake_string) + self.assertEqual(snowflake_object, snowflake_object) + self.assertEqual(snowflake_object, snowflake_string) + + self.assertNotEqual(snowflake_class, bigquery_class) + self.assertNotEqual(snowflake_class, bigquery_object) + self.assertNotEqual(snowflake_class, bigquery_string) + self.assertNotEqual(snowflake_object, bigquery_object) + self.assertNotEqual(snowflake_object, bigquery_string) + + self.assertTrue(snowflake_class in {"snowflake", "bigquery"}) + self.assertTrue(snowflake_object in {"snowflake", "bigquery"}) + self.assertFalse(snowflake_class in {"bigquery", "redshift"}) + self.assertFalse(snowflake_object in {"bigquery", "redshift"}) def test_cast(self): self.validate_all( @@ -561,6 +609,7 @@ class TestDialect(Validator): self.validate_all( "TIME_TO_STR(x, '%Y-%m-%d')", write={ + "bigquery": "FORMAT_DATE('%Y-%m-%d', x)", "drill": "TO_CHAR(x, 'yyyy-MM-dd')", "duckdb": "STRFTIME(x, '%Y-%m-%d')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", @@ -866,9 +915,9 @@ class TestDialect(Validator): write={ "drill": "CAST(x AS DATE)", "duckdb": "CAST(x AS DATE)", - "hive": "TO_DATE(x)", - "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", - "spark": "TO_DATE(x)", + "hive": "CAST(x AS DATE)", + "presto": "CAST(x AS DATE)", + "spark": "CAST(x AS DATE)", "sqlite": "x", }, ) @@ -893,7 +942,7 @@ class TestDialect(Validator): self.validate_all( "TS_OR_DS_ADD(CURRENT_DATE, 1, 'DAY')", write={ - "presto": "DATE_ADD('DAY', 1, CURRENT_DATE)", + "presto": "DATE_ADD('DAY', 1, CAST(CAST(CURRENT_DATE AS TIMESTAMP) AS DATE))", "hive": "DATE_ADD(CURRENT_DATE, 1)", }, ) @@ -1269,13 +1318,6 @@ class TestDialect(Validator): }, ) self.validate_all( - "SELECT * FROM a ORDER BY col_a NULLS LAST", - write={ - "mysql": UnsupportedError, - "starrocks": UnsupportedError, - }, - ) - self.validate_all( "POSITION(needle in haystack)", write={ "drill": "STRPOS(haystack, needle)", @@ -1315,35 +1357,37 @@ class TestDialect(Validator): self.validate_all( "CONCAT_WS('-', 'a', 'b')", write={ + "clickhouse": "CONCAT_WS('-', 'a', 'b')", "duckdb": "CONCAT_WS('-', 'a', 'b')", - "presto": "CONCAT_WS('-', 'a', 'b')", + "presto": "CONCAT_WS('-', CAST('a' AS VARCHAR), CAST('b' AS VARCHAR))", "hive": "CONCAT_WS('-', 'a', 'b')", "spark": "CONCAT_WS('-', 'a', 'b')", - "trino": "CONCAT_WS('-', 'a', 'b')", + "trino": "CONCAT_WS('-', CAST('a' AS VARCHAR), CAST('b' AS VARCHAR))", }, ) self.validate_all( "CONCAT_WS('-', x)", write={ + "clickhouse": "CONCAT_WS('-', x)", "duckdb": "CONCAT_WS('-', x)", "hive": "CONCAT_WS('-', x)", - "presto": "CONCAT_WS('-', x)", + "presto": "CONCAT_WS('-', CAST(x AS VARCHAR))", "spark": "CONCAT_WS('-', x)", - "trino": "CONCAT_WS('-', x)", + "trino": "CONCAT_WS('-', CAST(x AS VARCHAR))", }, ) self.validate_all( "CONCAT(a)", write={ - "clickhouse": "a", - "presto": "a", - "trino": "a", + "clickhouse": "CONCAT(a)", + "presto": "CAST(a AS VARCHAR)", + "trino": "CAST(a AS VARCHAR)", "tsql": "a", }, ) self.validate_all( - "COALESCE(CAST(a AS TEXT), '')", + "CONCAT(COALESCE(a, ''))", read={ "drill": "CONCAT(a)", "duckdb": "CONCAT(a)", @@ -1442,6 +1486,76 @@ class TestDialect(Validator): "spark": "FILTER(the_array, x -> x > 0)", }, ) + self.validate_all( + "a / b", + write={ + "bigquery": "a / b", + "clickhouse": "a / b", + "databricks": "a / b", + "duckdb": "a / b", + "hive": "a / b", + "mysql": "a / b", + "oracle": "a / b", + "snowflake": "a / b", + "spark": "a / b", + "starrocks": "a / b", + "drill": "CAST(a AS DOUBLE) / b", + "postgres": "CAST(a AS DOUBLE PRECISION) / b", + "presto": "CAST(a AS DOUBLE) / b", + "redshift": "CAST(a AS DOUBLE PRECISION) / b", + "sqlite": "CAST(a AS REAL) / b", + "teradata": "CAST(a AS DOUBLE) / b", + "trino": "CAST(a AS DOUBLE) / b", + "tsql": "CAST(a AS FLOAT) / b", + }, + ) + + def test_typeddiv(self): + typed_div = exp.Div(this=exp.column("a"), expression=exp.column("b"), typed=True) + div = exp.Div(this=exp.column("a"), expression=exp.column("b")) + typed_div_dialect = "presto" + div_dialect = "hive" + INT = exp.DataType.Type.INT + FLOAT = exp.DataType.Type.FLOAT + + for expression, types, dialect, expected in [ + (typed_div, (None, None), typed_div_dialect, "a / b"), + (typed_div, (None, None), div_dialect, "a / b"), + (div, (None, None), typed_div_dialect, "CAST(a AS DOUBLE) / b"), + (div, (None, None), div_dialect, "a / b"), + (typed_div, (INT, INT), typed_div_dialect, "a / b"), + (typed_div, (INT, INT), div_dialect, "CAST(a / b AS BIGINT)"), + (div, (INT, INT), typed_div_dialect, "CAST(a AS DOUBLE) / b"), + (div, (INT, INT), div_dialect, "a / b"), + (typed_div, (FLOAT, FLOAT), typed_div_dialect, "a / b"), + (typed_div, (FLOAT, FLOAT), div_dialect, "a / b"), + (div, (FLOAT, FLOAT), typed_div_dialect, "a / b"), + (div, (FLOAT, FLOAT), div_dialect, "a / b"), + (typed_div, (INT, FLOAT), typed_div_dialect, "a / b"), + (typed_div, (INT, FLOAT), div_dialect, "a / b"), + (div, (INT, FLOAT), typed_div_dialect, "a / b"), + (div, (INT, FLOAT), div_dialect, "a / b"), + ]: + with self.subTest(f"{expression.__class__.__name__} {types} {dialect} -> {expected}"): + expression = expression.copy() + expression.left.type = types[0] + expression.right.type = types[1] + self.assertEqual(expected, expression.sql(dialect=dialect)) + + def test_safediv(self): + safe_div = exp.Div(this=exp.column("a"), expression=exp.column("b"), safe=True) + div = exp.Div(this=exp.column("a"), expression=exp.column("b")) + safe_div_dialect = "mysql" + div_dialect = "snowflake" + + for expression, dialect, expected in [ + (safe_div, safe_div_dialect, "a / b"), + (safe_div, div_dialect, "a / NULLIF(b, 0)"), + (div, safe_div_dialect, "a / b"), + (div, div_dialect, "a / b"), + ]: + with self.subTest(f"{expression.__class__.__name__} {dialect} -> {expected}"): + self.assertEqual(expected, expression.sql(dialect=dialect)) def test_limit(self): self.validate_all( @@ -1547,7 +1661,7 @@ class TestDialect(Validator): "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 TEXT, c2 TEXT(1024))", write={ "duckdb": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))", - "hive": "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 STRING, c2 STRING(1024))", + "hive": "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 STRING, c2 VARCHAR(1024))", "oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))", "postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))", "sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))", @@ -1864,7 +1978,7 @@ SELECT write={ "bigquery": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "clickhouse": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", - "databricks": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", + "databricks": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", "duckdb": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "hive": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", "mysql": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", @@ -1872,11 +1986,11 @@ SELECT "presto": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "redshift": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "snowflake": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", - "spark": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", + "spark": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", "spark2": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", "sqlite": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "trino": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", - "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", + "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c AS c FROM t) AS subq", }, ) self.validate_all( @@ -1885,13 +1999,60 @@ SELECT "bigquery": "SELECT * FROM (SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq1) AS subq2", "duckdb": "SELECT * FROM (SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq1) AS subq2", "hive": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c FROM t) AS subq1) AS subq2", - "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c FROM t) AS subq1) AS subq2", + "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c AS c FROM t) AS subq1) AS subq2", }, ) self.validate_all( "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq", write={ "duckdb": "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq", - "tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y FROM t2) AS subq", + "tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y AS y FROM t2) AS subq", }, ) + + def test_unsupported_null_ordering(self): + # We'll transpile a portable query from the following dialects to MySQL / T-SQL, which + # both treat NULLs as small values, so the expected output queries should be equivalent + with_last_nulls = "duckdb" + with_small_nulls = "spark" + with_large_nulls = "postgres" + + sql = "SELECT * FROM t ORDER BY c" + sql_nulls_last = "SELECT * FROM t ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END, c" + sql_nulls_first = "SELECT * FROM t ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c" + + for read_dialect, desc, nulls_first, expected_sql in ( + (with_last_nulls, False, None, sql_nulls_last), + (with_last_nulls, True, None, sql), + (with_last_nulls, False, True, sql), + (with_last_nulls, True, True, sql_nulls_first), + (with_last_nulls, False, False, sql_nulls_last), + (with_last_nulls, True, False, sql), + (with_small_nulls, False, None, sql), + (with_small_nulls, True, None, sql), + (with_small_nulls, False, True, sql), + (with_small_nulls, True, True, sql_nulls_first), + (with_small_nulls, False, False, sql_nulls_last), + (with_small_nulls, True, False, sql), + (with_large_nulls, False, None, sql_nulls_last), + (with_large_nulls, True, None, sql_nulls_first), + (with_large_nulls, False, True, sql), + (with_large_nulls, True, True, sql_nulls_first), + (with_large_nulls, False, False, sql_nulls_last), + (with_large_nulls, True, False, sql), + ): + with self.subTest( + f"read: {read_dialect}, descending: {desc}, nulls first: {nulls_first}" + ): + sort_order = " DESC" if desc else "" + null_order = ( + " NULLS FIRST" + if nulls_first + else (" NULLS LAST" if nulls_first is not None else "") + ) + + expected_sql = f"{expected_sql}{sort_order}" + expression = parse_one(f"{sql}{sort_order}{null_order}", read=read_dialect) + + self.assertEqual(expression.sql(dialect="mysql"), expected_sql) + self.assertEqual(expression.sql(dialect="tsql"), expected_sql) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index f9de953..687a807 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -1,4 +1,5 @@ from sqlglot import ErrorLevel, UnsupportedError, exp, parse_one, transpile +from sqlglot.helper import logger as helper_logger from tests.dialects.test_dialect import Validator @@ -71,7 +72,7 @@ class TestDuckDB(Validator): "SELECT UNNEST(ARRAY[1, 2, 3]), UNNEST(ARRAY[4, 5]), UNNEST(ARRAY[6])", write={ "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_3, col_2, NULL) AS col_2, IF(pos = pos_4, col_3, NULL) AS col_3 FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2, 3]), ARRAY_LENGTH([4, 5]), ARRAY_LENGTH([6])) - 1)) AS pos CROSS JOIN UNNEST([1, 2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5]) AS col_2 WITH OFFSET AS pos_3 CROSS JOIN UNNEST([6]) AS col_3 WITH OFFSET AS pos_4 WHERE ((pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5]) - 1)))) AND (pos = pos_4 OR (pos > (ARRAY_LENGTH([6]) - 1) AND pos_4 = (ARRAY_LENGTH([6]) - 1)))", - "presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_3, col_2) AS col_2, IF(pos = pos_4, col_3) AS col_3 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2, 3]) AND pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5]) AND pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (pos = pos_4 OR (pos > CARDINALITY(ARRAY[6]) AND pos_4 = CARDINALITY(ARRAY[6])))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col, IF(_u.pos = _u_3.pos_3, _u_3.col_2) AS col_2, IF(_u.pos = _u_4.pos_4, _u_4.col_3) AS col_3 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((_u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[1, 2, 3]) AND _u_2.pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (_u.pos = _u_3.pos_3 OR (_u.pos > CARDINALITY(ARRAY[4, 5]) AND _u_3.pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (_u.pos = _u_4.pos_4 OR (_u.pos > CARDINALITY(ARRAY[6]) AND _u_4.pos_4 = CARDINALITY(ARRAY[6])))", }, ) @@ -79,7 +80,7 @@ class TestDuckDB(Validator): "SELECT UNNEST(ARRAY[1, 2, 3]), UNNEST(ARRAY[4, 5]), UNNEST(ARRAY[6]) FROM x", write={ "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_3, col_2, NULL) AS col_2, IF(pos = pos_4, col_3, NULL) AS col_3 FROM x, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2, 3]), ARRAY_LENGTH([4, 5]), ARRAY_LENGTH([6])) - 1)) AS pos CROSS JOIN UNNEST([1, 2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5]) AS col_2 WITH OFFSET AS pos_3 CROSS JOIN UNNEST([6]) AS col_3 WITH OFFSET AS pos_4 WHERE ((pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5]) - 1)))) AND (pos = pos_4 OR (pos > (ARRAY_LENGTH([6]) - 1) AND pos_4 = (ARRAY_LENGTH([6]) - 1)))", - "presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_3, col_2) AS col_2, IF(pos = pos_4, col_3) AS col_3 FROM x, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2, 3]) AND pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5]) AND pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (pos = pos_4 OR (pos > CARDINALITY(ARRAY[6]) AND pos_4 = CARDINALITY(ARRAY[6])))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col, IF(_u.pos = _u_3.pos_3, _u_3.col_2) AS col_2, IF(_u.pos = _u_4.pos_4, _u_4.col_3) AS col_3 FROM x, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((_u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[1, 2, 3]) AND _u_2.pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (_u.pos = _u_3.pos_3 OR (_u.pos > CARDINALITY(ARRAY[4, 5]) AND _u_3.pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (_u.pos = _u_4.pos_4 OR (_u.pos > CARDINALITY(ARRAY[6]) AND _u_4.pos_4 = CARDINALITY(ARRAY[6])))", }, ) self.validate_all( @@ -96,7 +97,6 @@ class TestDuckDB(Validator): ) self.validate_identity("SELECT i FROM RANGE(5) AS _(i) ORDER BY i ASC") - self.validate_identity("[x.STRING_SPLIT(' ')[1] FOR x IN ['1', '2', 3] IF x.CONTAINS('1')]") self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y") self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x") self.validate_identity("SELECT SUM(x) FILTER (x = 1)", "SELECT SUM(x) FILTER(WHERE x = 1)") @@ -109,6 +109,10 @@ class TestDuckDB(Validator): parse_one("a // b", read="duckdb").assert_is(exp.IntDiv).sql(dialect="duckdb"), "a // b" ) + self.validate_identity("SELECT EPOCH_MS(10) AS t") + self.validate_identity("SELECT MAKE_TIMESTAMP(10) AS t") + self.validate_identity("SELECT TO_TIMESTAMP(10) AS t") + self.validate_identity("SELECT UNNEST(column, recursive := TRUE) FROM table") self.validate_identity("VAR_POP(a)") self.validate_identity("SELECT * FROM foo ASOF LEFT JOIN bar ON a = b") self.validate_identity("PIVOT Cities ON Year USING SUM(Population)") @@ -152,10 +156,17 @@ class TestDuckDB(Validator): self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"}) self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'}) self.validate_all( + "SELECT * FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2'))", + read={ + "duckdb": "SELECT * FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2'))", + "snowflake": "SELECT * FROM produce PIVOT(SUM(produce.sales) FOR produce.quarter IN ('Q1', 'Q2'))", + }, + ) + self.validate_all( "SELECT UNNEST([1, 2, 3])", write={ "duckdb": "SELECT UNNEST([1, 2, 3])", - "snowflake": "SELECT IFF(pos = pos_2, col, NULL) AS col FROM (SELECT value FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (GREATEST(ARRAY_SIZE([1, 2, 3])) - 1) + 1)))) AS _u(pos) CROSS JOIN (SELECT value, index FROM TABLE(FLATTEN(INPUT => [1, 2, 3]))) AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > (ARRAY_SIZE([1, 2, 3]) - 1) AND pos_2 = (ARRAY_SIZE([1, 2, 3]) - 1))", + "snowflake": "SELECT IFF(_u.pos = _u_2.pos_2, _u_2.col, NULL) AS col FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (GREATEST(ARRAY_SIZE([1, 2, 3])) - 1) + 1))) AS _u(seq, key, path, index, pos, this) CROSS JOIN TABLE(FLATTEN(INPUT => [1, 2, 3])) AS _u_2(seq, key, path, pos_2, col, this) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > (ARRAY_SIZE([1, 2, 3]) - 1) AND _u_2.pos_2 = (ARRAY_SIZE([1, 2, 3]) - 1))", }, ) self.validate_all( @@ -355,14 +366,14 @@ class TestDuckDB(Validator): "STRUCT_PACK(x := 1, y := '2')", write={ "duckdb": "{'x': 1, 'y': '2'}", - "spark": "STRUCT(x = 1, y = '2')", + "spark": "STRUCT(1 AS x, '2' AS y)", }, ) self.validate_all( "STRUCT_PACK(key1 := 'value1', key2 := 42)", write={ "duckdb": "{'key1': 'value1', 'key2': 42}", - "spark": "STRUCT(key1 = 'value1', key2 = 42)", + "spark": "STRUCT('value1' AS key1, 42 AS key2)", }, ) self.validate_all( @@ -441,6 +452,16 @@ class TestDuckDB(Validator): }, ) self.validate_all( + "SELECT CAST('2018-01-01 00:00:00' AS DATE) + INTERVAL 3 DAY", + read={ + "hive": "SELECT DATE_ADD('2018-01-01 00:00:00', 3)", + }, + write={ + "duckdb": "SELECT CAST('2018-01-01 00:00:00' AS DATE) + INTERVAL '3' DAY", + "hive": "SELECT CAST('2018-01-01 00:00:00' AS DATE) + INTERVAL '3' DAY", + }, + ) + self.validate_all( "SELECT CAST('2020-05-06' AS DATE) - INTERVAL 5 DAY", read={"bigquery": "SELECT DATE_SUB(CAST('2020-05-06' AS DATE), INTERVAL 5 DAY)"}, ) @@ -483,6 +504,35 @@ class TestDuckDB(Validator): self.validate_identity("SELECT ISNAN(x)") + def test_array_index(self): + with self.assertLogs(helper_logger) as cm: + self.validate_all( + "SELECT some_arr[1] AS first FROM blah", + read={ + "bigquery": "SELECT some_arr[0] AS first FROM blah", + }, + write={ + "bigquery": "SELECT some_arr[0] AS first FROM blah", + "duckdb": "SELECT some_arr[1] AS first FROM blah", + "presto": "SELECT some_arr[1] AS first FROM blah", + }, + ) + self.validate_identity( + "[x.STRING_SPLIT(' ')[1] FOR x IN ['1', '2', 3] IF x.CONTAINS('1')]" + ) + + self.assertEqual( + cm.output, + [ + "WARNING:sqlglot:Applying array index offset (-1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (-1)", + "WARNING:sqlglot:Applying array index offset (1)", + ], + ) + def test_time(self): self.validate_identity("SELECT CURRENT_DATE") self.validate_identity("SELECT CURRENT_TIMESTAMP") @@ -533,16 +583,16 @@ class TestDuckDB(Validator): self.validate_all( "EPOCH_MS(x)", write={ - "bigquery": "UNIX_TO_TIME(x / 1000)", - "duckdb": "TO_TIMESTAMP(x / 1000)", - "presto": "FROM_UNIXTIME(x / 1000)", - "spark": "CAST(FROM_UNIXTIME(x / 1000) AS TIMESTAMP)", + "bigquery": "TIMESTAMP_MILLIS(x)", + "duckdb": "EPOCH_MS(x)", + "presto": "FROM_UNIXTIME(CAST(x AS DOUBLE) / 1000)", + "spark": "TIMESTAMP_MILLIS(x)", }, ) self.validate_all( "STRFTIME(x, '%y-%-m-%S')", write={ - "bigquery": "TIME_TO_STR(x, '%y-%-m-%S')", + "bigquery": "FORMAT_DATE('%y-%-m-%S', x)", "duckdb": "STRFTIME(x, '%y-%-m-%S')", "postgres": "TO_CHAR(x, 'YY-FMMM-SS')", "presto": "DATE_FORMAT(x, '%y-%c-%s')", @@ -552,6 +602,7 @@ class TestDuckDB(Validator): self.validate_all( "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", write={ + "bigquery": "FORMAT_DATE('%Y-%m-%d %H:%M:%S', x)", "duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", "presto": "DATE_FORMAT(x, '%Y-%m-%d %T')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", @@ -570,7 +621,7 @@ class TestDuckDB(Validator): self.validate_all( "TO_TIMESTAMP(x)", write={ - "bigquery": "UNIX_TO_TIME(x)", + "bigquery": "TIMESTAMP_SECONDS(x)", "duckdb": "TO_TIMESTAMP(x)", "presto": "FROM_UNIXTIME(x)", "hive": "FROM_UNIXTIME(x)", @@ -651,22 +702,25 @@ class TestDuckDB(Validator): "CAST(ROW(1, ROW(1)) AS STRUCT(number BIGINT, row STRUCT(number BIGINT)))" ) - self.validate_all("CAST(x AS NUMERIC(1, 2))", write={"duckdb": "CAST(x AS DECIMAL(1, 2))"}) - self.validate_all("CAST(x AS HUGEINT)", write={"duckdb": "CAST(x AS INT128)"}) - self.validate_all("CAST(x AS CHAR)", write={"duckdb": "CAST(x AS TEXT)"}) - self.validate_all("CAST(x AS BPCHAR)", write={"duckdb": "CAST(x AS TEXT)"}) - self.validate_all("CAST(x AS STRING)", write={"duckdb": "CAST(x AS TEXT)"}) - self.validate_all("CAST(x AS INT1)", write={"duckdb": "CAST(x AS TINYINT)"}) - self.validate_all("CAST(x AS FLOAT4)", write={"duckdb": "CAST(x AS REAL)"}) - self.validate_all("CAST(x AS FLOAT)", write={"duckdb": "CAST(x AS REAL)"}) - self.validate_all("CAST(x AS INT4)", write={"duckdb": "CAST(x AS INT)"}) - self.validate_all("CAST(x AS INTEGER)", write={"duckdb": "CAST(x AS INT)"}) - self.validate_all("CAST(x AS SIGNED)", write={"duckdb": "CAST(x AS INT)"}) - self.validate_all("CAST(x AS BLOB)", write={"duckdb": "CAST(x AS BLOB)"}) - self.validate_all("CAST(x AS BYTEA)", write={"duckdb": "CAST(x AS BLOB)"}) - self.validate_all("CAST(x AS BINARY)", write={"duckdb": "CAST(x AS BLOB)"}) - self.validate_all("CAST(x AS VARBINARY)", write={"duckdb": "CAST(x AS BLOB)"}) - self.validate_all("CAST(x AS LOGICAL)", write={"duckdb": "CAST(x AS BOOLEAN)"}) + self.validate_identity("CAST(x AS INT64)", "CAST(x AS BIGINT)") + self.validate_identity("CAST(x AS INT32)", "CAST(x AS INT)") + self.validate_identity("CAST(x AS INT16)", "CAST(x AS SMALLINT)") + self.validate_identity("CAST(x AS NUMERIC(1, 2))", "CAST(x AS DECIMAL(1, 2))") + self.validate_identity("CAST(x AS HUGEINT)", "CAST(x AS INT128)") + self.validate_identity("CAST(x AS CHAR)", "CAST(x AS TEXT)") + self.validate_identity("CAST(x AS BPCHAR)", "CAST(x AS TEXT)") + self.validate_identity("CAST(x AS STRING)", "CAST(x AS TEXT)") + self.validate_identity("CAST(x AS INT1)", "CAST(x AS TINYINT)") + self.validate_identity("CAST(x AS FLOAT4)", "CAST(x AS REAL)") + self.validate_identity("CAST(x AS FLOAT)", "CAST(x AS REAL)") + self.validate_identity("CAST(x AS INT4)", "CAST(x AS INT)") + self.validate_identity("CAST(x AS INTEGER)", "CAST(x AS INT)") + self.validate_identity("CAST(x AS SIGNED)", "CAST(x AS INT)") + self.validate_identity("CAST(x AS BLOB)", "CAST(x AS BLOB)") + self.validate_identity("CAST(x AS BYTEA)", "CAST(x AS BLOB)") + self.validate_identity("CAST(x AS BINARY)", "CAST(x AS BLOB)") + self.validate_identity("CAST(x AS VARBINARY)", "CAST(x AS BLOB)") + self.validate_identity("CAST(x AS LOGICAL)", "CAST(x AS BOOLEAN)") self.validate_all( "CAST(x AS NUMERIC)", write={ @@ -799,3 +853,17 @@ class TestDuckDB(Validator): "duckdb": "SELECT CAST(w AS TIMESTAMP_S), CAST(x AS TIMESTAMP_MS), CAST(y AS TIMESTAMP), CAST(z AS TIMESTAMP_NS)", }, ) + + def test_isnan(self): + self.validate_all( + "ISNAN(x)", + read={"bigquery": "IS_NAN(x)"}, + write={"bigquery": "IS_NAN(x)", "duckdb": "ISNAN(x)"}, + ) + + def test_isinf(self): + self.validate_all( + "ISINF(x)", + read={"bigquery": "IS_INF(x)"}, + write={"bigquery": "IS_INF(x)", "duckdb": "ISINF(x)"}, + ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index ba95442..b3366a2 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -315,6 +315,7 @@ class TestHive(Validator): self.validate_all( "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", write={ + "bigquery": "FORMAT_DATE('%Y-%m-%d %H:%M:%S', CAST('2020-01-01' AS DATETIME))", "duckdb": "STRFTIME(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%M:%S')", "presto": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %T')", "hive": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')", @@ -324,21 +325,29 @@ class TestHive(Validator): self.validate_all( "DATE_ADD('2020-01-01', 1)", write={ + "": "TS_OR_DS_ADD('2020-01-01', 1, DAY)", + "bigquery": "DATE_ADD(CAST(CAST('2020-01-01' AS DATETIME) AS DATE), INTERVAL 1 DAY)", "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY", - "presto": "DATE_ADD('DAY', 1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", "hive": "DATE_ADD('2020-01-01', 1)", + "presto": "DATE_ADD('DAY', 1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", + "redshift": "DATEADD(DAY, 1, '2020-01-01')", + "snowflake": "DATEADD(DAY, 1, CAST(CAST('2020-01-01' AS TIMESTAMPNTZ) AS DATE))", "spark": "DATE_ADD('2020-01-01', 1)", - "": "TS_OR_DS_ADD('2020-01-01', 1, 'DAY')", + "tsql": "DATEADD(DAY, 1, CAST(CAST('2020-01-01' AS DATETIME2) AS DATE))", }, ) self.validate_all( "DATE_SUB('2020-01-01', 1)", write={ + "": "TS_OR_DS_ADD('2020-01-01', 1 * -1, DAY)", + "bigquery": "DATE_ADD(CAST(CAST('2020-01-01' AS DATETIME) AS DATE), INTERVAL (1 * -1) DAY)", "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL (1 * -1) DAY", - "presto": "DATE_ADD('DAY', 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", "hive": "DATE_ADD('2020-01-01', 1 * -1)", + "presto": "DATE_ADD('DAY', 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", + "redshift": "DATEADD(DAY, 1 * -1, '2020-01-01')", + "snowflake": "DATEADD(DAY, 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMPNTZ) AS DATE))", "spark": "DATE_ADD('2020-01-01', 1 * -1)", - "": "TS_OR_DS_ADD('2020-01-01', 1 * -1, 'DAY')", + "tsql": "DATEADD(DAY, 1 * -1, CAST(CAST('2020-01-01' AS DATETIME2) AS DATE))", }, ) self.validate_all("DATE_ADD('2020-01-01', -1)", read={"": "DATE_SUB('2020-01-01', 1)"}) @@ -351,8 +360,8 @@ class TestHive(Validator): write={ "duckdb": "DATE_DIFF('day', CAST(x AS DATE), CAST(CAST(y AS DATE) AS DATE))", "presto": "DATE_DIFF('day', CAST(CAST(x AS TIMESTAMP) AS DATE), CAST(CAST(CAST(CAST(y AS TIMESTAMP) AS DATE) AS TIMESTAMP) AS DATE))", - "hive": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", - "spark": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", + "hive": "DATEDIFF(TO_DATE(y), TO_DATE(x))", + "spark": "DATEDIFF(TO_DATE(y), TO_DATE(x))", "": "DATEDIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))", }, ) @@ -522,11 +531,16 @@ class TestHive(Validator): ) self.validate_all( "ARRAY_CONTAINS(x, 1)", + read={ + "duckdb": "LIST_HAS(x, 1)", + "snowflake": "ARRAY_CONTAINS(1, x)", + }, write={ "duckdb": "ARRAY_CONTAINS(x, 1)", "presto": "CONTAINS(x, 1)", "hive": "ARRAY_CONTAINS(x, 1)", "spark": "ARRAY_CONTAINS(x, 1)", + "snowflake": "ARRAY_CONTAINS(1, x)", }, ) self.validate_all( @@ -687,7 +701,7 @@ class TestHive(Validator): "x div y", write={ "duckdb": "x // y", - "presto": "CAST(x / y AS INTEGER)", + "presto": "CAST(CAST(x AS DOUBLE) / y AS INTEGER)", "hive": "CAST(x / y AS INT)", "spark": "CAST(x / y AS INT)", }, @@ -707,11 +721,15 @@ class TestHive(Validator): self.validate_all( "COLLECT_SET(x)", read={ + "doris": "COLLECT_SET(x)", "presto": "SET_AGG(x)", + "snowflake": "ARRAY_UNIQUE_AGG(x)", }, write={ - "presto": "SET_AGG(x)", + "doris": "COLLECT_SET(x)", "hive": "COLLECT_SET(x)", + "presto": "SET_AGG(x)", + "snowflake": "ARRAY_UNIQUE_AGG(x)", "spark": "COLLECT_SET(x)", }, ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 3c165a3..19245f0 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -114,8 +114,17 @@ class TestMySQL(Validator): "mysql": "CREATE TABLE test (ts DATETIME, ts_tz TIMESTAMP, ts_ltz TIMESTAMP)", }, ) + self.validate_all( + "ALTER TABLE test_table ALTER COLUMN test_column SET DATA TYPE LONGTEXT", + write={ + "mysql": "ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT", + }, + ) + self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1") def test_identity(self): + self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')") + self.validate_identity("SELECT @var1 := 1, @var2") self.validate_identity("UNLOCK TABLES") self.validate_identity("LOCK TABLES `app_fields` WRITE") self.validate_identity("SELECT 1 XOR 0") @@ -523,7 +532,15 @@ class TestMySQL(Validator): ) self.validate_all( "SELECT DATEDIFF(x, y)", - write={"mysql": "SELECT DATEDIFF(x, y)", "presto": "SELECT DATE_DIFF('day', y, x)"}, + read={ + "presto": "SELECT DATE_DIFF('day', y, x)", + "redshift": "SELECT DATEDIFF(day, y, x)", + }, + write={ + "mysql": "SELECT DATEDIFF(x, y)", + "presto": "SELECT DATE_DIFF('day', y, x)", + "redshift": "SELECT DATEDIFF(day, y, x)", + }, ) self.validate_all( "DAYOFYEAR(x)", @@ -574,10 +591,16 @@ class TestMySQL(Validator): def test_mysql(self): self.validate_all( + "SELECT * FROM x LEFT JOIN y ON x.id = y.id UNION SELECT * FROM x RIGHT JOIN y ON x.id = y.id LIMIT 0", + read={ + "postgres": "SELECT * FROM x FULL JOIN y ON x.id = y.id LIMIT 0", + }, + ) + self.validate_all( # MySQL doesn't support FULL OUTER joins - "SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.x = t2.x UNION SELECT * FROM t1 RIGHT OUTER JOIN t2 ON t1.x = t2.x", + "WITH t1 AS (SELECT 1) SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.x = t2.x UNION SELECT * FROM t1 RIGHT OUTER JOIN t2 ON t1.x = t2.x", read={ - "postgres": "SELECT * FROM t1 FULL OUTER JOIN t2 ON t1.x = t2.x", + "postgres": "WITH t1 AS (SELECT 1) SELECT * FROM t1 FULL OUTER JOIN t2 ON t1.x = t2.x", }, ) self.validate_all( @@ -601,7 +624,9 @@ class TestMySQL(Validator): "mysql": "SELECT * FROM test LIMIT 1 OFFSET 1", "postgres": "SELECT * FROM test LIMIT 0 + 1 OFFSET 0 + 1", "presto": "SELECT * FROM test OFFSET 1 LIMIT 1", + "snowflake": "SELECT * FROM test LIMIT 1 OFFSET 1", "trino": "SELECT * FROM test OFFSET 1 LIMIT 1", + "bigquery": "SELECT * FROM test LIMIT 1 OFFSET 1", }, ) self.validate_all( @@ -984,3 +1009,28 @@ COMMENT='客户账户表'""" "mysql": "DATE_FORMAT(x, '%M')", }, ) + + def test_safe_div(self): + self.validate_all( + "a / b", + write={ + "bigquery": "a / NULLIF(b, 0)", + "clickhouse": "a / b", + "databricks": "a / NULLIF(b, 0)", + "duckdb": "a / b", + "hive": "a / b", + "mysql": "a / b", + "oracle": "a / NULLIF(b, 0)", + "snowflake": "a / NULLIF(b, 0)", + "spark": "a / b", + "starrocks": "a / b", + "drill": "CAST(a AS DOUBLE) / NULLIF(b, 0)", + "postgres": "CAST(a AS DOUBLE PRECISION) / NULLIF(b, 0)", + "presto": "CAST(a AS DOUBLE) / NULLIF(b, 0)", + "redshift": "CAST(a AS DOUBLE PRECISION) / NULLIF(b, 0)", + "sqlite": "CAST(a AS REAL) / b", + "teradata": "CAST(a AS DOUBLE) / NULLIF(b, 0)", + "trino": "CAST(a AS DOUBLE) / NULLIF(b, 0)", + "tsql": "CAST(a AS FLOAT) / NULLIF(b, 0)", + }, + ) diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index d92eea5..e9ebac1 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -1,3 +1,4 @@ +from sqlglot import exp, parse_one from sqlglot.errors import UnsupportedError from tests.dialects.test_dialect import Validator @@ -6,6 +7,15 @@ class TestOracle(Validator): dialect = "oracle" def test_oracle(self): + self.validate_identity("ALTER TABLE tbl_name DROP FOREIGN KEY fk_symbol") + self.assertIsInstance( + parse_one("ALTER TABLE tbl_name DROP FOREIGN KEY fk_symbol", dialect="oracle"), + exp.AlterTable, + ) + self.validate_identity( + "ALTER TABLE Payments ADD (Stock NUMBER NOT NULL, dropid VARCHAR2(500) NOT NULL)" + ) + self.validate_identity("ALTER TABLE Payments ADD Stock NUMBER NOT NULL") self.validate_identity("SELECT x FROM t WHERE cond FOR UPDATE") self.validate_identity("SELECT JSON_OBJECT(k1: v1 FORMAT JSON, k2: v2 FORMAT JSON)") self.validate_identity("SELECT JSON_OBJECT('name': first_name || ' ' || last_name) FROM t") @@ -57,8 +67,16 @@ class TestOracle(Validator): "SELECT * FROM t SAMPLE (.25)", "SELECT * FROM t SAMPLE (0.25)", ) + self.validate_identity("SELECT TO_CHAR(-100, 'L99', 'NL_CURRENCY = '' AusDollars '' ')") self.validate_all( + "SELECT TO_CHAR(TIMESTAMP '1999-12-01 10:00:00')", + write={ + "oracle": "SELECT TO_CHAR(CAST('1999-12-01 10:00:00' AS TIMESTAMP), 'YYYY-MM-DD HH24:MI:SS')", + "postgres": "SELECT TO_CHAR(CAST('1999-12-01 10:00:00' AS TIMESTAMP), 'YYYY-MM-DD HH24:MI:SS')", + }, + ) + self.validate_all( "SELECT CAST(NULL AS VARCHAR2(2328 CHAR)) AS COL1", write={ "oracle": "SELECT CAST(NULL AS VARCHAR2(2328 CHAR)) AS COL1", diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 0e5f1a1..17a65d7 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -1,6 +1,5 @@ -from unittest import mock - from sqlglot import ParseError, exp, parse_one, transpile +from sqlglot.helper import logger as helper_logger from tests.dialects.test_dialect import Validator @@ -22,6 +21,9 @@ class TestPostgres(Validator): self.validate_identity("UPDATE tbl_name SET foo = 123 RETURNING a") self.validate_identity("CREATE TABLE cities_partdef PARTITION OF cities DEFAULT") self.validate_identity( + "CREATE CONSTRAINT TRIGGER my_trigger AFTER INSERT OR DELETE OR UPDATE OF col_a, col_b ON public.my_table DEFERRABLE INITIALLY DEFERRED FOR EACH ROW EXECUTE FUNCTION do_sth()" + ) + self.validate_identity( "CREATE TABLE cust_part3 PARTITION OF customers FOR VALUES WITH (MODULUS 3, REMAINDER 2)" ) self.validate_identity( @@ -43,6 +45,9 @@ class TestPostgres(Validator): "CREATE INDEX foo ON bar.baz USING btree(col1 varchar_pattern_ops ASC, col2)" ) self.validate_identity( + "CREATE INDEX index_issues_on_title_trigram ON public.issues USING gin(title public.gin_trgm_ops)" + ) + self.validate_identity( "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO NOTHING RETURNING *" ) self.validate_identity( @@ -148,7 +153,7 @@ class TestPostgres(Validator): write={ "hive": "SELECT EXPLODE(c) FROM t", "postgres": "SELECT UNNEST(c) FROM t", - "presto": "SELECT IF(pos = pos_2, col) AS col FROM t, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(c)))) AS _u(pos) CROSS JOIN UNNEST(c) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(c) AND pos_2 = CARDINALITY(c))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col FROM t, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(c)))) AS _u(pos) CROSS JOIN UNNEST(c) WITH ORDINALITY AS _u_2(col, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(c) AND _u_2.pos_2 = CARDINALITY(c))", }, ) self.validate_all( @@ -156,20 +161,46 @@ class TestPostgres(Validator): write={ "hive": "SELECT EXPLODE(ARRAY(1))", "postgres": "SELECT UNNEST(ARRAY[1])", - "presto": "SELECT IF(pos = pos_2, col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[1]) AND pos_2 = CARDINALITY(ARRAY[1]))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[1]) AND _u_2.pos_2 = CARDINALITY(ARRAY[1]))", }, ) - @mock.patch("sqlglot.helper.logger") - def test_array_offset(self, logger): - self.validate_all( - "SELECT col[1]", - write={ - "hive": "SELECT col[0]", - "postgres": "SELECT col[1]", - "presto": "SELECT col[1]", - }, - ) + def test_array_offset(self): + with self.assertLogs(helper_logger) as cm: + self.validate_all( + "SELECT col[1]", + write={ + "bigquery": "SELECT col[0]", + "duckdb": "SELECT col[1]", + "hive": "SELECT col[0]", + "postgres": "SELECT col[1]", + "presto": "SELECT col[1]", + }, + ) + + self.assertEqual( + cm.output, + [ + "WARNING:sqlglot:Applying array index offset (-1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (1)", + ], + ) + + def test_operator(self): + expr = parse_one("1 OPERATOR(+) 2 OPERATOR(*) 3", read="postgres") + + expr.left.assert_is(exp.Operator) + expr.left.left.assert_is(exp.Literal) + expr.left.right.assert_is(exp.Literal) + expr.right.assert_is(exp.Literal) + self.assertEqual(expr.sql(dialect="postgres"), "1 OPERATOR(+) 2 OPERATOR(*) 3") + + self.validate_identity("SELECT operator FROM t") + self.validate_identity("SELECT 1 OPERATOR(+) 2") + self.validate_identity("SELECT 1 OPERATOR(+) /* foo */ 2") + self.validate_identity("SELECT 1 OPERATOR(pg_catalog.+) 2") def test_postgres(self): expr = parse_one( @@ -198,6 +229,14 @@ class TestPostgres(Validator): self.assertEqual(expr.sql(dialect="postgres"), alter_table_only) self.validate_identity( + "SELECT c.oid, n.nspname, c.relname " + "FROM pg_catalog.pg_class AS c " + "LEFT JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace " + "WHERE c.relname OPERATOR(pg_catalog.~) '^(courses)$' COLLATE pg_catalog.default AND " + "pg_catalog.PG_TABLE_IS_VISIBLE(c.oid) " + "ORDER BY 2, 3" + ) + self.validate_identity( "SELECT ARRAY[]::INT[] AS foo", "SELECT CAST(ARRAY[] AS INT[]) AS foo", ) @@ -728,26 +767,23 @@ class TestPostgres(Validator): ) def test_string_concat(self): - self.validate_all( - "SELECT CONCAT('abcde', 2, NULL, 22)", - write={ - "postgres": "SELECT CONCAT(COALESCE(CAST('abcde' AS TEXT), ''), COALESCE(CAST(2 AS TEXT), ''), COALESCE(CAST(NULL AS TEXT), ''), COALESCE(CAST(22 AS TEXT), ''))", - }, - ) + self.validate_identity("SELECT CONCAT('abcde', 2, NULL, 22)") + self.validate_all( "CONCAT(a, b)", write={ - "": "CONCAT(COALESCE(CAST(a AS TEXT), ''), COALESCE(CAST(b AS TEXT), ''))", - "duckdb": "CONCAT(COALESCE(CAST(a AS TEXT), ''), COALESCE(CAST(b AS TEXT), ''))", - "postgres": "CONCAT(COALESCE(CAST(a AS TEXT), ''), COALESCE(CAST(b AS TEXT), ''))", - "presto": "CONCAT(CAST(COALESCE(CAST(a AS VARCHAR), '') AS VARCHAR), CAST(COALESCE(CAST(b AS VARCHAR), '') AS VARCHAR))", + "": "CONCAT(COALESCE(a, ''), COALESCE(b, ''))", + "clickhouse": "CONCAT(COALESCE(a, ''), COALESCE(b, ''))", + "duckdb": "CONCAT(a, b)", + "postgres": "CONCAT(a, b)", + "presto": "CONCAT(COALESCE(CAST(a AS VARCHAR), ''), COALESCE(CAST(b AS VARCHAR), ''))", }, ) self.validate_all( "a || b", write={ "": "a || b", - "clickhouse": "CONCAT(CAST(a AS String), CAST(b AS String))", + "clickhouse": "a || b", "duckdb": "a || b", "postgres": "a || b", "presto": "CONCAT(CAST(a AS VARCHAR), CAST(b AS VARCHAR))", diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index ed734b6..6a82756 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -1,6 +1,5 @@ -from unittest import mock - from sqlglot import UnsupportedError, exp, parse_one +from sqlglot.helper import logger as helper_logger from tests.dialects.test_dialect import Validator @@ -34,14 +33,6 @@ class TestPresto(Validator): }, ) self.validate_all( - "SELECT DATE_DIFF('week', CAST(CAST('2009-01-01' AS TIMESTAMP) AS DATE), CAST(CAST('2009-12-31' AS TIMESTAMP) AS DATE))", - read={"redshift": "SELECT DATEDIFF(week, '2009-01-01', '2009-12-31')"}, - ) - self.validate_all( - "SELECT DATE_ADD('month', 18, CAST(CAST('2008-02-28' AS TIMESTAMP) AS DATE))", - read={"redshift": "SELECT DATEADD(month, 18, '2008-02-28')"}, - ) - self.validate_all( "SELECT CAST('1970-01-01 00:00:00' AS TIMESTAMP)", read={"postgres": "SELECT 'epoch'::TIMESTAMP"}, ) @@ -229,6 +220,7 @@ class TestPresto(Validator): self.validate_all( "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", write={ + "bigquery": "FORMAT_DATE('%Y-%m-%d %H:%M:%S', x)", "duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", "presto": "DATE_FORMAT(x, '%Y-%m-%d %T')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", @@ -308,7 +300,12 @@ class TestPresto(Validator): "hive": "CURRENT_TIMESTAMP()", }, ) - + self.validate_all( + "SELECT DATE_ADD('DAY', 1, CAST(CURRENT_DATE AS TIMESTAMP))", + read={ + "redshift": "SELECT DATEADD(DAY, 1, CURRENT_DATE)", + }, + ) self.validate_all( "DAY_OF_WEEK(timestamp '2012-08-08 01:00:00')", write={ @@ -537,8 +534,7 @@ class TestPresto(Validator): }, ) - @mock.patch("sqlglot.helper.logger") - def test_presto(self, logger): + def test_presto(self): self.validate_identity("string_agg(x, ',')", "ARRAY_JOIN(ARRAY_AGG(x), ',')") self.validate_identity( "SELECT * FROM example.testdb.customer_orders FOR VERSION AS OF 8954597067493422955" @@ -558,6 +554,24 @@ class TestPresto(Validator): "SELECT SPLIT_TO_MAP('a:1;b:2;a:3', ';', ':', (k, v1, v2) -> CONCAT(v1, v2))" ) + with self.assertLogs(helper_logger) as cm: + self.validate_all( + "SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table", + write={ + "postgres": UnsupportedError, + "presto": "SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table", + }, + ) + self.validate_all( + "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 4)", + write={ + "": "SELECT ARRAY(1, 2, 3)[3]", + "bigquery": "SELECT [1, 2, 3][SAFE_ORDINAL(4)]", + "postgres": "SELECT (ARRAY[1, 2, 3])[4]", + "presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 4)", + }, + ) + self.validate_all( "SELECT MAX_BY(a.id, a.timestamp) FROM a", read={ @@ -669,21 +683,6 @@ class TestPresto(Validator): self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' week"}) self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' WEEKS"}) self.validate_all( - "SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table", - write={ - "postgres": UnsupportedError, - "presto": "SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table", - }, - ) - self.validate_all( - "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 4)", - write={ - "": "SELECT ARRAY(1, 2, 3)[3]", - "postgres": "SELECT (ARRAY[1, 2, 3])[4]", - "presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 4)", - }, - ) - self.validate_all( "SELECT SUBSTRING(a, 1, 3), SUBSTRING(a, LENGTH(a) - (3 - 1))", read={ "redshift": "SELECT LEFT(a, 3), RIGHT(a, 3)", @@ -890,6 +889,13 @@ class TestPresto(Validator): }, ) self.validate_all( + "JSON_FORMAT(CAST(MAP_FROM_ENTRIES(ARRAY[('action_type', 'at')]) AS JSON))", + write={ + "presto": "JSON_FORMAT(CAST(MAP_FROM_ENTRIES(ARRAY[('action_type', 'at')]) AS JSON))", + "spark": "TO_JSON(MAP_FROM_ENTRIES(ARRAY(('action_type', 'at'))))", + }, + ) + self.validate_all( "JSON_FORMAT(x)", read={ "spark": "TO_JSON(x)", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 3e42525..c6be789 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -6,15 +6,6 @@ class TestRedshift(Validator): dialect = "redshift" def test_redshift(self): - self.validate_identity( - "SELECT DATE_DIFF('month', CAST('2020-02-29 00:00:00' AS TIMESTAMP), CAST('2020-03-02 00:00:00' AS TIMESTAMP))", - "SELECT DATEDIFF(month, CAST(CAST('2020-02-29 00:00:00' AS TIMESTAMP) AS DATE), CAST(CAST('2020-03-02 00:00:00' AS TIMESTAMP) AS DATE))", - ) - self.validate_identity( - "SELECT * FROM x WHERE y = DATEADD('month', -1, DATE_TRUNC('month', (SELECT y FROM #temp_table)))", - "SELECT * FROM x WHERE y = DATEADD(month, -1, CAST(DATE_TRUNC('month', (SELECT y FROM #temp_table)) AS DATE))", - ) - self.validate_all( "LISTAGG(sellerid, ', ')", read={ @@ -72,8 +63,11 @@ class TestRedshift(Validator): self.validate_all( "SELECT ADD_MONTHS('2008-03-31', 1)", write={ - "redshift": "SELECT DATEADD(month, 1, CAST('2008-03-31' AS DATE))", - "trino": "SELECT DATE_ADD('month', 1, CAST(CAST('2008-03-31' AS TIMESTAMP) AS DATE))", + "bigquery": "SELECT DATE_ADD(CAST('2008-03-31' AS DATETIME), INTERVAL 1 MONTH)", + "duckdb": "SELECT CAST('2008-03-31' AS TIMESTAMP) + INTERVAL 1 month", + "redshift": "SELECT DATEADD(month, 1, '2008-03-31')", + "trino": "SELECT DATE_ADD('month', 1, CAST('2008-03-31' AS TIMESTAMP))", + "tsql": "SELECT DATEADD(month, 1, CAST('2008-03-31' AS DATETIME2))", }, ) self.validate_all( @@ -205,18 +199,18 @@ class TestRedshift(Validator): "databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", - "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) _t WHERE _row_number = 1", "presto": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "redshift": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", "snowflake": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", "spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "sqlite": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", - "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", "tableau": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "teradata": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "trino": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", - "tsql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "tsql": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", }, ) self.validate_all( @@ -240,18 +234,43 @@ class TestRedshift(Validator): self.validate_all( "DATEDIFF('day', a, b)", write={ - "redshift": "DATEDIFF(day, CAST(a AS DATE), CAST(b AS DATE))", - "presto": "DATE_DIFF('day', CAST(CAST(a AS TIMESTAMP) AS DATE), CAST(CAST(b AS TIMESTAMP) AS DATE))", + "bigquery": "DATE_DIFF(CAST(b AS DATETIME), CAST(a AS DATETIME), day)", + "duckdb": "DATE_DIFF('day', CAST(a AS TIMESTAMP), CAST(b AS TIMESTAMP))", + "hive": "DATEDIFF(b, a)", + "redshift": "DATEDIFF(day, a, b)", + "presto": "DATE_DIFF('day', CAST(a AS TIMESTAMP), CAST(b AS TIMESTAMP))", }, ) self.validate_all( - "SELECT TOP 1 x FROM y", + "SELECT DATEADD(month, 18, '2008-02-28')", + write={ + "bigquery": "SELECT DATE_ADD(CAST('2008-02-28' AS DATETIME), INTERVAL 18 MONTH)", + "duckdb": "SELECT CAST('2008-02-28' AS TIMESTAMP) + INTERVAL 18 month", + "hive": "SELECT ADD_MONTHS('2008-02-28', 18)", + "mysql": "SELECT DATE_ADD('2008-02-28', INTERVAL 18 MONTH)", + "postgres": "SELECT CAST('2008-02-28' AS TIMESTAMP) + INTERVAL '18 month'", + "presto": "SELECT DATE_ADD('month', 18, CAST('2008-02-28' AS TIMESTAMP))", + "redshift": "SELECT DATEADD(month, 18, '2008-02-28')", + "snowflake": "SELECT DATEADD(month, 18, CAST('2008-02-28' AS TIMESTAMPNTZ))", + "tsql": "SELECT DATEADD(month, 18, CAST('2008-02-28' AS DATETIME2))", + }, + ) + self.validate_all( + "SELECT DATEDIFF(week, '2009-01-01', '2009-12-31')", write={ - "redshift": "SELECT x FROM y LIMIT 1", + "bigquery": "SELECT DATE_DIFF(CAST('2009-12-31' AS DATETIME), CAST('2009-01-01' AS DATETIME), week)", + "duckdb": "SELECT DATE_DIFF('week', CAST('2009-01-01' AS TIMESTAMP), CAST('2009-12-31' AS TIMESTAMP))", + "hive": "SELECT CAST(DATEDIFF('2009-12-31', '2009-01-01') / 7 AS INT)", + "postgres": "SELECT CAST(EXTRACT(days FROM (CAST('2009-12-31' AS TIMESTAMP) - CAST('2009-01-01' AS TIMESTAMP))) / 7 AS BIGINT)", + "presto": "SELECT DATE_DIFF('week', CAST('2009-01-01' AS TIMESTAMP), CAST('2009-12-31' AS TIMESTAMP))", + "redshift": "SELECT DATEDIFF(week, '2009-01-01', '2009-12-31')", + "snowflake": "SELECT DATEDIFF(week, '2009-01-01', '2009-12-31')", + "tsql": "SELECT DATEDIFF(week, '2009-01-01', '2009-12-31')", }, ) def test_identity(self): + self.validate_identity("SELECT DATEADD(day, 1, 'today')") self.validate_identity("SELECT * FROM #x") self.validate_identity("SELECT INTERVAL '5 day'") self.validate_identity("foo$") @@ -263,6 +282,26 @@ class TestRedshift(Validator): self.validate_identity("SELECT APPROXIMATE AS y") self.validate_identity("CREATE TABLE t (c BIGINT IDENTITY(0, 1))") self.validate_identity( + "SELECT CONCAT('abc', 'def')", + "SELECT 'abc' || 'def'", + ) + self.validate_identity( + "SELECT CONCAT_WS('DELIM', 'abc', 'def', 'ghi')", + "SELECT 'abc' || 'DELIM' || 'def' || 'DELIM' || 'ghi'", + ) + self.validate_identity( + "SELECT TOP 1 x FROM y", + "SELECT x FROM y LIMIT 1", + ) + self.validate_identity( + "SELECT DATE_DIFF('month', CAST('2020-02-29 00:00:00' AS TIMESTAMP), CAST('2020-03-02 00:00:00' AS TIMESTAMP))", + "SELECT DATEDIFF(month, CAST('2020-02-29 00:00:00' AS TIMESTAMP), CAST('2020-03-02 00:00:00' AS TIMESTAMP))", + ) + self.validate_identity( + "SELECT * FROM x WHERE y = DATEADD('month', -1, DATE_TRUNC('month', (SELECT y FROM #temp_table)))", + "SELECT * FROM x WHERE y = DATEADD(month, -1, DATE_TRUNC('month', (SELECT y FROM #temp_table)))", + ) + self.validate_identity( "SELECT 'a''b'", "SELECT 'a\\'b'", ) @@ -271,6 +310,12 @@ class TestRedshift(Validator): "CREATE TABLE t (c BIGINT IDENTITY(0, 1))", ) self.validate_identity( + "SELECT DATEADD(hour, 0, CAST('2020-02-02 01:03:05.124' AS TIMESTAMP))" + ) + self.validate_identity( + "SELECT DATEDIFF(second, '2020-02-02 00:00:00.000', '2020-02-02 01:03:05.124')" + ) + self.validate_identity( "CREATE OR REPLACE VIEW v1 AS SELECT id, AVG(average_metric1) AS m1, AVG(average_metric2) AS m2 FROM t GROUP BY id WITH NO SCHEMA BINDING" ) self.validate_identity( @@ -295,12 +340,8 @@ class TestRedshift(Validator): "CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)" ) self.validate_identity( - "SELECT DATEADD(day, 1, 'today')", - "SELECT DATEADD(day, 1, CAST('today' AS DATE))", - ) - self.validate_identity( "SELECT DATEADD('day', ndays, caldate)", - "SELECT DATEADD(day, ndays, CAST(caldate AS DATE))", + "SELECT DATEADD(day, ndays, caldate)", ) self.validate_identity( "CONVERT(INT, x)", @@ -308,7 +349,7 @@ class TestRedshift(Validator): ) self.validate_identity( "SELECT DATE_ADD('day', 1, DATE('2023-01-01'))", - "SELECT DATEADD(day, 1, CAST(DATE('2023-01-01') AS DATE))", + "SELECT DATEADD(day, 1, DATE('2023-01-01'))", ) self.validate_identity( """SELECT @@ -449,17 +490,3 @@ FROM ( "redshift": "CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING", }, ) - - def test_concat(self): - self.validate_all( - "SELECT CONCAT('abc', 'def')", - write={ - "redshift": "SELECT COALESCE(CAST('abc' AS VARCHAR(MAX)), '') || COALESCE(CAST('def' AS VARCHAR(MAX)), '')", - }, - ) - self.validate_all( - "SELECT CONCAT_WS('DELIM', 'abc', 'def', 'ghi')", - write={ - "redshift": "SELECT COALESCE(CAST('abc' AS VARCHAR(MAX)), '') || 'DELIM' || COALESCE(CAST('def' AS VARCHAR(MAX)), '') || 'DELIM' || COALESCE(CAST('ghi' AS VARCHAR(MAX)), '')", - }, - ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 2cad1d2..997c27b 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -6,13 +6,38 @@ from tests.dialects.test_dialect import Validator class TestSnowflake(Validator): + maxDiff = None dialect = "snowflake" def test_snowflake(self): + self.validate_identity("SELECT rename, replace") expr = parse_one("SELECT APPROX_TOP_K(C4, 3, 5) FROM t") expr.selects[0].assert_is(exp.AggFunc) self.assertEqual(expr.sql(dialect="snowflake"), "SELECT APPROX_TOP_K(C4, 3, 5) FROM t") + self.assertEqual( + exp.select(exp.Explode(this=exp.column("x")).as_("y", quoted=True)).sql( + "snowflake", pretty=True + ), + """SELECT + IFF(_u.pos = _u_2.pos_2, _u_2."y", NULL) AS "y" +FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, ( + GREATEST(ARRAY_SIZE(x)) - 1 +) + 1))) AS _u(seq, key, path, index, pos, this) +CROSS JOIN TABLE(FLATTEN(INPUT => x)) AS _u_2(seq, key, path, pos_2, "y", this) +WHERE + _u.pos = _u_2.pos_2 + OR ( + _u.pos > ( + ARRAY_SIZE(x) - 1 + ) AND _u_2.pos_2 = ( + ARRAY_SIZE(x) - 1 + ) + )""", + ) + + self.validate_identity("SELECT user_id, value FROM table_name sample ($s) SEED (0)") + self.validate_identity("SELECT ARRAY_UNIQUE_AGG(x)") self.validate_identity("SELECT OBJECT_CONSTRUCT()") self.validate_identity("SELECT DAYOFMONTH(CURRENT_TIMESTAMP())") self.validate_identity("SELECT DAYOFYEAR(CURRENT_TIMESTAMP())") @@ -48,6 +73,14 @@ class TestSnowflake(Validator): 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage' ) self.validate_identity( + "SELECT * FROM unnest(x) with ordinality", + "SELECT * FROM TABLE(FLATTEN(INPUT => x)) AS _u(seq, key, path, index, value, this)", + ) + self.validate_identity( + "CREATE TABLE foo (ID INT COMMENT $$some comment$$)", + "CREATE TABLE foo (ID INT COMMENT 'some comment')", + ) + self.validate_identity( "SELECT state, city, SUM(retail_price * quantity) AS gross_revenue FROM sales GROUP BY ALL" ) self.validate_identity( @@ -88,6 +121,21 @@ class TestSnowflake(Validator): self.validate_all("CAST(x AS CHARACTER VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all("CAST(x AS NCHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all( + # We need to qualify the columns in this query because "value" would be ambiguous + 'WITH t(x, "value") AS (SELECT [1, 2, 3], 1) SELECT IFF(_u.pos = _u_2.pos_2, _u_2."value", NULL) AS "value" FROM t, TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (GREATEST(ARRAY_SIZE(t.x)) - 1) + 1))) AS _u(seq, key, path, index, pos, this) CROSS JOIN TABLE(FLATTEN(INPUT => t.x)) AS _u_2(seq, key, path, pos_2, "value", this) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > (ARRAY_SIZE(t.x) - 1) AND _u_2.pos_2 = (ARRAY_SIZE(t.x) - 1))', + read={ + "duckdb": 'WITH t(x, "value") AS (SELECT [1,2,3], 1) SELECT UNNEST(t.x) AS "value" FROM t', + }, + ) + self.validate_all( + "SELECT { 'Manitoba': 'Winnipeg', 'foo': 'bar' } AS province_capital", + write={ + "duckdb": "SELECT {'Manitoba': 'Winnipeg', 'foo': 'bar'} AS province_capital", + "snowflake": "SELECT OBJECT_CONSTRUCT('Manitoba', 'Winnipeg', 'foo', 'bar') AS province_capital", + "spark": "SELECT STRUCT('Manitoba' AS Winnipeg, 'foo' AS bar) AS province_capital", + }, + ) + self.validate_all( "SELECT COLLATE('B', 'und:ci')", write={ "bigquery": "SELECT COLLATE('B', 'und:ci')", @@ -225,6 +273,7 @@ class TestSnowflake(Validator): "spark": "POWER(x, 2)", "sqlite": "POWER(x, 2)", "starrocks": "POWER(x, 2)", + "teradata": "x ** 2", "trino": "POWER(x, 2)", "tsql": "POWER(x, 2)", }, @@ -241,8 +290,8 @@ class TestSnowflake(Validator): "DIV0(foo, bar)", write={ "snowflake": "IFF(bar = 0, 0, foo / bar)", - "sqlite": "CASE WHEN bar = 0 THEN 0 ELSE foo / bar END", - "presto": "IF(bar = 0, 0, foo / bar)", + "sqlite": "CASE WHEN bar = 0 THEN 0 ELSE CAST(foo AS REAL) / bar END", + "presto": "IF(bar = 0, 0, CAST(foo AS DOUBLE) / bar)", "spark": "IF(bar = 0, 0, foo / bar)", "hive": "IF(bar = 0, 0, foo / bar)", "duckdb": "CASE WHEN bar = 0 THEN 0 ELSE foo / bar END", @@ -355,7 +404,7 @@ class TestSnowflake(Validator): self.validate_all( "SELECT TO_TIMESTAMP(1659981729)", write={ - "bigquery": "SELECT UNIX_TO_TIME(1659981729)", + "bigquery": "SELECT TIMESTAMP_SECONDS(1659981729)", "snowflake": "SELECT TO_TIMESTAMP(1659981729)", "spark": "SELECT CAST(FROM_UNIXTIME(1659981729) AS TIMESTAMP)", }, @@ -363,7 +412,7 @@ class TestSnowflake(Validator): self.validate_all( "SELECT TO_TIMESTAMP(1659981729000, 3)", write={ - "bigquery": "SELECT UNIX_TO_TIME(1659981729000, 'millis')", + "bigquery": "SELECT TIMESTAMP_MILLIS(1659981729000)", "snowflake": "SELECT TO_TIMESTAMP(1659981729000, 3)", "spark": "SELECT TIMESTAMP_MILLIS(1659981729000)", }, @@ -371,7 +420,6 @@ class TestSnowflake(Validator): self.validate_all( "SELECT TO_TIMESTAMP('1659981729')", write={ - "bigquery": "SELECT UNIX_TO_TIME('1659981729')", "snowflake": "SELECT TO_TIMESTAMP('1659981729')", "spark": "SELECT CAST(FROM_UNIXTIME('1659981729') AS TIMESTAMP)", }, @@ -379,9 +427,11 @@ class TestSnowflake(Validator): self.validate_all( "SELECT TO_TIMESTAMP(1659981729000000000, 9)", write={ - "bigquery": "SELECT UNIX_TO_TIME(1659981729000000000, 'micros')", + "bigquery": "SELECT TIMESTAMP_MICROS(CAST(1659981729000000000 / 1000 AS INT64))", + "duckdb": "SELECT TO_TIMESTAMP(1659981729000000000 / 1000000000)", + "presto": "SELECT FROM_UNIXTIME(CAST(1659981729000000000 AS DOUBLE) / 1000000000)", "snowflake": "SELECT TO_TIMESTAMP(1659981729000000000, 9)", - "spark": "SELECT TIMESTAMP_MICROS(1659981729000000000)", + "spark": "SELECT TIMESTAMP_SECONDS(1659981729000000000 / 1000000000)", }, ) self.validate_all( @@ -404,7 +454,6 @@ class TestSnowflake(Validator): "spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')", }, ) - self.validate_all( "SELECT IFF(TRUE, 'true', 'false')", write={ @@ -551,6 +600,7 @@ class TestSnowflake(Validator): staged_file.sql(dialect="snowflake"), ) + self.validate_identity("SELECT metadata$filename FROM @s1/") self.validate_identity("SELECT * FROM @~") self.validate_identity("SELECT * FROM @~/some/path/to/file.csv") self.validate_identity("SELECT * FROM @mystage") @@ -610,6 +660,13 @@ class TestSnowflake(Validator): "snowflake": "SELECT * FROM testtable SAMPLE BLOCK (0.012) SEED (99992)", }, ) + self.validate_all( + "SELECT * FROM (SELECT * FROM t1 join t2 on t1.a = t2.c) SAMPLE (1)", + write={ + "snowflake": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) SAMPLE (1)", + "spark": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) SAMPLE (1 PERCENT)", + }, + ) def test_timestamps(self): self.validate_identity("SELECT CAST('12:00:00' AS TIME)") @@ -719,6 +776,17 @@ class TestSnowflake(Validator): ) def test_ddl(self): + self.validate_identity( + """create external table et2( + col1 date as (parse_json(metadata$external_table_partition):COL1::date), + col2 varchar as (parse_json(metadata$external_table_partition):COL2::varchar), + col3 number as (parse_json(metadata$external_table_partition):COL3::number)) + partition by (col1,col2,col3) + location=@s2/logs/ + partition_type = user_specified + file_format = (type = parquet)""", + "CREATE EXTERNAL TABLE et2 (col1 DATE AS (CAST(PARSE_JSON(metadata$external_table_partition)['COL1'] AS DATE)), col2 VARCHAR AS (CAST(PARSE_JSON(metadata$external_table_partition)['COL2'] AS VARCHAR)), col3 DECIMAL AS (CAST(PARSE_JSON(metadata$external_table_partition)['COL3'] AS DECIMAL))) LOCATION @s2/logs/ PARTITION BY (col1, col2, col3) partition_type=user_specified file_format=(type = parquet)", + ) self.validate_identity("CREATE OR REPLACE VIEW foo (uid) COPY GRANTS AS (SELECT 1)") self.validate_identity("CREATE TABLE geospatial_table (id INT, g GEOGRAPHY)") self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x") @@ -733,7 +801,7 @@ class TestSnowflake(Validator): "CREATE TABLE orders_clone_restore CLONE orders BEFORE (STATEMENT => '8e5d0ca9-005e-44e6-b858-a8f5b37c5726')" ) self.validate_identity( - "CREATE TABLE a (x DATE, y BIGINT) WITH (PARTITION BY (x), integration='q', auto_refresh=TRUE, file_format=(type = parquet))" + "CREATE TABLE a (x DATE, y BIGINT) PARTITION BY (x) integration='q' auto_refresh=TRUE file_format=(type = parquet)" ) self.validate_identity( "CREATE SCHEMA mytestschema_clone_restore CLONE testschema BEFORE (TIMESTAMP => TO_TIMESTAMP(40 * 365 * 86400))" @@ -1179,3 +1247,39 @@ MATCH_RECOGNIZE ( ast = parse_one("ALTER TABLE a SWAP WITH b", read="snowflake") assert isinstance(ast, exp.AlterTable) assert isinstance(ast.args["actions"][0], exp.SwapTable) + + def test_try_cast(self): + self.validate_identity("SELECT TRY_CAST(x AS DOUBLE)") + + self.validate_all("TRY_CAST('foo' AS TEXT)", read={"hive": "CAST('foo' AS STRING)"}) + self.validate_all("CAST(5 + 5 AS TEXT)", read={"hive": "CAST(5 + 5 AS STRING)"}) + self.validate_all( + "CAST(TRY_CAST('2020-01-01' AS DATE) AS TEXT)", + read={ + "hive": "CAST(CAST('2020-01-01' AS DATE) AS STRING)", + "snowflake": "CAST(TRY_CAST('2020-01-01' AS DATE) AS TEXT)", + }, + ) + self.validate_all( + "TRY_CAST(x AS TEXT)", + read={ + "hive": "CAST(x AS STRING)", + "snowflake": "TRY_CAST(x AS TEXT)", + }, + ) + + from sqlglot.optimizer.annotate_types import annotate_types + + expression = parse_one("SELECT CAST(t.x AS STRING) FROM t", read="hive") + + expression = annotate_types(expression, schema={"t": {"x": "string"}}) + self.assertEqual(expression.sql(dialect="snowflake"), "SELECT TRY_CAST(t.x AS TEXT) FROM t") + + expression = annotate_types(expression, schema={"t": {"x": "int"}}) + self.assertEqual(expression.sql(dialect="snowflake"), "SELECT CAST(t.x AS TEXT) FROM t") + + # We can't infer FOO's type since it's a UDF in this case, so we don't get rid of TRY_CAST + expression = parse_one("SELECT TRY_CAST(FOO() AS TEXT)", read="snowflake") + + expression = annotate_types(expression) + self.assertEqual(expression.sql(dialect="snowflake"), "SELECT TRY_CAST(FOO() AS TEXT)") diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 841a005..fe37027 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -8,6 +8,7 @@ class TestSpark(Validator): dialect = "spark" def test_ddl(self): + self.validate_identity("CREATE TABLE foo AS WITH t AS (SELECT 1 AS col) SELECT col FROM t") self.validate_identity("CREATE TEMPORARY VIEW test AS SELECT 1") self.validate_identity("CREATE TABLE foo (col VARCHAR(50))") self.validate_identity("CREATE TABLE foo (col STRUCT<struct_col_a: VARCHAR((50))>)") @@ -226,15 +227,23 @@ TBLPROPERTIES ( ) def test_spark(self): + self.validate_identity("FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), 'utc')") expr = parse_one("any_value(col, true)", read="spark") self.assertIsInstance(expr.args.get("ignore_nulls"), exp.Boolean) self.assertEqual(expr.sql(dialect="spark"), "ANY_VALUE(col, TRUE)") + self.assertEqual( + parse_one("REFRESH TABLE t", read="spark").assert_is(exp.Refresh).sql(dialect="spark"), + "REFRESH TABLE t", + ) + + self.validate_identity("REFRESH 'hdfs://path/to/table'") + self.validate_identity("REFRESH TABLE tempDB.view1") self.validate_identity("SELECT CASE WHEN a = NULL THEN 1 ELSE 2 END") self.validate_identity("SELECT * FROM t1 SEMI JOIN t2 ON t1.x = t2.x") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)") - self.validate_identity("REFRESH table a.b.c") + self.validate_identity("REFRESH TABLE a.b.c") self.validate_identity("INTERVAL -86 days") self.validate_identity("SELECT UNIX_TIMESTAMP()") self.validate_identity("TRIM(' SparkSQL ')") @@ -300,6 +309,18 @@ TBLPROPERTIES ( }, ) self.validate_all( + "SELECT DATEDIFF(week, '2020-01-01', '2020-12-31')", + write={ + "bigquery": "SELECT DATE_DIFF(CAST('2020-12-31' AS DATE), CAST('2020-01-01' AS DATE), week)", + "duckdb": "SELECT DATE_DIFF('week', CAST('2020-01-01' AS DATE), CAST('2020-12-31' AS DATE))", + "hive": "SELECT CAST(DATEDIFF(TO_DATE('2020-12-31'), TO_DATE('2020-01-01')) / 7 AS INT)", + "postgres": "SELECT CAST(EXTRACT(days FROM (CAST(CAST('2020-12-31' AS DATE) AS TIMESTAMP) - CAST(CAST('2020-01-01' AS DATE) AS TIMESTAMP))) / 7 AS BIGINT)", + "redshift": "SELECT DATEDIFF(week, CAST('2020-01-01' AS DATE), CAST('2020-12-31' AS DATE))", + "snowflake": "SELECT DATEDIFF(week, CAST('2020-01-01' AS DATE), CAST('2020-12-31' AS DATE))", + "spark": "SELECT DATEDIFF(week, TO_DATE('2020-01-01'), TO_DATE('2020-12-31'))", + }, + ) + self.validate_all( "SELECT MONTHS_BETWEEN('1997-02-28 10:30:00', '1996-10-30')", write={ "duckdb": "SELECT DATEDIFF('month', CAST('1996-10-30' AS TIMESTAMP), CAST('1997-02-28 10:30:00' AS TIMESTAMP))", @@ -435,7 +456,7 @@ TBLPROPERTIES ( "SELECT DATE_ADD(my_date_column, 1)", write={ "spark": "SELECT DATE_ADD(my_date_column, 1)", - "bigquery": "SELECT DATE_ADD(my_date_column, INTERVAL 1 DAY)", + "bigquery": "SELECT DATE_ADD(CAST(CAST(my_date_column AS DATETIME) AS DATE), INTERVAL 1 DAY)", }, ) self.validate_all( @@ -592,9 +613,9 @@ TBLPROPERTIES ( self.validate_all( "INSERT OVERWRITE TABLE table WITH cte AS (SELECT cola FROM other_table) SELECT cola FROM cte", write={ - "databricks": "INSERT OVERWRITE TABLE table WITH cte AS (SELECT cola FROM other_table) SELECT cola FROM cte", + "databricks": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", "hive": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", - "spark": "INSERT OVERWRITE TABLE table WITH cte AS (SELECT cola FROM other_table) SELECT cola FROM cte", + "spark": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", "spark2": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", }, ) @@ -604,7 +625,7 @@ TBLPROPERTIES ( "SELECT EXPLODE(x) FROM tbl", write={ "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col FROM tbl, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(x)) - 1)) AS pos CROSS JOIN UNNEST(x) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(x) - 1) AND pos_2 = (ARRAY_LENGTH(x) - 1))", - "presto": "SELECT IF(pos = pos_2, col) AS col FROM tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(x) AND pos_2 = CARDINALITY(x))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col FROM tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(col, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(x) AND _u_2.pos_2 = CARDINALITY(x))", "spark": "SELECT EXPLODE(x) FROM tbl", }, ) @@ -612,46 +633,46 @@ TBLPROPERTIES ( "SELECT EXPLODE(col) FROM _u", write={ "bigquery": "SELECT IF(pos = pos_2, col_2, NULL) AS col_2 FROM _u, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(col)) - 1)) AS pos CROSS JOIN UNNEST(col) AS col_2 WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(col) - 1) AND pos_2 = (ARRAY_LENGTH(col) - 1))", - "presto": "SELECT IF(pos = pos_2, col_2) AS col_2 FROM _u, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(col)))) AS _u_2(pos) CROSS JOIN UNNEST(col) WITH ORDINALITY AS _u_3(col_2, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(col) AND pos_2 = CARDINALITY(col))", + "presto": "SELECT IF(_u_2.pos = _u_3.pos_2, _u_3.col_2) AS col_2 FROM _u, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(col)))) AS _u_2(pos) CROSS JOIN UNNEST(col) WITH ORDINALITY AS _u_3(col_2, pos_2) WHERE _u_2.pos = _u_3.pos_2 OR (_u_2.pos > CARDINALITY(col) AND _u_3.pos_2 = CARDINALITY(col))", "spark": "SELECT EXPLODE(col) FROM _u", }, ) self.validate_all( "SELECT EXPLODE(col) AS exploded FROM schema.tbl", write={ - "presto": "SELECT IF(pos = pos_2, exploded) AS exploded FROM schema.tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(col)))) AS _u(pos) CROSS JOIN UNNEST(col) WITH ORDINALITY AS _u_2(exploded, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(col) AND pos_2 = CARDINALITY(col))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.exploded) AS exploded FROM schema.tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(col)))) AS _u(pos) CROSS JOIN UNNEST(col) WITH ORDINALITY AS _u_2(exploded, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(col) AND _u_2.pos_2 = CARDINALITY(col))", }, ) self.validate_all( "SELECT EXPLODE(ARRAY(1, 2))", write={ "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2])) - 1)) AS pos CROSS JOIN UNNEST([1, 2]) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2]) - 1))", - "presto": "SELECT IF(pos = pos_2, col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2]) AND pos_2 = CARDINALITY(ARRAY[1, 2]))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[1, 2]) AND _u_2.pos_2 = CARDINALITY(ARRAY[1, 2]))", }, ) self.validate_all( "SELECT POSEXPLODE(ARRAY(2, 3)) AS x", write={ "bigquery": "SELECT IF(pos = pos_2, x, NULL) AS x, IF(pos = pos_2, pos_2, NULL) AS pos_2 FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([2, 3])) - 1)) AS pos CROSS JOIN UNNEST([2, 3]) AS x WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH([2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([2, 3]) - 1))", - "presto": "SELECT IF(pos = pos_2, x) AS x, IF(pos = pos_2, pos_2) AS pos_2 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(x, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[2, 3]) AND pos_2 = CARDINALITY(ARRAY[2, 3]))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.x) AS x, IF(_u.pos = _u_2.pos_2, _u_2.pos_2) AS pos_2 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(x, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[2, 3]) AND _u_2.pos_2 = CARDINALITY(ARRAY[2, 3]))", }, ) self.validate_all( "SELECT POSEXPLODE(x) AS (a, b)", write={ - "presto": "SELECT IF(pos = a, b) AS b, IF(pos = a, a) AS a FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(b, a) WHERE pos = a OR (pos > CARDINALITY(x) AND a = CARDINALITY(x))", + "presto": "SELECT IF(_u.pos = _u_2.a, _u_2.b) AS b, IF(_u.pos = _u_2.a, _u_2.a) AS a FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(b, a) WHERE _u.pos = _u_2.a OR (_u.pos > CARDINALITY(x) AND _u_2.a = CARDINALITY(x))", }, ) self.validate_all( "SELECT POSEXPLODE(ARRAY(2, 3)), EXPLODE(ARRAY(4, 5, 6)) FROM tbl", write={ "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_2, pos_2, NULL) AS pos_2, IF(pos = pos_3, col_2, NULL) AS col_2 FROM tbl, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([2, 3]), ARRAY_LENGTH([4, 5, 6])) - 1)) AS pos CROSS JOIN UNNEST([2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5, 6]) AS col_2 WITH OFFSET AS pos_3 WHERE (pos = pos_2 OR (pos > (ARRAY_LENGTH([2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5, 6]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5, 6]) - 1)))", - "presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_2, pos_2) AS pos_2, IF(pos = pos_3, col_2) AS col_2 FROM tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3]), CARDINALITY(ARRAY[4, 5, 6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5, 6]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE (pos = pos_2 OR (pos > CARDINALITY(ARRAY[2, 3]) AND pos_2 = CARDINALITY(ARRAY[2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5, 6]) AND pos_3 = CARDINALITY(ARRAY[4, 5, 6])))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col, IF(_u.pos = _u_2.pos_2, _u_2.pos_2) AS pos_2, IF(_u.pos = _u_3.pos_3, _u_3.col_2) AS col_2 FROM tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3]), CARDINALITY(ARRAY[4, 5, 6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5, 6]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE (_u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[2, 3]) AND _u_2.pos_2 = CARDINALITY(ARRAY[2, 3]))) AND (_u.pos = _u_3.pos_3 OR (_u.pos > CARDINALITY(ARRAY[4, 5, 6]) AND _u_3.pos_3 = CARDINALITY(ARRAY[4, 5, 6])))", }, ) self.validate_all( "SELECT col, pos, POSEXPLODE(ARRAY(2, 3)) FROM _u", write={ - "presto": "SELECT col, pos, IF(pos_2 = pos_3, col_2) AS col_2, IF(pos_2 = pos_3, pos_3) AS pos_3 FROM _u, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u_2(pos_2) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE pos_2 = pos_3 OR (pos_2 > CARDINALITY(ARRAY[2, 3]) AND pos_3 = CARDINALITY(ARRAY[2, 3]))", + "presto": "SELECT col, pos, IF(_u_2.pos_2 = _u_3.pos_3, _u_3.col_2) AS col_2, IF(_u_2.pos_2 = _u_3.pos_3, _u_3.pos_3) AS pos_3 FROM _u, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u_2(pos_2) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE _u_2.pos_2 = _u_3.pos_3 OR (_u_2.pos_2 > CARDINALITY(ARRAY[2, 3]) AND _u_3.pos_3 = CARDINALITY(ARRAY[2, 3]))", }, ) diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 14703c4..85d4ebf 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -147,6 +147,9 @@ class TestTeradata(Validator): def test_mod(self): self.validate_all("a MOD b", write={"teradata": "a MOD b", "mysql": "a % b"}) + def test_power(self): + self.validate_all("a ** b", write={"teradata": "a ** b", "mysql": "POWER(a, b)"}) + def test_abbrev(self): self.validate_identity("a LT b", "a < b") self.validate_identity("a LE b", "a <= b") @@ -191,3 +194,14 @@ class TestTeradata(Validator): }, ) self.validate_identity("CAST('1992-01' AS FORMAT 'YYYY-DD')") + + self.validate_all( + "TRYCAST('-2.5' AS DECIMAL(5, 2))", + read={ + "snowflake": "TRY_CAST('-2.5' AS DECIMAL(5, 2))", + }, + write={ + "snowflake": "TRY_CAST('-2.5' AS DECIMAL(5, 2))", + "teradata": "TRYCAST('-2.5' AS DECIMAL(5, 2))", + }, + ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 0ac94f2..07179ef 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -6,8 +6,22 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): - self.validate_all( - "WITH t(c) AS (SELECT 1) SELECT * INTO foo FROM (SELECT c FROM t) AS temp", + # https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN + # tsql allows .. which means use the default schema + self.validate_identity("SELECT * FROM a..b") + + self.validate_identity("SELECT CONCAT(column1, column2)") + self.validate_identity("SELECT TestSpecialChar.Test# FROM TestSpecialChar") + self.validate_identity("SELECT TestSpecialChar.Test@ FROM TestSpecialChar") + self.validate_identity("SELECT TestSpecialChar.Test$ FROM TestSpecialChar") + self.validate_identity("SELECT TestSpecialChar.Test_ FROM TestSpecialChar") + self.validate_identity("SELECT TOP (2 + 1) 1") + self.validate_identity("SELECT * FROM t WHERE NOT c", "SELECT * FROM t WHERE NOT c <> 0") + self.validate_identity("1 AND true", "1 <> 0 AND (1 = 1)") + self.validate_identity("CAST(x AS int) OR y", "CAST(x AS INTEGER) <> 0 OR y <> 0") + + self.validate_all( + "WITH t(c) AS (SELECT 1) SELECT * INTO foo FROM (SELECT c AS c FROM t) AS temp", read={ "duckdb": "CREATE TABLE foo AS WITH t(c) AS (SELECT 1) SELECT c FROM t", }, @@ -25,7 +39,7 @@ class TestTSQL(Validator): }, ) self.validate_all( - "WITH t(c) AS (SELECT 1) MERGE INTO x AS z USING (SELECT c FROM t) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", + "WITH t(c) AS (SELECT 1) MERGE INTO x AS z USING (SELECT c AS c FROM t) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", read={ "postgres": "MERGE INTO x AS z USING (WITH t(c) AS (SELECT 1) SELECT c FROM t) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", }, @@ -167,18 +181,6 @@ class TestTSQL(Validator): ) self.validate_all( - "SELECT DATEPART(year, CAST('2017-01-01' AS DATE))", - read={"postgres": "SELECT DATE_PART('year', '2017-01-01'::DATE)"}, - ) - self.validate_all( - "SELECT DATEPART(month, CAST('2017-03-01' AS DATE))", - read={"postgres": "SELECT DATE_PART('month', '2017-03-01'::DATE)"}, - ) - self.validate_all( - "SELECT DATEPART(day, CAST('2017-01-02' AS DATE))", - read={"postgres": "SELECT DATE_PART('day', '2017-01-02'::DATE)"}, - ) - self.validate_all( "SELECT CAST([a].[b] AS SMALLINT) FROM foo", write={ "tsql": 'SELECT CAST("a"."b" AS SMALLINT) FROM foo', @@ -229,11 +231,13 @@ class TestTSQL(Validator): self.validate_all( "HASHBYTES('SHA1', x)", read={ + "snowflake": "SHA1(x)", "spark": "SHA(x)", }, write={ - "tsql": "HASHBYTES('SHA1', x)", + "snowflake": "SHA1(x)", "spark": "SHA(x)", + "tsql": "HASHBYTES('SHA1', x)", }, ) self.validate_all( @@ -561,6 +565,21 @@ class TestTSQL(Validator): ) def test_ddl(self): + expression = parse_one("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B", dialect="tsql") + self.assertIsInstance(expression, exp.AlterTable) + self.assertIsInstance(expression.args["actions"][0], exp.Drop) + self.assertEqual( + expression.sql(dialect="tsql"), "ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B" + ) + + for clusterd_keyword in ("CLUSTERED", "NONCLUSTERED"): + self.validate_identity( + 'CREATE TABLE "dbo"."benchmark" (' + '"name" CHAR(7) NOT NULL, ' + '"internal_id" VARCHAR(10) NOT NULL, ' + f'UNIQUE {clusterd_keyword} ("internal_id" ASC))' + ) + self.validate_identity( "CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < CURRENT_TIMESTAMP - 7 END", "CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < GETDATE() - 7 END", @@ -589,6 +608,12 @@ class TestTSQL(Validator): }, ) self.validate_all( + "SELECT * INTO foo.bar.baz FROM (SELECT * FROM a.b.c) AS temp", + read={ + "": "CREATE TABLE foo.bar.baz AS (SELECT * FROM a.b.c)", + }, + ) + self.validate_all( "IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id('db.tbl') AND name = 'idx') EXEC('CREATE INDEX idx ON db.tbl')", read={ "": "CREATE INDEX IF NOT EXISTS idx ON db.tbl", @@ -622,7 +647,6 @@ class TestTSQL(Validator): "tsql": "CREATE OR ALTER VIEW a.b AS SELECT 1", }, ) - self.validate_all( "ALTER TABLE a ADD b INTEGER, c INTEGER", read={ @@ -633,7 +657,6 @@ class TestTSQL(Validator): "tsql": "ALTER TABLE a ADD b INTEGER, c INTEGER", }, ) - self.validate_all( "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))", write={ @@ -833,7 +856,7 @@ WHERE ) def test_len(self): - self.validate_all("LEN(x)", write={"spark": "LENGTH(x)"}) + self.validate_all("LEN(x)", read={"": "LENGTH(x)"}, write={"spark": "LENGTH(x)"}) def test_replicate(self): self.validate_all("REPLICATE('x', 2)", write={"spark": "REPEAT('x', 2)"}) @@ -870,11 +893,68 @@ WHERE ) def test_datepart(self): + self.validate_identity( + "DATEPART(QUARTER, x)", + "DATEPART(quarter, CAST(x AS DATETIME2))", + ) + self.validate_identity( + "DATEPART(YEAR, x)", + "FORMAT(CAST(x AS DATETIME2), 'yyyy')", + ) + self.validate_identity( + "DATEPART(HOUR, date_and_time)", + "DATEPART(hour, CAST(date_and_time AS DATETIME2))", + ) + self.validate_identity( + "DATEPART(WEEKDAY, date_and_time)", + "DATEPART(dw, CAST(date_and_time AS DATETIME2))", + ) + self.validate_identity( + "DATEPART(DW, date_and_time)", + "DATEPART(dw, CAST(date_and_time AS DATETIME2))", + ) + self.validate_all( "SELECT DATEPART(month,'1970-01-01')", - write={"spark": "SELECT DATE_FORMAT(CAST('1970-01-01' AS TIMESTAMP), 'MM')"}, + write={ + "postgres": "SELECT TO_CHAR(CAST('1970-01-01' AS TIMESTAMP), 'MM')", + "spark": "SELECT DATE_FORMAT(CAST('1970-01-01' AS TIMESTAMP), 'MM')", + "tsql": "SELECT FORMAT(CAST('1970-01-01' AS DATETIME2), 'MM')", + }, + ) + self.validate_all( + "SELECT DATEPART(year, CAST('2017-01-01' AS DATE))", + read={ + "postgres": "SELECT DATE_PART('year', '2017-01-01'::DATE)", + }, + write={ + "postgres": "SELECT TO_CHAR(CAST(CAST('2017-01-01' AS DATE) AS TIMESTAMP), 'YYYY')", + "spark": "SELECT DATE_FORMAT(CAST(CAST('2017-01-01' AS DATE) AS TIMESTAMP), 'yyyy')", + "tsql": "SELECT FORMAT(CAST(CAST('2017-01-01' AS DATE) AS DATETIME2), 'yyyy')", + }, + ) + self.validate_all( + "SELECT DATEPART(month, CAST('2017-03-01' AS DATE))", + read={ + "postgres": "SELECT DATE_PART('month', '2017-03-01'::DATE)", + }, + write={ + "postgres": "SELECT TO_CHAR(CAST(CAST('2017-03-01' AS DATE) AS TIMESTAMP), 'MM')", + "spark": "SELECT DATE_FORMAT(CAST(CAST('2017-03-01' AS DATE) AS TIMESTAMP), 'MM')", + "tsql": "SELECT FORMAT(CAST(CAST('2017-03-01' AS DATE) AS DATETIME2), 'MM')", + }, + ) + self.validate_all( + "SELECT DATEPART(day, CAST('2017-01-02' AS DATE))", + read={ + "postgres": "SELECT DATE_PART('day', '2017-01-02'::DATE)", + }, + write={ + "postgres": "SELECT TO_CHAR(CAST(CAST('2017-01-02' AS DATE) AS TIMESTAMP), 'DD')", + "spark": "SELECT DATE_FORMAT(CAST(CAST('2017-01-02' AS DATE) AS TIMESTAMP), 'dd')", + "tsql": "SELECT FORMAT(CAST(CAST('2017-01-02' AS DATE) AS DATETIME2), 'dd')", + }, ) - self.validate_identity("DATEPART(YEAR, x)", "FORMAT(CAST(x AS DATETIME2), 'yyyy')") def test_convert_date_format(self): self.validate_all( @@ -1073,10 +1153,7 @@ WHERE def test_date_diff(self): self.validate_identity("SELECT DATEDIFF(hour, 1.5, '2021-01-01')") - self.validate_identity( - "SELECT DATEDIFF(year, '2020-01-01', '2021-01-01')", - "SELECT DATEDIFF(year, CAST('2020-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", - ) + self.validate_all( "SELECT DATEDIFF(quarter, 0, '2021-01-01')", write={ @@ -1098,7 +1175,7 @@ WHERE write={ "tsql": "SELECT DATEDIFF(year, CAST('2020-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", "spark": "SELECT DATEDIFF(year, CAST('2020-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", - "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('2021-01-01' AS TIMESTAMP), CAST('2020-01-01' AS TIMESTAMP)) AS INT) / 12", + "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('2021-01-01' AS TIMESTAMP), CAST('2020-01-01' AS TIMESTAMP)) / 12 AS INT)", }, ) self.validate_all( @@ -1114,16 +1191,18 @@ WHERE write={ "databricks": "SELECT DATEDIFF(quarter, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", "spark": "SELECT DATEDIFF(quarter, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", - "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) AS INT) / 3", + "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) / 3 AS INT)", "tsql": "SELECT DATEDIFF(quarter, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))", }, ) def test_iif(self): self.validate_identity( - "SELECT IF(cond, 'True', 'False')", "SELECT IIF(cond, 'True', 'False')" + "SELECT IF(cond, 'True', 'False')", "SELECT IIF(cond <> 0, 'True', 'False')" + ) + self.validate_identity( + "SELECT IIF(cond, 'True', 'False')", "SELECT IIF(cond <> 0, 'True', 'False')" ) - self.validate_identity("SELECT IIF(cond, 'True', 'False')") self.validate_all( "SELECT IIF(cond, 'True', 'False');", write={ @@ -1173,9 +1252,14 @@ WHERE def test_top(self): self.validate_all( - "SELECT TOP 3 * FROM A", + "SELECT DISTINCT TOP 3 * FROM A", + read={ + "spark": "SELECT DISTINCT * FROM A LIMIT 3", + }, write={ - "spark": "SELECT * FROM A LIMIT 3", + "spark": "SELECT DISTINCT * FROM A LIMIT 3", + "teradata": "SELECT DISTINCT TOP 3 * FROM A", + "tsql": "SELECT DISTINCT TOP 3 * FROM A", }, ) self.validate_all( @@ -1292,6 +1376,26 @@ WHERE }, ) + def test_temporal_table(self): + self.validate_identity( + """CREATE TABLE test ("data" CHAR(7), "valid_from" DATETIME2(2) GENERATED ALWAYS AS ROW START NOT NULL, "valid_to" DATETIME2(2) GENERATED ALWAYS AS ROW END NOT NULL, PERIOD FOR SYSTEM_TIME ("valid_from", "valid_to")) WITH(SYSTEM_VERSIONING=ON)""" + ) + self.validate_identity( + """CREATE TABLE test ("data" CHAR(7), "valid_from" DATETIME2(2) GENERATED ALWAYS AS ROW START HIDDEN NOT NULL, "valid_to" DATETIME2(2) GENERATED ALWAYS AS ROW END HIDDEN NOT NULL, PERIOD FOR SYSTEM_TIME ("valid_from", "valid_to")) WITH(SYSTEM_VERSIONING=ON(HISTORY_TABLE="dbo"."benchmark_history", DATA_CONSISTENCY_CHECK=ON))""" + ) + self.validate_identity( + """CREATE TABLE test ("data" CHAR(7), "valid_from" DATETIME2(2) GENERATED ALWAYS AS ROW START NOT NULL, "valid_to" DATETIME2(2) GENERATED ALWAYS AS ROW END NOT NULL, PERIOD FOR SYSTEM_TIME ("valid_from", "valid_to")) WITH(SYSTEM_VERSIONING=ON(HISTORY_TABLE="dbo"."benchmark_history", DATA_CONSISTENCY_CHECK=ON))""" + ) + self.validate_identity( + """CREATE TABLE test ("data" CHAR(7), "valid_from" DATETIME2(2) GENERATED ALWAYS AS ROW START NOT NULL, "valid_to" DATETIME2(2) GENERATED ALWAYS AS ROW END NOT NULL, PERIOD FOR SYSTEM_TIME ("valid_from", "valid_to")) WITH(SYSTEM_VERSIONING=ON(HISTORY_TABLE="dbo"."benchmark_history", DATA_CONSISTENCY_CHECK=OFF))""" + ) + self.validate_identity( + """CREATE TABLE test ("data" CHAR(7), "valid_from" DATETIME2(2) GENERATED ALWAYS AS ROW START NOT NULL, "valid_to" DATETIME2(2) GENERATED ALWAYS AS ROW END NOT NULL, PERIOD FOR SYSTEM_TIME ("valid_from", "valid_to")) WITH(SYSTEM_VERSIONING=ON(HISTORY_TABLE="dbo"."benchmark_history"))""" + ) + self.validate_identity( + """CREATE TABLE test ("data" CHAR(7), "valid_from" DATETIME2(2) GENERATED ALWAYS AS ROW START NOT NULL, "valid_to" DATETIME2(2) GENERATED ALWAYS AS ROW END NOT NULL, PERIOD FOR SYSTEM_TIME ("valid_from", "valid_to")) WITH(SYSTEM_VERSIONING=ON(HISTORY_TABLE="dbo"."benchmark_history"))""" + ) + def test_system_time(self): self.validate_all( "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo'", @@ -1433,3 +1537,28 @@ FROM OPENJSON(@json) WITH ( "spark": "SET count = (SELECT COUNT(1) FROM x)", }, ) + + def test_qualify_derived_table_outputs(self): + self.validate_identity( + "WITH t AS (SELECT 1) SELECT * FROM t", + 'WITH t AS (SELECT 1 AS "1") SELECT * FROM t', + ) + self.validate_identity( + 'WITH t AS (SELECT "c") SELECT * FROM t', + 'WITH t AS (SELECT "c" AS "c") SELECT * FROM t', + ) + self.validate_identity( + "SELECT * FROM (SELECT 1) AS subq", + 'SELECT * FROM (SELECT 1 AS "1") AS subq', + ) + self.validate_identity( + 'SELECT * FROM (SELECT "c") AS subq', + 'SELECT * FROM (SELECT "c" AS "c") AS subq', + ) + + self.validate_all( + "WITH t1(c) AS (SELECT 1), t2 AS (SELECT CAST(c AS INTEGER) AS c FROM t1) SELECT * FROM t2", + read={ + "duckdb": "WITH t1(c) AS (SELECT 1), t2 AS (SELECT CAST(c AS INTEGER) FROM t1) SELECT * FROM t2", + }, + ) |