diff options
author | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:45:55 +0000 |
---|---|---|
committer | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:45:55 +0000 |
commit | 02df6cdb000c8dbf739abda2af321a4f90d1b059 (patch) | |
tree | 2fc1daf848082ff67a11e60025cac260e3c318b2 /tests | |
parent | Adding upstream version 19.0.1. (diff) | |
download | sqlglot-02df6cdb000c8dbf739abda2af321a4f90d1b059.tar.xz sqlglot-02df6cdb000c8dbf739abda2af321a4f90d1b059.zip |
Adding upstream version 20.1.0.upstream/20.1.0
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'tests')
34 files changed, 1448 insertions, 334 deletions
diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py index 117789e..7a12808 100644 --- a/tests/dataframe/unit/test_column.py +++ b/tests/dataframe/unit/test_column.py @@ -146,7 +146,7 @@ class TestDataframeColumn(unittest.TestCase): self.assertEqual("cola BETWEEN 1 AND 3", F.col("cola").between(1, 3).sql()) self.assertEqual("cola BETWEEN 10.1 AND 12.1", F.col("cola").between(10.1, 12.1).sql()) self.assertEqual( - "cola BETWEEN TO_DATE('2022-01-01') AND TO_DATE('2022-03-01')", + "cola BETWEEN CAST('2022-01-01' AS DATE) AND CAST('2022-03-01' AS DATE)", F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(), ) self.assertEqual( diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 586b8fc..54b327c 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -27,7 +27,7 @@ class TestFunctions(unittest.TestCase): test_null = SF.lit(None) self.assertEqual("NULL", test_null.sql()) test_date = SF.lit(datetime.date(2022, 1, 1)) - self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) + self.assertEqual("CAST('2022-01-01' AS DATE)", test_date.sql()) test_datetime = SF.lit(datetime.datetime(2022, 1, 1, 1, 1, 1)) self.assertEqual("CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql()) test_dict = SF.lit({"cola": 1, "colb": "test"}) @@ -49,7 +49,7 @@ class TestFunctions(unittest.TestCase): test_array = SF.col([1, 2, "3"]) self.assertEqual("ARRAY(1, 2, '3')", test_array.sql()) test_date = SF.col(datetime.date(2022, 1, 1)) - self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) + self.assertEqual("CAST('2022-01-01' AS DATE)", test_date.sql()) test_datetime = SF.col(datetime.datetime(2022, 1, 1, 1, 1, 1)) self.assertEqual("CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql()) test_dict = SF.col({"cola": 1, "colb": "test"}) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 3601e47..420803a 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -1,6 +1,14 @@ from unittest import mock -from sqlglot import ErrorLevel, ParseError, TokenError, UnsupportedError, transpile +from sqlglot import ( + ErrorLevel, + ParseError, + TokenError, + UnsupportedError, + parse, + transpile, +) +from sqlglot.helper import logger as helper_logger from tests.dialects.test_dialect import Validator @@ -9,6 +17,28 @@ class TestBigQuery(Validator): maxDiff = None def test_bigquery(self): + with self.assertLogs(helper_logger) as cm: + self.validate_all( + "SELECT a[1], b[OFFSET(1)], c[ORDINAL(1)], d[SAFE_OFFSET(1)], e[SAFE_ORDINAL(1)]", + write={ + "duckdb": "SELECT a[2], b[2], c[1], d[2], e[1]", + "bigquery": "SELECT a[1], b[OFFSET(1)], c[ORDINAL(1)], d[SAFE_OFFSET(1)], e[SAFE_ORDINAL(1)]", + "presto": "SELECT a[2], b[2], c[1], ELEMENT_AT(d, 2), ELEMENT_AT(e, 1)", + }, + ) + + self.validate_all( + "a[0]", + read={ + "duckdb": "a[1]", + "presto": "a[1]", + }, + ) + + self.validate_identity( + "select array_contains([1, 2, 3], 1)", + "SELECT EXISTS(SELECT 1 FROM UNNEST([1, 2, 3]) AS _col WHERE _col = 1)", + ) self.validate_identity("CREATE SCHEMA x DEFAULT COLLATE 'en'") self.validate_identity("CREATE TABLE x (y INT64) DEFAULT COLLATE 'en'") self.validate_identity("PARSE_JSON('{}', wide_number_mode => 'exact')") @@ -37,6 +67,15 @@ class TestBigQuery(Validator): with self.assertRaises(ParseError): transpile("DATE_ADD(x, day)", read="bigquery") + for_in_stmts = parse( + "FOR record IN (SELECT word FROM shakespeare) DO SELECT record.word; END FOR;", + read="bigquery", + ) + self.assertEqual( + [s.sql(dialect="bigquery") for s in for_in_stmts], + ["FOR record IN (SELECT word FROM shakespeare) DO SELECT record.word", "END FOR"], + ) + self.validate_identity("SELECT test.Unknown FROM test") self.validate_identity(r"SELECT '\n\r\a\v\f\t'") self.validate_identity("SELECT * FROM tbl FOR SYSTEM_TIME AS OF z") @@ -89,6 +128,11 @@ class TestBigQuery(Validator): self.validate_identity("ROLLBACK TRANSACTION") self.validate_identity("CAST(x AS BIGNUMERIC)") self.validate_identity("SELECT y + 1 FROM x GROUP BY y + 1 ORDER BY 1") + self.validate_identity("SELECT TIMESTAMP_SECONDS(2) AS t") + self.validate_identity("SELECT TIMESTAMP_MILLIS(2) AS t") + self.validate_identity( + "FOR record IN (SELECT word, word_count FROM bigquery-public-data.samples.shakespeare LIMIT 5) DO SELECT record.word, record.word_count" + ) self.validate_identity( "DATE(CAST('2016-12-25 05:30:00+07' AS DATETIME), 'America/Los_Angeles')" ) @@ -143,6 +187,19 @@ class TestBigQuery(Validator): self.validate_all("x <> ''''''", write={"bigquery": "x <> ''"}) self.validate_all("CAST(x AS DATETIME)", read={"": "x::timestamp"}) self.validate_all( + "SELECT TIMESTAMP_MICROS(x)", + read={ + "duckdb": "SELECT MAKE_TIMESTAMP(x)", + "spark": "SELECT TIMESTAMP_MICROS(x)", + }, + write={ + "bigquery": "SELECT TIMESTAMP_MICROS(x)", + "duckdb": "SELECT MAKE_TIMESTAMP(x)", + "snowflake": "SELECT TO_TIMESTAMP(x / 1000, 3)", + "spark": "SELECT TIMESTAMP_MICROS(x)", + }, + ) + self.validate_all( "SELECT * FROM t WHERE EXISTS(SELECT * FROM unnest(nums) AS x WHERE x > 1)", write={ "bigquery": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS x WHERE x > 1)", diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 93d1ced..86ddb00 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -6,22 +6,6 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): - self.validate_identity("x <> y") - - self.validate_all( - "has([1], x)", - read={ - "postgres": "x = any(array[1])", - }, - ) - self.validate_all( - "NOT has([1], x)", - read={ - "postgres": "any(array[1]) <> x", - }, - ) - self.validate_identity("x = y") - string_types = [ "BLOB", "LONGBLOB", @@ -40,6 +24,8 @@ class TestClickhouse(Validator): self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)") self.assertIsNone(expr._meta) + self.validate_identity("x = y") + self.validate_identity("x <> y") self.validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 0.01)") self.validate_identity("SELECT * FROM (SELECT a FROM b SAMPLE 1 / 10 OFFSET 1 / 2)") self.validate_identity("SELECT sum(foo * bar) FROM bla SAMPLE 10000000") @@ -81,7 +67,17 @@ class TestClickhouse(Validator): self.validate_identity("position(haystack, needle, position)") self.validate_identity("CAST(x AS DATETIME)") self.validate_identity("CAST(x as MEDIUMINT)", "CAST(x AS Int32)") - + self.validate_identity("SELECT arrayJoin([1, 2, 3] AS src) AS dst, 'Hello', src") + self.validate_identity( + "SELECT SUM(1) AS impressions, arrayJoin(cities) AS city, arrayJoin(browsers) AS browser FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities, ['Firefox', 'Chrome', 'Chrome'] AS browsers) GROUP BY 2, 3" + ) + self.validate_identity( + "SELECT sum(1) AS impressions, (arrayJoin(arrayZip(cities, browsers)) AS t).1 AS city, t.2 AS browser FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities, ['Firefox', 'Chrome', 'Chrome'] AS browsers) GROUP BY 2, 3" + ) + self.validate_identity( + "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ['Istanbul', 'Berlin']", + "SELECT SUM(1) AS impressions FROM (SELECT ['Istanbul', 'Berlin', 'Bobruisk'] AS cities) WHERE arrayJoin(cities) IN ('Istanbul', 'Berlin')", + ) self.validate_identity( 'SELECT CAST(tuple(1 AS "a", 2 AS "b", 3.0 AS "c").2 AS Nullable(String))' ) @@ -102,6 +98,25 @@ class TestClickhouse(Validator): ) self.validate_all( + "SELECT arrayJoin([1,2,3])", + write={ + "clickhouse": "SELECT arrayJoin([1, 2, 3])", + "postgres": "SELECT UNNEST(ARRAY[1, 2, 3])", + }, + ) + self.validate_all( + "has([1], x)", + read={ + "postgres": "x = any(array[1])", + }, + ) + self.validate_all( + "NOT has([1], x)", + read={ + "postgres": "any(array[1]) <> x", + }, + ) + self.validate_all( "SELECT CAST('2020-01-01' AS TIMESTAMP) + INTERVAL '500' microsecond", read={ "duckdb": "SELECT TIMESTAMP '2020-01-01' + INTERVAL '500 us'", @@ -197,12 +212,15 @@ class TestClickhouse(Validator): }, ) self.validate_all( - "CONCAT(CASE WHEN COALESCE(CAST(a AS String), '') IS NULL THEN COALESCE(CAST(a AS String), '') ELSE CAST(COALESCE(CAST(a AS String), '') AS String) END, CASE WHEN COALESCE(CAST(b AS String), '') IS NULL THEN COALESCE(CAST(b AS String), '') ELSE CAST(COALESCE(CAST(b AS String), '') AS String) END)", - read={"postgres": "CONCAT(a, b)"}, - ) - self.validate_all( - "CONCAT(CASE WHEN a IS NULL THEN a ELSE CAST(a AS String) END, CASE WHEN b IS NULL THEN b ELSE CAST(b AS String) END)", - read={"mysql": "CONCAT(a, b)"}, + "CONCAT(a, b)", + read={ + "clickhouse": "CONCAT(a, b)", + "mysql": "CONCAT(a, b)", + }, + write={ + "mysql": "CONCAT(a, b)", + "postgres": "CONCAT(a, b)", + }, ) self.validate_all( r"'Enum8(\'Sunday\' = 0)'", write={"clickhouse": "'Enum8(''Sunday'' = 0)'"} @@ -320,6 +338,10 @@ class TestClickhouse(Validator): self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5") self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1") + query = parse_one("""WITH (SELECT 1) AS y SELECT * FROM y""", read="clickhouse") + self.assertIsInstance(query.args["with"].expressions[0].this, exp.Subquery) + self.assertEqual(query.args["with"].expressions[0].alias, "y") + def test_ternary(self): self.validate_all("x ? 1 : 2", write={"clickhouse": "CASE WHEN x THEN 1 ELSE 2 END"}) self.validate_all( diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 8bb88b3..7c13e79 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -131,7 +131,7 @@ class TestDatabricks(Validator): "SELECT DATEDIFF(week, 'start', 'end')", write={ "databricks": "SELECT DATEDIFF(week, 'start', 'end')", - "postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 48 + EXTRACT(month FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 4 + EXTRACT(day FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) / 7 AS BIGINT)", + "postgres": "SELECT CAST(EXTRACT(days FROM (CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP))) / 7 AS BIGINT)", }, ) self.validate_all( diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 2546c98..49afc62 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -7,9 +7,10 @@ from sqlglot import ( ParseError, TokenError, UnsupportedError, + exp, parse_one, ) -from sqlglot.dialects import Hive +from sqlglot.dialects import BigQuery, Hive, Snowflake class Validator(unittest.TestCase): @@ -78,9 +79,56 @@ class TestDialect(Validator): self.assertIsNotNone(Dialect[dialect.value]) def test_get_or_raise(self): - self.assertEqual(Dialect.get_or_raise(Hive), Hive) - self.assertEqual(Dialect.get_or_raise(Hive()), Hive) - self.assertEqual(Dialect.get_or_raise("hive"), Hive) + self.assertIsInstance(Dialect.get_or_raise(Hive), Hive) + self.assertIsInstance(Dialect.get_or_raise(Hive()), Hive) + self.assertIsInstance(Dialect.get_or_raise("hive"), Hive) + + with self.assertRaises(ValueError): + Dialect.get_or_raise(1) + + default_mysql = Dialect.get_or_raise("mysql") + self.assertEqual(default_mysql.normalization_strategy, "CASE_SENSITIVE") + + lowercase_mysql = Dialect.get_or_raise("mysql,normalization_strategy=lowercase") + self.assertEqual(lowercase_mysql.normalization_strategy, "LOWERCASE") + + lowercase_mysql = Dialect.get_or_raise("mysql, normalization_strategy = lowercase") + self.assertEqual(lowercase_mysql.normalization_strategy.value, "LOWERCASE") + + with self.assertRaises(ValueError) as cm: + Dialect.get_or_raise("mysql, normalization_strategy") + + self.assertEqual( + str(cm.exception), + "Invalid dialect format: 'mysql, normalization_strategy'. " + "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'.", + ) + + def test_compare_dialects(self): + bigquery_class = Dialect["bigquery"] + bigquery_object = BigQuery() + bigquery_string = "bigquery" + + snowflake_class = Dialect["snowflake"] + snowflake_object = Snowflake() + snowflake_string = "snowflake" + + self.assertEqual(snowflake_class, snowflake_class) + self.assertEqual(snowflake_class, snowflake_object) + self.assertEqual(snowflake_class, snowflake_string) + self.assertEqual(snowflake_object, snowflake_object) + self.assertEqual(snowflake_object, snowflake_string) + + self.assertNotEqual(snowflake_class, bigquery_class) + self.assertNotEqual(snowflake_class, bigquery_object) + self.assertNotEqual(snowflake_class, bigquery_string) + self.assertNotEqual(snowflake_object, bigquery_object) + self.assertNotEqual(snowflake_object, bigquery_string) + + self.assertTrue(snowflake_class in {"snowflake", "bigquery"}) + self.assertTrue(snowflake_object in {"snowflake", "bigquery"}) + self.assertFalse(snowflake_class in {"bigquery", "redshift"}) + self.assertFalse(snowflake_object in {"bigquery", "redshift"}) def test_cast(self): self.validate_all( @@ -561,6 +609,7 @@ class TestDialect(Validator): self.validate_all( "TIME_TO_STR(x, '%Y-%m-%d')", write={ + "bigquery": "FORMAT_DATE('%Y-%m-%d', x)", "drill": "TO_CHAR(x, 'yyyy-MM-dd')", "duckdb": "STRFTIME(x, '%Y-%m-%d')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", @@ -866,9 +915,9 @@ class TestDialect(Validator): write={ "drill": "CAST(x AS DATE)", "duckdb": "CAST(x AS DATE)", - "hive": "TO_DATE(x)", - "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", - "spark": "TO_DATE(x)", + "hive": "CAST(x AS DATE)", + "presto": "CAST(x AS DATE)", + "spark": "CAST(x AS DATE)", "sqlite": "x", }, ) @@ -893,7 +942,7 @@ class TestDialect(Validator): self.validate_all( "TS_OR_DS_ADD(CURRENT_DATE, 1, 'DAY')", write={ - "presto": "DATE_ADD('DAY', 1, CURRENT_DATE)", + "presto": "DATE_ADD('DAY', 1, CAST(CAST(CURRENT_DATE AS TIMESTAMP) AS DATE))", "hive": "DATE_ADD(CURRENT_DATE, 1)", }, ) @@ -1269,13 +1318,6 @@ class TestDialect(Validator): }, ) self.validate_all( - "SELECT * FROM a ORDER BY col_a NULLS LAST", - write={ - "mysql": UnsupportedError, - "starrocks": UnsupportedError, - }, - ) - self.validate_all( "POSITION(needle in haystack)", write={ "drill": "STRPOS(haystack, needle)", @@ -1315,35 +1357,37 @@ class TestDialect(Validator): self.validate_all( "CONCAT_WS('-', 'a', 'b')", write={ + "clickhouse": "CONCAT_WS('-', 'a', 'b')", "duckdb": "CONCAT_WS('-', 'a', 'b')", - "presto": "CONCAT_WS('-', 'a', 'b')", + "presto": "CONCAT_WS('-', CAST('a' AS VARCHAR), CAST('b' AS VARCHAR))", "hive": "CONCAT_WS('-', 'a', 'b')", "spark": "CONCAT_WS('-', 'a', 'b')", - "trino": "CONCAT_WS('-', 'a', 'b')", + "trino": "CONCAT_WS('-', CAST('a' AS VARCHAR), CAST('b' AS VARCHAR))", }, ) self.validate_all( "CONCAT_WS('-', x)", write={ + "clickhouse": "CONCAT_WS('-', x)", "duckdb": "CONCAT_WS('-', x)", "hive": "CONCAT_WS('-', x)", - "presto": "CONCAT_WS('-', x)", + "presto": "CONCAT_WS('-', CAST(x AS VARCHAR))", "spark": "CONCAT_WS('-', x)", - "trino": "CONCAT_WS('-', x)", + "trino": "CONCAT_WS('-', CAST(x AS VARCHAR))", }, ) self.validate_all( "CONCAT(a)", write={ - "clickhouse": "a", - "presto": "a", - "trino": "a", + "clickhouse": "CONCAT(a)", + "presto": "CAST(a AS VARCHAR)", + "trino": "CAST(a AS VARCHAR)", "tsql": "a", }, ) self.validate_all( - "COALESCE(CAST(a AS TEXT), '')", + "CONCAT(COALESCE(a, ''))", read={ "drill": "CONCAT(a)", "duckdb": "CONCAT(a)", @@ -1442,6 +1486,76 @@ class TestDialect(Validator): "spark": "FILTER(the_array, x -> x > 0)", }, ) + self.validate_all( + "a / b", + write={ + "bigquery": "a / b", + "clickhouse": "a / b", + "databricks": "a / b", + "duckdb": "a / b", + "hive": "a / b", + "mysql": "a / b", + "oracle": "a / b", + "snowflake": "a / b", + "spark": "a / b", + "starrocks": "a / b", + "drill": "CAST(a AS DOUBLE) / b", + "postgres": "CAST(a AS DOUBLE PRECISION) / b", + "presto": "CAST(a AS DOUBLE) / b", + "redshift": "CAST(a AS DOUBLE PRECISION) / b", + "sqlite": "CAST(a AS REAL) / b", + "teradata": "CAST(a AS DOUBLE) / b", + "trino": "CAST(a AS DOUBLE) / b", + "tsql": "CAST(a AS FLOAT) / b", + }, + ) + + def test_typeddiv(self): + typed_div = exp.Div(this=exp.column("a"), expression=exp.column("b"), typed=True) + div = exp.Div(this=exp.column("a"), expression=exp.column("b")) + typed_div_dialect = "presto" + div_dialect = "hive" + INT = exp.DataType.Type.INT + FLOAT = exp.DataType.Type.FLOAT + + for expression, types, dialect, expected in [ + (typed_div, (None, None), typed_div_dialect, "a / b"), + (typed_div, (None, None), div_dialect, "a / b"), + (div, (None, None), typed_div_dialect, "CAST(a AS DOUBLE) / b"), + (div, (None, None), div_dialect, "a / b"), + (typed_div, (INT, INT), typed_div_dialect, "a / b"), + (typed_div, (INT, INT), div_dialect, "CAST(a / b AS BIGINT)"), + (div, (INT, INT), typed_div_dialect, "CAST(a AS DOUBLE) / b"), + (div, (INT, INT), div_dialect, "a / b"), + (typed_div, (FLOAT, FLOAT), typed_div_dialect, "a / b"), + (typed_div, (FLOAT, FLOAT), div_dialect, "a / b"), + (div, (FLOAT, FLOAT), typed_div_dialect, "a / b"), + (div, (FLOAT, FLOAT), div_dialect, "a / b"), + (typed_div, (INT, FLOAT), typed_div_dialect, "a / b"), + (typed_div, (INT, FLOAT), div_dialect, "a / b"), + (div, (INT, FLOAT), typed_div_dialect, "a / b"), + (div, (INT, FLOAT), div_dialect, "a / b"), + ]: + with self.subTest(f"{expression.__class__.__name__} {types} {dialect} -> {expected}"): + expression = expression.copy() + expression.left.type = types[0] + expression.right.type = types[1] + self.assertEqual(expected, expression.sql(dialect=dialect)) + + def test_safediv(self): + safe_div = exp.Div(this=exp.column("a"), expression=exp.column("b"), safe=True) + div = exp.Div(this=exp.column("a"), expression=exp.column("b")) + safe_div_dialect = "mysql" + div_dialect = "snowflake" + + for expression, dialect, expected in [ + (safe_div, safe_div_dialect, "a / b"), + (safe_div, div_dialect, "a / NULLIF(b, 0)"), + (div, safe_div_dialect, "a / b"), + (div, div_dialect, "a / b"), + ]: + with self.subTest(f"{expression.__class__.__name__} {dialect} -> {expected}"): + self.assertEqual(expected, expression.sql(dialect=dialect)) def test_limit(self): self.validate_all( @@ -1547,7 +1661,7 @@ class TestDialect(Validator): "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 TEXT, c2 TEXT(1024))", write={ "duckdb": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))", - "hive": "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 STRING, c2 STRING(1024))", + "hive": "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 STRING, c2 VARCHAR(1024))", "oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))", "postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))", "sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))", @@ -1864,7 +1978,7 @@ SELECT write={ "bigquery": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "clickhouse": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", - "databricks": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", + "databricks": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", "duckdb": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "hive": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", "mysql": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", @@ -1872,11 +1986,11 @@ SELECT "presto": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "redshift": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "snowflake": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", - "spark": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", + "spark": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", "spark2": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", "sqlite": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", "trino": "SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq", - "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c FROM t) AS subq", + "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT c AS c FROM t) AS subq", }, ) self.validate_all( @@ -1885,13 +1999,60 @@ SELECT "bigquery": "SELECT * FROM (SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq1) AS subq2", "duckdb": "SELECT * FROM (SELECT * FROM (WITH t AS (SELECT 1 AS c) SELECT c FROM t) AS subq1) AS subq2", "hive": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c FROM t) AS subq1) AS subq2", - "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c FROM t) AS subq1) AS subq2", + "tsql": "WITH t AS (SELECT 1 AS c) SELECT * FROM (SELECT * FROM (SELECT c AS c FROM t) AS subq1) AS subq2", }, ) self.validate_all( "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq", write={ "duckdb": "WITH t1(x) AS (SELECT 1) SELECT * FROM (WITH t2(y) AS (SELECT 2) SELECT y FROM t2) AS subq", - "tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y FROM t2) AS subq", + "tsql": "WITH t1(x) AS (SELECT 1), t2(y) AS (SELECT 2) SELECT * FROM (SELECT y AS y FROM t2) AS subq", }, ) + + def test_unsupported_null_ordering(self): + # We'll transpile a portable query from the following dialects to MySQL / T-SQL, which + # both treat NULLs as small values, so the expected output queries should be equivalent + with_last_nulls = "duckdb" + with_small_nulls = "spark" + with_large_nulls = "postgres" + + sql = "SELECT * FROM t ORDER BY c" + sql_nulls_last = "SELECT * FROM t ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END, c" + sql_nulls_first = "SELECT * FROM t ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c" + + for read_dialect, desc, nulls_first, expected_sql in ( + (with_last_nulls, False, None, sql_nulls_last), + (with_last_nulls, True, None, sql), + (with_last_nulls, False, True, sql), + (with_last_nulls, True, True, sql_nulls_first), + (with_last_nulls, False, False, sql_nulls_last), + (with_last_nulls, True, False, sql), + (with_small_nulls, False, None, sql), + (with_small_nulls, True, None, sql), + (with_small_nulls, False, True, sql), + (with_small_nulls, True, True, sql_nulls_first), + (with_small_nulls, False, False, sql_nulls_last), + (with_small_nulls, True, False, sql), + (with_large_nulls, False, None, sql_nulls_last), + (with_large_nulls, True, None, sql_nulls_first), + (with_large_nulls, False, True, sql), + (with_large_nulls, True, True, sql_nulls_first), + (with_large_nulls, False, False, sql_nulls_last), + (with_large_nulls, True, False, sql), + ): + with self.subTest( + f"read: {read_dialect}, descending: {desc}, nulls first: {nulls_first}" + ): + sort_order = " DESC" if desc else "" + null_order = ( + " NULLS FIRST" + if nulls_first + else (" NULLS LAST" if nulls_first is not None else "") + ) + + expected_sql = f"{expected_sql}{sort_order}" + expression = parse_one(f"{sql}{sort_order}{null_order}", read=read_dialect) + + self.assertEqual(expression.sql(dialect="mysql"), expected_sql) + self.assertEqual(expression.sql(dialect="tsql"), expected_sql) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index f9de953..687a807 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -1,4 +1,5 @@ from sqlglot import ErrorLevel, UnsupportedError, exp, parse_one, transpile +from sqlglot.helper import logger as helper_logger from tests.dialects.test_dialect import Validator @@ -71,7 +72,7 @@ class TestDuckDB(Validator): "SELECT UNNEST(ARRAY[1, 2, 3]), UNNEST(ARRAY[4, 5]), UNNEST(ARRAY[6])", write={ "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_3, col_2, NULL) AS col_2, IF(pos = pos_4, col_3, NULL) AS col_3 FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2, 3]), ARRAY_LENGTH([4, 5]), ARRAY_LENGTH([6])) - 1)) AS pos CROSS JOIN UNNEST([1, 2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5]) AS col_2 WITH OFFSET AS pos_3 CROSS JOIN UNNEST([6]) AS col_3 WITH OFFSET AS pos_4 WHERE ((pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5]) - 1)))) AND (pos = pos_4 OR (pos > (ARRAY_LENGTH([6]) - 1) AND pos_4 = (ARRAY_LENGTH([6]) - 1)))", - "presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_3, col_2) AS col_2, IF(pos = pos_4, col_3) AS col_3 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2, 3]) AND pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5]) AND pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (pos = pos_4 OR (pos > CARDINALITY(ARRAY[6]) AND pos_4 = CARDINALITY(ARRAY[6])))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col, IF(_u.pos = _u_3.pos_3, _u_3.col_2) AS col_2, IF(_u.pos = _u_4.pos_4, _u_4.col_3) AS col_3 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((_u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[1, 2, 3]) AND _u_2.pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (_u.pos = _u_3.pos_3 OR (_u.pos > CARDINALITY(ARRAY[4, 5]) AND _u_3.pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (_u.pos = _u_4.pos_4 OR (_u.pos > CARDINALITY(ARRAY[6]) AND _u_4.pos_4 = CARDINALITY(ARRAY[6])))", }, ) @@ -79,7 +80,7 @@ class TestDuckDB(Validator): "SELECT UNNEST(ARRAY[1, 2, 3]), UNNEST(ARRAY[4, 5]), UNNEST(ARRAY[6]) FROM x", write={ "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_3, col_2, NULL) AS col_2, IF(pos = pos_4, col_3, NULL) AS col_3 FROM x, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2, 3]), ARRAY_LENGTH([4, 5]), ARRAY_LENGTH([6])) - 1)) AS pos CROSS JOIN UNNEST([1, 2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5]) AS col_2 WITH OFFSET AS pos_3 CROSS JOIN UNNEST([6]) AS col_3 WITH OFFSET AS pos_4 WHERE ((pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5]) - 1)))) AND (pos = pos_4 OR (pos > (ARRAY_LENGTH([6]) - 1) AND pos_4 = (ARRAY_LENGTH([6]) - 1)))", - "presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_3, col_2) AS col_2, IF(pos = pos_4, col_3) AS col_3 FROM x, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2, 3]) AND pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5]) AND pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (pos = pos_4 OR (pos > CARDINALITY(ARRAY[6]) AND pos_4 = CARDINALITY(ARRAY[6])))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col, IF(_u.pos = _u_3.pos_3, _u_3.col_2) AS col_2, IF(_u.pos = _u_4.pos_4, _u_4.col_3) AS col_3 FROM x, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2, 3]), CARDINALITY(ARRAY[4, 5]), CARDINALITY(ARRAY[6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5]) WITH ORDINALITY AS _u_3(col_2, pos_3) CROSS JOIN UNNEST(ARRAY[6]) WITH ORDINALITY AS _u_4(col_3, pos_4) WHERE ((_u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[1, 2, 3]) AND _u_2.pos_2 = CARDINALITY(ARRAY[1, 2, 3]))) AND (_u.pos = _u_3.pos_3 OR (_u.pos > CARDINALITY(ARRAY[4, 5]) AND _u_3.pos_3 = CARDINALITY(ARRAY[4, 5])))) AND (_u.pos = _u_4.pos_4 OR (_u.pos > CARDINALITY(ARRAY[6]) AND _u_4.pos_4 = CARDINALITY(ARRAY[6])))", }, ) self.validate_all( @@ -96,7 +97,6 @@ class TestDuckDB(Validator): ) self.validate_identity("SELECT i FROM RANGE(5) AS _(i) ORDER BY i ASC") - self.validate_identity("[x.STRING_SPLIT(' ')[1] FOR x IN ['1', '2', 3] IF x.CONTAINS('1')]") self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y") self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x") self.validate_identity("SELECT SUM(x) FILTER (x = 1)", "SELECT SUM(x) FILTER(WHERE x = 1)") @@ -109,6 +109,10 @@ class TestDuckDB(Validator): parse_one("a // b", read="duckdb").assert_is(exp.IntDiv).sql(dialect="duckdb"), "a // b" ) + self.validate_identity("SELECT EPOCH_MS(10) AS t") + self.validate_identity("SELECT MAKE_TIMESTAMP(10) AS t") + self.validate_identity("SELECT TO_TIMESTAMP(10) AS t") + self.validate_identity("SELECT UNNEST(column, recursive := TRUE) FROM table") self.validate_identity("VAR_POP(a)") self.validate_identity("SELECT * FROM foo ASOF LEFT JOIN bar ON a = b") self.validate_identity("PIVOT Cities ON Year USING SUM(Population)") @@ -152,10 +156,17 @@ class TestDuckDB(Validator): self.validate_all("x ~ y", write={"duckdb": "REGEXP_MATCHES(x, y)"}) self.validate_all("SELECT * FROM 'x.y'", write={"duckdb": 'SELECT * FROM "x.y"'}) self.validate_all( + "SELECT * FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2'))", + read={ + "duckdb": "SELECT * FROM produce PIVOT(SUM(sales) FOR quarter IN ('Q1', 'Q2'))", + "snowflake": "SELECT * FROM produce PIVOT(SUM(produce.sales) FOR produce.quarter IN ('Q1', 'Q2'))", + }, + ) + self.validate_all( "SELECT UNNEST([1, 2, 3])", write={ "duckdb": "SELECT UNNEST([1, 2, 3])", - "snowflake": "SELECT IFF(pos = pos_2, col, NULL) AS col FROM (SELECT value FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (GREATEST(ARRAY_SIZE([1, 2, 3])) - 1) + 1)))) AS _u(pos) CROSS JOIN (SELECT value, index FROM TABLE(FLATTEN(INPUT => [1, 2, 3]))) AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > (ARRAY_SIZE([1, 2, 3]) - 1) AND pos_2 = (ARRAY_SIZE([1, 2, 3]) - 1))", + "snowflake": "SELECT IFF(_u.pos = _u_2.pos_2, _u_2.col, NULL) AS col FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (GREATEST(ARRAY_SIZE([1, 2, 3])) - 1) + 1))) AS _u(seq, key, path, index, pos, this) CROSS JOIN TABLE(FLATTEN(INPUT => [1, 2, 3])) AS _u_2(seq, key, path, pos_2, col, this) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > (ARRAY_SIZE([1, 2, 3]) - 1) AND _u_2.pos_2 = (ARRAY_SIZE([1, 2, 3]) - 1))", }, ) self.validate_all( @@ -355,14 +366,14 @@ class TestDuckDB(Validator): "STRUCT_PACK(x := 1, y := '2')", write={ "duckdb": "{'x': 1, 'y': '2'}", - "spark": "STRUCT(x = 1, y = '2')", + "spark": "STRUCT(1 AS x, '2' AS y)", }, ) self.validate_all( "STRUCT_PACK(key1 := 'value1', key2 := 42)", write={ "duckdb": "{'key1': 'value1', 'key2': 42}", - "spark": "STRUCT(key1 = 'value1', key2 = 42)", + "spark": "STRUCT('value1' AS key1, 42 AS key2)", }, ) self.validate_all( @@ -441,6 +452,16 @@ class TestDuckDB(Validator): }, ) self.validate_all( + "SELECT CAST('2018-01-01 00:00:00' AS DATE) + INTERVAL 3 DAY", + read={ + "hive": "SELECT DATE_ADD('2018-01-01 00:00:00', 3)", + }, + write={ + "duckdb": "SELECT CAST('2018-01-01 00:00:00' AS DATE) + INTERVAL '3' DAY", + "hive": "SELECT CAST('2018-01-01 00:00:00' AS DATE) + INTERVAL '3' DAY", + }, + ) + self.validate_all( "SELECT CAST('2020-05-06' AS DATE) - INTERVAL 5 DAY", read={"bigquery": "SELECT DATE_SUB(CAST('2020-05-06' AS DATE), INTERVAL 5 DAY)"}, ) @@ -483,6 +504,35 @@ class TestDuckDB(Validator): self.validate_identity("SELECT ISNAN(x)") + def test_array_index(self): + with self.assertLogs(helper_logger) as cm: + self.validate_all( + "SELECT some_arr[1] AS first FROM blah", + read={ + "bigquery": "SELECT some_arr[0] AS first FROM blah", + }, + write={ + "bigquery": "SELECT some_arr[0] AS first FROM blah", + "duckdb": "SELECT some_arr[1] AS first FROM blah", + "presto": "SELECT some_arr[1] AS first FROM blah", + }, + ) + self.validate_identity( + "[x.STRING_SPLIT(' ')[1] FOR x IN ['1', '2', 3] IF x.CONTAINS('1')]" + ) + + self.assertEqual( + cm.output, + [ + "WARNING:sqlglot:Applying array index offset (-1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (-1)", + "WARNING:sqlglot:Applying array index offset (1)", + ], + ) + def test_time(self): self.validate_identity("SELECT CURRENT_DATE") self.validate_identity("SELECT CURRENT_TIMESTAMP") @@ -533,16 +583,16 @@ class TestDuckDB(Validator): self.validate_all( "EPOCH_MS(x)", write={ - "bigquery": "UNIX_TO_TIME(x / 1000)", - "duckdb": "TO_TIMESTAMP(x / 1000)", - "presto": "FROM_UNIXTIME(x / 1000)", - "spark": "CAST(FROM_UNIXTIME(x / 1000) AS TIMESTAMP)", + "bigquery": "TIMESTAMP_MILLIS(x)", + "duckdb": "EPOCH_MS(x)", + "presto": "FROM_UNIXTIME(CAST(x AS DOUBLE) / 1000)", + "spark": "TIMESTAMP_MILLIS(x)", }, ) self.validate_all( "STRFTIME(x, '%y-%-m-%S')", write={ - "bigquery": "TIME_TO_STR(x, '%y-%-m-%S')", + "bigquery": "FORMAT_DATE('%y-%-m-%S', x)", "duckdb": "STRFTIME(x, '%y-%-m-%S')", "postgres": "TO_CHAR(x, 'YY-FMMM-SS')", "presto": "DATE_FORMAT(x, '%y-%c-%s')", @@ -552,6 +602,7 @@ class TestDuckDB(Validator): self.validate_all( "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", write={ + "bigquery": "FORMAT_DATE('%Y-%m-%d %H:%M:%S', x)", "duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", "presto": "DATE_FORMAT(x, '%Y-%m-%d %T')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", @@ -570,7 +621,7 @@ class TestDuckDB(Validator): self.validate_all( "TO_TIMESTAMP(x)", write={ - "bigquery": "UNIX_TO_TIME(x)", + "bigquery": "TIMESTAMP_SECONDS(x)", "duckdb": "TO_TIMESTAMP(x)", "presto": "FROM_UNIXTIME(x)", "hive": "FROM_UNIXTIME(x)", @@ -651,22 +702,25 @@ class TestDuckDB(Validator): "CAST(ROW(1, ROW(1)) AS STRUCT(number BIGINT, row STRUCT(number BIGINT)))" ) - self.validate_all("CAST(x AS NUMERIC(1, 2))", write={"duckdb": "CAST(x AS DECIMAL(1, 2))"}) - self.validate_all("CAST(x AS HUGEINT)", write={"duckdb": "CAST(x AS INT128)"}) - self.validate_all("CAST(x AS CHAR)", write={"duckdb": "CAST(x AS TEXT)"}) - self.validate_all("CAST(x AS BPCHAR)", write={"duckdb": "CAST(x AS TEXT)"}) - self.validate_all("CAST(x AS STRING)", write={"duckdb": "CAST(x AS TEXT)"}) - self.validate_all("CAST(x AS INT1)", write={"duckdb": "CAST(x AS TINYINT)"}) - self.validate_all("CAST(x AS FLOAT4)", write={"duckdb": "CAST(x AS REAL)"}) - self.validate_all("CAST(x AS FLOAT)", write={"duckdb": "CAST(x AS REAL)"}) - self.validate_all("CAST(x AS INT4)", write={"duckdb": "CAST(x AS INT)"}) - self.validate_all("CAST(x AS INTEGER)", write={"duckdb": "CAST(x AS INT)"}) - self.validate_all("CAST(x AS SIGNED)", write={"duckdb": "CAST(x AS INT)"}) - self.validate_all("CAST(x AS BLOB)", write={"duckdb": "CAST(x AS BLOB)"}) - self.validate_all("CAST(x AS BYTEA)", write={"duckdb": "CAST(x AS BLOB)"}) - self.validate_all("CAST(x AS BINARY)", write={"duckdb": "CAST(x AS BLOB)"}) - self.validate_all("CAST(x AS VARBINARY)", write={"duckdb": "CAST(x AS BLOB)"}) - self.validate_all("CAST(x AS LOGICAL)", write={"duckdb": "CAST(x AS BOOLEAN)"}) + self.validate_identity("CAST(x AS INT64)", "CAST(x AS BIGINT)") + self.validate_identity("CAST(x AS INT32)", "CAST(x AS INT)") + self.validate_identity("CAST(x AS INT16)", "CAST(x AS SMALLINT)") + self.validate_identity("CAST(x AS NUMERIC(1, 2))", "CAST(x AS DECIMAL(1, 2))") + self.validate_identity("CAST(x AS HUGEINT)", "CAST(x AS INT128)") + self.validate_identity("CAST(x AS CHAR)", "CAST(x AS TEXT)") + self.validate_identity("CAST(x AS BPCHAR)", "CAST(x AS TEXT)") + self.validate_identity("CAST(x AS STRING)", "CAST(x AS TEXT)") + self.validate_identity("CAST(x AS INT1)", "CAST(x AS TINYINT)") + self.validate_identity("CAST(x AS FLOAT4)", "CAST(x AS REAL)") + self.validate_identity("CAST(x AS FLOAT)", "CAST(x AS REAL)") + self.validate_identity("CAST(x AS INT4)", "CAST(x AS INT)") + self.validate_identity("CAST(x AS INTEGER)", "CAST(x AS INT)") + self.validate_identity("CAST(x AS SIGNED)", "CAST(x AS INT)") + self.validate_identity("CAST(x AS BLOB)", "CAST(x AS BLOB)") + self.validate_identity("CAST(x AS BYTEA)", "CAST(x AS BLOB)") + self.validate_identity("CAST(x AS BINARY)", "CAST(x AS BLOB)") + self.validate_identity("CAST(x AS VARBINARY)", "CAST(x AS BLOB)") + self.validate_identity("CAST(x AS LOGICAL)", "CAST(x AS BOOLEAN)") self.validate_all( "CAST(x AS NUMERIC)", write={ @@ -799,3 +853,17 @@ class TestDuckDB(Validator): "duckdb": "SELECT CAST(w AS TIMESTAMP_S), CAST(x AS TIMESTAMP_MS), CAST(y AS TIMESTAMP), CAST(z AS TIMESTAMP_NS)", }, ) + + def test_isnan(self): + self.validate_all( + "ISNAN(x)", + read={"bigquery": "IS_NAN(x)"}, + write={"bigquery": "IS_NAN(x)", "duckdb": "ISNAN(x)"}, + ) + + def test_isinf(self): + self.validate_all( + "ISINF(x)", + read={"bigquery": "IS_INF(x)"}, + write={"bigquery": "IS_INF(x)", "duckdb": "ISINF(x)"}, + ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index ba95442..b3366a2 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -315,6 +315,7 @@ class TestHive(Validator): self.validate_all( "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", write={ + "bigquery": "FORMAT_DATE('%Y-%m-%d %H:%M:%S', CAST('2020-01-01' AS DATETIME))", "duckdb": "STRFTIME(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%M:%S')", "presto": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %T')", "hive": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')", @@ -324,21 +325,29 @@ class TestHive(Validator): self.validate_all( "DATE_ADD('2020-01-01', 1)", write={ + "": "TS_OR_DS_ADD('2020-01-01', 1, DAY)", + "bigquery": "DATE_ADD(CAST(CAST('2020-01-01' AS DATETIME) AS DATE), INTERVAL 1 DAY)", "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY", - "presto": "DATE_ADD('DAY', 1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", "hive": "DATE_ADD('2020-01-01', 1)", + "presto": "DATE_ADD('DAY', 1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", + "redshift": "DATEADD(DAY, 1, '2020-01-01')", + "snowflake": "DATEADD(DAY, 1, CAST(CAST('2020-01-01' AS TIMESTAMPNTZ) AS DATE))", "spark": "DATE_ADD('2020-01-01', 1)", - "": "TS_OR_DS_ADD('2020-01-01', 1, 'DAY')", + "tsql": "DATEADD(DAY, 1, CAST(CAST('2020-01-01' AS DATETIME2) AS DATE))", }, ) self.validate_all( "DATE_SUB('2020-01-01', 1)", write={ + "": "TS_OR_DS_ADD('2020-01-01', 1 * -1, DAY)", + "bigquery": "DATE_ADD(CAST(CAST('2020-01-01' AS DATETIME) AS DATE), INTERVAL (1 * -1) DAY)", "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL (1 * -1) DAY", - "presto": "DATE_ADD('DAY', 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", "hive": "DATE_ADD('2020-01-01', 1 * -1)", + "presto": "DATE_ADD('DAY', 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMP) AS DATE))", + "redshift": "DATEADD(DAY, 1 * -1, '2020-01-01')", + "snowflake": "DATEADD(DAY, 1 * -1, CAST(CAST('2020-01-01' AS TIMESTAMPNTZ) AS DATE))", "spark": "DATE_ADD('2020-01-01', 1 * -1)", - "": "TS_OR_DS_ADD('2020-01-01', 1 * -1, 'DAY')", + "tsql": "DATEADD(DAY, 1 * -1, CAST(CAST('2020-01-01' AS DATETIME2) AS DATE))", }, ) self.validate_all("DATE_ADD('2020-01-01', -1)", read={"": "DATE_SUB('2020-01-01', 1)"}) @@ -351,8 +360,8 @@ class TestHive(Validator): write={ "duckdb": "DATE_DIFF('day', CAST(x AS DATE), CAST(CAST(y AS DATE) AS DATE))", "presto": "DATE_DIFF('day', CAST(CAST(x AS TIMESTAMP) AS DATE), CAST(CAST(CAST(CAST(y AS TIMESTAMP) AS DATE) AS TIMESTAMP) AS DATE))", - "hive": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", - "spark": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", + "hive": "DATEDIFF(TO_DATE(y), TO_DATE(x))", + "spark": "DATEDIFF(TO_DATE(y), TO_DATE(x))", "": "DATEDIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))", }, ) @@ -522,11 +531,16 @@ class TestHive(Validator): ) self.validate_all( "ARRAY_CONTAINS(x, 1)", + read={ + "duckdb": "LIST_HAS(x, 1)", + "snowflake": "ARRAY_CONTAINS(1, x)", + }, write={ "duckdb": "ARRAY_CONTAINS(x, 1)", "presto": "CONTAINS(x, 1)", "hive": "ARRAY_CONTAINS(x, 1)", "spark": "ARRAY_CONTAINS(x, 1)", + "snowflake": "ARRAY_CONTAINS(1, x)", }, ) self.validate_all( @@ -687,7 +701,7 @@ class TestHive(Validator): "x div y", write={ "duckdb": "x // y", - "presto": "CAST(x / y AS INTEGER)", + "presto": "CAST(CAST(x AS DOUBLE) / y AS INTEGER)", "hive": "CAST(x / y AS INT)", "spark": "CAST(x / y AS INT)", }, @@ -707,11 +721,15 @@ class TestHive(Validator): self.validate_all( "COLLECT_SET(x)", read={ + "doris": "COLLECT_SET(x)", "presto": "SET_AGG(x)", + "snowflake": "ARRAY_UNIQUE_AGG(x)", }, write={ - "presto": "SET_AGG(x)", + "doris": "COLLECT_SET(x)", "hive": "COLLECT_SET(x)", + "presto": "SET_AGG(x)", + "snowflake": "ARRAY_UNIQUE_AGG(x)", "spark": "COLLECT_SET(x)", }, ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 3c165a3..19245f0 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -114,8 +114,17 @@ class TestMySQL(Validator): "mysql": "CREATE TABLE test (ts DATETIME, ts_tz TIMESTAMP, ts_ltz TIMESTAMP)", }, ) + self.validate_all( + "ALTER TABLE test_table ALTER COLUMN test_column SET DATA TYPE LONGTEXT", + write={ + "mysql": "ALTER TABLE test_table MODIFY COLUMN test_column LONGTEXT", + }, + ) + self.validate_identity("ALTER TABLE test_table ALTER COLUMN test_column SET DEFAULT 1") def test_identity(self): + self.validate_identity("SELECT DATE_FORMAT(NOW(), '%Y-%m-%d %H:%i:00.0000')") + self.validate_identity("SELECT @var1 := 1, @var2") self.validate_identity("UNLOCK TABLES") self.validate_identity("LOCK TABLES `app_fields` WRITE") self.validate_identity("SELECT 1 XOR 0") @@ -523,7 +532,15 @@ class TestMySQL(Validator): ) self.validate_all( "SELECT DATEDIFF(x, y)", - write={"mysql": "SELECT DATEDIFF(x, y)", "presto": "SELECT DATE_DIFF('day', y, x)"}, + read={ + "presto": "SELECT DATE_DIFF('day', y, x)", + "redshift": "SELECT DATEDIFF(day, y, x)", + }, + write={ + "mysql": "SELECT DATEDIFF(x, y)", + "presto": "SELECT DATE_DIFF('day', y, x)", + "redshift": "SELECT DATEDIFF(day, y, x)", + }, ) self.validate_all( "DAYOFYEAR(x)", @@ -574,10 +591,16 @@ class TestMySQL(Validator): def test_mysql(self): self.validate_all( + "SELECT * FROM x LEFT JOIN y ON x.id = y.id UNION SELECT * FROM x RIGHT JOIN y ON x.id = y.id LIMIT 0", + read={ + "postgres": "SELECT * FROM x FULL JOIN y ON x.id = y.id LIMIT 0", + }, + ) + self.validate_all( # MySQL doesn't support FULL OUTER joins - "SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.x = t2.x UNION SELECT * FROM t1 RIGHT OUTER JOIN t2 ON t1.x = t2.x", + "WITH t1 AS (SELECT 1) SELECT * FROM t1 LEFT OUTER JOIN t2 ON t1.x = t2.x UNION SELECT * FROM t1 RIGHT OUTER JOIN t2 ON t1.x = t2.x", read={ - "postgres": "SELECT * FROM t1 FULL OUTER JOIN t2 ON t1.x = t2.x", + "postgres": "WITH t1 AS (SELECT 1) SELECT * FROM t1 FULL OUTER JOIN t2 ON t1.x = t2.x", }, ) self.validate_all( @@ -601,7 +624,9 @@ class TestMySQL(Validator): "mysql": "SELECT * FROM test LIMIT 1 OFFSET 1", "postgres": "SELECT * FROM test LIMIT 0 + 1 OFFSET 0 + 1", "presto": "SELECT * FROM test OFFSET 1 LIMIT 1", + "snowflake": "SELECT * FROM test LIMIT 1 OFFSET 1", "trino": "SELECT * FROM test OFFSET 1 LIMIT 1", + "bigquery": "SELECT * FROM test LIMIT 1 OFFSET 1", }, ) self.validate_all( @@ -984,3 +1009,28 @@ COMMENT='客户账户表'""" "mysql": "DATE_FORMAT(x, '%M')", }, ) + + def test_safe_div(self): + self.validate_all( + "a / b", + write={ + "bigquery": "a / NULLIF(b, 0)", + "clickhouse": "a / b", + "databricks": "a / NULLIF(b, 0)", + "duckdb": "a / b", + "hive": "a / b", + "mysql": "a / b", + "oracle": "a / NULLIF(b, 0)", + "snowflake": "a / NULLIF(b, 0)", + "spark": "a / b", + "starrocks": "a / b", + "drill": "CAST(a AS DOUBLE) / NULLIF(b, 0)", + "postgres": "CAST(a AS DOUBLE PRECISION) / NULLIF(b, 0)", + "presto": "CAST(a AS DOUBLE) / NULLIF(b, 0)", + "redshift": "CAST(a AS DOUBLE PRECISION) / NULLIF(b, 0)", + "sqlite": "CAST(a AS REAL) / b", + "teradata": "CAST(a AS DOUBLE) / NULLIF(b, 0)", + "trino": "CAST(a AS DOUBLE) / NULLIF(b, 0)", + "tsql": "CAST(a AS FLOAT) / NULLIF(b, 0)", + }, + ) diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index d92eea5..e9ebac1 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -1,3 +1,4 @@ +from sqlglot import exp, parse_one from sqlglot.errors import UnsupportedError from tests.dialects.test_dialect import Validator @@ -6,6 +7,15 @@ class TestOracle(Validator): dialect = "oracle" def test_oracle(self): + self.validate_identity("ALTER TABLE tbl_name DROP FOREIGN KEY fk_symbol") + self.assertIsInstance( + parse_one("ALTER TABLE tbl_name DROP FOREIGN KEY fk_symbol", dialect="oracle"), + exp.AlterTable, + ) + self.validate_identity( + "ALTER TABLE Payments ADD (Stock NUMBER NOT NULL, dropid VARCHAR2(500) NOT NULL)" + ) + self.validate_identity("ALTER TABLE Payments ADD Stock NUMBER NOT NULL") self.validate_identity("SELECT x FROM t WHERE cond FOR UPDATE") self.validate_identity("SELECT JSON_OBJECT(k1: v1 FORMAT JSON, k2: v2 FORMAT JSON)") self.validate_identity("SELECT JSON_OBJECT('name': first_name || ' ' || last_name) FROM t") @@ -57,8 +67,16 @@ class TestOracle(Validator): "SELECT * FROM t SAMPLE (.25)", "SELECT * FROM t SAMPLE (0.25)", ) + self.validate_identity("SELECT TO_CHAR(-100, 'L99', 'NL_CURRENCY = '' AusDollars '' ')") self.validate_all( + "SELECT TO_CHAR(TIMESTAMP '1999-12-01 10:00:00')", + write={ + "oracle": "SELECT TO_CHAR(CAST('1999-12-01 10:00:00' AS TIMESTAMP), 'YYYY-MM-DD HH24:MI:SS')", + "postgres": "SELECT TO_CHAR(CAST('1999-12-01 10:00:00' AS TIMESTAMP), 'YYYY-MM-DD HH24:MI:SS')", + }, + ) + self.validate_all( "SELECT CAST(NULL AS VARCHAR2(2328 CHAR)) AS COL1", write={ "oracle": "SELECT CAST(NULL AS VARCHAR2(2328 CHAR)) AS COL1", diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 0e5f1a1..17a65d7 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -1,6 +1,5 @@ -from unittest import mock - from sqlglot import ParseError, exp, parse_one, transpile +from sqlglot.helper import logger as helper_logger from tests.dialects.test_dialect import Validator @@ -22,6 +21,9 @@ class TestPostgres(Validator): self.validate_identity("UPDATE tbl_name SET foo = 123 RETURNING a") self.validate_identity("CREATE TABLE cities_partdef PARTITION OF cities DEFAULT") self.validate_identity( + "CREATE CONSTRAINT TRIGGER my_trigger AFTER INSERT OR DELETE OR UPDATE OF col_a, col_b ON public.my_table DEFERRABLE INITIALLY DEFERRED FOR EACH ROW EXECUTE FUNCTION do_sth()" + ) + self.validate_identity( "CREATE TABLE cust_part3 PARTITION OF customers FOR VALUES WITH (MODULUS 3, REMAINDER 2)" ) self.validate_identity( @@ -43,6 +45,9 @@ class TestPostgres(Validator): "CREATE INDEX foo ON bar.baz USING btree(col1 varchar_pattern_ops ASC, col2)" ) self.validate_identity( + "CREATE INDEX index_issues_on_title_trigram ON public.issues USING gin(title public.gin_trgm_ops)" + ) + self.validate_identity( "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO NOTHING RETURNING *" ) self.validate_identity( @@ -148,7 +153,7 @@ class TestPostgres(Validator): write={ "hive": "SELECT EXPLODE(c) FROM t", "postgres": "SELECT UNNEST(c) FROM t", - "presto": "SELECT IF(pos = pos_2, col) AS col FROM t, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(c)))) AS _u(pos) CROSS JOIN UNNEST(c) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(c) AND pos_2 = CARDINALITY(c))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col FROM t, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(c)))) AS _u(pos) CROSS JOIN UNNEST(c) WITH ORDINALITY AS _u_2(col, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(c) AND _u_2.pos_2 = CARDINALITY(c))", }, ) self.validate_all( @@ -156,20 +161,46 @@ class TestPostgres(Validator): write={ "hive": "SELECT EXPLODE(ARRAY(1))", "postgres": "SELECT UNNEST(ARRAY[1])", - "presto": "SELECT IF(pos = pos_2, col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[1]) AND pos_2 = CARDINALITY(ARRAY[1]))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[1]) AND _u_2.pos_2 = CARDINALITY(ARRAY[1]))", }, ) - @mock.patch("sqlglot.helper.logger") - def test_array_offset(self, logger): - self.validate_all( - "SELECT col[1]", - write={ - "hive": "SELECT col[0]", - "postgres": "SELECT col[1]", - "presto": "SELECT col[1]", - }, - ) + def test_array_offset(self): + with self.assertLogs(helper_logger) as cm: + self.validate_all( + "SELECT col[1]", + write={ + "bigquery": "SELECT col[0]", + "duckdb": "SELECT col[1]", + "hive": "SELECT col[0]", + "postgres": "SELECT col[1]", + "presto": "SELECT col[1]", + }, + ) + + self.assertEqual( + cm.output, + [ + "WARNING:sqlglot:Applying array index offset (-1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (1)", + ], + ) + + def test_operator(self): + expr = parse_one("1 OPERATOR(+) 2 OPERATOR(*) 3", read="postgres") + + expr.left.assert_is(exp.Operator) + expr.left.left.assert_is(exp.Literal) + expr.left.right.assert_is(exp.Literal) + expr.right.assert_is(exp.Literal) + self.assertEqual(expr.sql(dialect="postgres"), "1 OPERATOR(+) 2 OPERATOR(*) 3") + + self.validate_identity("SELECT operator FROM t") + self.validate_identity("SELECT 1 OPERATOR(+) 2") + self.validate_identity("SELECT 1 OPERATOR(+) /* foo */ 2") + self.validate_identity("SELECT 1 OPERATOR(pg_catalog.+) 2") def test_postgres(self): expr = parse_one( @@ -198,6 +229,14 @@ class TestPostgres(Validator): self.assertEqual(expr.sql(dialect="postgres"), alter_table_only) self.validate_identity( + "SELECT c.oid, n.nspname, c.relname " + "FROM pg_catalog.pg_class AS c " + "LEFT JOIN pg_catalog.pg_namespace AS n ON n.oid = c.relnamespace " + "WHERE c.relname OPERATOR(pg_catalog.~) '^(courses)$' COLLATE pg_catalog.default AND " + "pg_catalog.PG_TABLE_IS_VISIBLE(c.oid) " + "ORDER BY 2, 3" + ) + self.validate_identity( "SELECT ARRAY[]::INT[] AS foo", "SELECT CAST(ARRAY[] AS INT[]) AS foo", ) @@ -728,26 +767,23 @@ class TestPostgres(Validator): ) def test_string_concat(self): - self.validate_all( - "SELECT CONCAT('abcde', 2, NULL, 22)", - write={ - "postgres": "SELECT CONCAT(COALESCE(CAST('abcde' AS TEXT), ''), COALESCE(CAST(2 AS TEXT), ''), COALESCE(CAST(NULL AS TEXT), ''), COALESCE(CAST(22 AS TEXT), ''))", - }, - ) + self.validate_identity("SELECT CONCAT('abcde', 2, NULL, 22)") + self.validate_all( "CONCAT(a, b)", write={ - "": "CONCAT(COALESCE(CAST(a AS TEXT), ''), COALESCE(CAST(b AS TEXT), ''))", - "duckdb": "CONCAT(COALESCE(CAST(a AS TEXT), ''), COALESCE(CAST(b AS TEXT), ''))", - "postgres": "CONCAT(COALESCE(CAST(a AS TEXT), ''), COALESCE(CAST(b AS TEXT), ''))", - "presto": "CONCAT(CAST(COALESCE(CAST(a AS VARCHAR), '') AS VARCHAR), CAST(COALESCE(CAST(b AS VARCHAR), '') AS VARCHAR))", + "": "CONCAT(COALESCE(a, ''), COALESCE(b, ''))", + "clickhouse": "CONCAT(COALESCE(a, ''), COALESCE(b, ''))", + "duckdb": "CONCAT(a, b)", + "postgres": "CONCAT(a, b)", + "presto": "CONCAT(COALESCE(CAST(a AS VARCHAR), ''), COALESCE(CAST(b AS VARCHAR), ''))", }, ) self.validate_all( "a || b", write={ "": "a || b", - "clickhouse": "CONCAT(CAST(a AS String), CAST(b AS String))", + "clickhouse": "a || b", "duckdb": "a || b", "postgres": "a || b", "presto": "CONCAT(CAST(a AS VARCHAR), CAST(b AS VARCHAR))", diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index ed734b6..6a82756 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -1,6 +1,5 @@ -from unittest import mock - from sqlglot import UnsupportedError, exp, parse_one +from sqlglot.helper import logger as helper_logger from tests.dialects.test_dialect import Validator @@ -34,14 +33,6 @@ class TestPresto(Validator): }, ) self.validate_all( - "SELECT DATE_DIFF('week', CAST(CAST('2009-01-01' AS TIMESTAMP) AS DATE), CAST(CAST('2009-12-31' AS TIMESTAMP) AS DATE))", - read={"redshift": "SELECT DATEDIFF(week, '2009-01-01', '2009-12-31')"}, - ) - self.validate_all( - "SELECT DATE_ADD('month', 18, CAST(CAST('2008-02-28' AS TIMESTAMP) AS DATE))", - read={"redshift": "SELECT DATEADD(month, 18, '2008-02-28')"}, - ) - self.validate_all( "SELECT CAST('1970-01-01 00:00:00' AS TIMESTAMP)", read={"postgres": "SELECT 'epoch'::TIMESTAMP"}, ) @@ -229,6 +220,7 @@ class TestPresto(Validator): self.validate_all( "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", write={ + "bigquery": "FORMAT_DATE('%Y-%m-%d %H:%M:%S', x)", "duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", "presto": "DATE_FORMAT(x, '%Y-%m-%d %T')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", @@ -308,7 +300,12 @@ class TestPresto(Validator): "hive": "CURRENT_TIMESTAMP()", }, ) - + self.validate_all( + "SELECT DATE_ADD('DAY', 1, CAST(CURRENT_DATE AS TIMESTAMP))", + read={ + "redshift": "SELECT DATEADD(DAY, 1, CURRENT_DATE)", + }, + ) self.validate_all( "DAY_OF_WEEK(timestamp '2012-08-08 01:00:00')", write={ @@ -537,8 +534,7 @@ class TestPresto(Validator): }, ) - @mock.patch("sqlglot.helper.logger") - def test_presto(self, logger): + def test_presto(self): self.validate_identity("string_agg(x, ',')", "ARRAY_JOIN(ARRAY_AGG(x), ',')") self.validate_identity( "SELECT * FROM example.testdb.customer_orders FOR VERSION AS OF 8954597067493422955" @@ -558,6 +554,24 @@ class TestPresto(Validator): "SELECT SPLIT_TO_MAP('a:1;b:2;a:3', ';', ':', (k, v1, v2) -> CONCAT(v1, v2))" ) + with self.assertLogs(helper_logger) as cm: + self.validate_all( + "SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table", + write={ + "postgres": UnsupportedError, + "presto": "SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table", + }, + ) + self.validate_all( + "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 4)", + write={ + "": "SELECT ARRAY(1, 2, 3)[3]", + "bigquery": "SELECT [1, 2, 3][SAFE_ORDINAL(4)]", + "postgres": "SELECT (ARRAY[1, 2, 3])[4]", + "presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 4)", + }, + ) + self.validate_all( "SELECT MAX_BY(a.id, a.timestamp) FROM a", read={ @@ -669,21 +683,6 @@ class TestPresto(Validator): self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' week"}) self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' WEEKS"}) self.validate_all( - "SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table", - write={ - "postgres": UnsupportedError, - "presto": "SELECT COALESCE(ELEMENT_AT(MAP_FROM_ENTRIES(ARRAY[(51, '1')]), id), quantity) FROM my_table", - }, - ) - self.validate_all( - "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 4)", - write={ - "": "SELECT ARRAY(1, 2, 3)[3]", - "postgres": "SELECT (ARRAY[1, 2, 3])[4]", - "presto": "SELECT ELEMENT_AT(ARRAY[1, 2, 3], 4)", - }, - ) - self.validate_all( "SELECT SUBSTRING(a, 1, 3), SUBSTRING(a, LENGTH(a) - (3 - 1))", read={ "redshift": "SELECT LEFT(a, 3), RIGHT(a, 3)", @@ -890,6 +889,13 @@ class TestPresto(Validator): }, ) self.validate_all( + "JSON_FORMAT(CAST(MAP_FROM_ENTRIES(ARRAY[('action_type', 'at')]) AS JSON))", + write={ + "presto": "JSON_FORMAT(CAST(MAP_FROM_ENTRIES(ARRAY[('action_type', 'at')]) AS JSON))", + "spark": "TO_JSON(MAP_FROM_ENTRIES(ARRAY(('action_type', 'at'))))", + }, + ) + self.validate_all( "JSON_FORMAT(x)", read={ "spark": "TO_JSON(x)", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 3e42525..c6be789 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -6,15 +6,6 @@ class TestRedshift(Validator): dialect = "redshift" def test_redshift(self): - self.validate_identity( - "SELECT DATE_DIFF('month', CAST('2020-02-29 00:00:00' AS TIMESTAMP), CAST('2020-03-02 00:00:00' AS TIMESTAMP))", - "SELECT DATEDIFF(month, CAST(CAST('2020-02-29 00:00:00' AS TIMESTAMP) AS DATE), CAST(CAST('2020-03-02 00:00:00' AS TIMESTAMP) AS DATE))", - ) - self.validate_identity( - "SELECT * FROM x WHERE y = DATEADD('month', -1, DATE_TRUNC('month', (SELECT y FROM #temp_table)))", - "SELECT * FROM x WHERE y = DATEADD(month, -1, CAST(DATE_TRUNC('month', (SELECT y FROM #temp_table)) AS DATE))", - ) - self.validate_all( "LISTAGG(sellerid, ', ')", read={ @@ -72,8 +63,11 @@ class TestRedshift(Validator): self.validate_all( "SELECT ADD_MONTHS('2008-03-31', 1)", write={ - "redshift": "SELECT DATEADD(month, 1, CAST('2008-03-31' AS DATE))", - "trino": "SELECT DATE_ADD('month', 1, CAST(CAST('2008-03-31' AS TIMESTAMP) AS DATE))", + "bigquery": "SELECT DATE_ADD(CAST('2008-03-31' AS DATETIME), INTERVAL 1 MONTH)", + "duckdb": "SELECT CAST('2008-03-31' AS TIMESTAMP) + INTERVAL 1 month", + "redshift": "SELECT DATEADD(month, 1, '2008-03-31')", + "trino": "SELECT DATE_ADD('month', 1, CAST('2008-03-31' AS TIMESTAMP))", + "tsql": "SELECT DATEADD(month, 1, CAST('2008-03-31' AS DATETIME2))", }, ) self.validate_all( @@ -205,18 +199,18 @@ class TestRedshift(Validator): "databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", - "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) _t WHERE _row_number = 1", "presto": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "redshift": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", "snowflake": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", "spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "sqlite": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", - "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", "tableau": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "teradata": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", "trino": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", - "tsql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) AS _t WHERE _row_number = 1", + "tsql": "SELECT a, b FROM (SELECT a AS a, b AS b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY CASE WHEN c IS NULL THEN 1 ELSE 0 END DESC, c DESC) AS _row_number FROM x) AS _t WHERE _row_number = 1", }, ) self.validate_all( @@ -240,18 +234,43 @@ class TestRedshift(Validator): self.validate_all( "DATEDIFF('day', a, b)", write={ - "redshift": "DATEDIFF(day, CAST(a AS DATE), CAST(b AS DATE))", - "presto": "DATE_DIFF('day', CAST(CAST(a AS TIMESTAMP) AS DATE), CAST(CAST(b AS TIMESTAMP) AS DATE))", + "bigquery": "DATE_DIFF(CAST(b AS DATETIME), CAST(a AS DATETIME), day)", + "duckdb": "DATE_DIFF('day', CAST(a AS TIMESTAMP), CAST(b AS TIMESTAMP))", + "hive": "DATEDIFF(b, a)", + "redshift": "DATEDIFF(day, a, b)", + "presto": "DATE_DIFF('day', CAST(a AS TIMESTAMP), CAST(b AS TIMESTAMP))", }, ) self.validate_all( - "SELECT TOP 1 x FROM y", + "SELECT DATEADD(month, 18, '2008-02-28')", + write={ + "bigquery": "SELECT DATE_ADD(CAST('2008-02-28' AS DATETIME), INTERVAL 18 MONTH)", + "duckdb": "SELECT CAST('2008-02-28' AS TIMESTAMP) + INTERVAL 18 month", + "hive": "SELECT ADD_MONTHS('2008-02-28', 18)", + "mysql": "SELECT DATE_ADD('2008-02-28', INTERVAL 18 MONTH)", + "postgres": "SELECT CAST('2008-02-28' AS TIMESTAMP) + INTERVAL '18 month'", + "presto": "SELECT DATE_ADD('month', 18, CAST('2008-02-28' AS TIMESTAMP))", + "redshift": "SELECT DATEADD(month, 18, '2008-02-28')", + "snowflake": "SELECT DATEADD(month, 18, CAST('2008-02-28' AS TIMESTAMPNTZ))", + "tsql": "SELECT DATEADD(month, 18, CAST('2008-02-28' AS DATETIME2))", + }, + ) + self.validate_all( + "SELECT DATEDIFF(week, '2009-01-01', '2009-12-31')", write={ - "redshift": "SELECT x FROM y LIMIT 1", + "bigquery": "SELECT DATE_DIFF(CAST('2009-12-31' AS DATETIME), CAST('2009-01-01' AS DATETIME), week)", + "duckdb": "SELECT DATE_DIFF('week', CAST('2009-01-01' AS TIMESTAMP), CAST('2009-12-31' AS TIMESTAMP))", + "hive": "SELECT CAST(DATEDIFF('2009-12-31', '2009-01-01') / 7 AS INT)", + "postgres": "SELECT CAST(EXTRACT(days FROM (CAST('2009-12-31' AS TIMESTAMP) - CAST('2009-01-01' AS TIMESTAMP))) / 7 AS BIGINT)", + "presto": "SELECT DATE_DIFF('week', CAST('2009-01-01' AS TIMESTAMP), CAST('2009-12-31' AS TIMESTAMP))", + "redshift": "SELECT DATEDIFF(week, '2009-01-01', '2009-12-31')", + "snowflake": "SELECT DATEDIFF(week, '2009-01-01', '2009-12-31')", + "tsql": "SELECT DATEDIFF(week, '2009-01-01', '2009-12-31')", }, ) def test_identity(self): + self.validate_identity("SELECT DATEADD(day, 1, 'today')") self.validate_identity("SELECT * FROM #x") self.validate_identity("SELECT INTERVAL '5 day'") self.validate_identity("foo$") @@ -263,6 +282,26 @@ class TestRedshift(Validator): self.validate_identity("SELECT APPROXIMATE AS y") self.validate_identity("CREATE TABLE t (c BIGINT IDENTITY(0, 1))") self.validate_identity( + "SELECT CONCAT('abc', 'def')", + "SELECT 'abc' || 'def'", + ) + self.validate_identity( + "SELECT CONCAT_WS('DELIM', 'abc', 'def', 'ghi')", + "SELECT 'abc' || 'DELIM' || 'def' || 'DELIM' || 'ghi'", + ) + self.validate_identity( + "SELECT TOP 1 x FROM y", + "SELECT x FROM y LIMIT 1", + ) + self.validate_identity( + "SELECT DATE_DIFF('month', CAST('2020-02-29 00:00:00' AS TIMESTAMP), CAST('2020-03-02 00:00:00' AS TIMESTAMP))", + "SELECT DATEDIFF(month, CAST('2020-02-29 00:00:00' AS TIMESTAMP), CAST('2020-03-02 00:00:00' AS TIMESTAMP))", + ) + self.validate_identity( + "SELECT * FROM x WHERE y = DATEADD('month', -1, DATE_TRUNC('month', (SELECT y FROM #temp_table)))", + "SELECT * FROM x WHERE y = DATEADD(month, -1, DATE_TRUNC('month', (SELECT y FROM #temp_table)))", + ) + self.validate_identity( "SELECT 'a''b'", "SELECT 'a\\'b'", ) @@ -271,6 +310,12 @@ class TestRedshift(Validator): "CREATE TABLE t (c BIGINT IDENTITY(0, 1))", ) self.validate_identity( + "SELECT DATEADD(hour, 0, CAST('2020-02-02 01:03:05.124' AS TIMESTAMP))" + ) + self.validate_identity( + "SELECT DATEDIFF(second, '2020-02-02 00:00:00.000', '2020-02-02 01:03:05.124')" + ) + self.validate_identity( "CREATE OR REPLACE VIEW v1 AS SELECT id, AVG(average_metric1) AS m1, AVG(average_metric2) AS m2 FROM t GROUP BY id WITH NO SCHEMA BINDING" ) self.validate_identity( @@ -295,12 +340,8 @@ class TestRedshift(Validator): "CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)" ) self.validate_identity( - "SELECT DATEADD(day, 1, 'today')", - "SELECT DATEADD(day, 1, CAST('today' AS DATE))", - ) - self.validate_identity( "SELECT DATEADD('day', ndays, caldate)", - "SELECT DATEADD(day, ndays, CAST(caldate AS DATE))", + "SELECT DATEADD(day, ndays, caldate)", ) self.validate_identity( "CONVERT(INT, x)", @@ -308,7 +349,7 @@ class TestRedshift(Validator): ) self.validate_identity( "SELECT DATE_ADD('day', 1, DATE('2023-01-01'))", - "SELECT DATEADD(day, 1, CAST(DATE('2023-01-01') AS DATE))", + "SELECT DATEADD(day, 1, DATE('2023-01-01'))", ) self.validate_identity( """SELECT @@ -449,17 +490,3 @@ FROM ( "redshift": "CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING", }, ) - - def test_concat(self): - self.validate_all( - "SELECT CONCAT('abc', 'def')", - write={ - "redshift": "SELECT COALESCE(CAST('abc' AS VARCHAR(MAX)), '') || COALESCE(CAST('def' AS VARCHAR(MAX)), '')", - }, - ) - self.validate_all( - "SELECT CONCAT_WS('DELIM', 'abc', 'def', 'ghi')", - write={ - "redshift": "SELECT COALESCE(CAST('abc' AS VARCHAR(MAX)), '') || 'DELIM' || COALESCE(CAST('def' AS VARCHAR(MAX)), '') || 'DELIM' || COALESCE(CAST('ghi' AS VARCHAR(MAX)), '')", - }, - ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 2cad1d2..997c27b 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -6,13 +6,38 @@ from tests.dialects.test_dialect import Validator class TestSnowflake(Validator): + maxDiff = None dialect = "snowflake" def test_snowflake(self): + self.validate_identity("SELECT rename, replace") expr = parse_one("SELECT APPROX_TOP_K(C4, 3, 5) FROM t") expr.selects[0].assert_is(exp.AggFunc) self.assertEqual(expr.sql(dialect="snowflake"), "SELECT APPROX_TOP_K(C4, 3, 5) FROM t") + self.assertEqual( + exp.select(exp.Explode(this=exp.column("x")).as_("y", quoted=True)).sql( + "snowflake", pretty=True + ), + """SELECT + IFF(_u.pos = _u_2.pos_2, _u_2."y", NULL) AS "y" +FROM TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, ( + GREATEST(ARRAY_SIZE(x)) - 1 +) + 1))) AS _u(seq, key, path, index, pos, this) +CROSS JOIN TABLE(FLATTEN(INPUT => x)) AS _u_2(seq, key, path, pos_2, "y", this) +WHERE + _u.pos = _u_2.pos_2 + OR ( + _u.pos > ( + ARRAY_SIZE(x) - 1 + ) AND _u_2.pos_2 = ( + ARRAY_SIZE(x) - 1 + ) + )""", + ) + + self.validate_identity("SELECT user_id, value FROM table_name sample ($s) SEED (0)") + self.validate_identity("SELECT ARRAY_UNIQUE_AGG(x)") self.validate_identity("SELECT OBJECT_CONSTRUCT()") self.validate_identity("SELECT DAYOFMONTH(CURRENT_TIMESTAMP())") self.validate_identity("SELECT DAYOFYEAR(CURRENT_TIMESTAMP())") @@ -48,6 +73,14 @@ class TestSnowflake(Validator): 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage' ) self.validate_identity( + "SELECT * FROM unnest(x) with ordinality", + "SELECT * FROM TABLE(FLATTEN(INPUT => x)) AS _u(seq, key, path, index, value, this)", + ) + self.validate_identity( + "CREATE TABLE foo (ID INT COMMENT $$some comment$$)", + "CREATE TABLE foo (ID INT COMMENT 'some comment')", + ) + self.validate_identity( "SELECT state, city, SUM(retail_price * quantity) AS gross_revenue FROM sales GROUP BY ALL" ) self.validate_identity( @@ -88,6 +121,21 @@ class TestSnowflake(Validator): self.validate_all("CAST(x AS CHARACTER VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all("CAST(x AS NCHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"}) self.validate_all( + # We need to qualify the columns in this query because "value" would be ambiguous + 'WITH t(x, "value") AS (SELECT [1, 2, 3], 1) SELECT IFF(_u.pos = _u_2.pos_2, _u_2."value", NULL) AS "value" FROM t, TABLE(FLATTEN(INPUT => ARRAY_GENERATE_RANGE(0, (GREATEST(ARRAY_SIZE(t.x)) - 1) + 1))) AS _u(seq, key, path, index, pos, this) CROSS JOIN TABLE(FLATTEN(INPUT => t.x)) AS _u_2(seq, key, path, pos_2, "value", this) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > (ARRAY_SIZE(t.x) - 1) AND _u_2.pos_2 = (ARRAY_SIZE(t.x) - 1))', + read={ + "duckdb": 'WITH t(x, "value") AS (SELECT [1,2,3], 1) SELECT UNNEST(t.x) AS "value" FROM t', + }, + ) + self.validate_all( + "SELECT { 'Manitoba': 'Winnipeg', 'foo': 'bar' } AS province_capital", + write={ + "duckdb": "SELECT {'Manitoba': 'Winnipeg', 'foo': 'bar'} AS province_capital", + "snowflake": "SELECT OBJECT_CONSTRUCT('Manitoba', 'Winnipeg', 'foo', 'bar') AS province_capital", + "spark": "SELECT STRUCT('Manitoba' AS Winnipeg, 'foo' AS bar) AS province_capital", + }, + ) + self.validate_all( "SELECT COLLATE('B', 'und:ci')", write={ "bigquery": "SELECT COLLATE('B', 'und:ci')", @@ -225,6 +273,7 @@ class TestSnowflake(Validator): "spark": "POWER(x, 2)", "sqlite": "POWER(x, 2)", "starrocks": "POWER(x, 2)", + "teradata": "x ** 2", "trino": "POWER(x, 2)", "tsql": "POWER(x, 2)", }, @@ -241,8 +290,8 @@ class TestSnowflake(Validator): "DIV0(foo, bar)", write={ "snowflake": "IFF(bar = 0, 0, foo / bar)", - "sqlite": "CASE WHEN bar = 0 THEN 0 ELSE foo / bar END", - "presto": "IF(bar = 0, 0, foo / bar)", + "sqlite": "CASE WHEN bar = 0 THEN 0 ELSE CAST(foo AS REAL) / bar END", + "presto": "IF(bar = 0, 0, CAST(foo AS DOUBLE) / bar)", "spark": "IF(bar = 0, 0, foo / bar)", "hive": "IF(bar = 0, 0, foo / bar)", "duckdb": "CASE WHEN bar = 0 THEN 0 ELSE foo / bar END", @@ -355,7 +404,7 @@ class TestSnowflake(Validator): self.validate_all( "SELECT TO_TIMESTAMP(1659981729)", write={ - "bigquery": "SELECT UNIX_TO_TIME(1659981729)", + "bigquery": "SELECT TIMESTAMP_SECONDS(1659981729)", "snowflake": "SELECT TO_TIMESTAMP(1659981729)", "spark": "SELECT CAST(FROM_UNIXTIME(1659981729) AS TIMESTAMP)", }, @@ -363,7 +412,7 @@ class TestSnowflake(Validator): self.validate_all( "SELECT TO_TIMESTAMP(1659981729000, 3)", write={ - "bigquery": "SELECT UNIX_TO_TIME(1659981729000, 'millis')", + "bigquery": "SELECT TIMESTAMP_MILLIS(1659981729000)", "snowflake": "SELECT TO_TIMESTAMP(1659981729000, 3)", "spark": "SELECT TIMESTAMP_MILLIS(1659981729000)", }, @@ -371,7 +420,6 @@ class TestSnowflake(Validator): self.validate_all( "SELECT TO_TIMESTAMP('1659981729')", write={ - "bigquery": "SELECT UNIX_TO_TIME('1659981729')", "snowflake": "SELECT TO_TIMESTAMP('1659981729')", "spark": "SELECT CAST(FROM_UNIXTIME('1659981729') AS TIMESTAMP)", }, @@ -379,9 +427,11 @@ class TestSnowflake(Validator): self.validate_all( "SELECT TO_TIMESTAMP(1659981729000000000, 9)", write={ - "bigquery": "SELECT UNIX_TO_TIME(1659981729000000000, 'micros')", + "bigquery": "SELECT TIMESTAMP_MICROS(CAST(1659981729000000000 / 1000 AS INT64))", + "duckdb": "SELECT TO_TIMESTAMP(1659981729000000000 / 1000000000)", + "presto": "SELECT FROM_UNIXTIME(CAST(1659981729000000000 AS DOUBLE) / 1000000000)", "snowflake": "SELECT TO_TIMESTAMP(1659981729000000000, 9)", - "spark": "SELECT TIMESTAMP_MICROS(1659981729000000000)", + "spark": "SELECT TIMESTAMP_SECONDS(1659981729000000000 / 1000000000)", }, ) self.validate_all( @@ -404,7 +454,6 @@ class TestSnowflake(Validator): "spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')", }, ) - self.validate_all( "SELECT IFF(TRUE, 'true', 'false')", write={ @@ -551,6 +600,7 @@ class TestSnowflake(Validator): staged_file.sql(dialect="snowflake"), ) + self.validate_identity("SELECT metadata$filename FROM @s1/") self.validate_identity("SELECT * FROM @~") self.validate_identity("SELECT * FROM @~/some/path/to/file.csv") self.validate_identity("SELECT * FROM @mystage") @@ -610,6 +660,13 @@ class TestSnowflake(Validator): "snowflake": "SELECT * FROM testtable SAMPLE BLOCK (0.012) SEED (99992)", }, ) + self.validate_all( + "SELECT * FROM (SELECT * FROM t1 join t2 on t1.a = t2.c) SAMPLE (1)", + write={ + "snowflake": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) SAMPLE (1)", + "spark": "SELECT * FROM (SELECT * FROM t1 JOIN t2 ON t1.a = t2.c) SAMPLE (1 PERCENT)", + }, + ) def test_timestamps(self): self.validate_identity("SELECT CAST('12:00:00' AS TIME)") @@ -719,6 +776,17 @@ class TestSnowflake(Validator): ) def test_ddl(self): + self.validate_identity( + """create external table et2( + col1 date as (parse_json(metadata$external_table_partition):COL1::date), + col2 varchar as (parse_json(metadata$external_table_partition):COL2::varchar), + col3 number as (parse_json(metadata$external_table_partition):COL3::number)) + partition by (col1,col2,col3) + location=@s2/logs/ + partition_type = user_specified + file_format = (type = parquet)""", + "CREATE EXTERNAL TABLE et2 (col1 DATE AS (CAST(PARSE_JSON(metadata$external_table_partition)['COL1'] AS DATE)), col2 VARCHAR AS (CAST(PARSE_JSON(metadata$external_table_partition)['COL2'] AS VARCHAR)), col3 DECIMAL AS (CAST(PARSE_JSON(metadata$external_table_partition)['COL3'] AS DECIMAL))) LOCATION @s2/logs/ PARTITION BY (col1, col2, col3) partition_type=user_specified file_format=(type = parquet)", + ) self.validate_identity("CREATE OR REPLACE VIEW foo (uid) COPY GRANTS AS (SELECT 1)") self.validate_identity("CREATE TABLE geospatial_table (id INT, g GEOGRAPHY)") self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x") @@ -733,7 +801,7 @@ class TestSnowflake(Validator): "CREATE TABLE orders_clone_restore CLONE orders BEFORE (STATEMENT => '8e5d0ca9-005e-44e6-b858-a8f5b37c5726')" ) self.validate_identity( - "CREATE TABLE a (x DATE, y BIGINT) WITH (PARTITION BY (x), integration='q', auto_refresh=TRUE, file_format=(type = parquet))" + "CREATE TABLE a (x DATE, y BIGINT) PARTITION BY (x) integration='q' auto_refresh=TRUE file_format=(type = parquet)" ) self.validate_identity( "CREATE SCHEMA mytestschema_clone_restore CLONE testschema BEFORE (TIMESTAMP => TO_TIMESTAMP(40 * 365 * 86400))" @@ -1179,3 +1247,39 @@ MATCH_RECOGNIZE ( ast = parse_one("ALTER TABLE a SWAP WITH b", read="snowflake") assert isinstance(ast, exp.AlterTable) assert isinstance(ast.args["actions"][0], exp.SwapTable) + + def test_try_cast(self): + self.validate_identity("SELECT TRY_CAST(x AS DOUBLE)") + + self.validate_all("TRY_CAST('foo' AS TEXT)", read={"hive": "CAST('foo' AS STRING)"}) + self.validate_all("CAST(5 + 5 AS TEXT)", read={"hive": "CAST(5 + 5 AS STRING)"}) + self.validate_all( + "CAST(TRY_CAST('2020-01-01' AS DATE) AS TEXT)", + read={ + "hive": "CAST(CAST('2020-01-01' AS DATE) AS STRING)", + "snowflake": "CAST(TRY_CAST('2020-01-01' AS DATE) AS TEXT)", + }, + ) + self.validate_all( + "TRY_CAST(x AS TEXT)", + read={ + "hive": "CAST(x AS STRING)", + "snowflake": "TRY_CAST(x AS TEXT)", + }, + ) + + from sqlglot.optimizer.annotate_types import annotate_types + + expression = parse_one("SELECT CAST(t.x AS STRING) FROM t", read="hive") + + expression = annotate_types(expression, schema={"t": {"x": "string"}}) + self.assertEqual(expression.sql(dialect="snowflake"), "SELECT TRY_CAST(t.x AS TEXT) FROM t") + + expression = annotate_types(expression, schema={"t": {"x": "int"}}) + self.assertEqual(expression.sql(dialect="snowflake"), "SELECT CAST(t.x AS TEXT) FROM t") + + # We can't infer FOO's type since it's a UDF in this case, so we don't get rid of TRY_CAST + expression = parse_one("SELECT TRY_CAST(FOO() AS TEXT)", read="snowflake") + + expression = annotate_types(expression) + self.assertEqual(expression.sql(dialect="snowflake"), "SELECT TRY_CAST(FOO() AS TEXT)") diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 841a005..fe37027 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -8,6 +8,7 @@ class TestSpark(Validator): dialect = "spark" def test_ddl(self): + self.validate_identity("CREATE TABLE foo AS WITH t AS (SELECT 1 AS col) SELECT col FROM t") self.validate_identity("CREATE TEMPORARY VIEW test AS SELECT 1") self.validate_identity("CREATE TABLE foo (col VARCHAR(50))") self.validate_identity("CREATE TABLE foo (col STRUCT<struct_col_a: VARCHAR((50))>)") @@ -226,15 +227,23 @@ TBLPROPERTIES ( ) def test_spark(self): + self.validate_identity("FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), 'utc')") expr = parse_one("any_value(col, true)", read="spark") self.assertIsInstance(expr.args.get("ignore_nulls"), exp.Boolean) self.assertEqual(expr.sql(dialect="spark"), "ANY_VALUE(col, TRUE)") + self.assertEqual( + parse_one("REFRESH TABLE t", read="spark").assert_is(exp.Refresh).sql(dialect="spark"), + "REFRESH TABLE t", + ) + + self.validate_identity("REFRESH 'hdfs://path/to/table'") + self.validate_identity("REFRESH TABLE tempDB.view1") self.validate_identity("SELECT CASE WHEN a = NULL THEN 1 ELSE 2 END") self.validate_identity("SELECT * FROM t1 SEMI JOIN t2 ON t1.x = t2.x") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)") - self.validate_identity("REFRESH table a.b.c") + self.validate_identity("REFRESH TABLE a.b.c") self.validate_identity("INTERVAL -86 days") self.validate_identity("SELECT UNIX_TIMESTAMP()") self.validate_identity("TRIM(' SparkSQL ')") @@ -300,6 +309,18 @@ TBLPROPERTIES ( }, ) self.validate_all( + "SELECT DATEDIFF(week, '2020-01-01', '2020-12-31')", + write={ + "bigquery": "SELECT DATE_DIFF(CAST('2020-12-31' AS DATE), CAST('2020-01-01' AS DATE), week)", + "duckdb": "SELECT DATE_DIFF('week', CAST('2020-01-01' AS DATE), CAST('2020-12-31' AS DATE))", + "hive": "SELECT CAST(DATEDIFF(TO_DATE('2020-12-31'), TO_DATE('2020-01-01')) / 7 AS INT)", + "postgres": "SELECT CAST(EXTRACT(days FROM (CAST(CAST('2020-12-31' AS DATE) AS TIMESTAMP) - CAST(CAST('2020-01-01' AS DATE) AS TIMESTAMP))) / 7 AS BIGINT)", + "redshift": "SELECT DATEDIFF(week, CAST('2020-01-01' AS DATE), CAST('2020-12-31' AS DATE))", + "snowflake": "SELECT DATEDIFF(week, CAST('2020-01-01' AS DATE), CAST('2020-12-31' AS DATE))", + "spark": "SELECT DATEDIFF(week, TO_DATE('2020-01-01'), TO_DATE('2020-12-31'))", + }, + ) + self.validate_all( "SELECT MONTHS_BETWEEN('1997-02-28 10:30:00', '1996-10-30')", write={ "duckdb": "SELECT DATEDIFF('month', CAST('1996-10-30' AS TIMESTAMP), CAST('1997-02-28 10:30:00' AS TIMESTAMP))", @@ -435,7 +456,7 @@ TBLPROPERTIES ( "SELECT DATE_ADD(my_date_column, 1)", write={ "spark": "SELECT DATE_ADD(my_date_column, 1)", - "bigquery": "SELECT DATE_ADD(my_date_column, INTERVAL 1 DAY)", + "bigquery": "SELECT DATE_ADD(CAST(CAST(my_date_column AS DATETIME) AS DATE), INTERVAL 1 DAY)", }, ) self.validate_all( @@ -592,9 +613,9 @@ TBLPROPERTIES ( self.validate_all( "INSERT OVERWRITE TABLE table WITH cte AS (SELECT cola FROM other_table) SELECT cola FROM cte", write={ - "databricks": "INSERT OVERWRITE TABLE table WITH cte AS (SELECT cola FROM other_table) SELECT cola FROM cte", + "databricks": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", "hive": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", - "spark": "INSERT OVERWRITE TABLE table WITH cte AS (SELECT cola FROM other_table) SELECT cola FROM cte", + "spark": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", "spark2": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", }, ) @@ -604,7 +625,7 @@ TBLPROPERTIES ( "SELECT EXPLODE(x) FROM tbl", write={ "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col FROM tbl, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(x)) - 1)) AS pos CROSS JOIN UNNEST(x) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(x) - 1) AND pos_2 = (ARRAY_LENGTH(x) - 1))", - "presto": "SELECT IF(pos = pos_2, col) AS col FROM tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(x) AND pos_2 = CARDINALITY(x))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col FROM tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(col, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(x) AND _u_2.pos_2 = CARDINALITY(x))", "spark": "SELECT EXPLODE(x) FROM tbl", }, ) @@ -612,46 +633,46 @@ TBLPROPERTIES ( "SELECT EXPLODE(col) FROM _u", write={ "bigquery": "SELECT IF(pos = pos_2, col_2, NULL) AS col_2 FROM _u, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(col)) - 1)) AS pos CROSS JOIN UNNEST(col) AS col_2 WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH(col) - 1) AND pos_2 = (ARRAY_LENGTH(col) - 1))", - "presto": "SELECT IF(pos = pos_2, col_2) AS col_2 FROM _u, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(col)))) AS _u_2(pos) CROSS JOIN UNNEST(col) WITH ORDINALITY AS _u_3(col_2, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(col) AND pos_2 = CARDINALITY(col))", + "presto": "SELECT IF(_u_2.pos = _u_3.pos_2, _u_3.col_2) AS col_2 FROM _u, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(col)))) AS _u_2(pos) CROSS JOIN UNNEST(col) WITH ORDINALITY AS _u_3(col_2, pos_2) WHERE _u_2.pos = _u_3.pos_2 OR (_u_2.pos > CARDINALITY(col) AND _u_3.pos_2 = CARDINALITY(col))", "spark": "SELECT EXPLODE(col) FROM _u", }, ) self.validate_all( "SELECT EXPLODE(col) AS exploded FROM schema.tbl", write={ - "presto": "SELECT IF(pos = pos_2, exploded) AS exploded FROM schema.tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(col)))) AS _u(pos) CROSS JOIN UNNEST(col) WITH ORDINALITY AS _u_2(exploded, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(col) AND pos_2 = CARDINALITY(col))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.exploded) AS exploded FROM schema.tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(col)))) AS _u(pos) CROSS JOIN UNNEST(col) WITH ORDINALITY AS _u_2(exploded, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(col) AND _u_2.pos_2 = CARDINALITY(col))", }, ) self.validate_all( "SELECT EXPLODE(ARRAY(1, 2))", write={ "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([1, 2])) - 1)) AS pos CROSS JOIN UNNEST([1, 2]) AS col WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH([1, 2]) - 1) AND pos_2 = (ARRAY_LENGTH([1, 2]) - 1))", - "presto": "SELECT IF(pos = pos_2, col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[1, 2]) AND pos_2 = CARDINALITY(ARRAY[1, 2]))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[1, 2])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[1, 2]) WITH ORDINALITY AS _u_2(col, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[1, 2]) AND _u_2.pos_2 = CARDINALITY(ARRAY[1, 2]))", }, ) self.validate_all( "SELECT POSEXPLODE(ARRAY(2, 3)) AS x", write={ "bigquery": "SELECT IF(pos = pos_2, x, NULL) AS x, IF(pos = pos_2, pos_2, NULL) AS pos_2 FROM UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([2, 3])) - 1)) AS pos CROSS JOIN UNNEST([2, 3]) AS x WITH OFFSET AS pos_2 WHERE pos = pos_2 OR (pos > (ARRAY_LENGTH([2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([2, 3]) - 1))", - "presto": "SELECT IF(pos = pos_2, x) AS x, IF(pos = pos_2, pos_2) AS pos_2 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(x, pos_2) WHERE pos = pos_2 OR (pos > CARDINALITY(ARRAY[2, 3]) AND pos_2 = CARDINALITY(ARRAY[2, 3]))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.x) AS x, IF(_u.pos = _u_2.pos_2, _u_2.pos_2) AS pos_2 FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(x, pos_2) WHERE _u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[2, 3]) AND _u_2.pos_2 = CARDINALITY(ARRAY[2, 3]))", }, ) self.validate_all( "SELECT POSEXPLODE(x) AS (a, b)", write={ - "presto": "SELECT IF(pos = a, b) AS b, IF(pos = a, a) AS a FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(b, a) WHERE pos = a OR (pos > CARDINALITY(x) AND a = CARDINALITY(x))", + "presto": "SELECT IF(_u.pos = _u_2.a, _u_2.b) AS b, IF(_u.pos = _u_2.a, _u_2.a) AS a FROM UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(x)))) AS _u(pos) CROSS JOIN UNNEST(x) WITH ORDINALITY AS _u_2(b, a) WHERE _u.pos = _u_2.a OR (_u.pos > CARDINALITY(x) AND _u_2.a = CARDINALITY(x))", }, ) self.validate_all( "SELECT POSEXPLODE(ARRAY(2, 3)), EXPLODE(ARRAY(4, 5, 6)) FROM tbl", write={ "bigquery": "SELECT IF(pos = pos_2, col, NULL) AS col, IF(pos = pos_2, pos_2, NULL) AS pos_2, IF(pos = pos_3, col_2, NULL) AS col_2 FROM tbl, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH([2, 3]), ARRAY_LENGTH([4, 5, 6])) - 1)) AS pos CROSS JOIN UNNEST([2, 3]) AS col WITH OFFSET AS pos_2 CROSS JOIN UNNEST([4, 5, 6]) AS col_2 WITH OFFSET AS pos_3 WHERE (pos = pos_2 OR (pos > (ARRAY_LENGTH([2, 3]) - 1) AND pos_2 = (ARRAY_LENGTH([2, 3]) - 1))) AND (pos = pos_3 OR (pos > (ARRAY_LENGTH([4, 5, 6]) - 1) AND pos_3 = (ARRAY_LENGTH([4, 5, 6]) - 1)))", - "presto": "SELECT IF(pos = pos_2, col) AS col, IF(pos = pos_2, pos_2) AS pos_2, IF(pos = pos_3, col_2) AS col_2 FROM tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3]), CARDINALITY(ARRAY[4, 5, 6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5, 6]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE (pos = pos_2 OR (pos > CARDINALITY(ARRAY[2, 3]) AND pos_2 = CARDINALITY(ARRAY[2, 3]))) AND (pos = pos_3 OR (pos > CARDINALITY(ARRAY[4, 5, 6]) AND pos_3 = CARDINALITY(ARRAY[4, 5, 6])))", + "presto": "SELECT IF(_u.pos = _u_2.pos_2, _u_2.col) AS col, IF(_u.pos = _u_2.pos_2, _u_2.pos_2) AS pos_2, IF(_u.pos = _u_3.pos_3, _u_3.col_2) AS col_2 FROM tbl, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3]), CARDINALITY(ARRAY[4, 5, 6])))) AS _u(pos) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_2(col, pos_2) CROSS JOIN UNNEST(ARRAY[4, 5, 6]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE (_u.pos = _u_2.pos_2 OR (_u.pos > CARDINALITY(ARRAY[2, 3]) AND _u_2.pos_2 = CARDINALITY(ARRAY[2, 3]))) AND (_u.pos = _u_3.pos_3 OR (_u.pos > CARDINALITY(ARRAY[4, 5, 6]) AND _u_3.pos_3 = CARDINALITY(ARRAY[4, 5, 6])))", }, ) self.validate_all( "SELECT col, pos, POSEXPLODE(ARRAY(2, 3)) FROM _u", write={ - "presto": "SELECT col, pos, IF(pos_2 = pos_3, col_2) AS col_2, IF(pos_2 = pos_3, pos_3) AS pos_3 FROM _u, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u_2(pos_2) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE pos_2 = pos_3 OR (pos_2 > CARDINALITY(ARRAY[2, 3]) AND pos_3 = CARDINALITY(ARRAY[2, 3]))", + "presto": "SELECT col, pos, IF(_u_2.pos_2 = _u_3.pos_3, _u_3.col_2) AS col_2, IF(_u_2.pos_2 = _u_3.pos_3, _u_3.pos_3) AS pos_3 FROM _u, UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u_2(pos_2) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE _u_2.pos_2 = _u_3.pos_3 OR (_u_2.pos_2 > CARDINALITY(ARRAY[2, 3]) AND _u_3.pos_3 = CARDINALITY(ARRAY[2, 3]))", }, ) diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 14703c4..85d4ebf 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -147,6 +147,9 @@ class TestTeradata(Validator): def test_mod(self): self.validate_all("a MOD b", write={"teradata": "a MOD b", "mysql": "a % b"}) + def test_power(self): + self.validate_all("a ** b", write={"teradata": "a ** b", "mysql": "POWER(a, b)"}) + def test_abbrev(self): self.validate_identity("a LT b", "a < b") self.validate_identity("a LE b", "a <= b") @@ -191,3 +194,14 @@ class TestTeradata(Validator): }, ) self.validate_identity("CAST('1992-01' AS FORMAT 'YYYY-DD')") + + self.validate_all( + "TRYCAST('-2.5' AS DECIMAL(5, 2))", + read={ + "snowflake": "TRY_CAST('-2.5' AS DECIMAL(5, 2))", + }, + write={ + "snowflake": "TRY_CAST('-2.5' AS DECIMAL(5, 2))", + "teradata": "TRYCAST('-2.5' AS DECIMAL(5, 2))", + }, + ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 0ac94f2..07179ef 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -6,8 +6,22 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): - self.validate_all( - "WITH t(c) AS (SELECT 1) SELECT * INTO foo FROM (SELECT c FROM t) AS temp", + # https://learn.microsoft.com/en-us/previous-versions/sql/sql-server-2008-r2/ms187879(v=sql.105)?redirectedfrom=MSDN + # tsql allows .. which means use the default schema + self.validate_identity("SELECT * FROM a..b") + + self.validate_identity("SELECT CONCAT(column1, column2)") + self.validate_identity("SELECT TestSpecialChar.Test# FROM TestSpecialChar") + self.validate_identity("SELECT TestSpecialChar.Test@ FROM TestSpecialChar") + self.validate_identity("SELECT TestSpecialChar.Test$ FROM TestSpecialChar") + self.validate_identity("SELECT TestSpecialChar.Test_ FROM TestSpecialChar") + self.validate_identity("SELECT TOP (2 + 1) 1") + self.validate_identity("SELECT * FROM t WHERE NOT c", "SELECT * FROM t WHERE NOT c <> 0") + self.validate_identity("1 AND true", "1 <> 0 AND (1 = 1)") + self.validate_identity("CAST(x AS int) OR y", "CAST(x AS INTEGER) <> 0 OR y <> 0") + + self.validate_all( + "WITH t(c) AS (SELECT 1) SELECT * INTO foo FROM (SELECT c AS c FROM t) AS temp", read={ "duckdb": "CREATE TABLE foo AS WITH t(c) AS (SELECT 1) SELECT c FROM t", }, @@ -25,7 +39,7 @@ class TestTSQL(Validator): }, ) self.validate_all( - "WITH t(c) AS (SELECT 1) MERGE INTO x AS z USING (SELECT c FROM t) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", + "WITH t(c) AS (SELECT 1) MERGE INTO x AS z USING (SELECT c AS c FROM t) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", read={ "postgres": "MERGE INTO x AS z USING (WITH t(c) AS (SELECT 1) SELECT c FROM t) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", }, @@ -167,18 +181,6 @@ class TestTSQL(Validator): ) self.validate_all( - "SELECT DATEPART(year, CAST('2017-01-01' AS DATE))", - read={"postgres": "SELECT DATE_PART('year', '2017-01-01'::DATE)"}, - ) - self.validate_all( - "SELECT DATEPART(month, CAST('2017-03-01' AS DATE))", - read={"postgres": "SELECT DATE_PART('month', '2017-03-01'::DATE)"}, - ) - self.validate_all( - "SELECT DATEPART(day, CAST('2017-01-02' AS DATE))", - read={"postgres": "SELECT DATE_PART('day', '2017-01-02'::DATE)"}, - ) - self.validate_all( "SELECT CAST([a].[b] AS SMALLINT) FROM foo", write={ "tsql": 'SELECT CAST("a"."b" AS SMALLINT) FROM foo', @@ -229,11 +231,13 @@ class TestTSQL(Validator): self.validate_all( "HASHBYTES('SHA1', x)", read={ + "snowflake": "SHA1(x)", "spark": "SHA(x)", }, write={ - "tsql": "HASHBYTES('SHA1', x)", + "snowflake": "SHA1(x)", "spark": "SHA(x)", + "tsql": "HASHBYTES('SHA1', x)", }, ) self.validate_all( @@ -561,6 +565,21 @@ class TestTSQL(Validator): ) def test_ddl(self): + expression = parse_one("ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B", dialect="tsql") + self.assertIsInstance(expression, exp.AlterTable) + self.assertIsInstance(expression.args["actions"][0], exp.Drop) + self.assertEqual( + expression.sql(dialect="tsql"), "ALTER TABLE dbo.DocExe DROP CONSTRAINT FK_Column_B" + ) + + for clusterd_keyword in ("CLUSTERED", "NONCLUSTERED"): + self.validate_identity( + 'CREATE TABLE "dbo"."benchmark" (' + '"name" CHAR(7) NOT NULL, ' + '"internal_id" VARCHAR(10) NOT NULL, ' + f'UNIQUE {clusterd_keyword} ("internal_id" ASC))' + ) + self.validate_identity( "CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < CURRENT_TIMESTAMP - 7 END", "CREATE PROCEDURE foo AS BEGIN DELETE FROM bla WHERE foo < GETDATE() - 7 END", @@ -589,6 +608,12 @@ class TestTSQL(Validator): }, ) self.validate_all( + "SELECT * INTO foo.bar.baz FROM (SELECT * FROM a.b.c) AS temp", + read={ + "": "CREATE TABLE foo.bar.baz AS (SELECT * FROM a.b.c)", + }, + ) + self.validate_all( "IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id('db.tbl') AND name = 'idx') EXEC('CREATE INDEX idx ON db.tbl')", read={ "": "CREATE INDEX IF NOT EXISTS idx ON db.tbl", @@ -622,7 +647,6 @@ class TestTSQL(Validator): "tsql": "CREATE OR ALTER VIEW a.b AS SELECT 1", }, ) - self.validate_all( "ALTER TABLE a ADD b INTEGER, c INTEGER", read={ @@ -633,7 +657,6 @@ class TestTSQL(Validator): "tsql": "ALTER TABLE a ADD b INTEGER, c INTEGER", }, ) - self.validate_all( "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))", write={ @@ -833,7 +856,7 @@ WHERE ) def test_len(self): - self.validate_all("LEN(x)", write={"spark": "LENGTH(x)"}) + self.validate_all("LEN(x)", read={"": "LENGTH(x)"}, write={"spark": "LENGTH(x)"}) def test_replicate(self): self.validate_all("REPLICATE('x', 2)", write={"spark": "REPEAT('x', 2)"}) @@ -870,11 +893,68 @@ WHERE ) def test_datepart(self): + self.validate_identity( + "DATEPART(QUARTER, x)", + "DATEPART(quarter, CAST(x AS DATETIME2))", + ) + self.validate_identity( + "DATEPART(YEAR, x)", + "FORMAT(CAST(x AS DATETIME2), 'yyyy')", + ) + self.validate_identity( + "DATEPART(HOUR, date_and_time)", + "DATEPART(hour, CAST(date_and_time AS DATETIME2))", + ) + self.validate_identity( + "DATEPART(WEEKDAY, date_and_time)", + "DATEPART(dw, CAST(date_and_time AS DATETIME2))", + ) + self.validate_identity( + "DATEPART(DW, date_and_time)", + "DATEPART(dw, CAST(date_and_time AS DATETIME2))", + ) + self.validate_all( "SELECT DATEPART(month,'1970-01-01')", - write={"spark": "SELECT DATE_FORMAT(CAST('1970-01-01' AS TIMESTAMP), 'MM')"}, + write={ + "postgres": "SELECT TO_CHAR(CAST('1970-01-01' AS TIMESTAMP), 'MM')", + "spark": "SELECT DATE_FORMAT(CAST('1970-01-01' AS TIMESTAMP), 'MM')", + "tsql": "SELECT FORMAT(CAST('1970-01-01' AS DATETIME2), 'MM')", + }, + ) + self.validate_all( + "SELECT DATEPART(year, CAST('2017-01-01' AS DATE))", + read={ + "postgres": "SELECT DATE_PART('year', '2017-01-01'::DATE)", + }, + write={ + "postgres": "SELECT TO_CHAR(CAST(CAST('2017-01-01' AS DATE) AS TIMESTAMP), 'YYYY')", + "spark": "SELECT DATE_FORMAT(CAST(CAST('2017-01-01' AS DATE) AS TIMESTAMP), 'yyyy')", + "tsql": "SELECT FORMAT(CAST(CAST('2017-01-01' AS DATE) AS DATETIME2), 'yyyy')", + }, + ) + self.validate_all( + "SELECT DATEPART(month, CAST('2017-03-01' AS DATE))", + read={ + "postgres": "SELECT DATE_PART('month', '2017-03-01'::DATE)", + }, + write={ + "postgres": "SELECT TO_CHAR(CAST(CAST('2017-03-01' AS DATE) AS TIMESTAMP), 'MM')", + "spark": "SELECT DATE_FORMAT(CAST(CAST('2017-03-01' AS DATE) AS TIMESTAMP), 'MM')", + "tsql": "SELECT FORMAT(CAST(CAST('2017-03-01' AS DATE) AS DATETIME2), 'MM')", + }, + ) + self.validate_all( + "SELECT DATEPART(day, CAST('2017-01-02' AS DATE))", + read={ + "postgres": "SELECT DATE_PART('day', '2017-01-02'::DATE)", + }, + write={ + "postgres": "SELECT TO_CHAR(CAST(CAST('2017-01-02' AS DATE) AS TIMESTAMP), 'DD')", + "spark": "SELECT DATE_FORMAT(CAST(CAST('2017-01-02' AS DATE) AS TIMESTAMP), 'dd')", + "tsql": "SELECT FORMAT(CAST(CAST('2017-01-02' AS DATE) AS DATETIME2), 'dd')", + }, ) - self.validate_identity("DATEPART(YEAR, x)", "FORMAT(CAST(x AS DATETIME2), 'yyyy')") def test_convert_date_format(self): self.validate_all( @@ -1073,10 +1153,7 @@ WHERE def test_date_diff(self): self.validate_identity("SELECT DATEDIFF(hour, 1.5, '2021-01-01')") - self.validate_identity( - "SELECT DATEDIFF(year, '2020-01-01', '2021-01-01')", - "SELECT DATEDIFF(year, CAST('2020-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", - ) + self.validate_all( "SELECT DATEDIFF(quarter, 0, '2021-01-01')", write={ @@ -1098,7 +1175,7 @@ WHERE write={ "tsql": "SELECT DATEDIFF(year, CAST('2020-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", "spark": "SELECT DATEDIFF(year, CAST('2020-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", - "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('2021-01-01' AS TIMESTAMP), CAST('2020-01-01' AS TIMESTAMP)) AS INT) / 12", + "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('2021-01-01' AS TIMESTAMP), CAST('2020-01-01' AS TIMESTAMP)) / 12 AS INT)", }, ) self.validate_all( @@ -1114,16 +1191,18 @@ WHERE write={ "databricks": "SELECT DATEDIFF(quarter, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", "spark": "SELECT DATEDIFF(quarter, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", - "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) AS INT) / 3", + "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) / 3 AS INT)", "tsql": "SELECT DATEDIFF(quarter, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))", }, ) def test_iif(self): self.validate_identity( - "SELECT IF(cond, 'True', 'False')", "SELECT IIF(cond, 'True', 'False')" + "SELECT IF(cond, 'True', 'False')", "SELECT IIF(cond <> 0, 'True', 'False')" + ) + self.validate_identity( + "SELECT IIF(cond, 'True', 'False')", "SELECT IIF(cond <> 0, 'True', 'False')" ) - self.validate_identity("SELECT IIF(cond, 'True', 'False')") self.validate_all( "SELECT IIF(cond, 'True', 'False');", write={ @@ -1173,9 +1252,14 @@ WHERE def test_top(self): self.validate_all( - "SELECT TOP 3 * FROM A", + "SELECT DISTINCT TOP 3 * FROM A", + read={ + "spark": "SELECT DISTINCT * FROM A LIMIT 3", + }, write={ - "spark": "SELECT * FROM A LIMIT 3", + "spark": "SELECT DISTINCT * FROM A LIMIT 3", + "teradata": "SELECT DISTINCT TOP 3 * FROM A", + "tsql": "SELECT DISTINCT TOP 3 * FROM A", }, ) self.validate_all( @@ -1292,6 +1376,26 @@ WHERE }, ) + def test_temporal_table(self): + self.validate_identity( + """CREATE TABLE test ("data" CHAR(7), "valid_from" DATETIME2(2) GENERATED ALWAYS AS ROW START NOT NULL, "valid_to" DATETIME2(2) GENERATED ALWAYS AS ROW END NOT NULL, PERIOD FOR SYSTEM_TIME ("valid_from", "valid_to")) WITH(SYSTEM_VERSIONING=ON)""" + ) + self.validate_identity( + """CREATE TABLE test ("data" CHAR(7), "valid_from" DATETIME2(2) GENERATED ALWAYS AS ROW START HIDDEN NOT NULL, "valid_to" DATETIME2(2) GENERATED ALWAYS AS ROW END HIDDEN NOT NULL, PERIOD FOR SYSTEM_TIME ("valid_from", "valid_to")) WITH(SYSTEM_VERSIONING=ON(HISTORY_TABLE="dbo"."benchmark_history", DATA_CONSISTENCY_CHECK=ON))""" + ) + self.validate_identity( + """CREATE TABLE test ("data" CHAR(7), "valid_from" DATETIME2(2) GENERATED ALWAYS AS ROW START NOT NULL, "valid_to" DATETIME2(2) GENERATED ALWAYS AS ROW END NOT NULL, PERIOD FOR SYSTEM_TIME ("valid_from", "valid_to")) WITH(SYSTEM_VERSIONING=ON(HISTORY_TABLE="dbo"."benchmark_history", DATA_CONSISTENCY_CHECK=ON))""" + ) + self.validate_identity( + """CREATE TABLE test ("data" CHAR(7), "valid_from" DATETIME2(2) GENERATED ALWAYS AS ROW START NOT NULL, "valid_to" DATETIME2(2) GENERATED ALWAYS AS ROW END NOT NULL, PERIOD FOR SYSTEM_TIME ("valid_from", "valid_to")) WITH(SYSTEM_VERSIONING=ON(HISTORY_TABLE="dbo"."benchmark_history", DATA_CONSISTENCY_CHECK=OFF))""" + ) + self.validate_identity( + """CREATE TABLE test ("data" CHAR(7), "valid_from" DATETIME2(2) GENERATED ALWAYS AS ROW START NOT NULL, "valid_to" DATETIME2(2) GENERATED ALWAYS AS ROW END NOT NULL, PERIOD FOR SYSTEM_TIME ("valid_from", "valid_to")) WITH(SYSTEM_VERSIONING=ON(HISTORY_TABLE="dbo"."benchmark_history"))""" + ) + self.validate_identity( + """CREATE TABLE test ("data" CHAR(7), "valid_from" DATETIME2(2) GENERATED ALWAYS AS ROW START NOT NULL, "valid_to" DATETIME2(2) GENERATED ALWAYS AS ROW END NOT NULL, PERIOD FOR SYSTEM_TIME ("valid_from", "valid_to")) WITH(SYSTEM_VERSIONING=ON(HISTORY_TABLE="dbo"."benchmark_history"))""" + ) + def test_system_time(self): self.validate_all( "SELECT [x] FROM [a].[b] FOR SYSTEM_TIME AS OF 'foo'", @@ -1433,3 +1537,28 @@ FROM OPENJSON(@json) WITH ( "spark": "SET count = (SELECT COUNT(1) FROM x)", }, ) + + def test_qualify_derived_table_outputs(self): + self.validate_identity( + "WITH t AS (SELECT 1) SELECT * FROM t", + 'WITH t AS (SELECT 1 AS "1") SELECT * FROM t', + ) + self.validate_identity( + 'WITH t AS (SELECT "c") SELECT * FROM t', + 'WITH t AS (SELECT "c" AS "c") SELECT * FROM t', + ) + self.validate_identity( + "SELECT * FROM (SELECT 1) AS subq", + 'SELECT * FROM (SELECT 1 AS "1") AS subq', + ) + self.validate_identity( + 'SELECT * FROM (SELECT "c") AS subq', + 'SELECT * FROM (SELECT "c" AS "c") AS subq', + ) + + self.validate_all( + "WITH t1(c) AS (SELECT 1), t2 AS (SELECT CAST(c AS INTEGER) AS c FROM t1) SELECT * FROM t2", + read={ + "duckdb": "WITH t1(c) AS (SELECT 1), t2 AS (SELECT CAST(c AS INTEGER) FROM t1) SELECT * FROM t2", + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 6e0a3e5..effebca 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -731,7 +731,6 @@ WITH a AS (SELECT 1) INSERT INTO b SELECT * FROM a WITH a AS (SELECT * FROM b) UPDATE a SET col = 1 WITH a AS (SELECT * FROM b) CREATE TABLE b AS SELECT * FROM a WITH a AS (SELECT * FROM b) DELETE FROM a -WITH a AS (SELECT * FROM b) CACHE TABLE a SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ? SELECT :hello, ? FROM x LIMIT :my_limit SELECT a FROM b WHERE c IS ? @@ -867,3 +866,4 @@ KILL QUERY '123' CHR(97) SELECT * FROM UNNEST(x) WITH ORDINALITY UNION ALL SELECT * FROM UNNEST(y) WITH ORDINALITY WITH use(use) AS (SELECT 1) SELECT use FROM use +SELECT recursive FROM t diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql index 954b1c1..302acb9 100644 --- a/tests/fixtures/optimizer/canonicalize.sql +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -10,15 +10,27 @@ SELECT CAST(1 AS VARCHAR) AS "a" FROM "w" AS "w"; SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w; SELECT 1 + 3.2 AS "a" FROM "w" AS "w"; +SELECT '1' + 1 AS "col"; +SELECT '1' + 1 AS "col"; + +SELECT '1' + '1' AS "col"; +SELECT CONCAT('1', '1') AS "col"; + SELECT CAST('2022-01-01' AS DATE) + INTERVAL '1' day; SELECT CAST('2022-01-01' AS DATE) + INTERVAL '1' day AS "_col_0"; +SELECT CAST('2022-01-01' AS DATE) IS NULL AS "a"; +SELECT CAST('2022-01-01' AS DATE) IS NULL AS "a"; + -------------------------------------- -- Ensure boolean predicates -------------------------------------- SELECT a FROM x WHERE b; SELECT "x"."a" AS "a" FROM "x" AS "x" WHERE "x"."b" <> 0; +SELECT NOT b FROM x; +SELECT NOT "x"."b" <> 0 AS "_col_0" FROM "x" AS "x"; + SELECT a FROM x GROUP BY a HAVING SUM(b); SELECT "x"."a" AS "a" FROM "x" AS "x" GROUP BY "x"."a" HAVING SUM("x"."b") <> 0; @@ -46,8 +58,41 @@ CAST('2023-01-01' AS TIMESTAMP); TIMESTAMP('2023-01-01', '12:00:00'); TIMESTAMP('2023-01-01', '12:00:00'); +-------------------------------------- +-- Coerce date function args +-------------------------------------- +'2023-01-01' + INTERVAL '1' DAY; +CAST('2023-01-01' AS DATE) + INTERVAL '1' DAY; + +'2023-01-01' + INTERVAL '1' HOUR; +CAST('2023-01-01' AS DATETIME) + INTERVAL '1' HOUR; + +'2023-01-01 00:00:01' + INTERVAL '1' HOUR; +CAST('2023-01-01 00:00:01' AS DATETIME) + INTERVAL '1' HOUR; + +CAST('2023-01-01' AS DATE) + INTERVAL '1' HOUR; +CAST(CAST('2023-01-01' AS DATE) AS DATETIME) + INTERVAL '1' HOUR; + +SELECT t.d + INTERVAL '1' HOUR FROM temporal AS t; +SELECT CAST("t"."d" AS DATETIME) + INTERVAL '1' HOUR AS "_col_0" FROM "temporal" AS "t"; + DATE_ADD(CAST("x" AS DATE), 1, 'YEAR'); DATE_ADD(CAST("x" AS DATE), 1, 'YEAR'); DATE_ADD('2023-01-01', 1, 'YEAR'); DATE_ADD(CAST('2023-01-01' AS DATE), 1, 'YEAR'); + +DATE_ADD('2023-01-01 00:00:00', 1, 'DAY'); +DATE_ADD(CAST('2023-01-01 00:00:00' AS DATETIME), 1, 'DAY'); + +SELECT DATE_ADD(t.d, 1, 'HOUR') FROM temporal AS t; +SELECT DATE_ADD(CAST("t"."d" AS DATETIME), 1, 'HOUR') AS "_col_0" FROM "temporal" AS "t"; + +SELECT DATE_TRUNC('SECOND', t.d) FROM temporal AS t; +SELECT DATE_TRUNC('SECOND', CAST("t"."d" AS DATETIME)) AS "_col_0" FROM "temporal" AS "t"; + +DATE_TRUNC('DAY', '2023-01-01'); +DATE_TRUNC('DAY', CAST('2023-01-01' AS DATE)); + +DATEDIFF('2023-01-01', '2023-01-02', DAY); +DATEDIFF(CAST('2023-01-01' AS DATETIME), CAST('2023-01-02' AS DATETIME), DAY); diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index a9d6584..f81d54a 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -548,6 +548,23 @@ FROM ( FROM "sc"."tb" AS "tb" ) AS "_q_0" PIVOT(SUM("_q_0"."c") FOR "_q_0"."b" IN ('x', 'y', 'z')) AS "_q_1"; +# title: pivoted source with explicit selections where one of them is excluded & selected at the same time +# note: we need to respect the exclude when selecting * from pivoted source and not include the computed column twice +# execute: false +SELECT * EXCEPT (x), CAST(x AS TEXT) AS x FROM (SELECT a, b, c FROM sc.tb) PIVOT (SUM(c) FOR b IN ('x','y','z')); +SELECT + "_q_1"."a" AS "a", + "_q_1"."y" AS "y", + "_q_1"."z" AS "z", + CAST("_q_1"."x" AS TEXT) AS "x" +FROM ( + SELECT + "tb"."a" AS "a", + "tb"."b" AS "b", + "tb"."c" AS "c" + FROM "sc"."tb" AS "tb" +) AS "_q_0" PIVOT(SUM("_q_0"."c") FOR "_q_0"."b" IN ('x', 'y', 'z')) AS "_q_1"; + # title: pivoted source with implicit selections # execute: false SELECT * FROM (SELECT * FROM u) PIVOT (SUM(f) FOR h IN ('x', 'y')); @@ -1074,3 +1091,27 @@ SELECT `_q_0`.`fruitstruct`.`$id` AS `$id`, `_q_0`.`fruitstruct`.`value` AS `value` FROM `_q_0` AS `_q_0`; + +# title: mysql is case-sensitive by default +# dialect: mysql +# execute: false +WITH T AS (SELECT 1 AS CoL) SELECT * FROM `T`; +WITH `T` AS ( + SELECT + 1 AS `CoL` +) +SELECT + `T`.`CoL` AS `CoL` +FROM `T`; + +# title: override mysql's settings so it normalizes to lowercase +# dialect: mysql, normalization_strategy = lowercase +# execute: false +WITH T AS (SELECT 1 AS `CoL`) SELECT * FROM T; +WITH `t` AS ( + SELECT + 1 AS `CoL` +) +SELECT + `t`.`CoL` AS `CoL` +FROM `t`; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 3224a83..43127a9 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -65,10 +65,10 @@ SELECT a AS j, b FROM x ORDER BY j; SELECT x.a AS j, x.b AS b FROM x AS x ORDER BY j; SELECT a AS j, b AS a FROM x ORDER BY 1; -SELECT x.a AS j, x.b AS a FROM x AS x ORDER BY x.a; +SELECT x.a AS j, x.b AS a FROM x AS x ORDER BY j; SELECT SUM(a) AS c, SUM(b) AS d FROM x ORDER BY 1, 2; -SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY SUM(x.a), SUM(x.b); +SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY c, d; # execute: false SELECT CAST(a AS INT) FROM x ORDER BY a; @@ -76,7 +76,7 @@ SELECT CAST(x.a AS INT) AS a FROM x AS x ORDER BY a; # execute: false SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2; -SELECT SUM(x.a) AS _col_0, SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b); +SELECT SUM(x.a) AS _col_0, SUM(x.b) AS c FROM x AS x ORDER BY _col_0, c; SELECT a AS j, b FROM x GROUP BY j, b; SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a, x.b; @@ -85,7 +85,10 @@ SELECT a, b FROM x GROUP BY 1, 2; SELECT x.a AS a, x.b AS b FROM x AS x GROUP BY x.a, x.b; SELECT a, b FROM x ORDER BY 1, 2; -SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a, x.b; +SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a, b; + +SELECT DISTINCT a AS c, b AS d FROM x ORDER BY 1; +SELECT DISTINCT x.a AS c, x.b AS d FROM x AS x ORDER BY c; SELECT 2 FROM x GROUP BY 1; SELECT 2 AS "2" FROM x AS x GROUP BY 1; @@ -306,6 +309,10 @@ WITH cte AS (SELECT 1 AS x) SELECT cte.a AS a FROM cte AS cte(a); WITH cte(x, y) AS (SELECT 1, 2) SELECT cte.* FROM cte AS cte(a); WITH cte AS (SELECT 1 AS x, 2 AS y) SELECT cte.a AS a, cte.y AS y FROM cte AS cte(a); +# execute: false +WITH player AS (SELECT player.name, player.asset.info FROM players) SELECT * FROM player; +WITH player AS (SELECT players.player.name AS name, players.player.asset.info AS info FROM players AS players) SELECT player.name AS name, player.info AS info FROM player; + -------------------------------------- -- Except and Replace -------------------------------------- @@ -488,7 +495,7 @@ FROM ( ); SELECT _q_0.i AS i, _q_0.j AS j FROM (SELECT x.a + 1 AS i, x.a + 1 + 1 AS j FROM x AS x) AS _q_0; -# title: wrap expanded alias to ensure operator precedence isn't broken +# title: wrap expanded alias to ensure operator precedence isnt broken # execute: false SELECT x.a + x.b AS f, f * x.b FROM x; SELECT x.a + x.b AS f, (x.a + x.b) * x.b AS _col_1 FROM x AS x; diff --git a/tests/fixtures/optimizer/quote_identifiers.sql b/tests/fixtures/optimizer/quote_identifiers.sql new file mode 100644 index 0000000..21181f7 --- /dev/null +++ b/tests/fixtures/optimizer/quote_identifiers.sql @@ -0,0 +1,31 @@ +SELECT a FROM x; +SELECT "a" FROM "x"; + +SELECT "a" FROM "x"; +SELECT "a" FROM "x"; + +SELECT x.a AS a FROM db.x; +SELECT "x"."a" AS "a" FROM "db"."x"; + +SELECT @x; +SELECT @x; + +# dialect: snowflake +SELECT * FROM DUAL; +SELECT * FROM DUAL; + +# dialect: snowflake +SELECT * FROM "DUAL"; +SELECT * FROM "DUAL"; + +# dialect: snowflake +SELECT * FROM "dual"; +SELECT * FROM "dual"; + +# dialect: snowflake +SELECT dual FROM t; +SELECT "dual" FROM "t"; + +# dialect: snowflake +SELECT * FROM t AS dual; +SELECT * FROM "t" AS "dual"; diff --git a/tests/fixtures/optimizer/quote_identities.sql b/tests/fixtures/optimizer/quote_identities.sql deleted file mode 100644 index d6cfbf8..0000000 --- a/tests/fixtures/optimizer/quote_identities.sql +++ /dev/null @@ -1,11 +0,0 @@ -SELECT a FROM x; -SELECT "a" FROM "x"; - -SELECT "a" FROM "x"; -SELECT "a" FROM "x"; - -SELECT x.a AS a FROM db.x; -SELECT "x"."a" AS "a" FROM "db"."x"; - -SELECT @x; -SELECT @x; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index f50f688..2206e28 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -43,6 +43,9 @@ TRUE; 1.0 = 1; TRUE; +CAST('2023-01-01' AS DATE) = CAST('2023-01-01' AS DATE); +TRUE; + 'x' = 'y'; FALSE; @@ -360,6 +363,15 @@ x * (1 - y); (1.0 * 3) * 4 - 2 * (5 / 2); 12.0 - 2 * (5 / 2); +a * 0.5 / 10 / (2.0 + 3); +a * 0.5 / 10 / 5.0; + +a * 0.5 - 10 - (2.0 + 3); +a * 0.5 - 10 - 5.0; + +x * (10 - 5); +x * 5; + 6 - 2 + 4 * 2 + a; 12 + a; @@ -414,6 +426,9 @@ FALSE; 1 IS NOT NULL; TRUE; +date '1998-12-01' - interval x day; +CAST('1998-12-01' AS DATE) - INTERVAL x day; + date '1998-12-01' - interval '90' day; CAST('1998-09-02' AS DATE); @@ -447,6 +462,24 @@ CAST(x AS DATETIME) + INTERVAL '1' week; TS_OR_DS_TO_DATE('1998-12-01 00:00:01') - interval '90' day; CAST('1998-09-02' AS DATE); +DATE_ADD(CAST('2023-01-02' AS DATE), -2, 'MONTH'); +CAST('2022-11-02' AS DATE); + +DATE_SUB(CAST('2023-01-02' AS DATE), 1 + 1, 'DAY'); +CAST('2022-12-31' AS DATE); + +DATE_ADD(CAST('2023-01-02' AS DATETIME), -2, 'HOUR'); +CAST('2023-01-01 22:00:00' AS DATETIME); + +DATETIME_ADD(CAST('2023-01-02' AS DATETIME), -2, 'HOUR'); +CAST('2023-01-01 22:00:00' AS DATETIME); + +DATETIME_SUB(CAST('2023-01-02' AS DATETIME), 1 + 1, 'HOUR'); +CAST('2023-01-01 22:00:00' AS DATETIME); + +DATE_ADD(x, 1, 'MONTH'); +DATE_ADD(x, 1, 'MONTH'); + -------------------------------------- -- Comparisons -------------------------------------- @@ -663,6 +696,15 @@ ROW() OVER () = 1 OR ROW() OVER () IS NULL; a AND b AND COALESCE(ROW() OVER (), 1) = 1; a AND b AND (ROW() OVER () = 1 OR ROW() OVER () IS NULL); +COALESCE(1, 2); +1; + +COALESCE(CAST(CAST('2023-01-01' AS TIMESTAMP) AS DATE), x); +CAST(CAST('2023-01-01' AS TIMESTAMP) AS DATE); + +COALESCE(CAST(NULL AS DATE), x); +COALESCE(CAST(NULL AS DATE), x); + -------------------------------------- -- CONCAT -------------------------------------- @@ -673,7 +715,7 @@ CONCAT_WS(sep, x, y); CONCAT_WS(sep, x, y); CONCAT(x); -x; +CONCAT(x); CONCAT('a', 'b', 'c'); 'abc'; @@ -776,6 +818,9 @@ x >= CAST('2022-01-01' AS DATE); DATE_TRUNC('year', x) > TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE('2021-01-02')); x >= CAST('2022-01-01' AS DATE); +DATE_TRUNC('year', x) > TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE('2021-01-02', '%Y')); +DATE_TRUNC('year', x) > TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE('2021-01-02', '%Y')); + -- right is not a date DATE_TRUNC('year', x) <> '2021-01-02'; DATE_TRUNC('year', x) <> '2021-01-02'; diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index 62f1d79..d8cf64f 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -5699,7 +5699,7 @@ WHERE END > 0.1 ORDER BY "v1"."sum_sales" - "v1"."avg_monthly_sales", - "v1"."d_moy" + "d_moy" LIMIT 100; -------------------------------------- @@ -6020,9 +6020,9 @@ WITH "date_dim_2" AS ( WHERE "store"."currency_rank" <= 10 OR "store"."return_rank" <= 10 ORDER BY - 1, - "store"."return_rank", - "store"."currency_rank" + "channel", + "return_rank", + "currency_rank" LIMIT 100 ), "cte_4" AS ( SELECT @@ -6997,7 +6997,7 @@ WHERE END > 0.1 ORDER BY "v1"."sum_sales" - "v1"."avg_monthly_sales", - "v1"."avg_monthly_sales" + "avg_monthly_sales" LIMIT 100; -------------------------------------- @@ -10061,9 +10061,9 @@ WHERE AND "t_s_firstyear"."year1" = 1999 AND "t_s_firstyear"."year_total" > 0 ORDER BY - "t_s_secyear"."customer_id", - "t_s_secyear"."customer_first_name", - "t_s_secyear"."customer_last_name" + "customer_id", + "customer_first_name", + "customer_last_name" LIMIT 100; -------------------------------------- diff --git a/tests/test_build.py b/tests/test_build.py index 4dc993f..087bc7e 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -3,6 +3,7 @@ import unittest from sqlglot import ( alias, and_, + case, condition, except_, exp, @@ -77,9 +78,13 @@ class TestBuild(unittest.TestCase): (lambda: x.ilike("y"), "x ILIKE 'y'"), (lambda: x.rlike("y"), "REGEXP_LIKE(x, 'y')"), ( - lambda: exp.Case().when("x = 1", "x").else_("bar"), + lambda: case().when("x = 1", "x").else_("bar"), "CASE WHEN x = 1 THEN x ELSE bar END", ), + ( + lambda: case("x").when("1", "x").else_("bar"), + "CASE x WHEN 1 THEN x ELSE bar END", + ), (lambda: exp.func("COALESCE", "x", 1), "COALESCE(x, 1)"), (lambda: select("x"), "SELECT x"), (lambda: select("x"), "SELECT x"), @@ -614,6 +619,10 @@ class TestBuild(unittest.TestCase): "INSERT INTO tbl SELECT * FROM tbl2", ), ( + lambda: exp.insert("SELECT * FROM tbl2", "tbl", returning="*"), + "INSERT INTO tbl SELECT * FROM tbl2 RETURNING *", + ), + ( lambda: exp.insert("SELECT * FROM tbl2", "tbl", overwrite=True), "INSERT OVERWRITE TABLE tbl SELECT * FROM tbl2", ), @@ -630,6 +639,14 @@ class TestBuild(unittest.TestCase): "(x, y) IN ((1, 2), (3, 4))", "postgres", ), + ( + lambda: exp.cast_unless("CAST(x AS INT)", "int", "int"), + "CAST(x AS INT)", + ), + ( + lambda: exp.cast_unless("CAST(x AS TEXT)", "int", "int"), + "CAST(CAST(x AS TEXT) AS INT)", + ), ]: with self.subTest(sql): self.assertEqual(expression().sql(dialect[0] if dialect else None), sql) diff --git a/tests/test_executor.py b/tests/test_executor.py index 721550e..ffe00a7 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -289,11 +289,47 @@ class TestExecutor(unittest.TestCase): ["a"], [(1,), (2,), (3,)], ), + ( + "SELECT 1 AS a UNION SELECT 2 AS a UNION SELECT 3 AS a", + ["a"], + [(1,), (2,), (3,)], + ), + ( + "SELECT 1 / 2 AS a", + ["a"], + [ + (0.5,), + ], + ), + ("SELECT 1 / 0 AS a", ["a"], ZeroDivisionError), + ( + exp.select( + exp.alias_(exp.Literal.number(1).div(exp.Literal.number(2), typed=True), "a") + ), + ["a"], + [ + (0,), + ], + ), + ( + exp.select( + exp.alias_(exp.Literal.number(1).div(exp.Literal.number(0), safe=True), "a") + ), + ["a"], + [ + (None,), + ], + ), ]: with self.subTest(sql): - result = execute(sql, schema=schema, tables=tables) - self.assertEqual(result.columns, tuple(cols)) - self.assertEqual(set(result.rows), set(rows)) + if isinstance(rows, list): + result = execute(sql, schema=schema, tables=tables) + self.assertEqual(result.columns, tuple(cols)) + self.assertEqual(set(result.rows), set(rows)) + else: + with self.assertRaises(ExecuteError) as ctx: + execute(sql, schema=schema, tables=tables) + self.assertIsInstance(ctx.exception.__cause__, rows) def test_execute_catalog_db_table(self): tables = { @@ -632,6 +668,10 @@ class TestExecutor(unittest.TestCase): ("DATEDIFF('2022-01-03'::date, '2022-01-01'::TIMESTAMP::DATE)", 2), ("TRIM(' foo ')", "foo"), ("TRIM('afoob', 'ab')", "foo"), + ("ARRAY_JOIN(['foo', 'bar'], ':')", "foo:bar"), + ("ARRAY_JOIN(['hello', null ,'world'], ' ', ',')", "hello , world"), + ("ARRAY_JOIN(['', null ,'world'], ' ', ',')", " , world"), + ("STRUCT('foo', 'bar', null, null)", {"foo": "bar"}), ]: with self.subTest(sql): result = execute(f"SELECT {sql}") @@ -726,6 +766,11 @@ class TestExecutor(unittest.TestCase): [(1, 50), (2, 45), (3, 28)], ("a", "_col_1"), ), + ( + "SELECT a, ARRAY_UNIQUE_AGG(b) FROM x GROUP BY a", + [(1, [40, 10]), (2, [25, 20]), (3, [28])], + ("a", "_col_1"), + ), ): with self.subTest(sql): result = execute(sql, tables=tables) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 1fbe2d7..118b992 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -115,7 +115,26 @@ class TestExpressions(unittest.TestCase): self.assertIsNone(column.find_ancestor(exp.Join)) def test_to_dot(self): - column = parse_one('a.b.c."d".e.f').find(exp.Column) + orig = parse_one('a.b.c."d".e.f') + self.assertEqual(".".join(str(p) for p in orig.parts), 'a.b.c."d".e.f') + + self.assertEqual( + ".".join( + str(p) + for p in exp.Dot.build( + [ + exp.to_table("a.b.c"), + exp.to_identifier("d"), + exp.to_identifier("e"), + exp.to_identifier("f"), + ] + ).parts + ), + "a.b.c.d.e.f", + ) + + self.assertEqual(".".join(str(p) for p in orig.parts), 'a.b.c."d".e.f') + column = orig.find(exp.Column) dot = column.to_dot() self.assertEqual(dot.sql(), 'a.b.c."d".e.f') @@ -198,17 +217,42 @@ class TestExpressions(unittest.TestCase): "foo.`{bar,er}`", ) + self.assertEqual(exp.table_name(bq_dashed_table, identify=True), '"a-1"."b"."c"') + def test_table(self): self.assertEqual(exp.table_("a", alias="b"), parse_one("select * from a b").find(exp.Table)) self.assertEqual(exp.table_("a", "").sql(), "a") + self.assertEqual(exp.Table(db=exp.to_identifier("a")).sql(), "a") def test_replace_tables(self): self.assertEqual( exp.replace_tables( - parse_one("select * from a AS a, b, c.a, d.a cross join e.a"), - {"a": "a1", "b": "b.a", "c.a": "c.a2", "d.a": "d2"}, + parse_one( + 'select * from a AS a, b, c.a, d.a cross join e.a cross join "f-F"."A" cross join G' + ), + { + "a": "a1", + "b": "b.a", + "c.a": "c.a2", + "d.a": "d2", + "`f-F`.`A`": '"F"', + "g": "g1.a", + }, + dialect="bigquery", ).sql(), - "SELECT * FROM a1 AS a, b.a, c.a2, d2 CROSS JOIN e.a", + 'SELECT * FROM a1 AS a /* a */, b.a /* b */, c.a2 /* c.a */, d2 /* d.a */ CROSS JOIN e.a CROSS JOIN "F" /* f-F.A */ CROSS JOIN g1.a /* g */', + ) + + def test_expand(self): + self.assertEqual( + exp.expand( + parse_one('select * from "a-b"."C" AS a'), + { + "`a-b`.`c`": parse_one("select 1"), + }, + dialect="spark", + ).sql(), + "SELECT * FROM (SELECT 1) AS a /* source: a-b.c */", ) def test_replace_placeholders(self): @@ -267,9 +311,18 @@ class TestExpressions(unittest.TestCase): self.assertEqual(exp.func("bla", 1, "foo").sql(), "BLA(1, foo)") self.assertEqual(exp.func("COUNT", exp.Star()).sql(), "COUNT(*)") self.assertEqual(exp.func("bloo").sql(), "BLOO()") + self.assertEqual(exp.func("concat", exp.convert("a")).sql("duckdb"), "CONCAT('a')") self.assertEqual( exp.func("locate", "'x'", "'xo'", dialect="hive").sql("hive"), "LOCATE('x', 'xo')" ) + self.assertEqual( + exp.func("log", exp.to_identifier("x"), 2, dialect="bigquery").sql("bigquery"), + "LOG(x, 2)", + ) + self.assertEqual( + exp.func("log", dialect="bigquery", expression="x", this=2).sql("bigquery"), + "LOG(x, 2)", + ) self.assertIsInstance(exp.func("instr", "x", "b", dialect="mysql"), exp.StrPosition) self.assertIsInstance(exp.func("bla", 1, "foo"), exp.Anonymous) @@ -284,6 +337,15 @@ class TestExpressions(unittest.TestCase): with self.assertRaises(ValueError): exp.func("abs") + with self.assertRaises(ValueError) as cm: + exp.func("to_hex", dialect="bigquery", this=5) + + self.assertEqual( + str(cm.exception), + "Unable to convert 'to_hex' into a Func. Either manually construct the Func " + "expression of interest or parse the function call.", + ) + def test_named_selects(self): expression = parse_one( "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" diff --git a/tests/test_helper.py b/tests/test_helper.py index a8872e9..e17b281 100644 --- a/tests/test_helper.py +++ b/tests/test_helper.py @@ -1,6 +1,5 @@ import unittest -from sqlglot.dialects import BigQuery, Dialect, Snowflake from sqlglot.helper import merge_ranges, name_sequence, tsort @@ -30,32 +29,6 @@ class TestHelper(unittest.TestCase): } ) - def test_compare_dialects(self): - bigquery_class = Dialect["bigquery"] - bigquery_object = BigQuery() - bigquery_string = "bigquery" - - snowflake_class = Dialect["snowflake"] - snowflake_object = Snowflake() - snowflake_string = "snowflake" - - self.assertEqual(snowflake_class, snowflake_class) - self.assertEqual(snowflake_class, snowflake_object) - self.assertEqual(snowflake_class, snowflake_string) - self.assertEqual(snowflake_object, snowflake_object) - self.assertEqual(snowflake_object, snowflake_string) - - self.assertNotEqual(snowflake_class, bigquery_class) - self.assertNotEqual(snowflake_class, bigquery_object) - self.assertNotEqual(snowflake_class, bigquery_string) - self.assertNotEqual(snowflake_object, bigquery_object) - self.assertNotEqual(snowflake_object, bigquery_string) - - self.assertTrue(snowflake_class in {"snowflake", "bigquery"}) - self.assertTrue(snowflake_object in {"snowflake", "bigquery"}) - self.assertFalse(snowflake_class in {"bigquery", "redshift"}) - self.assertFalse(snowflake_object in {"bigquery", "redshift"}) - def test_name_sequence(self): s1 = name_sequence("a") s2 = name_sequence("b") diff --git a/tests/test_lineage.py b/tests/test_lineage.py index 25329e2..8755b42 100644 --- a/tests/test_lineage.py +++ b/tests/test_lineage.py @@ -229,13 +229,36 @@ class TestLineage(unittest.TestCase): "output", "SELECT (SELECT max(t3.my_column) my_column FROM foo t3) AS output FROM table3", ) - self.assertEqual(node.name, "SUBQUERY") + self.assertEqual(node.name, "output") node = node.downstream[0] self.assertEqual(node.name, "my_column") node = node.downstream[0] self.assertEqual(node.name, "t3.my_column") self.assertEqual(node.source.sql(), "foo AS t3") + node = lineage( + "y", + "SELECT SUM((SELECT max(a) a from x) + (SELECT min(b) b from x) + c) AS y FROM x", + ) + self.assertEqual(node.name, "y") + self.assertEqual(len(node.downstream), 3) + self.assertEqual(node.downstream[0].name, "a") + self.assertEqual(node.downstream[1].name, "b") + self.assertEqual(node.downstream[2].name, "x.c") + + node = lineage( + "x", + "WITH cte AS (SELECT a, b FROM z) SELECT sum(SELECT a FROM cte) AS x, (SELECT b FROM cte) as y FROM cte", + ) + self.assertEqual(node.name, "x") + self.assertEqual(len(node.downstream), 1) + node = node.downstream[0] + self.assertEqual(node.name, "a") + node = node.downstream[0] + self.assertEqual(node.name, "cte.a") + node = node.downstream[0] + self.assertEqual(node.name, "z.a") + def test_lineage_cte_union(self) -> None: query = """ WITH dataset AS ( diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index fd95577..141203d 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -103,6 +103,10 @@ class TestOptimizer(unittest.TestCase): "d": "TEXT", "e": "TEXT", }, + "temporal": { + "d": "DATE", + "t": "DATETIME", + }, } def check_file(self, file, func, pretty=False, execute=False, set_dialect=False, **kwargs): @@ -179,6 +183,18 @@ class TestOptimizer(unittest.TestCase): ) def test_qualify_tables(self): + self.assertEqual( + optimizer.qualify_tables.qualify_tables( + parse_one("select a from b"), catalog="catalog" + ).sql(), + "SELECT a FROM b AS b", + ) + + self.assertEqual( + optimizer.qualify_tables.qualify_tables(parse_one("select a from b"), db='"DB"').sql(), + 'SELECT a FROM "DB".b AS b', + ) + self.check_file( "qualify_tables", optimizer.qualify_tables.qualify_tables, @@ -282,6 +298,13 @@ class TestOptimizer(unittest.TestCase): self.assertEqual(optimizer.normalize_identifiers.normalize_identifiers("a%").sql(), '"a%"') + def test_quote_identifiers(self): + self.check_file( + "quote_identifiers", + optimizer.qualify_columns.quote_identifiers, + set_dialect=True, + ) + def test_pushdown_projection(self): self.check_file("pushdown_projections", pushdown_projections, schema=self.schema) @@ -300,8 +323,8 @@ class TestOptimizer(unittest.TestCase): safe_concat = parse_one("CONCAT('a', x, 'b', 'c')") simplified_safe_concat = optimizer.simplify.simplify(safe_concat) - self.assertIs(type(simplified_concat), exp.Concat) - self.assertIs(type(simplified_safe_concat), exp.SafeConcat) + self.assertEqual(simplified_concat.args["safe"], False) + self.assertEqual(simplified_safe_concat.args["safe"], True) self.assertEqual("CONCAT('a', x, 'bc')", simplified_concat.sql(dialect="presto")) self.assertEqual("CONCAT('a', x, 'bc')", simplified_safe_concat.sql()) @@ -561,6 +584,19 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT) self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT) + for numeric_type in ("BIGINT", "DOUBLE", "INT"): + query = f"SELECT '1' + CAST(x AS {numeric_type})" + expression = annotate_types(parse_one(query)).expressions[0] + self.assertEqual(expression.type, exp.DataType.build(numeric_type)) + + def test_typeddiv_annotation(self): + expressions = annotate_types( + parse_one("SELECT 2 / 3, 2 / 3.0", dialect="presto") + ).expressions + + self.assertEqual(expressions[0].type.this, exp.DataType.Type.BIGINT) + self.assertEqual(expressions[1].type.this, exp.DataType.Type.DOUBLE) + def test_bracket_annotation(self): expression = annotate_types(parse_one("SELECT A[:]")).expressions[0] @@ -609,45 +645,60 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') "b": "DATETIME", } } - for sql, expected_type, *expected_sql in [ + for sql, expected_type in [ ( "SELECT '2023-01-01' + INTERVAL '1' DAY", exp.DataType.Type.DATE, - "SELECT CAST('2023-01-01' AS DATE) + INTERVAL '1' DAY", ), ( "SELECT '2023-01-01' + INTERVAL '1' HOUR", exp.DataType.Type.DATETIME, - "SELECT CAST('2023-01-01' AS DATETIME) + INTERVAL '1' HOUR", ), ( "SELECT '2023-01-01 00:00:01' + INTERVAL '1' HOUR", exp.DataType.Type.DATETIME, - "SELECT CAST('2023-01-01 00:00:01' AS DATETIME) + INTERVAL '1' HOUR", ), ("SELECT 'nonsense' + INTERVAL '1' DAY", exp.DataType.Type.UNKNOWN), ("SELECT x.a + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATE), - ("SELECT x.a + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME), + ( + "SELECT x.a + INTERVAL '1' HOUR FROM x AS x", + exp.DataType.Type.DATETIME, + ), ("SELECT x.b + INTERVAL '1' DAY FROM x AS x", exp.DataType.Type.DATETIME), ("SELECT x.b + INTERVAL '1' HOUR FROM x AS x", exp.DataType.Type.DATETIME), ( "SELECT DATE_ADD('2023-01-01', 1, 'DAY')", exp.DataType.Type.DATE, - "SELECT DATE_ADD(CAST('2023-01-01' AS DATE), 1, 'DAY')", ), ( "SELECT DATE_ADD('2023-01-01 00:00:00', 1, 'DAY')", exp.DataType.Type.DATETIME, - "SELECT DATE_ADD(CAST('2023-01-01 00:00:00' AS DATETIME), 1, 'DAY')", ), ("SELECT DATE_ADD(x.a, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATE), - ("SELECT DATE_ADD(x.a, 1, 'HOUR') FROM x AS x", exp.DataType.Type.DATETIME), + ( + "SELECT DATE_ADD(x.a, 1, 'HOUR') FROM x AS x", + exp.DataType.Type.DATETIME, + ), ("SELECT DATE_ADD(x.b, 1, 'DAY') FROM x AS x", exp.DataType.Type.DATETIME), + ("SELECT DATE_TRUNC('DAY', x.a) FROM x AS x", exp.DataType.Type.DATE), + ("SELECT DATE_TRUNC('DAY', x.b) FROM x AS x", exp.DataType.Type.DATETIME), + ( + "SELECT DATE_TRUNC('SECOND', x.a) FROM x AS x", + exp.DataType.Type.DATETIME, + ), + ( + "SELECT DATE_TRUNC('DAY', '2023-01-01') FROM x AS x", + exp.DataType.Type.DATE, + ), + ( + "SELECT DATEDIFF('2023-01-01', '2023-01-02', DAY) FROM x AS x", + exp.DataType.Type.INT, + ), ]: with self.subTest(sql): expression = annotate_types(parse_one(sql), schema=schema) self.assertEqual(expected_type, expression.expressions[0].type.this) - self.assertEqual(expected_sql[0] if expected_sql else sql, expression.sql()) + self.assertEqual(sql, expression.sql()) def test_lateral_annotation(self): expression = optimizer.optimize( @@ -843,6 +894,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') ("MAX", "cold"): exp.DataType.Type.DATE, ("COUNT", "colb"): exp.DataType.Type.BIGINT, ("STDDEV", "cola"): exp.DataType.Type.DOUBLE, + ("ABS", "cola"): exp.DataType.Type.SMALLINT, + ("ABS", "colb"): exp.DataType.Type.FLOAT, } for (func, col), target_type in tests.items(): @@ -989,10 +1042,3 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') query = parse_one("select a.b:c from d", read="snowflake") qualified = optimizer.qualify.qualify(query) self.assertEqual(qualified.expressions[0].alias, "c") - - def test_qualify_tables_no_schema(self): - query = parse_one("select a from b") - self.assertEqual( - optimizer.qualify_tables.qualify_tables(query, catalog="catalog").sql(), - "SELECT a FROM b AS b", - ) diff --git a/tests/test_schema.py b/tests/test_schema.py index 34c507d..8bdd312 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -216,15 +216,15 @@ class TestSchema(unittest.TestCase): # Check that add_table normalizes both the table and the column names to be added / updated schema = MappingSchema() schema.add_table("Foo", {"SomeColumn": "INT", '"SomeColumn"': "DOUBLE"}) - self.assertEqual(schema.column_names(exp.Table(this="fOO")), ["somecolumn", "SomeColumn"]) + self.assertEqual(schema.column_names(exp.table_("fOO")), ["somecolumn", "SomeColumn"]) # Check that names are normalized to uppercase for Snowflake schema = MappingSchema(schema={"x": {"foo": "int", '"bLa"': "int"}}, dialect="snowflake") - self.assertEqual(schema.column_names(exp.Table(this="x")), ["FOO", "bLa"]) + self.assertEqual(schema.column_names(exp.table_("x")), ["FOO", "bLa"]) # Check that switching off the normalization logic works as expected schema = MappingSchema(schema={"x": {"foo": "int"}}, normalize=False, dialect="snowflake") - self.assertEqual(schema.column_names(exp.Table(this="x")), ["foo"]) + self.assertEqual(schema.column_names(exp.table_("x")), ["foo"]) # Check that the correct dialect is used when calling schema methods # Note: T-SQL is case-insensitive by default, so `fo` in clickhouse will match the normalized table name diff --git a/tests/test_tokens.py b/tests/test_tokens.py index f4d3858..b97f54a 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -13,6 +13,7 @@ class TestTokens(unittest.TestCase): (" group bys ", 2), ("group by)", 2), ("group bys)", 3), + ("group \r", 1), ): tokens = Tokenizer().tokenize(string) self.assertTrue("GROUP" in tokens[0].text.upper()) @@ -63,6 +64,13 @@ x""" self.assertEqual(Tokenizer().tokenize("'''abc'")[0].end, 6) self.assertEqual(Tokenizer().tokenize("'abc'")[0].start, 0) + tokens = Tokenizer().tokenize("SELECT\r\n 1,\r\n 2") + + self.assertEqual(tokens[0].line, 1) + self.assertEqual(tokens[1].line, 2) + self.assertEqual(tokens[2].line, 2) + self.assertEqual(tokens[3].line, 3) + def test_command(self): tokens = Tokenizer().tokenize("SHOW;") self.assertEqual(tokens[0].token_type, TokenType.SHOW) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index c16b1f6..b732b45 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -4,6 +4,7 @@ from unittest import mock from sqlglot import parse_one, transpile from sqlglot.errors import ErrorLevel, ParseError, UnsupportedError +from sqlglot.helper import logger as helper_logger from tests.helpers import ( assert_logger_contains, load_sql_fixture_pairs, @@ -91,6 +92,10 @@ class TestTranspile(unittest.TestCase): def test_comments(self): self.validate( + "SELECT c AS /* foo */ (a, b, c) FROM t", + "SELECT c AS (a, b, c) /* foo */ FROM t", + ) + self.validate( "SELECT * FROM t1\n/*x*/\nUNION ALL SELECT * FROM t2", "SELECT * FROM t1 /* x */ UNION ALL SELECT * FROM t2", ) @@ -434,6 +439,40 @@ SELECT FROM dw_1_dw_1_1.exactonline_2.transactionlines""", pretty=True, ) + self.validate( + """/* The result of some calculations + */ +with + base as ( + select + sum(sb.hep_amount) as hep_amount, + -- I AM REMOVED + sum(sb.hep_budget) + /* Budget defined in sharepoint */ + as blub + , 1 as bla + from gold.data_budget sb + group by all + ) +select + * +from base +""", + """/* The result of some calculations + */ +WITH base AS ( + SELECT + SUM(sb.hep_amount) AS hep_amount, + SUM(sb.hep_budget) /* I AM REMOVED */ AS blub, /* Budget defined in sharepoint */ + 1 AS bla + FROM gold.data_budget AS sb + GROUP BY ALL +) +SELECT + * +FROM base""", + pretty=True, + ) def test_types(self): self.validate("INT 1", "CAST(1 AS INT)") @@ -661,19 +700,27 @@ FROM dw_1_dw_1_1.exactonline_2.transactionlines""", write="spark2", ) - @mock.patch("sqlglot.helper.logger") - def test_index_offset(self, logger): - self.validate("x[0]", "x[1]", write="presto", identity=False) - self.validate("x[1]", "x[0]", read="presto", identity=False) - logger.warning.assert_any_call("Applying array index offset (%s)", 1) - logger.warning.assert_any_call("Applying array index offset (%s)", -1) + def test_index_offset(self): + with self.assertLogs(helper_logger) as cm: + self.validate("x[0]", "x[1]", write="presto", identity=False) + self.validate("x[1]", "x[0]", read="presto", identity=False) - self.validate("x[x - 1]", "x[x - 1]", write="presto", identity=False) - self.validate( - "x[array_size(y) - 1]", "x[CARDINALITY(y) - 1 + 1]", write="presto", identity=False - ) - self.validate("x[3 - 1]", "x[3]", write="presto", identity=False) - self.validate("MAP(a, b)[0]", "MAP(a, b)[0]", write="presto", identity=False) + self.validate("x[x - 1]", "x[x - 1]", write="presto", identity=False) + self.validate( + "x[array_size(y) - 1]", "x[CARDINALITY(y) - 1 + 1]", write="presto", identity=False + ) + self.validate("x[3 - 1]", "x[3]", write="presto", identity=False) + self.validate("MAP(a, b)[0]", "MAP(a, b)[0]", write="presto", identity=False) + + self.assertEqual( + cm.output, + [ + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (-1)", + "WARNING:sqlglot:Applying array index offset (1)", + "WARNING:sqlglot:Applying array index offset (1)", + ], + ) def test_identify_lambda(self): self.validate("x(y -> y)", 'X("y" -> "y")', identify=True) @@ -706,6 +753,10 @@ FROM dw_1_dw_1_1.exactonline_2.transactionlines""", def test_pretty_line_breaks(self): self.assertEqual(transpile("SELECT '1\n2'", pretty=True)[0], "SELECT\n '1\n2'") + self.assertEqual( + transpile("SELECT '1\n2'", pretty=True, unsupported_level=ErrorLevel.IGNORE)[0], + "SELECT\n '1\n2'", + ) @mock.patch("sqlglot.parser.logger") def test_error_level(self, logger): |