diff options
Diffstat (limited to 'tests')
34 files changed, 1006 insertions, 144 deletions
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py index 2dcdb39..4363b0d 100644 --- a/tests/dataframe/unit/dataframe_sql_validator.py +++ b/tests/dataframe/unit/dataframe_sql_validator.py @@ -1,14 +1,11 @@ -import typing as t -import unittest - from sqlglot.dataframe.sql import types -from sqlglot.dataframe.sql.dataframe import DataFrame from sqlglot.dataframe.sql.session import SparkSession -from sqlglot.helper import ensure_list +from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase -class DataFrameSQLValidator(unittest.TestCase): +class DataFrameSQLValidator(DataFrameTestBase): def setUp(self) -> None: + super().setUp() self.spark = SparkSession() self.employee_schema = types.StructType( [ @@ -29,12 +26,3 @@ class DataFrameSQLValidator(unittest.TestCase): self.df_employee = self.spark.createDataFrame( data=employee_data, schema=self.employee_schema ) - - def compare_sql( - self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False - ): - actual_sqls = df.sql(pretty=pretty) - expected_statements = ensure_list(expected_statements) - self.assertEqual(len(expected_statements), len(actual_sqls)) - for expected, actual in zip(expected_statements, actual_sqls): - self.assertEqual(expected, actual) diff --git a/tests/dataframe/unit/dataframe_test_base.py b/tests/dataframe/unit/dataframe_test_base.py new file mode 100644 index 0000000..6b07df9 --- /dev/null +++ b/tests/dataframe/unit/dataframe_test_base.py @@ -0,0 +1,23 @@ +import typing as t +import unittest + +import sqlglot +from sqlglot import MappingSchema +from sqlglot.dataframe.sql import SparkSession +from sqlglot.dataframe.sql.dataframe import DataFrame +from sqlglot.helper import ensure_list + + +class DataFrameTestBase(unittest.TestCase): + def setUp(self) -> None: + sqlglot.schema = MappingSchema() + SparkSession._instance = None + + def compare_sql( + self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False + ): + actual_sqls = df.sql(pretty=pretty) + expected_statements = ensure_list(expected_statements) + self.assertEqual(len(expected_statements), len(actual_sqls)) + for expected, actual in zip(expected_statements, actual_sqls): + self.assertEqual(expected, actual) diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py index 4c275e9..9758033 100644 --- a/tests/dataframe/unit/test_session.py +++ b/tests/dataframe/unit/test_session.py @@ -1,9 +1,6 @@ -from unittest import mock - import sqlglot from sqlglot.dataframe.sql import functions as F, types from sqlglot.dataframe.sql.session import SparkSession -from sqlglot.schema import MappingSchema from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator @@ -68,7 +65,6 @@ class TestDataframeSession(DataFrameSQLValidator): self.compare_sql(df, expected) - @mock.patch("sqlglot.schema", MappingSchema()) def test_sql_select_only(self): query = "SELECT cola, colb FROM table" sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") @@ -78,16 +74,6 @@ class TestDataframeSession(DataFrameSQLValidator): df.sql(pretty=False)[0], ) - @mock.patch("sqlglot.schema", MappingSchema()) - def test_select_quoted(self): - sqlglot.schema.add_table("`TEST`", {"name": "string"}, dialect="spark") - - self.assertEqual( - SparkSession().table("`TEST`").select(F.col("name")).sql(dialect="snowflake")[0], - '''SELECT "test"."name" AS "name" FROM "test" AS "test"''', - ) - - @mock.patch("sqlglot.schema", MappingSchema()) def test_sql_with_aggs(self): query = "SELECT cola, colb FROM table" sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") @@ -97,7 +83,6 @@ class TestDataframeSession(DataFrameSQLValidator): df.sql(pretty=False, optimize=False)[0], ) - @mock.patch("sqlglot.schema", MappingSchema()) def test_sql_create(self): query = "CREATE TABLE new_table AS WITH t1 AS (SELECT cola, colb FROM table) SELECT cola, colb, FROM t1" sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") @@ -105,7 +90,6 @@ class TestDataframeSession(DataFrameSQLValidator): expected = "CREATE TABLE new_table AS SELECT `table`.`cola` AS `cola`, `table`.`colb` AS `colb` FROM `table` AS `table`" self.compare_sql(df, expected) - @mock.patch("sqlglot.schema", MappingSchema()) def test_sql_insert(self): query = "WITH t1 AS (SELECT cola, colb FROM table) INSERT INTO new_table SELECT cola, colb FROM t1" sqlglot.schema.add_table("table", {"cola": "string", "colb": "string"}, dialect="spark") @@ -114,5 +98,4 @@ class TestDataframeSession(DataFrameSQLValidator): self.compare_sql(df, expected) def test_session_create_builder_patterns(self): - spark = SparkSession() - self.assertEqual(spark.builder.appName("abc").getOrCreate(), spark) + self.assertEqual(SparkSession.builder.appName("abc").getOrCreate(), SparkSession()) diff --git a/tests/dataframe/unit/test_session_case_sensitivity.py b/tests/dataframe/unit/test_session_case_sensitivity.py new file mode 100644 index 0000000..7e35289 --- /dev/null +++ b/tests/dataframe/unit/test_session_case_sensitivity.py @@ -0,0 +1,81 @@ +import sqlglot +from sqlglot.dataframe.sql import functions as F +from sqlglot.dataframe.sql.session import SparkSession +from sqlglot.errors import OptimizeError +from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase + + +class TestSessionCaseSensitivity(DataFrameTestBase): + def setUp(self) -> None: + super().setUp() + self.spark = SparkSession.builder.config("sqlframe.dialect", "snowflake").getOrCreate() + + tests = [ + ( + "All lower no intention of CS", + "test", + "test", + {"name": "VARCHAR"}, + "name", + '''SELECT "TEST"."NAME" AS "NAME" FROM "TEST" AS "TEST"''', + ), + ( + "Table has CS while column does not", + '"Test"', + '"Test"', + {"name": "VARCHAR"}, + "name", + '''SELECT "TEST"."NAME" AS "NAME" FROM "Test" AS "TEST"''', + ), + ( + "Column has CS while table does not", + "test", + "test", + {'"Name"': "VARCHAR"}, + '"Name"', + '''SELECT "TEST"."Name" AS "Name" FROM "TEST" AS "TEST"''', + ), + ( + "Both Table and column have CS", + '"Test"', + '"Test"', + {'"Name"': "VARCHAR"}, + '"Name"', + '''SELECT "TEST"."Name" AS "Name" FROM "Test" AS "TEST"''', + ), + ( + "Lowercase CS table and column", + '"test"', + '"test"', + {'"name"': "VARCHAR"}, + '"name"', + '''SELECT "TEST"."name" AS "name" FROM "test" AS "TEST"''', + ), + ( + "CS table and column and query table but no CS in query column", + '"test"', + '"test"', + {'"name"': "VARCHAR"}, + "name", + OptimizeError(), + ), + ( + "CS table and column and query column but no CS in query table", + '"test"', + "test", + {'"name"': "VARCHAR"}, + '"name"', + OptimizeError(), + ), + ] + + def test_basic_case_sensitivity(self): + for test_name, table_name, spark_table, schema, spark_column, expected in self.tests: + with self.subTest(test_name): + sqlglot.schema.add_table(table_name, schema, dialect=self.spark.dialect) + df = self.spark.table(spark_table).select(F.col(spark_column)) + if isinstance(expected, OptimizeError): + with self.assertRaises(OptimizeError): + df.sql() + else: + self.compare_sql(df, expected) diff --git a/tests/dataframe/unit/test_window.py b/tests/dataframe/unit/test_window.py index 45d736f..9c4c897 100644 --- a/tests/dataframe/unit/test_window.py +++ b/tests/dataframe/unit/test_window.py @@ -1,10 +1,9 @@ -import unittest - from sqlglot.dataframe.sql import functions as F from sqlglot.dataframe.sql.window import Window, WindowSpec +from tests.dataframe.unit.dataframe_test_base import DataFrameTestBase -class TestDataframeWindow(unittest.TestCase): +class TestDataframeWindow(DataFrameTestBase): def test_window_spec_partition_by(self): partition_by = WindowSpec().partitionBy(F.col("cola"), F.col("colb")) self.assertEqual("OVER (PARTITION BY cola, colb)", partition_by.sql()) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 52f86bd..b776bdd 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -6,8 +6,36 @@ from tests.dialects.test_dialect import Validator class TestBigQuery(Validator): dialect = "bigquery" + maxDiff = None def test_bigquery(self): + self.validate_identity("""SELECT JSON '"foo"' AS json_data""") + self.validate_identity("SELECT * FROM tbl FOR SYSTEM_TIME AS OF z") + + self.validate_all( + """SELECT + `u`.`harness_user_email` AS `harness_user_email`, + `d`.`harness_user_id` AS `harness_user_id`, + `harness_account_id` AS `harness_account_id` +FROM `analytics_staging`.`stg_mongodb__users` AS `u`, UNNEST(`u`.`harness_cluster_details`) AS `d`, UNNEST(`d`.`harness_account_ids`) AS `harness_account_id` +WHERE + NOT `harness_account_id` IS NULL""", + read={ + "": """ + SELECT + "u"."harness_user_email" AS "harness_user_email", + "_q_0"."d"."harness_user_id" AS "harness_user_id", + "_q_1"."harness_account_id" AS "harness_account_id" + FROM + "analytics_staging"."stg_mongodb__users" AS "u", + UNNEST("u"."harness_cluster_details") AS "_q_0"("d"), + UNNEST("_q_0"."d"."harness_account_ids") AS "_q_1"("harness_account_id") + WHERE + NOT "_q_1"."harness_account_id" IS NULL + """ + }, + pretty=True, + ) with self.assertRaises(TokenError): transpile("'\\'", read="bigquery") @@ -57,6 +85,10 @@ class TestBigQuery(Validator): self.validate_identity("SELECT * FROM my-table") self.validate_identity("SELECT * FROM my-project.mydataset.mytable") self.validate_identity("SELECT * FROM pro-ject_id.c.d CROSS JOIN foo-bar") + self.validate_identity("SELECT * FROM foo.bar.25", "SELECT * FROM foo.bar.`25`") + self.validate_identity("SELECT * FROM foo.bar.25_", "SELECT * FROM foo.bar.`25_`") + self.validate_identity("SELECT * FROM foo.bar.25x a", "SELECT * FROM foo.bar.`25x` AS a") + self.validate_identity("SELECT * FROM foo.bar.25ab c", "SELECT * FROM foo.bar.`25ab` AS c") self.validate_identity("x <> ''") self.validate_identity("DATE_TRUNC(col, WEEK(MONDAY))") self.validate_identity("SELECT b'abc'") @@ -105,6 +137,34 @@ class TestBigQuery(Validator): self.validate_all('x <> """"""', write={"bigquery": "x <> ''"}) self.validate_all("x <> ''''''", write={"bigquery": "x <> ''"}) self.validate_all("CAST(x AS DATETIME)", read={"": "x::timestamp"}) + self.validate_all( + "SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)", + write={ + "bigquery": "SELECT DATETIME_DIFF('2023-01-01T00:00:00', '2023-01-01T05:00:00', MILLISECOND)", + "databricks": "SELECT TIMESTAMPDIFF(MILLISECOND, '2023-01-01T05:00:00', '2023-01-01T00:00:00')", + }, + ), + self.validate_all( + "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", + write={ + "bigquery": "SELECT DATETIME_ADD('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", + "databricks": "SELECT TIMESTAMPADD(MILLISECOND, 1, '2023-01-01T00:00:00')", + }, + ), + self.validate_all( + "SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", + write={ + "bigquery": "SELECT DATETIME_SUB('2023-01-01T00:00:00', INTERVAL 1 MILLISECOND)", + "databricks": "SELECT TIMESTAMPADD(MILLISECOND, 1 * -1, '2023-01-01T00:00:00')", + }, + ), + self.validate_all( + "SELECT DATETIME_TRUNC('2023-01-01T01:01:01', HOUR)", + write={ + "bigquery": "SELECT DATETIME_TRUNC('2023-01-01T01:01:01', HOUR)", + "databricks": "SELECT DATE_TRUNC('HOUR', '2023-01-01T01:01:01')", + }, + ), self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"}) self.validate_all("CAST(x AS CHAR)", write={"bigquery": "CAST(x AS STRING)"}) self.validate_all("CAST(x AS NCHAR)", write={"bigquery": "CAST(x AS STRING)"}) @@ -141,6 +201,20 @@ class TestBigQuery(Validator): }, ) self.validate_all( + "SHA256(x)", + write={ + "bigquery": "SHA256(x)", + "spark2": "SHA2(x, 256)", + }, + ) + self.validate_all( + "SHA512(x)", + write={ + "bigquery": "SHA512(x)", + "spark2": "SHA2(x, 512)", + }, + ) + self.validate_all( "SELECT CAST('20201225' AS TIMESTAMP FORMAT 'YYYYMMDD' AT TIME ZONE 'America/New_York')", write={"bigquery": "SELECT PARSE_TIMESTAMP('%Y%m%d', '20201225', 'America/New_York')"}, ) @@ -249,7 +323,7 @@ class TestBigQuery(Validator): self.validate_all( "r'x\\y'", write={ - "bigquery": "'x\\\y'", + "bigquery": "'x\\\\y'", "hive": "'x\\\\y'", }, ) @@ -329,14 +403,14 @@ class TestBigQuery(Validator): self.validate_all( "[1, 2, 3]", read={ - "duckdb": "LIST_VALUE(1, 2, 3)", + "duckdb": "[1, 2, 3]", "presto": "ARRAY[1, 2, 3]", "hive": "ARRAY(1, 2, 3)", "spark": "ARRAY(1, 2, 3)", }, write={ "bigquery": "[1, 2, 3]", - "duckdb": "LIST_VALUE(1, 2, 3)", + "duckdb": "[1, 2, 3]", "presto": "ARRAY[1, 2, 3]", "hive": "ARRAY(1, 2, 3)", "spark": "ARRAY(1, 2, 3)", @@ -710,3 +784,28 @@ class TestBigQuery(Validator): "WITH cte AS (SELECT 1 AS foo UNION ALL SELECT 2) SELECT foo FROM cte", read={"postgres": "WITH cte(foo) AS (SELECT 1 UNION ALL SELECT 2) SELECT foo FROM cte"}, ) + + def test_json_object(self): + self.validate_identity("SELECT JSON_OBJECT() AS json_data") + self.validate_identity("SELECT JSON_OBJECT('foo', 10, 'bar', TRUE) AS json_data") + self.validate_identity("SELECT JSON_OBJECT('foo', 10, 'bar', ['a', 'b']) AS json_data") + self.validate_identity("SELECT JSON_OBJECT('a', 10, 'a', 'foo') AS json_data") + self.validate_identity( + "SELECT JSON_OBJECT(['a', 'b'], [10, NULL]) AS json_data", + "SELECT JSON_OBJECT('a', 10, 'b', NULL) AS json_data", + ) + self.validate_identity( + """SELECT JSON_OBJECT(['a', 'b'], [JSON '10', JSON '"foo"']) AS json_data""", + """SELECT JSON_OBJECT('a', JSON '10', 'b', JSON '"foo"') AS json_data""", + ) + self.validate_identity( + "SELECT JSON_OBJECT(['a', 'b'], [STRUCT(10 AS id, 'Red' AS color), STRUCT(20 AS id, 'Blue' AS color)]) AS json_data", + "SELECT JSON_OBJECT('a', STRUCT(10 AS id, 'Red' AS color), 'b', STRUCT(20 AS id, 'Blue' AS color)) AS json_data", + ) + self.validate_identity( + "SELECT JSON_OBJECT(['a', 'b'], [TO_JSON(10), TO_JSON(['foo', 'bar'])]) AS json_data", + "SELECT JSON_OBJECT('a', TO_JSON(10), 'b', TO_JSON(['foo', 'bar'])) AS json_data", + ) + + with self.assertRaises(ParseError): + transpile("SELECT JSON_OBJECT('a', 1, 'b') AS json_data", read="bigquery") diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 583be3e..ab2379d 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -6,6 +6,31 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): + self.validate_all( + "DATE_ADD('day', 1, x)", + read={ + "clickhouse": "dateAdd(day, 1, x)", + "presto": "DATE_ADD('day', 1, x)", + }, + write={ + "clickhouse": "DATE_ADD('day', 1, x)", + "presto": "DATE_ADD('day', 1, x)", + "": "DATE_ADD(x, 1, 'day')", + }, + ) + self.validate_all( + "DATE_DIFF('day', a, b)", + read={ + "clickhouse": "dateDiff('day', a, b)", + "presto": "DATE_DIFF('day', a, b)", + }, + write={ + "clickhouse": "DATE_DIFF('day', a, b)", + "presto": "DATE_DIFF('day', a, b)", + "": "DATEDIFF(b, a, day)", + }, + ) + expr = parse_one("count(x)") self.assertEqual(expr.sql(dialect="clickhouse"), "COUNT(x)") self.assertIsNone(expr._meta) @@ -47,8 +72,10 @@ class TestClickhouse(Validator): self.validate_identity("position(haystack, needle)") self.validate_identity("position(haystack, needle, position)") self.validate_identity("CAST(x AS DATETIME)") + self.validate_identity("CAST(x AS VARCHAR(255))", "CAST(x AS String)") + self.validate_identity("CAST(x AS BLOB)", "CAST(x AS String)") self.validate_identity( - 'SELECT CAST(tuple(1 AS "a", 2 AS "b", 3.0 AS "c").2 AS Nullable(TEXT))' + 'SELECT CAST(tuple(1 AS "a", 2 AS "b", 3.0 AS "c").2 AS Nullable(String))' ) self.validate_identity( "CREATE TABLE test (id UInt8) ENGINE=AggregatingMergeTree() ORDER BY tuple()" @@ -95,11 +122,11 @@ class TestClickhouse(Validator): }, ) self.validate_all( - "CONCAT(CASE WHEN COALESCE(CAST(a AS TEXT), '') IS NULL THEN COALESCE(CAST(a AS TEXT), '') ELSE CAST(COALESCE(CAST(a AS TEXT), '') AS TEXT) END, CASE WHEN COALESCE(CAST(b AS TEXT), '') IS NULL THEN COALESCE(CAST(b AS TEXT), '') ELSE CAST(COALESCE(CAST(b AS TEXT), '') AS TEXT) END)", + "CONCAT(CASE WHEN COALESCE(CAST(a AS String), '') IS NULL THEN COALESCE(CAST(a AS String), '') ELSE CAST(COALESCE(CAST(a AS String), '') AS String) END, CASE WHEN COALESCE(CAST(b AS String), '') IS NULL THEN COALESCE(CAST(b AS String), '') ELSE CAST(COALESCE(CAST(b AS String), '') AS String) END)", read={"postgres": "CONCAT(a, b)"}, ) self.validate_all( - "CONCAT(CASE WHEN a IS NULL THEN a ELSE CAST(a AS TEXT) END, CASE WHEN b IS NULL THEN b ELSE CAST(b AS TEXT) END)", + "CONCAT(CASE WHEN a IS NULL THEN a ELSE CAST(a AS String) END, CASE WHEN b IS NULL THEN b ELSE CAST(b AS String) END)", read={"mysql": "CONCAT(a, b)"}, ) self.validate_all( @@ -233,7 +260,7 @@ class TestClickhouse(Validator): self.validate_all( "SELECT {abc: UInt32}, {b: String}, {c: DateTime},{d: Map(String, Array(UInt8))}, {e: Tuple(UInt8, String)}", write={ - "clickhouse": "SELECT {abc: UInt32}, {b: TEXT}, {c: DATETIME}, {d: Map(TEXT, Array(UInt8))}, {e: Tuple(UInt8, String)}", + "clickhouse": "SELECT {abc: UInt32}, {b: String}, {c: DATETIME}, {d: Map(String, Array(UInt8))}, {e: Tuple(UInt8, String)}", "": "SELECT :abc, :b, :c, :d, :e", }, ) @@ -283,8 +310,8 @@ class TestClickhouse(Validator): "clickhouse": """CREATE TABLE example1 ( timestamp DATETIME, x UInt32 TTL now() + INTERVAL '1' MONTH, - y TEXT TTL timestamp + INTERVAL '1' DAY, - z TEXT + y String TTL timestamp + INTERVAL '1' DAY, + z String ) ENGINE=MergeTree ORDER BY tuple()""", @@ -305,7 +332,7 @@ ORDER BY tuple()""", "clickhouse": """CREATE TABLE test ( id UInt64, timestamp DateTime64, - data TEXT, + data String, max_hits UInt64, sum_hits UInt64 ) @@ -332,8 +359,8 @@ SET """, write={ "clickhouse": """CREATE TABLE test ( - id TEXT, - data TEXT + id String, + data String ) ENGINE=AggregatingMergeTree() ORDER BY tuple() @@ -416,7 +443,7 @@ WHERE "clickhouse": """CREATE TABLE table_for_recompression ( d DATETIME, key UInt64, - value TEXT + value String ) ENGINE=MergeTree() ORDER BY tuple() @@ -512,9 +539,9 @@ RANGE(MIN discount_start_date MAX discount_end_date)""", """, write={ "clickhouse": """CREATE DICTIONARY my_ip_trie_dictionary ( - prefix TEXT, + prefix String, asn UInt32, - cca2 TEXT DEFAULT '??' + cca2 String DEFAULT '??' ) PRIMARY KEY (prefix) SOURCE(CLICKHOUSE( @@ -540,7 +567,7 @@ LIFETIME(MIN 0 MAX 3600)""", write={ "clickhouse": """CREATE DICTIONARY polygons_test_dictionary ( key Array(Array(Array(Tuple(Float64, Float64)))), - name TEXT + name String ) PRIMARY KEY (key) SOURCE(CLICKHOUSE( diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 38a7952..f13d0f2 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -5,6 +5,7 @@ class TestDatabricks(Validator): dialect = "databricks" def test_databricks(self): + self.validate_identity("CREATE TABLE target SHALLOW CLONE source") self.validate_identity("INSERT INTO a REPLACE WHERE cond VALUES (1), (2)") self.validate_identity("SELECT c1 : price") self.validate_identity("CREATE FUNCTION a.b(x INT) RETURNS INT RETURN x + 1") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 63f789f..6a41218 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1,6 +1,13 @@ import unittest -from sqlglot import Dialect, Dialects, ErrorLevel, UnsupportedError, parse_one +from sqlglot import ( + Dialect, + Dialects, + ErrorLevel, + ParseError, + UnsupportedError, + parse_one, +) from sqlglot.dialects import Hive @@ -23,9 +30,10 @@ class Validator(unittest.TestCase): Args: sql (str): Main SQL expression - dialect (str): dialect of `sql` read (dict): Mapping of dialect -> SQL write (dict): Mapping of dialect -> SQL + pretty (bool): prettify both read and write + identify (bool): quote identifiers in both read and write """ expression = self.parse_one(sql) @@ -78,7 +86,7 @@ class TestDialect(Validator): "CAST(a AS TEXT)", write={ "bigquery": "CAST(a AS STRING)", - "clickhouse": "CAST(a AS TEXT)", + "clickhouse": "CAST(a AS String)", "drill": "CAST(a AS VARCHAR)", "duckdb": "CAST(a AS TEXT)", "mysql": "CAST(a AS CHAR)", @@ -116,7 +124,7 @@ class TestDialect(Validator): "CAST(a AS VARBINARY(4))", write={ "bigquery": "CAST(a AS BYTES)", - "clickhouse": "CAST(a AS VARBINARY(4))", + "clickhouse": "CAST(a AS String)", "duckdb": "CAST(a AS BLOB(4))", "mysql": "CAST(a AS VARBINARY(4))", "hive": "CAST(a AS BINARY(4))", @@ -133,7 +141,7 @@ class TestDialect(Validator): self.validate_all( "CAST(MAP('a', '1') AS MAP(TEXT, TEXT))", write={ - "clickhouse": "CAST(map('a', '1') AS Map(TEXT, TEXT))", + "clickhouse": "CAST(map('a', '1') AS Map(String, String))", }, ) self.validate_all( @@ -367,6 +375,60 @@ class TestDialect(Validator): }, ) + def test_nvl2(self): + self.validate_all( + "SELECT NVL2(a, b, c)", + write={ + "": "SELECT NVL2(a, b, c)", + "bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "databricks": "SELECT NVL2(a, b, c)", + "doris": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "drill": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "duckdb": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "hive": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "mysql": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "oracle": "SELECT NVL2(a, b, c)", + "postgres": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "presto": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "redshift": "SELECT NVL2(a, b, c)", + "snowflake": "SELECT NVL2(a, b, c)", + "spark": "SELECT NVL2(a, b, c)", + "spark2": "SELECT NVL2(a, b, c)", + "sqlite": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "starrocks": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "teradata": "SELECT NVL2(a, b, c)", + "trino": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + "tsql": "SELECT CASE WHEN NOT a IS NULL THEN b ELSE c END", + }, + ) + self.validate_all( + "SELECT NVL2(a, b)", + write={ + "": "SELECT NVL2(a, b)", + "bigquery": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "clickhouse": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "databricks": "SELECT NVL2(a, b)", + "doris": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "drill": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "duckdb": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "hive": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "mysql": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "oracle": "SELECT NVL2(a, b)", + "postgres": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "presto": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "redshift": "SELECT NVL2(a, b)", + "snowflake": "SELECT NVL2(a, b)", + "spark": "SELECT NVL2(a, b)", + "spark2": "SELECT NVL2(a, b)", + "sqlite": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "starrocks": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "teradata": "SELECT NVL2(a, b)", + "trino": "SELECT CASE WHEN NOT a IS NULL THEN b END", + "tsql": "SELECT CASE WHEN NOT a IS NULL THEN b END", + }, + ) + def test_time(self): self.validate_all( "STR_TO_TIME(x, '%Y-%m-%dT%H:%M:%S')", @@ -860,7 +922,7 @@ class TestDialect(Validator): "ARRAY(0, 1, 2)", write={ "bigquery": "[0, 1, 2]", - "duckdb": "LIST_VALUE(0, 1, 2)", + "duckdb": "[0, 1, 2]", "presto": "ARRAY[0, 1, 2]", "spark": "ARRAY(0, 1, 2)", }, @@ -879,7 +941,7 @@ class TestDialect(Validator): "ARRAY_SUM(ARRAY(1, 2))", write={ "trino": "REDUCE(ARRAY[1, 2], 0, (acc, x) -> acc + x, acc -> acc)", - "duckdb": "LIST_SUM(LIST_VALUE(1, 2))", + "duckdb": "LIST_SUM([1, 2])", "hive": "ARRAY_SUM(ARRAY(1, 2))", "presto": "ARRAY_SUM(ARRAY[1, 2])", "spark": "AGGREGATE(ARRAY(1, 2), 0, (acc, x) -> acc + x, acc -> acc)", @@ -1403,27 +1465,27 @@ class TestDialect(Validator): }, ) self.validate_all( - "CREATE INDEX my_idx ON tbl (a, b)", + "CREATE INDEX my_idx ON tbl(a, b)", read={ - "hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)", - "sqlite": "CREATE INDEX my_idx ON tbl (a, b)", + "hive": "CREATE INDEX my_idx ON TABLE tbl(a, b)", + "sqlite": "CREATE INDEX my_idx ON tbl(a, b)", }, write={ - "hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)", - "postgres": "CREATE INDEX my_idx ON tbl (a NULLS FIRST, b NULLS FIRST)", - "sqlite": "CREATE INDEX my_idx ON tbl (a, b)", + "hive": "CREATE INDEX my_idx ON TABLE tbl(a, b)", + "postgres": "CREATE INDEX my_idx ON tbl(a NULLS FIRST, b NULLS FIRST)", + "sqlite": "CREATE INDEX my_idx ON tbl(a, b)", }, ) self.validate_all( - "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + "CREATE UNIQUE INDEX my_idx ON tbl(a, b)", read={ - "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)", - "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl(a, b)", + "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl(a, b)", }, write={ - "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)", - "postgres": "CREATE UNIQUE INDEX my_idx ON tbl (a NULLS FIRST, b NULLS FIRST)", - "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl(a, b)", + "postgres": "CREATE UNIQUE INDEX my_idx ON tbl(a NULLS FIRST, b NULLS FIRST)", + "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl(a, b)", }, ) self.validate_all( @@ -1710,3 +1772,19 @@ SELECT "tsql": "SELECT COUNT_IF(col % 2 = 0) FILTER(WHERE col < 1000) FROM foo", }, ) + + def test_cast_to_user_defined_type(self): + self.validate_all( + "CAST(x AS some_udt)", + write={ + "": "CAST(x AS some_udt)", + "oracle": "CAST(x AS some_udt)", + "postgres": "CAST(x AS some_udt)", + "presto": "CAST(x AS some_udt)", + "teradata": "CAST(x AS some_udt)", + "tsql": "CAST(x AS some_udt)", + }, + ) + + with self.assertRaises(ParseError): + parse_one("CAST(x AS some_udt)", read="bigquery") diff --git a/tests/dialects/test_doris.py b/tests/dialects/test_doris.py index 63325a6..9591269 100644 --- a/tests/dialects/test_doris.py +++ b/tests/dialects/test_doris.py @@ -5,6 +5,7 @@ class TestDoris(Validator): dialect = "doris" def test_identity(self): + self.validate_identity("COALECSE(a, b, c, d)") self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") self.validate_identity("SELECT APPROX_COUNT_DISTINCT(a) FROM x") diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index c33c899..aca0d7a 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -6,6 +6,9 @@ class TestDuckDB(Validator): dialect = "duckdb" def test_duckdb(self): + self.validate_identity("[x.STRING_SPLIT(' ')[1] FOR x IN ['1', '2', 3] IF x.CONTAINS('1')]") + self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y") + self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x") self.validate_identity("SELECT SUM(x) FILTER (x = 1)", "SELECT SUM(x) FILTER(WHERE x = 1)") # https://github.com/duckdb/duckdb/releases/tag/v0.8.0 @@ -50,6 +53,7 @@ class TestDuckDB(Validator): "SELECT * FROM (PIVOT Cities ON Year USING SUM(Population) GROUP BY Country) AS pivot_alias" ) + self.validate_identity("FROM x SELECT x UNION SELECT 1", "SELECT x FROM x UNION SELECT 1") self.validate_all("FROM (FROM tbl)", write={"duckdb": "SELECT * FROM (SELECT * FROM tbl)"}) self.validate_all("FROM tbl", write={"duckdb": "SELECT * FROM tbl"}) self.validate_all("0b1010", write={"": "0 AS b1010"}) @@ -123,20 +127,20 @@ class TestDuckDB(Validator): }, ) self.validate_all( - "LIST_VALUE(0, 1, 2)", + "[0, 1, 2]", read={ "spark": "ARRAY(0, 1, 2)", }, write={ "bigquery": "[0, 1, 2]", - "duckdb": "LIST_VALUE(0, 1, 2)", + "duckdb": "[0, 1, 2]", "presto": "ARRAY[0, 1, 2]", "spark": "ARRAY(0, 1, 2)", }, ) self.validate_all( "SELECT ARRAY_LENGTH([0], 1) AS x", - write={"duckdb": "SELECT ARRAY_LENGTH(LIST_VALUE(0), 1) AS x"}, + write={"duckdb": "SELECT ARRAY_LENGTH([0], 1) AS x"}, ) self.validate_all( "REGEXP_MATCHES(x, y)", @@ -178,18 +182,18 @@ class TestDuckDB(Validator): "STRUCT_EXTRACT(x, 'abc')", write={ "duckdb": "STRUCT_EXTRACT(x, 'abc')", - "presto": 'x."abc"', - "hive": "x.`abc`", - "spark": "x.`abc`", + "presto": "x.abc", + "hive": "x.abc", + "spark": "x.abc", }, ) self.validate_all( "STRUCT_EXTRACT(STRUCT_EXTRACT(x, 'y'), 'abc')", write={ "duckdb": "STRUCT_EXTRACT(STRUCT_EXTRACT(x, 'y'), 'abc')", - "presto": 'x."y"."abc"', - "hive": "x.`y`.`abc`", - "spark": "x.`y`.`abc`", + "presto": "x.y.abc", + "hive": "x.y.abc", + "spark": "x.y.abc", }, ) self.validate_all( @@ -226,7 +230,7 @@ class TestDuckDB(Validator): }, ) self.validate_all( - "LIST_SUM(LIST_VALUE(1, 2))", + "LIST_SUM([1, 2])", read={ "spark": "ARRAY_SUM(ARRAY(1, 2))", }, @@ -304,14 +308,20 @@ class TestDuckDB(Validator): }, ) self.validate_all( - "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))", + "ARRAY_CONCAT([1, 2], [3, 4])", + read={ + "bigquery": "ARRAY_CONCAT([1, 2], [3, 4])", + "postgres": "ARRAY_CAT(ARRAY[1, 2], ARRAY[3, 4])", + "snowflake": "ARRAY_CAT([1, 2], [3, 4])", + }, write={ - "duckdb": "ARRAY_CONCAT(LIST_VALUE(1, 2), LIST_VALUE(3, 4))", - "presto": "CONCAT(ARRAY[1, 2], ARRAY[3, 4])", + "bigquery": "ARRAY_CONCAT([1, 2], [3, 4])", + "duckdb": "ARRAY_CONCAT([1, 2], [3, 4])", "hive": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))", - "spark": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))", + "postgres": "ARRAY_CAT(ARRAY[1, 2], ARRAY[3, 4])", + "presto": "CONCAT(ARRAY[1, 2], ARRAY[3, 4])", "snowflake": "ARRAY_CAT([1, 2], [3, 4])", - "bigquery": "ARRAY_CONCAT([1, 2], [3, 4])", + "spark": "CONCAT(ARRAY(1, 2), ARRAY(3, 4))", }, ) self.validate_all( @@ -502,6 +512,10 @@ class TestDuckDB(Validator): self.validate_identity("CAST(x AS INT128)") self.validate_identity("CAST(x AS DOUBLE)") self.validate_identity("CAST(x AS DECIMAL(15, 4))") + self.validate_identity("CAST(x AS STRUCT(number BIGINT))") + self.validate_identity( + "CAST(ROW(1, ROW(1)) AS STRUCT(number BIGINT, row STRUCT(number BIGINT)))" + ) self.validate_all("CAST(x AS NUMERIC(1, 2))", write={"duckdb": "CAST(x AS DECIMAL(1, 2))"}) self.validate_all("CAST(x AS HUGEINT)", write={"duckdb": "CAST(x AS INT128)"}) @@ -552,7 +566,7 @@ class TestDuckDB(Validator): self.validate_all( "cast([[1]] as int[][])", write={ - "duckdb": "CAST(LIST_VALUE(LIST_VALUE(1)) AS INT[][])", + "duckdb": "CAST([[1]] AS INT[][])", "spark": "CAST(ARRAY(ARRAY(1)) AS ARRAY<ARRAY<INT>>)", }, ) @@ -587,13 +601,13 @@ class TestDuckDB(Validator): self.validate_all( "CAST([STRUCT_PACK(a := 1)] AS STRUCT(a BIGINT)[])", write={ - "duckdb": "CAST(LIST_VALUE({'a': 1}) AS STRUCT(a BIGINT)[])", + "duckdb": "CAST([{'a': 1}] AS STRUCT(a BIGINT)[])", }, ) self.validate_all( "CAST([[STRUCT_PACK(a := 1)]] AS STRUCT(a BIGINT)[][])", write={ - "duckdb": "CAST(LIST_VALUE(LIST_VALUE({'a': 1})) AS STRUCT(a BIGINT)[][])", + "duckdb": "CAST([[{'a': 1}]] AS STRUCT(a BIGINT)[][])", }, ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 4c463f7..70a05fd 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -390,6 +390,13 @@ class TestHive(Validator): ) def test_hive(self): + self.validate_identity("SELECT * FROM my_table TIMESTAMP AS OF DATE_ADD(CURRENT_DATE, -1)") + self.validate_identity("SELECT * FROM my_table VERSION AS OF DATE_ADD(CURRENT_DATE, -1)") + + self.validate_identity( + "SELECT ROW() OVER (DISTRIBUTE BY x SORT BY y)", + "SELECT ROW() OVER (PARTITION BY x ORDER BY y)", + ) self.validate_identity("SELECT transform") self.validate_identity("SELECT * FROM test DISTRIBUTE BY y SORT BY x DESC ORDER BY l") self.validate_identity( @@ -591,7 +598,7 @@ class TestHive(Validator): read={ "": "VAR_MAP(a, b, c, d)", "clickhouse": "map(a, b, c, d)", - "duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))", + "duckdb": "MAP([a, c], [b, d])", "hive": "MAP(a, b, c, d)", "presto": "MAP(ARRAY[a, c], ARRAY[b, d])", "spark": "MAP(a, b, c, d)", @@ -599,7 +606,7 @@ class TestHive(Validator): write={ "": "MAP(ARRAY(a, c), ARRAY(b, d))", "clickhouse": "map(a, b, c, d)", - "duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))", + "duckdb": "MAP([a, c], [b, d])", "presto": "MAP(ARRAY[a, c], ARRAY[b, d])", "hive": "MAP(a, b, c, d)", "spark": "MAP(a, b, c, d)", @@ -609,7 +616,7 @@ class TestHive(Validator): self.validate_all( "MAP(a, b)", write={ - "duckdb": "MAP(LIST_VALUE(a), LIST_VALUE(b))", + "duckdb": "MAP([a], [b])", "presto": "MAP(ARRAY[a], ARRAY[b])", "hive": "MAP(a, b)", "spark": "MAP(a, b)", @@ -717,9 +724,7 @@ class TestHive(Validator): self.validate_identity("'\\\\n'") self.validate_identity("''") self.validate_identity("'\\\\'") - self.validate_identity("'\z'") self.validate_identity("'\\z'") - self.validate_identity("'\\\z'") self.validate_identity("'\\\\z'") def test_data_type(self): diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index d60f09d..fc63f9f 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -7,12 +7,16 @@ class TestMySQL(Validator): def test_ddl(self): self.validate_identity("CREATE TABLE foo (id BIGINT)") + self.validate_identity("CREATE TABLE 00f (1d BIGINT)") self.validate_identity("UPDATE items SET items.price = 0 WHERE items.id >= 5 LIMIT 10") self.validate_identity("DELETE FROM t WHERE a <= 10 LIMIT 10") self.validate_identity("CREATE TABLE foo (a BIGINT, INDEX USING BTREE (b))") self.validate_identity("CREATE TABLE foo (a BIGINT, FULLTEXT INDEX (b))") self.validate_identity("CREATE TABLE foo (a BIGINT, SPATIAL INDEX (b))") self.validate_identity( + "UPDATE items SET items.price = 0 WHERE items.id >= 5 ORDER BY items.id LIMIT 10" + ) + self.validate_identity( "CREATE TABLE foo (a BIGINT, INDEX b USING HASH (c) COMMENT 'd' VISIBLE ENGINE_ATTRIBUTE = 'e' WITH PARSER foo)" ) self.validate_identity( @@ -81,6 +85,9 @@ class TestMySQL(Validator): ) def test_identity(self): + self.validate_identity( + "SELECT * FROM x ORDER BY BINARY a", "SELECT * FROM x ORDER BY CAST(a AS BINARY)" + ) self.validate_identity("SELECT 1 XOR 0") self.validate_identity("SELECT 1 && 0", "SELECT 1 AND 0") self.validate_identity("SELECT /*+ BKA(t1) NO_BKA(t2) */ * FROM t1 INNER JOIN t2") @@ -171,8 +178,12 @@ class TestMySQL(Validator): self.validate_identity( "SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000" ) + self.validate_identity("INTERVAL '1' YEAR") + self.validate_identity("DATE_ADD(x, INTERVAL 1 YEAR)") def test_types(self): + self.validate_identity("CAST(x AS MEDIUMINT) + CAST(y AS YEAR(4))") + self.validate_all( "CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT)", read={ @@ -353,6 +364,7 @@ class TestMySQL(Validator): write={ "": "MATCH(col1, col2, col3) AGAINST('abc')", "mysql": "MATCH(col1, col2, col3) AGAINST('abc')", + "postgres": "(col1 @@ 'abc' OR col2 @@ 'abc' OR col3 @@ 'abc')", # not quite correct because it's not ts_query }, ) self.validate_all( @@ -818,3 +830,6 @@ COMMENT='客户账户表'""" cmd = self.parse_one("SET x = 1, y = 2") self.assertEqual(len(cmd.expressions), 2) + + def test_json_object(self): + self.validate_identity("SELECT JSON_OBJECT('id', 87, 'name', 'carrot')") diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 0c3b09f..01a9ca3 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -6,6 +6,11 @@ class TestOracle(Validator): dialect = "oracle" def test_oracle(self): + self.validate_identity("SELECT JSON_OBJECT('name': first_name || ' ' || last_name) FROM t") + self.validate_identity("COALESCE(c1, c2, c3)") + self.validate_identity("SELECT * FROM TABLE(foo)") + self.validate_identity("SELECT a$x#b") + self.validate_identity("SELECT :OBJECT") self.validate_identity("SELECT * FROM t FOR UPDATE") self.validate_identity("SELECT * FROM t FOR UPDATE WAIT 5") self.validate_identity("SELECT * FROM t FOR UPDATE NOWAIT") @@ -21,6 +26,9 @@ class TestOracle(Validator): self.validate_identity("SELECT * FROM table_name SAMPLE (25) s") self.validate_identity("SELECT * FROM V$SESSION") self.validate_identity( + "SELECT COUNT(1) INTO V_Temp FROM TABLE(CAST(somelist AS data_list)) WHERE col LIKE '%contact'" + ) + self.validate_identity( "SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name" ) self.validate_identity( @@ -28,12 +36,16 @@ class TestOracle(Validator): 'OVER (PARTITION BY department_id) AS "Worst", MAX(salary) KEEP (DENSE_RANK LAST ORDER BY commission_pct) ' 'OVER (PARTITION BY department_id) AS "Best" FROM employees ORDER BY department_id, salary, last_name' ) + self.validate_identity( + "SELECT UNIQUE col1, col2 FROM table", + "SELECT DISTINCT col1, col2 FROM table", + ) self.validate_all( "NVL(NULL, 1)", write={ "": "COALESCE(NULL, 1)", - "oracle": "NVL(NULL, 1)", + "oracle": "COALESCE(NULL, 1)", }, ) self.validate_all( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index a7719a9..8740aca 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -126,6 +126,8 @@ class TestPostgres(Validator): ) def test_postgres(self): + self.validate_identity("x @@ y") + expr = parse_one("SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)") unnest = expr.args["joins"][0].this.this unnest.assert_is(exp.Unnest) @@ -535,6 +537,54 @@ class TestPostgres(Validator): write={"postgres": "CAST(x AS CSTRING)"}, ) self.validate_all( + "x::oid", + write={"postgres": "CAST(x AS OID)"}, + ) + self.validate_all( + "x::regclass", + write={"postgres": "CAST(x AS REGCLASS)"}, + ) + self.validate_all( + "x::regcollation", + write={"postgres": "CAST(x AS REGCOLLATION)"}, + ) + self.validate_all( + "x::regconfig", + write={"postgres": "CAST(x AS REGCONFIG)"}, + ) + self.validate_all( + "x::regdictionary", + write={"postgres": "CAST(x AS REGDICTIONARY)"}, + ) + self.validate_all( + "x::regnamespace", + write={"postgres": "CAST(x AS REGNAMESPACE)"}, + ) + self.validate_all( + "x::regoper", + write={"postgres": "CAST(x AS REGOPER)"}, + ) + self.validate_all( + "x::regoperator", + write={"postgres": "CAST(x AS REGOPERATOR)"}, + ) + self.validate_all( + "x::regproc", + write={"postgres": "CAST(x AS REGPROC)"}, + ) + self.validate_all( + "x::regprocedure", + write={"postgres": "CAST(x AS REGPROCEDURE)"}, + ) + self.validate_all( + "x::regrole", + write={"postgres": "CAST(x AS REGROLE)"}, + ) + self.validate_all( + "x::regtype", + write={"postgres": "CAST(x AS REGTYPE)"}, + ) + self.validate_all( "TRIM(BOTH 'as' FROM 'as string as')", write={ "postgres": "TRIM(BOTH 'as' FROM 'as string as')", @@ -606,7 +656,7 @@ class TestPostgres(Validator): "a || b", write={ "": "a || b", - "clickhouse": "CONCAT(CAST(a AS TEXT), CAST(b AS TEXT))", + "clickhouse": "CONCAT(CAST(a AS String), CAST(b AS String))", "duckdb": "a || b", "postgres": "a || b", "presto": "CONCAT(CAST(a AS VARCHAR), CAST(b AS VARCHAR))", diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 5091540..dbca5b3 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -88,7 +88,7 @@ class TestPresto(Validator): "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))", write={ "bigquery": "CAST([1, 2] AS ARRAY<INT64>)", - "duckdb": "CAST(LIST_VALUE(1, 2) AS BIGINT[])", + "duckdb": "CAST([1, 2] AS BIGINT[])", "presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))", "spark": "CAST(ARRAY(1, 2) AS ARRAY<BIGINT>)", "snowflake": "CAST([1, 2] AS ARRAY)", @@ -98,7 +98,7 @@ class TestPresto(Validator): "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INT,INT))", write={ "bigquery": "CAST(MAP([1], [1]) AS MAP<INT64, INT64>)", - "duckdb": "CAST(MAP(LIST_VALUE(1), LIST_VALUE(1)) AS MAP(INT, INT))", + "duckdb": "CAST(MAP([1], [1]) AS MAP(INT, INT))", "presto": "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INTEGER, INTEGER))", "hive": "CAST(MAP(1, 1) AS MAP<INT, INT>)", "spark": "CAST(MAP_FROM_ARRAYS(ARRAY(1), ARRAY(1)) AS MAP<INT, INT>)", @@ -109,7 +109,7 @@ class TestPresto(Validator): "CAST(MAP(ARRAY['a','b','c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INT)))", write={ "bigquery": "CAST(MAP(['a', 'b', 'c'], [[1], [2], [3]]) AS MAP<STRING, ARRAY<INT64>>)", - "duckdb": "CAST(MAP(LIST_VALUE('a', 'b', 'c'), LIST_VALUE(LIST_VALUE(1), LIST_VALUE(2), LIST_VALUE(3))) AS MAP(TEXT, INT[]))", + "duckdb": "CAST(MAP(['a', 'b', 'c'], [[1], [2], [3]]) AS MAP(TEXT, INT[]))", "presto": "CAST(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INTEGER)))", "hive": "CAST(MAP('a', ARRAY(1), 'b', ARRAY(2), 'c', ARRAY(3)) AS MAP<STRING, ARRAY<INT>>)", "spark": "CAST(MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))) AS MAP<STRING, ARRAY<INT>>)", @@ -138,6 +138,13 @@ class TestPresto(Validator): def test_regex(self): self.validate_all( + "REGEXP_REPLACE('abcd', '[ab]')", + write={ + "presto": "REGEXP_REPLACE('abcd', '[ab]', '')", + "spark": "REGEXP_REPLACE('abcd', '[ab]', '')", + }, + ) + self.validate_all( "REGEXP_LIKE(a, 'x')", write={ "duckdb": "REGEXP_MATCHES(a, 'x')", @@ -289,6 +296,13 @@ class TestPresto(Validator): }, ) self.validate_all( + "DATE_ADD('DAY', 1 * -1, x)", + write={ + "presto": "DATE_ADD('DAY', 1 * -1, x)", + }, + read={"mysql": "DATE_SUB(x, INTERVAL 1 DAY)"}, + ) + self.validate_all( "NOW()", write={ "presto": "CURRENT_TIMESTAMP", @@ -339,6 +353,11 @@ class TestPresto(Validator): "presto": "SELECT CAST('2012-10-31 00:00' AS TIMESTAMP) AT TIME ZONE 'America/Sao_Paulo'", }, ) + self.validate_all( + "CAST(x AS TIMESTAMP)", + write={"presto": "CAST(x AS TIMESTAMP)"}, + read={"mysql": "CAST(x AS DATETIME)", "clickhouse": "CAST(x AS DATETIME64)"}, + ) def test_ddl(self): self.validate_all( @@ -480,6 +499,13 @@ class TestPresto(Validator): @mock.patch("sqlglot.helper.logger") def test_presto(self, logger): + self.validate_identity( + "SELECT * FROM example.testdb.customer_orders FOR VERSION AS OF 8954597067493422955" + ) + self.validate_identity( + "SELECT * FROM example.testdb.customer_orders FOR TIMESTAMP AS OF CAST('2022-03-23 09:59:29.803 Europe/Vienna' AS TIMESTAMP)" + ) + self.validate_identity("SELECT * FROM x OFFSET 1 LIMIT 1") self.validate_identity("SELECT * FROM x OFFSET 1 FETCH FIRST 1 ROWS ONLY") self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)") @@ -487,8 +513,58 @@ class TestPresto(Validator): self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE") self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ") self.validate_identity("APPROX_PERCENTILE(a, b, c, d)") + self.validate_identity( + "SELECT SPLIT_TO_MAP('a:1;b:2;a:3', ';', ':', (k, v1, v2) -> CONCAT(v1, v2))" + ) self.validate_all( + "SELECT ROW(1, 2)", + read={ + "spark": "SELECT STRUCT(1, 2)", + }, + write={ + "presto": "SELECT ROW(1, 2)", + "spark": "SELECT STRUCT(1, 2)", + }, + ) + self.validate_all( + "ARBITRARY(x)", + read={ + "bigquery": "ANY_VALUE(x)", + "clickhouse": "any(x)", + "databricks": "ANY_VALUE(x)", + "doris": "ANY_VALUE(x)", + "drill": "ANY_VALUE(x)", + "duckdb": "ANY_VALUE(x)", + "hive": "FIRST(x)", + "mysql": "ANY_VALUE(x)", + "oracle": "ANY_VALUE(x)", + "redshift": "ANY_VALUE(x)", + "snowflake": "ANY_VALUE(x)", + "spark": "ANY_VALUE(x)", + "spark2": "FIRST(x)", + }, + write={ + "bigquery": "ANY_VALUE(x)", + "clickhouse": "any(x)", + "databricks": "ANY_VALUE(x)", + "doris": "ANY_VALUE(x)", + "drill": "ANY_VALUE(x)", + "duckdb": "ANY_VALUE(x)", + "hive": "FIRST(x)", + "mysql": "ANY_VALUE(x)", + "oracle": "ANY_VALUE(x)", + "postgres": "MAX(x)", + "presto": "ARBITRARY(x)", + "redshift": "ANY_VALUE(x)", + "snowflake": "ANY_VALUE(x)", + "spark": "ANY_VALUE(x)", + "spark2": "FIRST(x)", + "sqlite": "MAX(x)", + "tsql": "MAX(x)", + }, + ) + self.validate_all( "STARTS_WITH('abc', 'a')", read={"spark": "STARTSWITH('abc', 'a')"}, write={ @@ -596,7 +672,7 @@ class TestPresto(Validator): "SELECT ARRAY[1, 2]", write={ "bigquery": "SELECT [1, 2]", - "duckdb": "SELECT LIST_VALUE(1, 2)", + "duckdb": "SELECT [1, 2]", "presto": "SELECT ARRAY[1, 2]", "spark": "SELECT ARRAY(1, 2)", }, @@ -748,7 +824,7 @@ class TestPresto(Validator): self.validate_all( """JSON_FORMAT(JSON '"x"')""", write={ - "bigquery": """TO_JSON_STRING(CAST('"x"' AS JSON))""", + "bigquery": """TO_JSON_STRING(JSON '"x"')""", "duckdb": """CAST(TO_JSON(CAST('"x"' AS JSON)) AS TEXT)""", "presto": """JSON_FORMAT(CAST('"x"' AS JSON))""", "spark": """REGEXP_EXTRACT(TO_JSON(FROM_JSON('["x"]', SCHEMA_OF_JSON('["x"]'))), '^.(.*).$', 1)""", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 3af27d4..245adf3 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -6,6 +6,14 @@ class TestRedshift(Validator): def test_redshift(self): self.validate_all( + "x ~* 'pat'", + write={ + "redshift": "x ~* 'pat'", + "snowflake": "REGEXP_LIKE(x, 'pat', 'i')", + }, + ) + + self.validate_all( "SELECT CAST('01:03:05.124' AS TIME(2) WITH TIME ZONE)", read={ "postgres": "SELECT CAST('01:03:05.124' AS TIMETZ(2))", @@ -163,22 +171,22 @@ class TestRedshift(Validator): self.validate_all( "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", write={ - "bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", - "databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", - "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", - "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", - "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1", - "oracle": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', - "presto": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', - "redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1', - "snowflake": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1', - "spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE `_row_number` = 1", - "sqlite": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', - "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE `_row_number` = 1", - "tableau": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', - "teradata": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', - "trino": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', - "tsql": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE "_row_number" = 1', + "bigquery": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", + "databricks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", + "drill": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", + "hive": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", + "mysql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1", + "oracle": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", + "presto": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", + "redshift": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1", + "snowflake": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1", + "spark": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", + "sqlite": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", + "starrocks": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1", + "tableau": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", + "teradata": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", + "trino": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", + "tsql": "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC NULLS FIRST) AS _row_number FROM x) WHERE _row_number = 1", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 3053d47..30a1f03 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -8,6 +8,35 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + self.validate_identity( + 'DESCRIBE TABLE "SNOWFLAKE_SAMPLE_DATA"."TPCDS_SF100TCL"."WEB_SITE" type=stage' + ) + + self.validate_all( + "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d", + read={ + "oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d", + }, + write={ + "oracle": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d", + "snowflake": "SELECT * FROM x START WITH a = b CONNECT BY c = PRIOR d", + }, + ) + self.validate_all( + "SELECT INSERT(a, 0, 0, 'b')", + read={ + "mysql": "SELECT INSERT(a, 0, 0, 'b')", + "snowflake": "SELECT INSERT(a, 0, 0, 'b')", + "tsql": "SELECT STUFF(a, 0, 0, 'b')", + }, + write={ + "mysql": "SELECT INSERT(a, 0, 0, 'b')", + "snowflake": "SELECT INSERT(a, 0, 0, 'b')", + "tsql": "SELECT STUFF(a, 0, 0, 'b')", + }, + ) + + self.validate_identity("LISTAGG(data['some_field'], ',')") self.validate_identity("WEEKOFYEAR(tstamp)") self.validate_identity("SELECT SUM(amount) FROM mytable GROUP BY ALL") self.validate_identity("WITH x AS (SELECT 1 AS foo) SELECT foo FROM IDENTIFIER('x')") @@ -383,12 +412,6 @@ class TestSnowflake(Validator): }, ) self.validate_all( - "SELECT NVL2(a, b, c)", - write={ - "snowflake": "SELECT NVL2(a, b, c)", - }, - ) - self.validate_all( "SELECT $$a$$", write={ "snowflake": "SELECT 'a'", @@ -598,7 +621,7 @@ class TestSnowflake(Validator): write={ "snowflake": "[0, 1, 2]", "bigquery": "[0, 1, 2]", - "duckdb": "LIST_VALUE(0, 1, 2)", + "duckdb": "[0, 1, 2]", "presto": "ARRAY[0, 1, 2]", "spark": "ARRAY(0, 1, 2)", }, @@ -1011,3 +1034,33 @@ MATCH_RECOGNIZE ( )""", pretty=True, ) + + def test_show(self): + # Parsed as Command + self.validate_identity("SHOW COLUMNS IN TABLE dt_test") + self.validate_identity("SHOW TABLES LIKE 'line%' IN tpch.public") + + ast = parse_one("SHOW TABLES HISTORY IN tpch.public") + self.assertIsInstance(ast, exp.Command) + + # Parsed as Show + self.validate_identity("SHOW PRIMARY KEYS") + self.validate_identity("SHOW PRIMARY KEYS IN ACCOUNT") + self.validate_identity("SHOW PRIMARY KEYS IN DATABASE") + self.validate_identity("SHOW PRIMARY KEYS IN DATABASE foo") + self.validate_identity("SHOW PRIMARY KEYS IN TABLE") + self.validate_identity("SHOW PRIMARY KEYS IN TABLE foo") + self.validate_identity( + 'SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."customers"', + 'SHOW PRIMARY KEYS IN TABLE "TEST"."PUBLIC"."customers"', + ) + self.validate_identity( + 'SHOW TERSE PRIMARY KEYS IN "TEST"."PUBLIC"."customers"', + 'SHOW PRIMARY KEYS IN TABLE "TEST"."PUBLIC"."customers"', + ) + + ast = parse_one('SHOW PRIMARY KEYS IN "TEST"."PUBLIC"."customers"', read="snowflake") + table = ast.find(exp.Table) + + self.assertIsNotNone(table) + self.assertEqual(table.sql(dialect="snowflake"), '"TEST"."PUBLIC"."customers"') diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 2afa868..a892b0f 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -1,5 +1,6 @@ from unittest import mock +from sqlglot import exp, parse_one from tests.dialects.test_dialect import Validator @@ -224,6 +225,10 @@ TBLPROPERTIES ( ) def test_spark(self): + expr = parse_one("any_value(col, true)", read="spark") + self.assertIsInstance(expr.args.get("ignore_nulls"), exp.Boolean) + self.assertEqual(expr.sql(dialect="spark"), "ANY_VALUE(col, TRUE)") + self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), x -> x + 1)") self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)") self.validate_identity("REFRESH table a.b.c") @@ -234,8 +239,46 @@ TBLPROPERTIES ( self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')") self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')") self.validate_identity("SPLIT(str, pattern, lim)") + self.validate_identity( + "SELECT STR_TO_MAP('a:1,b:2,c:3')", + "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", + ) self.validate_all( + "foo.bar", + read={ + "": "STRUCT_EXTRACT(foo, bar)", + }, + ) + self.validate_all( + "MAP(1, 2, 3, 4)", + write={ + "spark": "MAP(1, 2, 3, 4)", + "trino": "MAP(ARRAY[1, 3], ARRAY[2, 4])", + }, + ) + self.validate_all( + "MAP()", + read={ + "spark": "MAP()", + "trino": "MAP()", + }, + write={ + "trino": "MAP(ARRAY[], ARRAY[])", + }, + ) + self.validate_all( + "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", + read={ + "presto": "SELECT SPLIT_TO_MAP('a:1,b:2,c:3', ',', ':')", + "spark": "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", + }, + write={ + "presto": "SELECT SPLIT_TO_MAP('a:1,b:2,c:3', ',', ':')", + "spark": "SELECT STR_TO_MAP('a:1,b:2,c:3', ',', ':')", + }, + ) + self.validate_all( "SELECT DATEDIFF(month, CAST('1996-10-30' AS TIMESTAMP), CAST('1997-02-28 10:30:00' AS TIMESTAMP))", read={ "duckdb": "SELECT DATEDIFF('month', CAST('1996-10-30' AS TIMESTAMP), CAST('1997-02-28 10:30:00' AS TIMESTAMP))", @@ -399,7 +442,7 @@ TBLPROPERTIES ( "ARRAY(0, 1, 2)", write={ "bigquery": "[0, 1, 2]", - "duckdb": "LIST_VALUE(0, 1, 2)", + "duckdb": "[0, 1, 2]", "presto": "ARRAY[0, 1, 2]", "hive": "ARRAY(0, 1, 2)", "spark": "ARRAY(0, 1, 2)", @@ -466,7 +509,7 @@ TBLPROPERTIES ( self.validate_all( "MAP_FROM_ARRAYS(ARRAY(1), c)", write={ - "duckdb": "MAP(LIST_VALUE(1), c)", + "duckdb": "MAP([1], c)", "presto": "MAP(ARRAY[1], c)", "hive": "MAP(ARRAY(1), c)", "spark": "MAP_FROM_ARRAYS(ARRAY(1), c)", @@ -522,3 +565,13 @@ TBLPROPERTIES ( self.validate_identity( "SELECT TRANSFORM(zip_code, name, age) USING 'cat' FROM person WHERE zip_code > 94500" ) + + def test_insert_cte(self): + self.validate_all( + "INSERT OVERWRITE TABLE table WITH cte AS (SELECT cola FROM other_table) SELECT cola FROM cte", + write={ + "spark": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", + "spark2": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", + "databricks": "WITH cte AS (SELECT cola FROM other_table) INSERT OVERWRITE TABLE table SELECT cola FROM cte", + }, + ) diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 4d32241..32bdc71 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -4,6 +4,18 @@ from tests.dialects.test_dialect import Validator class TestTeradata(Validator): dialect = "teradata" + def test_teradata(self): + self.validate_all( + "DATABASE tduser", + read={ + "databricks": "USE tduser", + }, + write={ + "databricks": "USE tduser", + "teradata": "DATABASE tduser", + }, + ) + def test_translate(self): self.validate_all( "TRANSLATE(x USING LATIN_TO_UNICODE)", diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index f43b41b..c27b7fa 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -6,10 +6,55 @@ class TestTSQL(Validator): dialect = "tsql" def test_tsql(self): + self.validate_all( + "CREATE TABLE x ( A INTEGER NOT NULL, B INTEGER NULL )", + write={ + "tsql": "CREATE TABLE x (A INTEGER NOT NULL, B INTEGER NULL)", + "hive": "CREATE TABLE x (A INT NOT NULL, B INT)", + }, + ) + + self.validate_identity( + 'CREATE TABLE x (CONSTRAINT "pk_mytable" UNIQUE NONCLUSTERED (a DESC)) ON b (c)' + ) + + self.validate_identity( + """ + CREATE TABLE x( + [zip_cd] [varchar](5) NULL NOT FOR REPLICATION + CONSTRAINT [pk_mytable] PRIMARY KEY CLUSTERED + ([zip_cd_mkey] ASC) + WITH (PAD_INDEX = ON, STATISTICS_NORECOMPUTE = OFF) ON [PRIMARY] + ) ON [PRIMARY] + """, + 'CREATE TABLE x ("zip_cd" VARCHAR(5) NULL NOT FOR REPLICATION CONSTRAINT "pk_mytable" PRIMARY KEY CLUSTERED ("zip_cd_mkey") WITH (PAD_INDEX=ON, STATISTICS_NORECOMPUTE=OFF) ON "PRIMARY") ON "PRIMARY"', + ) + + self.validate_identity( + "CREATE TABLE tbl (a AS (x + 1) PERSISTED, b AS (y + 2), c AS (y / 3) PERSISTED NOT NULL)" + ) + + self.validate_identity( + "CREATE TABLE [db].[tbl]([a] [int])", 'CREATE TABLE "db"."tbl" ("a" INTEGER)' + ) + projection = parse_one("SELECT a = 1", read="tsql").selects[0] projection.assert_is(exp.Alias) projection.args["alias"].assert_is(exp.Identifier) + self.validate_all( + "IF OBJECT_ID('tempdb.dbo.#TempTableName', 'U') IS NOT NULL DROP TABLE #TempTableName", + write={ + "tsql": "DROP TABLE IF EXISTS #TempTableName", + "spark": "DROP TABLE IF EXISTS TempTableName", + }, + ) + + self.validate_identity( + "MERGE INTO mytable WITH (HOLDLOCK) AS T USING mytable_merge AS S " + "ON (T.user_id = S.user_id) WHEN NOT MATCHED THEN INSERT (c1, c2) VALUES (S.c1, S.c2)" + ) + self.validate_identity("UPDATE STATISTICS x") self.validate_identity("UPDATE x SET y = 1 OUTPUT x.a, x.b INTO @y FROM y") self.validate_identity("UPDATE x SET y = 1 OUTPUT x.a, x.b FROM y") self.validate_identity("INSERT INTO x (y) OUTPUT x.a, x.b INTO l SELECT * FROM z") @@ -397,8 +442,68 @@ class TestTSQL(Validator): }, ) + self.validate_all( + "CAST(x AS BOOLEAN)", + write={"tsql": "CAST(x AS BIT)"}, + ) + def test_ddl(self): self.validate_all( + "CREATE TABLE tbl (id INTEGER IDENTITY PRIMARY KEY)", + read={ + "mysql": "CREATE TABLE tbl (id INT AUTO_INCREMENT PRIMARY KEY)", + "tsql": "CREATE TABLE tbl (id INTEGER IDENTITY PRIMARY KEY)", + }, + ) + self.validate_all( + "CREATE TABLE tbl (id INTEGER NOT NULL IDENTITY(10, 1) PRIMARY KEY)", + read={ + "postgres": "CREATE TABLE tbl (id INT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10) PRIMARY KEY)", + "tsql": "CREATE TABLE tbl (id INTEGER NOT NULL IDENTITY(10, 1) PRIMARY KEY)", + }, + ) + self.validate_all( + "IF NOT EXISTS (SELECT * FROM sys.indexes WHERE object_id = object_id('db.tbl') AND name = 'idx') EXEC('CREATE INDEX idx ON db.tbl')", + read={ + "": "CREATE INDEX IF NOT EXISTS idx ON db.tbl", + }, + ) + + self.validate_all( + "IF NOT EXISTS (SELECT * FROM information_schema.schemata WHERE schema_name = 'foo') EXEC('CREATE SCHEMA foo')", + read={ + "": "CREATE SCHEMA IF NOT EXISTS foo", + }, + ) + self.validate_all( + "IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'foo') EXEC('CREATE TABLE foo (a INTEGER)')", + read={ + "": "CREATE TABLE IF NOT EXISTS foo (a INTEGER)", + }, + ) + + self.validate_all( + "CREATE OR ALTER VIEW a.b AS SELECT 1", + read={ + "": "CREATE OR REPLACE VIEW a.b AS SELECT 1", + }, + write={ + "tsql": "CREATE OR ALTER VIEW a.b AS SELECT 1", + }, + ) + + self.validate_all( + "ALTER TABLE a ADD b INTEGER, c INTEGER", + read={ + "": "ALTER TABLE a ADD COLUMN b INT, ADD COLUMN c INT", + }, + write={ + "": "ALTER TABLE a ADD COLUMN b INT, ADD COLUMN c INT", + "tsql": "ALTER TABLE a ADD b INTEGER, c INTEGER", + }, + ) + + self.validate_all( "CREATE TABLE #mytemp (a INTEGER, b CHAR(2), c TIME(4), d FLOAT(24))", write={ "spark": "CREATE TEMPORARY TABLE mytemp (a INT, b CHAR(2), c TIMESTAMP, d FLOAT)", @@ -898,6 +1003,9 @@ WHERE ) def test_iif(self): + self.validate_identity( + "SELECT IF(cond, 'True', 'False')", "SELECT IIF(cond, 'True', 'False')" + ) self.validate_identity("SELECT IIF(cond, 'True', 'False')") self.validate_all( "SELECT IIF(cond, 'True', 'False');", @@ -961,9 +1069,12 @@ WHERE ) def test_format(self): + self.validate_identity("SELECT FORMAT(foo, 'dddd', 'de-CH')") + self.validate_identity("SELECT FORMAT(EndOfDayRate, 'N', 'en-us')") self.validate_identity("SELECT FORMAT('01-01-1991', 'd.mm.yyyy')") self.validate_identity("SELECT FORMAT(12345, '###.###.###')") self.validate_identity("SELECT FORMAT(1234567, 'f')") + self.validate_all( "SELECT FORMAT(1000000.01,'###,###.###')", write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"}, diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index b21d65d..0953fee 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -139,6 +139,7 @@ x ILIKE '%y%' ESCAPE '\' 1 AS escape INTERVAL '1' day INTERVAL '1' MONTH +INTERVAL '1' YEAR INTERVAL '-1' CURRENT_DATE INTERVAL '-31' CAST(GETDATE() AS DATE) INTERVAL (1 + 3) DAYS @@ -620,13 +621,13 @@ CREATE FUNCTION a() LANGUAGE sql CREATE FUNCTION a() LANGUAGE sql RETURNS INT CREATE FUNCTION a.b(x INT) RETURNS INT AS RETURN x + 1 CREATE FUNCTION a.b.c() -CREATE INDEX abc ON t (a) -CREATE INDEX "abc" ON t (a) -CREATE INDEX abc ON t (a, b, b) -CREATE INDEX abc ON t (a NULLS LAST) +CREATE INDEX abc ON t(a) +CREATE INDEX "abc" ON t(a) +CREATE INDEX abc ON t(a, b, b) +CREATE INDEX abc ON t(a NULLS LAST) CREATE INDEX pointloc ON points USING GIST(BOX(location, location)) -CREATE UNIQUE INDEX abc ON t (a, b, b) -CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b) +CREATE UNIQUE INDEX abc ON t(a, b, b) +CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl(a, b) CREATE SCHEMA x CREATE SCHEMA IF NOT EXISTS y CREATE DATABASE x @@ -836,6 +837,7 @@ JSON_OBJECT('x': NULL, 'y': 1 WITH UNIQUE KEYS) JSON_OBJECT('x': NULL, 'y': 1 ABSENT ON NULL WITH UNIQUE KEYS) JSON_OBJECT('x': 1 RETURNING VARCHAR(100)) JSON_OBJECT('x': 1 RETURNING VARBINARY FORMAT JSON ENCODING UTF8) +PRIOR AS x SELECT if.x SELECT NEXT VALUE FOR db.schema.sequence_name SELECT NEXT VALUE FOR db.schema.sequence_name OVER (ORDER BY foo), col @@ -855,3 +857,4 @@ SELECT * FROM (tbl1 CROSS JOIN (SELECT * FROM tbl2) AS t1) /* comment1 */ DELETE FROM x /* comment2 */ WHERE y > 1 /* comment */ CREATE TABLE foo AS SELECT 1 SELECT next, transform, if +SELECT "any", "case", "if", "next" diff --git a/tests/fixtures/optimizer/eliminate_subqueries.sql b/tests/fixtures/optimizer/eliminate_subqueries.sql index 8542c3e..e18d2a4 100644 --- a/tests/fixtures/optimizer/eliminate_subqueries.sql +++ b/tests/fixtures/optimizer/eliminate_subqueries.sql @@ -97,3 +97,7 @@ WITH cte1 AS (SELECT a FROM x), cte2 AS (SELECT a FROM cte1 AS cte3) SELECT a FR -- Wrapped subquery joined with table SELECT * FROM ((SELECT c FROM t1) JOIN t2); WITH cte AS (SELECT c FROM t1) SELECT * FROM (cte AS cte, t2); + +-- Wrapped subquery with redundant parentheses +SELECT * FROM (((SELECT * FROM tbl))); +WITH cte AS (SELECT * FROM tbl) SELECT * FROM cte AS cte; diff --git a/tests/fixtures/optimizer/optimize_joins.sql b/tests/fixtures/optimizer/optimize_joins.sql index b8e39c3..6d5bed2 100644 --- a/tests/fixtures/optimizer/optimize_joins.sql +++ b/tests/fixtures/optimizer/optimize_joins.sql @@ -14,7 +14,7 @@ SELECT * FROM x INNER JOIN z ON x.id = z.id; SELECT * FROM x JOIN z ON x.id = z.id; SELECT * FROM x LEFT OUTER JOIN z; -SELECT * FROM x LEFT JOIN z; +SELECT * FROM x LEFT JOIN z ON TRUE; SELECT * FROM x CROSS JOIN z; SELECT * FROM x CROSS JOIN z; @@ -22,11 +22,17 @@ SELECT * FROM x CROSS JOIN z; SELECT * FROM x JOIN z; SELECT * FROM x CROSS JOIN z; +SELECT * FROM x FULL JOIN z; +SELECT * FROM x FULL JOIN z ON TRUE; + SELECT * FROM x NATURAL JOIN z; -SELECT * FROM x NATURAL JOIN z; +SELECT * FROM x NATURAL JOIN z ON TRUE; SELECT * FROM x RIGHT JOIN z; -SELECT * FROM x RIGHT JOIN z; +SELECT * FROM x RIGHT JOIN z ON TRUE; SELECT * FROM x JOIN z USING (id); SELECT * FROM x JOIN z USING (id); + +SELECT * FROM x CROSS JOIN z ON TRUE; +SELECT * FROM x CROSS JOIN z; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index b318a92..18ee804 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -950,3 +950,40 @@ JOIN "n" AS "foo"("a") SELECT CONCAT('a', 'b') || CONCAT(CONCAT('c', 'd'), CONCAT('e', 'f')) + ('g' || 'h' || 'i'); SELECT 'abcdefghi' AS "_col_0"; + +# title: complex query with derived tables and redundant parentheses +# execute: false +# dialect: snowflake +SELECT + ("SUBQUERY_0"."KEY") AS "SUBQUERY_1_COL_0" +FROM + ( + SELECT + * + FROM + ((( + SELECT + * + FROM + ( + SELECT + event_name AS key, + insert_ts + FROM + ( + SELECT + insert_ts, + event_name + FROM + sales + WHERE + insert_ts > '2023-08-07 21:03:35.590 -0700' + ) + ) + ))) AS "SF_CONNECTOR_QUERY_ALIAS" + ) AS "SUBQUERY_0"; +SELECT + "SALES"."EVENT_NAME" AS "SUBQUERY_1_COL_0" +FROM "SALES" AS "SALES" +WHERE + "SALES"."INSERT_TS" > '2023-08-07 21:03:35.590 -0700'; diff --git a/tests/fixtures/optimizer/pushdown_predicates.sql b/tests/fixtures/optimizer/pushdown_predicates.sql index 79ce353..cfa69fb 100644 --- a/tests/fixtures/optimizer/pushdown_predicates.sql +++ b/tests/fixtures/optimizer/pushdown_predicates.sql @@ -36,3 +36,11 @@ WITH t1 AS (SELECT x.a, x.b, ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.a) A WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a; WITH m AS (SELECT a, b FROM (VALUES (1, 2)) AS a1(a, b)), n AS (SELECT a, b FROM m WHERE m.a = 1), o AS (SELECT a, b FROM m WHERE m.a = 2) SELECT n.a, n.b, n.a, o.b FROM n FULL OUTER JOIN o ON n.a = o.a; + +-- Pushdown predicate to HAVING (CNF) +SELECT x.cnt AS cnt FROM (SELECT COUNT(1) AS cnt FROM x AS x) AS x WHERE x.cnt > 0; +SELECT x.cnt AS cnt FROM (SELECT COUNT(1) AS cnt FROM x AS x HAVING COUNT(1) > 0) AS x WHERE TRUE; + +-- Pushdown predicate to HAVING (DNF) +SELECT x.cnt AS cnt FROM (SELECT COUNT(1) AS cnt, COUNT(x.a) AS cnt_a, COUNT(x.b) AS cnt_b FROM x AS x) AS x WHERE (x.cnt_a > 0 AND x.cnt_b > 0) OR x.cnt > 0; +SELECT x.cnt AS cnt FROM (SELECT COUNT(1) AS cnt, COUNT(x.a) AS cnt_a, COUNT(x.b) AS cnt_b FROM x AS x HAVING COUNT(1) > 0 OR (COUNT(x.a) > 0 AND COUNT(x.b) > 0)) AS x WHERE x.cnt > 0 OR (x.cnt_a > 0 AND x.cnt_b > 0);
\ No newline at end of file diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 3ed02cd..66fb19c 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -264,6 +264,9 @@ TRUE; (FALSE); FALSE; +((TRUE)); +TRUE; + (FALSE OR TRUE); TRUE; @@ -288,6 +291,9 @@ x = y AND z; x * (1 - y); x * (1 - y); +(((x % 20) = 0) = TRUE); +((x % 20) = 0) = TRUE; + -------------------------------------- -- Literals -------------------------------------- @@ -612,6 +618,9 @@ TRUE; x = 2018 OR x <> 2018; x <> 2018 OR x = 2018; +t0.x = t1.x AND t0.y < t1.y AND t0.y <= t1.y; +t0.x = t1.x AND t0.y < t1.y AND t0.y <= t1.y; + -------------------------------------- -- Coalesce -------------------------------------- @@ -645,6 +654,12 @@ x = 1 OR x IS NULL; COALESCE(x, 1) IS NULL; FALSE; +COALESCE(ROW() OVER (), 1) = 1; +ROW() OVER () = 1 OR ROW() OVER () IS NULL; + +a AND b AND COALESCE(ROW() OVER (), 1) = 1; +a AND b AND (ROW() OVER () = 1 OR ROW() OVER () IS NULL); + -------------------------------------- -- CONCAT -------------------------------------- diff --git a/tests/test_build.py b/tests/test_build.py index 826a59b..4dc993f 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -496,7 +496,7 @@ class TestBuild(unittest.TestCase): ), ( lambda: exp.update("tbl", {"x": None, "y": {"x": 1}}), - "UPDATE tbl SET x = NULL, y = MAP('x', 1)", + "UPDATE tbl SET x = NULL, y = MAP(ARRAY('x'), ARRAY(1))", ), ( lambda: exp.update("tbl", {"x": 1}, where="y > 0"), diff --git a/tests/test_expressions.py b/tests/test_expressions.py index f68ced2..5d1f810 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -2,7 +2,7 @@ import datetime import math import unittest -from sqlglot import alias, exp, parse_one +from sqlglot import ParseError, alias, exp, parse_one class TestExpressions(unittest.TestCase): @@ -188,6 +188,7 @@ class TestExpressions(unittest.TestCase): def test_table(self): self.assertEqual(exp.table_("a", alias="b"), parse_one("select * from a b").find(exp.Table)) + self.assertEqual(exp.table_("a", "").sql(), "a") def test_replace_tables(self): self.assertEqual( @@ -666,7 +667,7 @@ class TestExpressions(unittest.TestCase): (True, "TRUE"), ((1, "2", None), "(1, '2', NULL)"), ([1, "2", None], "ARRAY(1, '2', NULL)"), - ({"x": None}, "MAP('x', NULL)"), + ({"x": None}, "MAP(ARRAY('x'), ARRAY(NULL))"), ( datetime.datetime(2022, 10, 1, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01T01:01:01.000001+00:00')", @@ -681,6 +682,11 @@ class TestExpressions(unittest.TestCase): with self.subTest(value): self.assertEqual(exp.convert(value).sql(), expected) + self.assertEqual( + exp.convert({"test": "value"}).sql(dialect="spark"), + "MAP_FROM_ARRAYS(ARRAY('test'), ARRAY('value'))", + ) + def test_comment_alias(self): sql = """ SELECT @@ -841,6 +847,9 @@ FROM foo""", ) self.assertEqual(exp.DataType.build("USER-DEFINED").sql(), "USER-DEFINED") + self.assertEqual(exp.DataType.build("ARRAY<UNKNOWN>").sql(), "ARRAY<UNKNOWN>") + self.assertEqual(exp.DataType.build("ARRAY<NULL>").sql(), "ARRAY<NULL>") + def test_rename_table(self): self.assertEqual( exp.rename_table("t1", "t2").sql(), @@ -879,3 +888,51 @@ FROM foo""", ast.meta["some_other_meta_key"] = "some_other_meta_value" self.assertEqual(ast.meta.get("some_other_meta_key"), "some_other_meta_value") + + def test_unnest(self): + ast = parse_one("SELECT (((1)))") + self.assertIs(ast.selects[0].unnest(), ast.find(exp.Literal)) + + ast = parse_one("SELECT * FROM (((SELECT * FROM t)))") + self.assertIs(ast.args["from"].this.unnest(), list(ast.find_all(exp.Select))[1]) + + ast = parse_one("SELECT * FROM ((((SELECT * FROM t))) AS foo)") + second_subquery = ast.args["from"].this.this + innermost_subquery = list(ast.find_all(exp.Select))[1].parent + self.assertIs(second_subquery, innermost_subquery.unwrap()) + + def test_is_type(self): + ast = parse_one("CAST(x AS VARCHAR)") + assert ast.is_type("VARCHAR") + assert not ast.is_type("VARCHAR(5)") + assert not ast.is_type("FLOAT") + + ast = parse_one("CAST(x AS VARCHAR(5))") + assert ast.is_type("VARCHAR") + assert ast.is_type("VARCHAR(5)") + assert not ast.is_type("VARCHAR(4)") + assert not ast.is_type("FLOAT") + + ast = parse_one("CAST(x AS ARRAY<INT>)") + assert ast.is_type("ARRAY") + assert ast.is_type("ARRAY<INT>") + assert not ast.is_type("ARRAY<FLOAT>") + assert not ast.is_type("INT") + + ast = parse_one("CAST(x AS ARRAY)") + assert ast.is_type("ARRAY") + assert not ast.is_type("ARRAY<INT>") + assert not ast.is_type("ARRAY<FLOAT>") + assert not ast.is_type("INT") + + ast = parse_one("CAST(x AS STRUCT<a INT, b FLOAT>)") + assert ast.is_type("STRUCT") + assert ast.is_type("STRUCT<a INT, b FLOAT>") + assert not ast.is_type("STRUCT<a VARCHAR, b INT>") + + dtype = exp.DataType.build("foo", udt=True) + assert dtype.is_type("foo") + assert not dtype.is_type("bar") + + with self.assertRaises(ParseError): + exp.DataType.build("foo") diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index a1bd309..e001c1f 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -758,6 +758,24 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') self.assertEqual(exp.DataType.Type.INT, expression.selects[0].type.this) self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this) + def test_nested_type_annotation(self): + schema = {"order": {"customer_id": "bigint", "item_id": "bigint", "item_price": "numeric"}} + sql = """ + SELECT ARRAY_AGG(DISTINCT order.item_id) FILTER (WHERE order.item_price > 10) AS items, + FROM order AS order + GROUP BY order.customer_id + """ + expression = annotate_types(parse_one(sql), schema=schema) + + self.assertEqual(exp.DataType.Type.ARRAY, expression.selects[0].type.this) + self.assertEqual(expression.selects[0].type.sql(), "ARRAY<BIGINT>") + + expression = annotate_types( + parse_one("SELECT ARRAY_CAT(ARRAY[1,2,3], ARRAY[4,5])", read="postgres") + ) + self.assertEqual(exp.DataType.Type.ARRAY, expression.selects[0].type.this) + self.assertEqual(expression.selects[0].type.sql(), "ARRAY<INT>") + def test_recursive_cte(self): query = parse_one( """ diff --git a/tests/test_parser.py b/tests/test_parser.py index e7b0ca9..7135dd8 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -217,6 +217,9 @@ class TestParser(unittest.TestCase): parse_one("IF(a > 0)") with self.assertRaises(ParseError): + parse_one("SELECT CASE FROM x") + + with self.assertRaises(ParseError): parse_one("WITH cte AS (SELECT * FROM x)") with self.assertRaises(ParseError): @@ -435,7 +438,7 @@ class TestParser(unittest.TestCase): self.assertIsInstance(parse_one("TIMESTAMP('2022-01-01')"), exp.Func) self.assertIsInstance(parse_one("TIMESTAMP()"), exp.Func) self.assertIsInstance(parse_one("map.x"), exp.Column) - self.assertIsInstance(parse_one("CAST(x AS CHAR(5))").to.expressions[0], exp.DataTypeSize) + self.assertIsInstance(parse_one("CAST(x AS CHAR(5))").to.expressions[0], exp.DataTypeParam) self.assertEqual(parse_one("1::int64", dialect="bigquery"), parse_one("CAST(1 AS BIGINT)")) def test_set_expression(self): @@ -708,3 +711,11 @@ class TestParser(unittest.TestCase): parse_one("SELECT a, b ?? c ?? 'No Data' FROM z").sql(), "SELECT a, COALESCE(COALESCE(b, c), 'No Data') FROM z", ) + + def test_parse_intervals(self): + ast = parse_one( + "SELECT a FROM tbl WHERE a <= DATE '1998-12-01' - INTERVAL '71 days' GROUP BY b" + ) + + self.assertEqual(ast.find(exp.Interval).this.sql(), "'71'") + self.assertEqual(ast.find(exp.Interval).unit.assert_is(exp.Var).sql(), "days") diff --git a/tests/test_tokens.py b/tests/test_tokens.py index f3343e7..f4d3858 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -6,6 +6,18 @@ from sqlglot.tokens import Tokenizer, TokenType class TestTokens(unittest.TestCase): + def test_space_keywords(self): + for string, length in ( + ("group bys", 2), + (" group bys", 2), + (" group bys ", 2), + ("group by)", 2), + ("group bys)", 3), + ): + tokens = Tokenizer().tokenize(string) + self.assertTrue("GROUP" in tokens[0].text.upper()) + self.assertEqual(len(tokens), length) + def test_comment_attachment(self): tokenizer = Tokenizer() sql_comment = [ diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 80d12ac..2109f53 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -52,17 +52,17 @@ class TestTransforms(unittest.TestCase): self.validate( eliminate_distinct_on, "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", - 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1', + "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1", ) self.validate( eliminate_distinct_on, "SELECT DISTINCT ON (a) a, b FROM x", - 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a) AS _row_number FROM x) WHERE "_row_number" = 1', + "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY a) AS _row_number FROM x) WHERE _row_number = 1", ) self.validate( eliminate_distinct_on, "SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC", - 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1', + "SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) WHERE _row_number = 1", ) self.validate( eliminate_distinct_on, @@ -72,7 +72,7 @@ class TestTransforms(unittest.TestCase): self.validate( eliminate_distinct_on, "SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC", - 'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1', + "SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE _row_number_2 = 1", ) def test_eliminate_qualify(self): diff --git a/tests/test_transpile.py b/tests/test_transpile.py index e58ed86..2b51be2 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -393,11 +393,13 @@ LEFT OUTER JOIN b""", self.validate("x::INT y", "CAST(x AS INT) AS y") self.validate("x::INT AS y", "CAST(x AS INT) AS y") self.validate("x::INT::BOOLEAN", "CAST(CAST(x AS INT) AS BOOLEAN)") + self.validate("interval::int", "CAST(interval AS INT)") + self.validate("x::user_defined_type", "CAST(x AS user_defined_type)") self.validate("CAST(x::INT AS BOOLEAN)", "CAST(CAST(x AS INT) AS BOOLEAN)") self.validate("CAST(x AS INT)::BOOLEAN", "CAST(CAST(x AS INT) AS BOOLEAN)") with self.assertRaises(ParseError): - transpile("x::z") + transpile("x::z", read="duckdb") def test_not_range(self): self.validate("a NOT LIKE b", "NOT a LIKE b") |