From 67c28dbe67209effad83d93b850caba5ee1e20e3 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Wed, 3 May 2023 11:12:28 +0200 Subject: Merging upstream version 11.7.1. Signed-off-by: Daniel Baumann --- tests/dataframe/integration/test_dataframe.py | 85 +++++++++++++++- tests/dataframe/unit/test_dataframe_writer.py | 7 ++ tests/dataframe/unit/test_functions.py | 13 ++- tests/dialects/test_bigquery.py | 21 +++- tests/dialects/test_clickhouse.py | 1 + tests/dialects/test_databricks.py | 36 +++++++ tests/dialects/test_dialect.py | 41 +++++--- tests/dialects/test_drill.py | 2 +- tests/dialects/test_duckdb.py | 18 +++- tests/dialects/test_hive.py | 55 +++++----- tests/dialects/test_mysql.py | 89 +++++++++++++++- tests/dialects/test_oracle.py | 48 ++++++++- tests/dialects/test_postgres.py | 94 +++++++++++++++-- tests/dialects/test_presto.py | 83 +++++++++++++-- tests/dialects/test_redshift.py | 40 +++++++- tests/dialects/test_snowflake.py | 19 ++-- tests/dialects/test_spark.py | 35 +++++++ tests/dialects/test_starrocks.py | 1 + tests/dialects/test_teradata.py | 37 +++++++ tests/dialects/test_tsql.py | 91 ++++++++++++++++- tests/fixtures/identity.sql | 21 +++- tests/fixtures/optimizer/canonicalize.sql | 2 +- tests/fixtures/optimizer/normalize.sql | 3 + tests/fixtures/optimizer/qualify_columns.sql | 23 +++++ tests/fixtures/optimizer/simplify.sql | 3 + tests/fixtures/optimizer/tpc-ds/tpc-ds.sql | 8 +- tests/test_build.py | 52 ++++++++++ tests/test_expressions.py | 5 +- tests/test_lineage.py | 140 +++++++++++++++++++++++++- tests/test_parser.py | 102 ++++++++++++++++++- tests/test_schema.py | 12 ++- tests/test_tokens.py | 18 +++- tests/test_transforms.py | 5 + tests/test_transpile.py | 94 +++++++++++++++-- 34 files changed, 1190 insertions(+), 114 deletions(-) (limited to 'tests') diff --git a/tests/dataframe/integration/test_dataframe.py b/tests/dataframe/integration/test_dataframe.py index 19e3b89..d00464b 100644 --- a/tests/dataframe/integration/test_dataframe.py +++ b/tests/dataframe/integration/test_dataframe.py @@ -276,6 +276,7 @@ class TestDataframeFunc(DataFrameValidator): self.df_spark_employee.store_id, self.df_spark_store.store_name, self.df_spark_store["num_sales"], + F.lit("literal_value"), ) dfs_joined = self.df_sqlglot_employee.join( self.df_sqlglot_store, @@ -289,6 +290,7 @@ class TestDataframeFunc(DataFrameValidator): self.df_sqlglot_employee.store_id, self.df_sqlglot_store.store_name, self.df_sqlglot_store["num_sales"], + SF.lit("literal_value"), ) self.compare_spark_with_sqlglot(df_joined, dfs_joined) @@ -330,8 +332,8 @@ class TestDataframeFunc(DataFrameValidator): def test_join_inner_equality_multiple_bitwise_and(self): df_joined = self.df_spark_employee.join( self.df_spark_store, - on=(self.df_spark_employee.store_id == self.df_spark_store.store_id) - & (self.df_spark_employee.age == self.df_spark_store.num_sales), + on=(self.df_spark_store.store_id == self.df_spark_employee.store_id) + & (self.df_spark_store.num_sales == self.df_spark_employee.age), how="inner", ).select( self.df_spark_employee.employee_id, @@ -344,8 +346,8 @@ class TestDataframeFunc(DataFrameValidator): ) dfs_joined = self.df_sqlglot_employee.join( self.df_sqlglot_store, - on=(self.df_sqlglot_employee.store_id == self.df_sqlglot_store.store_id) - & (self.df_sqlglot_employee.age == self.df_sqlglot_store.num_sales), + on=(self.df_sqlglot_store.store_id == self.df_sqlglot_employee.store_id) + & (self.df_sqlglot_store.num_sales == self.df_sqlglot_employee.age), how="inner", ).select( self.df_sqlglot_employee.employee_id, @@ -443,6 +445,81 @@ class TestDataframeFunc(DataFrameValidator): ) self.compare_spark_with_sqlglot(df, dfs) + def test_triple_join_no_select(self): + df = ( + self.df_employee.join( + self.df_store, + on=self.df_employee["employee_id"] == self.df_store["store_id"], + how="left", + ) + .join( + self.df_district, + on=self.df_store["store_id"] == self.df_district["district_id"], + how="left", + ) + .orderBy(F.col("employee_id")) + ) + dfs = ( + self.dfs_employee.join( + self.dfs_store, + on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"], + how="left", + ) + .join( + self.dfs_district, + on=self.dfs_store["store_id"] == self.dfs_district["district_id"], + how="left", + ) + .orderBy(SF.col("employee_id")) + ) + self.compare_spark_with_sqlglot(df, dfs) + + def test_triple_joins_filter(self): + df = ( + self.df_employee.join( + self.df_store, + on=self.df_employee["employee_id"] == self.df_store["store_id"], + how="left", + ).join( + self.df_district, + on=self.df_store["store_id"] == self.df_district["district_id"], + how="left", + ) + ).filter(F.coalesce(self.df_store["num_sales"], F.lit(0)) > 100) + dfs = ( + self.dfs_employee.join( + self.dfs_store, + on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"], + how="left", + ).join( + self.dfs_district, + on=self.dfs_store["store_id"] == self.dfs_district["district_id"], + how="left", + ) + ).filter(SF.coalesce(self.dfs_store["num_sales"], SF.lit(0)) > 100) + self.compare_spark_with_sqlglot(df, dfs) + + def test_triple_join_column_name_only(self): + df = ( + self.df_employee.join( + self.df_store, + on=self.df_employee["employee_id"] == self.df_store["store_id"], + how="left", + ) + .join(self.df_district, on="district_id", how="left") + .orderBy(F.col("employee_id")) + ) + dfs = ( + self.dfs_employee.join( + self.dfs_store, + on=self.dfs_employee["employee_id"] == self.dfs_store["store_id"], + how="left", + ) + .join(self.dfs_district, on="district_id", how="left") + .orderBy(SF.col("employee_id")) + ) + self.compare_spark_with_sqlglot(df, dfs) + def test_join_select_and_select_start(self): df = self.df_spark_employee.select( F.col("fname"), F.col("lname"), F.col("age"), F.col("store_id") diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py index 042b915..3f45468 100644 --- a/tests/dataframe/unit/test_dataframe_writer.py +++ b/tests/dataframe/unit/test_dataframe_writer.py @@ -86,3 +86,10 @@ class TestDataFrameWriter(DataFrameSQLValidator): "CREATE TABLE table_name AS SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`", ] self.compare_sql(df, expected_statements) + + def test_quotes(self): + sqlglot.schema.add_table('"Test"', {'"ID"': "STRING"}) + df = self.spark.table('"Test"') + self.compare_sql( + df.select(df['"ID"']), ["SELECT `Test`.`ID` AS `ID` FROM `Test` AS `Test`"] + ) diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index d9a32c4..befa68b 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -807,14 +807,17 @@ class TestFunctions(unittest.TestCase): self.assertEqual("DATE_ADD(cola, 2)", col.sql()) col_col_for_add = SF.date_add("cola", "colb") self.assertEqual("DATE_ADD(cola, colb)", col_col_for_add.sql()) + current_date_add = SF.date_add(SF.current_date(), 5) + self.assertEqual("DATE_ADD(CURRENT_DATE, 5)", current_date_add.sql()) + self.assertEqual("DATEADD(day, 5, CURRENT_DATE)", current_date_add.sql(dialect="snowflake")) def test_date_sub(self): col_str = SF.date_sub("cola", 2) - self.assertEqual("DATE_SUB(cola, 2)", col_str.sql()) + self.assertEqual("DATE_ADD(cola, -2)", col_str.sql()) col = SF.date_sub(SF.col("cola"), 2) - self.assertEqual("DATE_SUB(cola, 2)", col.sql()) + self.assertEqual("DATE_ADD(cola, -2)", col.sql()) col_col_for_add = SF.date_sub("cola", "colb") - self.assertEqual("DATE_SUB(cola, colb)", col_col_for_add.sql()) + self.assertEqual("DATE_ADD(cola, colb * -1)", col_col_for_add.sql()) def test_date_diff(self): col_str = SF.date_diff("cola", "colb") @@ -957,9 +960,9 @@ class TestFunctions(unittest.TestCase): def test_sha1(self): col_str = SF.sha1("Spark") - self.assertEqual("SHA1('Spark')", col_str.sql()) + self.assertEqual("SHA('Spark')", col_str.sql()) col = SF.sha1(SF.col("cola")) - self.assertEqual("SHA1(cola)", col.sql()) + self.assertEqual("SHA(cola)", col.sql()) def test_sha2(self): col_str = SF.sha2("Spark", 256) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index e210292..703b7dc 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -14,10 +14,16 @@ class TestBigQuery(Validator): "SELECT * FROM (SELECT * FROM `t`) AS a UNPIVOT((c) FOR c_name IN (v1, v2))" ) + self.validate_all( + "CREATE TEMP TABLE foo AS SELECT 1", + write={"bigquery": "CREATE TEMPORARY TABLE foo AS SELECT 1"}, + ) self.validate_all("LEAST(x, y)", read={"sqlite": "MIN(x, y)"}) self.validate_all("CAST(x AS CHAR)", write={"bigquery": "CAST(x AS STRING)"}) self.validate_all("CAST(x AS NCHAR)", write={"bigquery": "CAST(x AS STRING)"}) self.validate_all("CAST(x AS NVARCHAR)", write={"bigquery": "CAST(x AS STRING)"}) + self.validate_all("CAST(x AS TIMESTAMP)", write={"bigquery": "CAST(x AS DATETIME)"}) + self.validate_all("CAST(x AS TIMESTAMPTZ)", write={"bigquery": "CAST(x AS TIMESTAMP)"}) self.validate_all( "SELECT ARRAY(SELECT AS STRUCT 1 a, 2 b)", write={ @@ -59,7 +65,7 @@ class TestBigQuery(Validator): "spark": r"'/\*.*\*/'", }, ) - with self.assertRaises(RuntimeError): + with self.assertRaises(ValueError): transpile("'\\'", read="bigquery") self.validate_all( @@ -285,6 +291,7 @@ class TestBigQuery(Validator): "DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY)", write={ "postgres": "CURRENT_DATE - INTERVAL '1' DAY", + "bigquery": "DATE_SUB(CURRENT_DATE, INTERVAL 1 DAY)", }, ) self.validate_all( @@ -359,11 +366,23 @@ class TestBigQuery(Validator): self.validate_identity("BEGIN TRANSACTION") self.validate_identity("COMMIT TRANSACTION") self.validate_identity("ROLLBACK TRANSACTION") + self.validate_identity("CAST(x AS BIGNUMERIC)") + + self.validate_identity("SELECT * FROM UNNEST([1]) WITH ORDINALITY") + self.validate_all( + "SELECT * FROM UNNEST([1]) WITH OFFSET", + write={"bigquery": "SELECT * FROM UNNEST([1]) WITH OFFSET AS offset"}, + ) + self.validate_all( + "SELECT * FROM UNNEST([1]) WITH OFFSET y", + write={"bigquery": "SELECT * FROM UNNEST([1]) WITH OFFSET AS y"}, + ) def test_user_defined_functions(self): self.validate_identity( "CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) RETURNS FLOAT64 NOT DETERMINISTIC LANGUAGE js AS 'return x*y;'" ) + self.validate_identity("CREATE TEMPORARY FUNCTION udf(x ANY TYPE) AS (x)") self.validate_identity("CREATE TEMPORARY FUNCTION a(x FLOAT64, y FLOAT64) AS ((x + 4) / y)") self.validate_identity( "CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE AS SELECT s, t" diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 40a3a04..9fd2b45 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -5,6 +5,7 @@ class TestClickhouse(Validator): dialect = "clickhouse" def test_clickhouse(self): + self.validate_identity("SELECT INTERVAL t.days day") self.validate_identity("SELECT match('abc', '([a-z]+)')") self.validate_identity("dictGet(x, 'y')") self.validate_identity("SELECT * FROM x FINAL") diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 48ea6d1..4619108 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -5,10 +5,46 @@ class TestDatabricks(Validator): dialect = "databricks" def test_databricks(self): + self.validate_identity("SELECT c1 : price") self.validate_identity("CREATE FUNCTION a.b(x INT) RETURNS INT RETURN x + 1") self.validate_identity("CREATE FUNCTION a AS b") self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1") + # https://docs.databricks.com/sql/language-manual/functions/colonsign.html + def test_json(self): + self.validate_identity("""SELECT c1 : price FROM VALUES ('{ "price": 5 }') AS T(c1)""") + + self.validate_all( + """SELECT c1:['price'] FROM VALUES('{ "price": 5 }') AS T(c1)""", + write={ + "databricks": """SELECT c1 : ARRAY('price') FROM VALUES ('{ "price": 5 }') AS T(c1)""", + }, + ) + self.validate_all( + """SELECT c1:item[1].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + write={ + "databricks": """SELECT c1 : item[1].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + }, + ) + self.validate_all( + """SELECT c1:item[*].price FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + write={ + "databricks": """SELECT c1 : item[*].price FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + }, + ) + self.validate_all( + """SELECT from_json(c1:item[*].price, 'ARRAY')[0] FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + write={ + "databricks": """SELECT FROM_JSON(c1 : item[*].price, 'ARRAY')[0] FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + }, + ) + self.validate_all( + """SELECT inline(from_json(c1:item[*], 'ARRAY>')) FROM VALUES('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + write={ + "databricks": """SELECT INLINE(FROM_JSON(c1 : item[*], 'ARRAY>')) FROM VALUES ('{ "item": [ { "model" : "basic", "price" : 6.12 }, { "model" : "medium", "price" : 9.24 } ] }') AS T(c1)""", + }, + ) + def test_datediff(self): self.validate_all( "SELECT DATEDIFF(year, 'start', 'end')", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 3558d62..bcbbfd6 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -435,6 +435,7 @@ class TestDialect(Validator): write={ "duckdb": "EPOCH(CAST('2020-01-01' AS TIMESTAMP))", "hive": "UNIX_TIMESTAMP('2020-01-01')", + "mysql": "UNIX_TIMESTAMP('2020-01-01')", "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %T'))", }, ) @@ -561,25 +562,25 @@ class TestDialect(Validator): }, ) self.validate_all( - "DATE_ADD(x, 1, 'day')", + "DATE_ADD(x, 1, 'DAY')", read={ "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", - "snowflake": "DATEADD('day', 1, x)", + "snowflake": "DATEADD('DAY', 1, x)", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", }, write={ "bigquery": "DATE_ADD(x, INTERVAL 1 DAY)", "drill": "DATE_ADD(x, INTERVAL 1 DAY)", - "duckdb": "x + INTERVAL 1 day", + "duckdb": "x + INTERVAL 1 DAY", "hive": "DATE_ADD(x, 1)", "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", - "postgres": "x + INTERVAL '1' day", - "presto": "DATE_ADD('day', 1, x)", - "snowflake": "DATEADD(day, 1, x)", + "postgres": "x + INTERVAL '1' DAY", + "presto": "DATE_ADD('DAY', 1, x)", + "snowflake": "DATEADD(DAY, 1, x)", "spark": "DATE_ADD(x, 1)", - "sqlite": "DATE(x, '1 day')", + "sqlite": "DATE(x, '1 DAY')", "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", - "tsql": "DATEADD(day, 1, x)", + "tsql": "DATEADD(DAY, 1, x)", }, ) self.validate_all( @@ -631,14 +632,14 @@ class TestDialect(Validator): "snowflake": "DATE_TRUNC('day', x::DATE)", }, ) + self.validate_all( + "TIMESTAMP_TRUNC(TRY_CAST(x AS DATE), day)", + read={"postgres": "DATE_TRUNC('day', x::DATE)"}, + ) self.validate_all( "TIMESTAMP_TRUNC(CAST(x AS DATE), day)", - read={ - "postgres": "DATE_TRUNC('day', x::DATE)", - "starrocks": "DATE_TRUNC('day', x::DATE)", - }, + read={"starrocks": "DATE_TRUNC('day', x::DATE)"}, ) - self.validate_all( "DATE_TRUNC('week', x)", write={ @@ -751,6 +752,20 @@ class TestDialect(Validator): "spark": "DATE_ADD('2021-02-01', 1)", }, ) + self.validate_all( + "TS_OR_DS_ADD(x, 1, 'DAY')", + write={ + "presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR(CAST(x AS VARCHAR), 1, 10), '%Y-%m-%d'))", + "hive": "DATE_ADD(x, 1)", + }, + ) + self.validate_all( + "TS_OR_DS_ADD(CURRENT_DATE, 1, 'DAY')", + write={ + "presto": "DATE_ADD('DAY', 1, CURRENT_DATE)", + "hive": "DATE_ADD(CURRENT_DATE, 1)", + }, + ) self.validate_all( "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", write={ diff --git a/tests/dialects/test_drill.py b/tests/dialects/test_drill.py index f035176..a7f609a 100644 --- a/tests/dialects/test_drill.py +++ b/tests/dialects/test_drill.py @@ -14,7 +14,7 @@ class TestDrill(Validator): self.validate_all( "SELECT '2021-01-01' + INTERVAL 1 MONTH", write={ - "mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH", + "mysql": "SELECT '2021-01-01' + INTERVAL '1' MONTH", }, ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 245d82a..9e0040c 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -80,7 +80,7 @@ class TestDuckDB(Validator): "snowflake": "CONVERT_TIMEZONE('America/New_York', CAST(start AS TIMESTAMPTZ))", }, write={ - "bigquery": "TIMESTAMP(DATETIME(CAST(start AS TIMESTAMPTZ), 'America/New_York'))", + "bigquery": "TIMESTAMP(DATETIME(CAST(start AS TIMESTAMP), 'America/New_York'))", "duckdb": "CAST(start AS TIMESTAMPTZ) AT TIME ZONE 'America/New_York'", "snowflake": "CONVERT_TIMEZONE('America/New_York', CAST(start AS TIMESTAMPTZ))", }, @@ -148,6 +148,12 @@ class TestDuckDB(Validator): "duckdb": "CREATE TABLE IF NOT EXISTS table (cola INT, colb TEXT)", }, ) + self.validate_all( + "CREATE TABLE IF NOT EXISTS table (cola INT COMMENT 'cola', colb STRING) USING ICEBERG PARTITIONED BY (colb)", + write={ + "duckdb": "CREATE TABLE IF NOT EXISTS table (cola INT, colb TEXT)", + }, + ) self.validate_all( "LIST_VALUE(0, 1, 2)", read={ @@ -245,7 +251,7 @@ class TestDuckDB(Validator): }, ) self.validate_all( - "POWER(CAST(2 AS SMALLINT), 3)", + "POWER(TRY_CAST(2 AS SMALLINT), 3)", read={ "hive": "POW(2S, 3)", "spark": "POW(2S, 3)", @@ -339,6 +345,12 @@ class TestDuckDB(Validator): "bigquery": "ARRAY_CONCAT([1, 2], [3, 4])", }, ) + self.validate_all( + "SELECT CAST(CAST(x AS DATE) AS DATE) + INTERVAL 1 DAY", + read={ + "hive": "SELECT DATE_ADD(TO_DATE(x), 1)", + }, + ) with self.assertRaises(UnsupportedError): transpile( @@ -408,7 +420,7 @@ class TestDuckDB(Validator): "CAST(x AS DATE) + INTERVAL (7 * -1) DAY", read={"spark": "DATE_SUB(x, 7)"} ) self.validate_all( - "CAST(1 AS DOUBLE)", + "TRY_CAST(1 AS DOUBLE)", read={ "hive": "1d", "spark": "1d", diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 1a83575..c69368c 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -70,8 +70,8 @@ class TestHive(Validator): self.validate_all( "1s", write={ - "duckdb": "CAST(1 AS SMALLINT)", - "presto": "CAST(1 AS SMALLINT)", + "duckdb": "TRY_CAST(1 AS SMALLINT)", + "presto": "TRY_CAST(1 AS SMALLINT)", "hive": "CAST(1 AS SMALLINT)", "spark": "CAST(1 AS SHORT)", }, @@ -79,8 +79,8 @@ class TestHive(Validator): self.validate_all( "1S", write={ - "duckdb": "CAST(1 AS SMALLINT)", - "presto": "CAST(1 AS SMALLINT)", + "duckdb": "TRY_CAST(1 AS SMALLINT)", + "presto": "TRY_CAST(1 AS SMALLINT)", "hive": "CAST(1 AS SMALLINT)", "spark": "CAST(1 AS SHORT)", }, @@ -88,8 +88,8 @@ class TestHive(Validator): self.validate_all( "1Y", write={ - "duckdb": "CAST(1 AS TINYINT)", - "presto": "CAST(1 AS TINYINT)", + "duckdb": "TRY_CAST(1 AS TINYINT)", + "presto": "TRY_CAST(1 AS TINYINT)", "hive": "CAST(1 AS TINYINT)", "spark": "CAST(1 AS BYTE)", }, @@ -97,8 +97,8 @@ class TestHive(Validator): self.validate_all( "1L", write={ - "duckdb": "CAST(1 AS BIGINT)", - "presto": "CAST(1 AS BIGINT)", + "duckdb": "TRY_CAST(1 AS BIGINT)", + "presto": "TRY_CAST(1 AS BIGINT)", "hive": "CAST(1 AS BIGINT)", "spark": "CAST(1 AS LONG)", }, @@ -106,8 +106,8 @@ class TestHive(Validator): self.validate_all( "1.0bd", write={ - "duckdb": "CAST(1.0 AS DECIMAL)", - "presto": "CAST(1.0 AS DECIMAL)", + "duckdb": "TRY_CAST(1.0 AS DECIMAL)", + "presto": "TRY_CAST(1.0 AS DECIMAL)", "hive": "CAST(1.0 AS DECIMAL)", "spark": "CAST(1.0 AS DECIMAL)", }, @@ -148,6 +148,9 @@ class TestHive(Validator): self.validate_identity( """CREATE EXTERNAL TABLE x (y INT) ROW FORMAT SERDE 'serde' ROW FORMAT DELIMITED FIELDS TERMINATED BY '1' WITH SERDEPROPERTIES ('input.regex'='')""", ) + self.validate_identity( + """CREATE EXTERNAL TABLE `my_table` (`a7` ARRAY) ROW FORMAT SERDE 'a' STORED AS INPUTFORMAT 'b' OUTPUTFORMAT 'c' LOCATION 'd' TBLPROPERTIES ('e'='f')""" + ) def test_lateral_view(self): self.validate_all( @@ -318,6 +321,11 @@ class TestHive(Validator): "": "TS_OR_DS_ADD('2020-01-01', 1 * -1, 'DAY')", }, ) + self.validate_all("DATE_ADD('2020-01-01', -1)", read={"": "DATE_SUB('2020-01-01', 1)"}) + self.validate_all("DATE_ADD(a, b * -1)", read={"": "DATE_SUB(a, b)"}) + self.validate_all( + "ADD_MONTHS('2020-01-01', -2)", read={"": "DATE_SUB('2020-01-01', 2, month)"} + ) self.validate_all( "DATEDIFF(TO_DATE(y), x)", write={ @@ -504,11 +512,10 @@ class TestHive(Validator): }, ) self.validate_all( - "SELECT * FROM x TABLESAMPLE(10) y", + "SELECT * FROM x TABLESAMPLE(10 PERCENT) y", write={ - "presto": "SELECT * FROM x AS y TABLESAMPLE (10)", - "hive": "SELECT * FROM x TABLESAMPLE (10) AS y", - "spark": "SELECT * FROM x TABLESAMPLE (10) AS y", + "hive": "SELECT * FROM x TABLESAMPLE (10 PERCENT) AS y", + "spark": "SELECT * FROM x TABLESAMPLE (10 PERCENT) AS y", }, ) self.validate_all( @@ -650,25 +657,13 @@ class TestHive(Validator): }, ) self.validate_all( - "SELECT * FROM x TABLESAMPLE (1) AS foo", - read={ - "presto": "SELECT * FROM x AS foo TABLESAMPLE (1)", - }, - write={ - "presto": "SELECT * FROM x AS foo TABLESAMPLE (1)", - "hive": "SELECT * FROM x TABLESAMPLE (1) AS foo", - "spark": "SELECT * FROM x TABLESAMPLE (1) AS foo", - }, - ) - self.validate_all( - "SELECT * FROM x TABLESAMPLE (1) AS foo", + "SELECT * FROM x TABLESAMPLE (1 PERCENT) AS foo", read={ - "presto": "SELECT * FROM x AS foo TABLESAMPLE (1)", + "presto": "SELECT * FROM x AS foo TABLESAMPLE BERNOULLI (1)", }, write={ - "presto": "SELECT * FROM x AS foo TABLESAMPLE (1)", - "hive": "SELECT * FROM x TABLESAMPLE (1) AS foo", - "spark": "SELECT * FROM x TABLESAMPLE (1) AS foo", + "hive": "SELECT * FROM x TABLESAMPLE (1 PERCENT) AS foo", + "spark": "SELECT * FROM x TABLESAMPLE (1 PERCENT) AS foo", }, ) self.validate_all( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index f618728..524d95e 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -14,8 +14,18 @@ class TestMySQL(Validator): "spark": "CREATE TABLE z (a INT) COMMENT 'x'", }, ) + self.validate_all( + "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP) DEFAULT CHARSET=utf8 ROW_FORMAT=DYNAMIC", + write={ + "mysql": "CREATE TABLE t (c DATETIME DEFAULT CURRENT_TIMESTAMP() ON UPDATE CURRENT_TIMESTAMP()) DEFAULT CHARACTER SET=utf8 ROW_FORMAT=DYNAMIC", + }, + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON DUPLICATE KEY UPDATE SET x.id = 1" + ) def test_identity(self): + self.validate_identity("SELECT CURRENT_TIMESTAMP(6)") self.validate_identity("x ->> '$.name'") self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") self.validate_identity("SELECT TRIM(LEADING 'bla' FROM ' XXX ')") @@ -186,7 +196,7 @@ class TestMySQL(Validator): self.validate_all( 'SELECT "2021-01-01" + INTERVAL 1 MONTH', write={ - "mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH", + "mysql": "SELECT '2021-01-01' + INTERVAL '1' MONTH", }, ) @@ -239,7 +249,83 @@ class TestMySQL(Validator): write={"mysql": "MATCH(a.b) AGAINST('abc')"}, ) + def test_date_format(self): + self.validate_all( + "SELECT DATE_FORMAT('2017-06-15', '%Y')", + write={ + "mysql": "SELECT DATE_FORMAT('2017-06-15', '%Y')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'yyyy')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2017-06-15', '%m')", + write={ + "mysql": "SELECT DATE_FORMAT('2017-06-15', '%m')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'mm')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2017-06-15', '%d')", + write={ + "mysql": "SELECT DATE_FORMAT('2017-06-15', '%d')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'DD')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2017-06-15', '%Y-%m-%d')", + write={ + "mysql": "SELECT DATE_FORMAT('2017-06-15', '%Y-%m-%d')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'yyyy-mm-DD')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2017-06-15 22:23:34', '%H')", + write={ + "mysql": "SELECT DATE_FORMAT('2017-06-15 22:23:34', '%H')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15 22:23:34' AS TIMESTAMPNTZ), 'hh24')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2017-06-15', '%w')", + write={ + "mysql": "SELECT DATE_FORMAT('2017-06-15', '%w')", + "snowflake": "SELECT TO_CHAR(CAST('2017-06-15' AS TIMESTAMPNTZ), 'dy')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')", + write={ + "mysql": "SELECT DATE_FORMAT('2009-10-04 22:23:00', '%W %M %Y')", + "snowflake": "SELECT TO_CHAR(CAST('2009-10-04 22:23:00' AS TIMESTAMPNTZ), 'DY mmmm yyyy')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('2007-10-04 22:23:00', '%H:%i:%s')", + write={ + "mysql": "SELECT DATE_FORMAT('2007-10-04 22:23:00', '%T')", + "snowflake": "SELECT TO_CHAR(CAST('2007-10-04 22:23:00' AS TIMESTAMPNTZ), 'hh24:mi:ss')", + }, + ) + self.validate_all( + "SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %a %d %m %b')", + write={ + "mysql": "SELECT DATE_FORMAT('1900-10-04 22:23:00', '%d %y %W %d %m %b')", + "snowflake": "SELECT TO_CHAR(CAST('1900-10-04 22:23:00' AS TIMESTAMPNTZ), 'DD yy DY DD mm mon')", + }, + ) + + def test_mysql_time(self): + self.validate_identity("FROM_UNIXTIME(a, b)") + self.validate_identity("FROM_UNIXTIME(a, b, c)") + self.validate_identity("TIME_STR_TO_UNIX(x)", "UNIX_TIMESTAMP(x)") + def test_mysql(self): + self.validate_all( + "SELECT DATE(DATE_SUB(`dt`, INTERVAL DAYOFMONTH(`dt`) - 1 DAY)) AS __timestamp FROM tableT", + write={ + "mysql": "SELECT DATE(DATE_SUB(`dt`, INTERVAL (DAYOFMONTH(`dt`) - 1) DAY)) AS __timestamp FROM tableT", + }, + ) self.validate_all( "SELECT a FROM tbl FOR UPDATE", write={ @@ -247,6 +333,7 @@ class TestMySQL(Validator): "mysql": "SELECT a FROM tbl FOR UPDATE", "oracle": "SELECT a FROM tbl FOR UPDATE", "postgres": "SELECT a FROM tbl FOR UPDATE", + "redshift": "SELECT a FROM tbl", "tsql": "SELECT a FROM tbl FOR UPDATE", }, ) diff --git a/tests/dialects/test_oracle.py b/tests/dialects/test_oracle.py index 80fa0f1..dd297d6 100644 --- a/tests/dialects/test_oracle.py +++ b/tests/dialects/test_oracle.py @@ -6,6 +6,27 @@ class TestOracle(Validator): def test_oracle(self): self.validate_identity("SELECT * FROM V$SESSION") + self.validate_identity( + "SELECT MIN(column_name) KEEP (DENSE_RANK FIRST ORDER BY column_name DESC) FROM table_name" + ) + + self.validate_all( + "NVL(NULL, 1)", + write={ + "oracle": "NVL(NULL, 1)", + "": "IFNULL(NULL, 1)", + }, + ) + + self.validate_all( + "DATE '2022-01-01'", + write={ + "": "DATE_STR_TO_DATE('2022-01-01')", + "mysql": "CAST('2022-01-01' AS DATE)", + "oracle": "TO_DATE('2022-01-01', 'YYYY-MM-DD')", + "postgres": "CAST('2022-01-01' AS DATE)", + }, + ) def test_join_marker(self): self.validate_identity("SELECT e1.x, e2.x FROM e e1, e e2 WHERE e1.y (+) = e2.y") @@ -81,7 +102,7 @@ FROM warehouses, XMLTABLE( FROM XMLTABLE( 'ROWSET/ROW' PASSING - dbms_xmlgen.getxmltype ("SELECT table_name, column_name, data_default FROM user_tab_columns") + dbms_xmlgen.GETXMLTYPE('SELECT table_name, column_name, data_default FROM user_tab_columns') COLUMNS table_name VARCHAR2(128) PATH '*[1]', column_name VARCHAR2(128) PATH '*[2]', @@ -90,3 +111,28 @@ FROM XMLTABLE( }, pretty=True, ) + + def test_match_recognize(self): + self.validate_identity( + """SELECT + * +FROM sales_history +MATCH_RECOGNIZE ( + PARTITION BY product + ORDER BY + tstamp + MEASURES + STRT.tstamp AS start_tstamp, + LAST(UP.tstamp) AS peak_tstamp, + LAST(DOWN.tstamp) AS end_tstamp, + MATCH_NUMBER() AS mno + ONE ROW PER MATCH + AFTER MATCH SKIP TO LAST DOWN + PATTERN (STRT UP+ FLAT* DOWN+) + DEFINE + UP AS UP.units_sold > PREV(UP.units_sold), + FLAT AS FLAT.units_sold = PREV(FLAT.units_sold), + DOWN AS DOWN.units_sold < PREV(DOWN.units_sold) +) MR""", + pretty=True, + ) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index a89ae30..e2f9c41 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -1,4 +1,4 @@ -from sqlglot import ParseError, transpile +from sqlglot import ParseError, exp, parse_one, transpile from tests.dialects.test_dialect import Validator @@ -10,10 +10,24 @@ class TestPostgres(Validator): self.validate_identity("CREATE TABLE test (foo HSTORE)") self.validate_identity("CREATE TABLE test (foo JSONB)") self.validate_identity("CREATE TABLE test (foo VARCHAR(64)[])") - self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING a") self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING a, b") self.validate_identity("INSERT INTO x VALUES (1, 'a', 2.0) RETURNING *") + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO NOTHING RETURNING *" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO UPDATE SET x.id = 1 RETURNING *" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT (id) DO UPDATE SET x.id = excluded.id RETURNING *" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT ON CONSTRAINT pkey DO NOTHING RETURNING *" + ) + self.validate_identity( + "INSERT INTO x VALUES (1, 'a', 2.0) ON CONFLICT ON CONSTRAINT pkey DO UPDATE SET x.id = 1 RETURNING *" + ) self.validate_identity( "DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid RETURNING a" ) @@ -75,6 +89,7 @@ class TestPostgres(Validator): self.validate_identity("SELECT ARRAY[1, 2, 3] <@ ARRAY[1, 2]") self.validate_identity("SELECT ARRAY[1, 2, 3] && ARRAY[1, 2]") self.validate_identity("$x") + self.validate_identity("x$") self.validate_identity("SELECT ARRAY[1, 2, 3]") self.validate_identity("SELECT ARRAY(SELECT 1)") self.validate_identity("SELECT ARRAY_LENGTH(ARRAY[1, 2, 3], 1)") @@ -107,6 +122,12 @@ class TestPostgres(Validator): self.validate_identity("COMMENT ON TABLE mytable IS 'this'") self.validate_identity("SELECT e'\\xDEADBEEF'") self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") + self.validate_all( + "e'x'", + write={ + "mysql": "x", + }, + ) self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""") self.validate_identity( "SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)" @@ -117,6 +138,28 @@ class TestPostgres(Validator): self.validate_identity("x ~ 'y'") self.validate_identity("x ~* 'y'") + self.validate_all( + "SELECT DATE_PART('isodow'::varchar(6), current_date)", + write={ + "postgres": "SELECT EXTRACT(CAST('isodow' AS VARCHAR(6)) FROM CURRENT_DATE)", + }, + ) + self.validate_all( + "SELECT DATE_PART('minute', timestamp '2023-01-04 04:05:06.789')", + write={ + "postgres": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", + "redshift": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", + "snowflake": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))", + }, + ) + self.validate_all( + "SELECT DATE_PART('month', date '20220502')", + write={ + "postgres": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + "redshift": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + "snowflake": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + }, + ) self.validate_all( "SELECT (DATE '2016-01-10', DATE '2016-02-01') OVERLAPS (DATE '2016-01-20', DATE '2016-02-10')", write={ @@ -141,17 +184,17 @@ class TestPostgres(Validator): self.validate_all( "GENERATE_SERIES(a, b, ' 2 days ')", write={ - "postgres": "GENERATE_SERIES(a, b, INTERVAL '2' days)", - "presto": "SEQUENCE(a, b, INTERVAL '2' days)", - "trino": "SEQUENCE(a, b, INTERVAL '2' days)", + "postgres": "GENERATE_SERIES(a, b, INTERVAL '2' day)", + "presto": "SEQUENCE(a, b, INTERVAL '2' day)", + "trino": "SEQUENCE(a, b, INTERVAL '2' day)", }, ) self.validate_all( "GENERATE_SERIES('2019-01-01'::TIMESTAMP, NOW(), '1day')", write={ "postgres": "GENERATE_SERIES(CAST('2019-01-01' AS TIMESTAMP), CURRENT_TIMESTAMP, INTERVAL '1' day)", - "presto": "SEQUENCE(CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)", - "trino": "SEQUENCE(CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)", + "presto": "SEQUENCE(TRY_CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)", + "trino": "SEQUENCE(TRY_CAST('2019-01-01' AS TIMESTAMP), CAST(CURRENT_TIMESTAMP AS TIMESTAMP), INTERVAL '1' day)", }, ) self.validate_all( @@ -296,7 +339,10 @@ class TestPostgres(Validator): ) self.validate_all( """'{"a":1,"b":2}'::json->'b'""", - write={"postgres": """CAST('{"a":1,"b":2}' AS JSON) -> 'b'"""}, + write={ + "postgres": """CAST('{"a":1,"b":2}' AS JSON) -> 'b'""", + "redshift": """CAST('{"a":1,"b":2}' AS JSON)."b\"""", + }, ) self.validate_all( """'{"x": {"y": 1}}'::json->'x'->'y'""", @@ -326,7 +372,7 @@ class TestPostgres(Validator): """SELECT JSON_ARRAY_ELEMENTS((foo->'sections')::JSON) AS sections""", write={ "postgres": """SELECT JSON_ARRAY_ELEMENTS(CAST((foo -> 'sections') AS JSON)) AS sections""", - "presto": """SELECT JSON_ARRAY_ELEMENTS(CAST((JSON_EXTRACT(foo, 'sections')) AS JSON)) AS sections""", + "presto": """SELECT JSON_ARRAY_ELEMENTS(TRY_CAST((JSON_EXTRACT(foo, 'sections')) AS JSON)) AS sections""", }, ) self.validate_all( @@ -389,6 +435,36 @@ class TestPostgres(Validator): "spark": "TRIM(BOTH 'as' FROM 'as string as')", }, ) + self.validate_all( + "merge into x as x using (select id) as y on a = b WHEN matched then update set X.a = y.b", + write={ + "postgres": "MERGE INTO x AS x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", + "snowflake": "MERGE INTO x AS x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET X.a = y.b", + }, + ) + self.validate_all( + "merge into x as z using (select id) as y on a = b WHEN matched then update set X.a = y.b", + write={ + "postgres": "MERGE INTO x AS z USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", + "snowflake": "MERGE INTO x AS z USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET X.a = y.b", + }, + ) + self.validate_all( + "merge into x as z using (select id) as y on a = b WHEN matched then update set Z.a = y.b", + write={ + "postgres": "MERGE INTO x AS z USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", + "snowflake": "MERGE INTO x AS z USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET Z.a = y.b", + }, + ) + self.validate_all( + "merge into x using (select id) as y on a = b WHEN matched then update set x.a = y.b", + write={ + "postgres": "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET a = y.b", + "snowflake": "MERGE INTO x USING (SELECT id) AS y ON a = b WHEN MATCHED THEN UPDATE SET x.a = y.b", + }, + ) + + self.assertIsInstance(parse_one("id::UUID", read="postgres"), exp.TryCast) def test_bool_or(self): self.validate_all( diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 1007899..3080476 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -60,7 +60,7 @@ class TestPresto(Validator): self.validate_all( "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", write={ - "bigquery": "CAST(x AS TIMESTAMPTZ)", + "bigquery": "CAST(x AS TIMESTAMP)", "duckdb": "CAST(x AS TIMESTAMPTZ(9))", "presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", "hive": "CAST(x AS TIMESTAMP)", @@ -106,7 +106,33 @@ class TestPresto(Validator): }, ) + def test_interval_plural_to_singular(self): + # Microseconds, weeks and quarters are not supported in Presto/Trino INTERVAL literals + unit_to_expected = { + "SeCoNds": "second", + "minutes": "minute", + "hours": "hour", + "days": "day", + "months": "month", + "years": "year", + } + + for unit, expected in unit_to_expected.items(): + self.validate_all( + f"SELECT INTERVAL '1' {unit}", + write={ + "bigquery": f"SELECT INTERVAL '1' {expected}", + "presto": f"SELECT INTERVAL '1' {expected}", + "trino": f"SELECT INTERVAL '1' {expected}", + }, + ) + def test_time(self): + self.validate_identity("FROM_UNIXTIME(a, b)") + self.validate_identity("FROM_UNIXTIME(a, b, c)") + self.validate_identity("TRIM(a, b)") + self.validate_identity("VAR_POP(a)") + self.validate_all( "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", write={ @@ -158,10 +184,6 @@ class TestPresto(Validator): "spark": "FROM_UNIXTIME(x)", }, ) - self.validate_identity("FROM_UNIXTIME(a, b)") - self.validate_identity("FROM_UNIXTIME(a, b, c)") - self.validate_identity("TRIM(a, b)") - self.validate_identity("VAR_POP(a)") self.validate_all( "TO_UNIXTIME(x)", write={ @@ -243,7 +265,7 @@ class TestPresto(Validator): }, ) self.validate_all( - "CREATE TABLE test STORED = 'PARQUET' AS SELECT 1", + "CREATE TABLE test STORED AS 'PARQUET' AS SELECT 1", write={ "duckdb": "CREATE TABLE test AS SELECT 1", "presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1", @@ -362,6 +384,14 @@ class TestPresto(Validator): }, ) + self.validate_all( + "SELECT a FROM x CROSS JOIN UNNEST(ARRAY(y)) AS t (a) CROSS JOIN b", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(ARRAY[y]) AS t(a) CROSS JOIN b", + "hive": "SELECT a FROM x CROSS JOIN b LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", + }, + ) + def test_presto(self): self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)") self.validate_identity("SELECT * FROM (VALUES (1))") @@ -369,6 +399,9 @@ class TestPresto(Validator): self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ") self.validate_identity("APPROX_PERCENTILE(a, b, c, d)") + self.validate_all("INTERVAL '1 day'", write={"trino": "INTERVAL '1' day"}) + self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' week"}) + self.validate_all("(5 * INTERVAL '7' day)", read={"": "INTERVAL '5' WEEKS"}) self.validate_all( "SELECT JSON_OBJECT(KEY 'key1' VALUE 1, KEY 'key2' VALUE TRUE)", write={ @@ -643,3 +676,41 @@ class TestPresto(Validator): "presto": "SELECT CAST(ARRAY[1, 23, 456] AS JSON)", }, ) + + def test_explode_to_unnest(self): + self.validate_all( + "SELECT col FROM tbl CROSS JOIN UNNEST(x) AS _u(col)", + read={"spark": "SELECT EXPLODE(x) FROM tbl"}, + ) + self.validate_all( + "SELECT col_2 FROM _u CROSS JOIN UNNEST(col) AS _u_2(col_2)", + read={"spark": "SELECT EXPLODE(col) FROM _u"}, + ) + self.validate_all( + "SELECT exploded FROM schema.tbl CROSS JOIN UNNEST(col) AS _u(exploded)", + read={"spark": "SELECT EXPLODE(col) AS exploded FROM schema.tbl"}, + ) + self.validate_all( + "SELECT col FROM UNNEST(SEQUENCE(1, 2)) AS _u(col)", + read={"spark": "SELECT EXPLODE(SEQUENCE(1, 2))"}, + ) + self.validate_all( + "SELECT col FROM tbl AS t CROSS JOIN UNNEST(t.c) AS _u(col)", + read={"spark": "SELECT EXPLODE(t.c) FROM tbl t"}, + ) + self.validate_all( + "SELECT pos, col FROM UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u(col, pos)", + read={"spark": "SELECT POSEXPLODE(SEQUENCE(2, 3))"}, + ) + self.validate_all( + "SELECT pos, col FROM tbl CROSS JOIN UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u(col, pos)", + read={"spark": "SELECT POSEXPLODE(SEQUENCE(2, 3)) FROM tbl"}, + ) + self.validate_all( + "SELECT pos, col FROM tbl AS t CROSS JOIN UNNEST(t.c) WITH ORDINALITY AS _u(col, pos)", + read={"spark": "SELECT POSEXPLODE(t.c) FROM tbl t"}, + ) + self.validate_all( + "SELECT col, pos, pos_2, col_2 FROM _u CROSS JOIN UNNEST(SEQUENCE(2, 3)) WITH ORDINALITY AS _u_2(col_2, pos_2)", + read={"spark": "SELECT col, pos, POSEXPLODE(SEQUENCE(2, 3)) FROM _u"}, + ) diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 0933051..e5bd0e5 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -5,6 +5,44 @@ class TestRedshift(Validator): dialect = "redshift" def test_redshift(self): + self.validate_identity("SELECT * FROM #x") + self.validate_identity("SELECT INTERVAL '5 day'") + self.validate_identity("foo$") + self.validate_identity("$foo") + + self.validate_all( + "SELECT SNAPSHOT", + write={ + "": "SELECT SNAPSHOT", + "redshift": 'SELECT "SNAPSHOT"', + }, + ) + + self.validate_all( + "SELECT SYSDATE", + write={ + "": "SELECT CURRENT_TIMESTAMP()", + "postgres": "SELECT CURRENT_TIMESTAMP", + "redshift": "SELECT SYSDATE", + }, + ) + self.validate_all( + "SELECT DATE_PART(minute, timestamp '2023-01-04 04:05:06.789')", + write={ + "postgres": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", + "redshift": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMP))", + "snowflake": "SELECT EXTRACT(minute FROM CAST('2023-01-04 04:05:06.789' AS TIMESTAMPNTZ))", + }, + ) + self.validate_all( + "SELECT DATE_PART(month, date '20220502')", + write={ + "postgres": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + "redshift": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + "snowflake": "SELECT EXTRACT(month FROM CAST('20220502' AS DATE))", + }, + ) + self.validate_all("SELECT INTERVAL '5 day'", read={"": "SELECT INTERVAL '5' days"}) self.validate_all("CONVERT(INTEGER, x)", write={"redshift": "CAST(x AS INTEGER)"}) self.validate_all( "DATEADD('day', ndays, caldate)", write={"redshift": "DATEADD(day, ndays, caldate)"} @@ -27,7 +65,7 @@ class TestRedshift(Validator): "SELECT ST_AsEWKT(ST_GeomFromEWKT('SRID=4326;POINT(10 20)')::geography)", write={ "redshift": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", - "bigquery": "SELECT ST_ASEWKT(CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", + "bigquery": "SELECT ST_ASEWKT(TRY_CAST(ST_GEOMFROMEWKT('SRID=4326;POINT(10 20)') AS GEOGRAPHY))", }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index eb423a5..5c8b096 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -6,12 +6,16 @@ class TestSnowflake(Validator): dialect = "snowflake" def test_snowflake(self): + self.validate_identity("OBJECT_CONSTRUCT(*)") + self.validate_identity("SELECT TO_DATE('2019-02-28') + INTERVAL '1 day, 1 year'") + self.validate_identity("SELECT CAST('2021-01-01' AS DATE) + INTERVAL '1 DAY'") self.validate_identity("SELECT HLL(*)") self.validate_identity("SELECT HLL(a)") self.validate_identity("SELECT HLL(DISTINCT t.a)") self.validate_identity("SELECT HLL(a, b, c)") self.validate_identity("SELECT HLL(DISTINCT a, b, c)") - self.validate_identity("$x") + self.validate_identity("$x") # parameter + self.validate_identity("a$b") # valid snowflake identifier self.validate_identity("SELECT REGEXP_LIKE(a, b, c)") self.validate_identity("PUT file:///dir/tmp.csv @%table") self.validate_identity("CREATE TABLE foo (bar FLOAT AUTOINCREMENT START 0 INCREMENT 1)") @@ -255,19 +259,18 @@ class TestSnowflake(Validator): write={ "bigquery": "SELECT PARSE_TIMESTAMP('%Y-%m-%d %H:%M:%S', '2013-04-05 01:02:03')", "snowflake": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-mm-dd hh24:mi:ss')", - "spark": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-MM-dd HH:mm:ss')", + "spark": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-MM-d HH:mm:ss')", }, ) self.validate_all( - "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", + "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/DD/yyyy hh24:mi:ss')", read={ "bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')", "duckdb": "SELECT STRPTIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", - "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", }, write={ "bigquery": "SELECT PARSE_TIMESTAMP('%m/%d/%Y %H:%M:%S', '04/05/2013 01:02:03')", - "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", + "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/DD/yyyy hh24:mi:ss')", "spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')", }, ) @@ -841,11 +844,13 @@ MATCH_RECOGNIZE ( PARTITION BY a, b ORDER BY x DESC - MEASURES y AS b + MEASURES + y AS b {row} {after} PATTERN (^ S1 S2*? ( {{- S3 -}} S4 )+ | PERMUTE(S1, S2){{1,2}} $) - DEFINE x AS y + DEFINE + x AS y )""", pretty=True, ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 0da2931..bfaed53 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -214,6 +214,41 @@ TBLPROPERTIES ( self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')") self.validate_identity("SPLIT(str, pattern, lim)") + self.validate_all( + "BOOLEAN(x)", + write={ + "": "CAST(x AS BOOLEAN)", + "spark": "CAST(x AS BOOLEAN)", + }, + ) + self.validate_all( + "INT(x)", + write={ + "": "CAST(x AS INT)", + "spark": "CAST(x AS INT)", + }, + ) + self.validate_all( + "STRING(x)", + write={ + "": "CAST(x AS TEXT)", + "spark": "CAST(x AS STRING)", + }, + ) + self.validate_all( + "DATE(x)", + write={ + "": "CAST(x AS DATE)", + "spark": "CAST(x AS DATE)", + }, + ) + self.validate_all( + "TIMESTAMP(x)", + write={ + "": "CAST(x AS TIMESTAMP)", + "spark": "CAST(x AS TIMESTAMP)", + }, + ) self.validate_all( "CAST(x AS TIMESTAMP)", read={"trino": "CAST(x AS TIMESTAMP(6) WITH TIME ZONE)"} ) diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py index 35d8b45..b33231c 100644 --- a/tests/dialects/test_starrocks.py +++ b/tests/dialects/test_starrocks.py @@ -6,6 +6,7 @@ class TestMySQL(Validator): def test_identity(self): self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") + self.validate_identity("SELECT APPROX_COUNT_DISTINCT(a) FROM x") def test_time(self): self.validate_identity("TIMESTAMP('2022-01-01')") diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py index 5d4f7db..dcb513d 100644 --- a/tests/dialects/test_teradata.py +++ b/tests/dialects/test_teradata.py @@ -39,6 +39,31 @@ class TestTeradata(Validator): write={"teradata": "CREATE OR REPLACE VIEW a AS (SELECT b FROM c)"}, ) + self.validate_all( + "CREATE VOLATILE TABLE a", + write={ + "teradata": "CREATE VOLATILE TABLE a", + "bigquery": "CREATE TABLE a", + "clickhouse": "CREATE TABLE a", + "databricks": "CREATE TABLE a", + "drill": "CREATE TABLE a", + "duckdb": "CREATE TABLE a", + "hive": "CREATE TABLE a", + "mysql": "CREATE TABLE a", + "oracle": "CREATE TABLE a", + "postgres": "CREATE TABLE a", + "presto": "CREATE TABLE a", + "redshift": "CREATE TABLE a", + "snowflake": "CREATE TABLE a", + "spark": "CREATE TABLE a", + "sqlite": "CREATE TABLE a", + "starrocks": "CREATE TABLE a", + "tableau": "CREATE TABLE a", + "trino": "CREATE TABLE a", + "tsql": "CREATE TABLE a", + }, + ) + def test_insert(self): self.validate_all( "INS INTO x SELECT * FROM y", write={"teradata": "INSERT INTO x SELECT * FROM y"} @@ -71,3 +96,15 @@ class TestTeradata(Validator): ) self.validate_identity("CREATE TABLE z (a SYSUDTLIB.INT)") + + def test_cast(self): + self.validate_all( + "CAST('1992-01' AS DATE FORMAT 'YYYY-DD')", + write={ + "teradata": "CAST('1992-01' AS DATE FORMAT 'YYYY-DD')", + "databricks": "DATE_FORMAT('1992-01', 'YYYY-DD')", + "mysql": "DATE_FORMAT('1992-01', 'YYYY-DD')", + "spark": "DATE_FORMAT('1992-01', 'YYYY-DD')", + "": "TIME_TO_STR('1992-01', 'YYYY-DD')", + }, + ) diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index d9ee4ae..b6e893c 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -66,6 +66,54 @@ class TestTSQL(Validator): "postgres": "STRING_AGG(x, '|')", }, ) + self.validate_all( + "SELECT CAST([a].[b] AS SMALLINT) FROM foo", + write={ + "tsql": 'SELECT CAST("a"."b" AS SMALLINT) FROM foo', + "spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo", + }, + ) + self.validate_all( + "HASHBYTES('SHA1', x)", + read={ + "spark": "SHA(x)", + }, + write={ + "tsql": "HASHBYTES('SHA1', x)", + "spark": "SHA(x)", + }, + ) + self.validate_all( + "HASHBYTES('SHA2_256', x)", + read={ + "spark": "SHA2(x, 256)", + }, + write={ + "tsql": "HASHBYTES('SHA2_256', x)", + "spark": "SHA2(x, 256)", + }, + ) + self.validate_all( + "HASHBYTES('SHA2_512', x)", + read={ + "spark": "SHA2(x, 512)", + }, + write={ + "tsql": "HASHBYTES('SHA2_512', x)", + "spark": "SHA2(x, 512)", + }, + ) + self.validate_all( + "HASHBYTES('MD5', 'x')", + read={ + "spark": "MD5('x')", + }, + write={ + "tsql": "HASHBYTES('MD5', 'x')", + "spark": "MD5('x')", + }, + ) + self.validate_identity("HASHBYTES('MD2', 'x')") def test_types(self): self.validate_identity("CAST(x AS XML)") @@ -399,7 +447,7 @@ WHERE self.validate_all( "SELECT CONVERT(VARCHAR(10), testdb.dbo.test.x, 120) y FROM testdb.dbo.test", write={ - "mysql": "SELECT CAST(TIME_TO_STR(testdb.dbo.test.x, '%Y-%m-%d %H:%M:%S') AS VARCHAR(10)) AS y FROM testdb.dbo.test", + "mysql": "SELECT CAST(DATE_FORMAT(testdb.dbo.test.x, '%Y-%m-%d %T') AS VARCHAR(10)) AS y FROM testdb.dbo.test", "spark": "SELECT CAST(DATE_FORMAT(testdb.dbo.test.x, 'yyyy-MM-dd HH:mm:ss') AS VARCHAR(10)) AS y FROM testdb.dbo.test", }, ) @@ -482,6 +530,12 @@ WHERE "spark": "SELECT x.a, x.b, t.v, t.y FROM x LEFT JOIN LATERAL (SELECT v, y FROM t) AS t(v, y)", }, ) + self.validate_all( + "SELECT x.a, x.b, t.v, t.y, s.v, s.y FROM x OUTER APPLY (SELECT v, y FROM t) t(v, y) OUTER APPLY (SELECT v, y FROM t) s(v, y) LEFT JOIN z ON z.id = s.id", + write={ + "spark": "SELECT x.a, x.b, t.v, t.y, s.v, s.y FROM x LEFT JOIN LATERAL (SELECT v, y FROM t) AS t(v, y) LEFT JOIN LATERAL (SELECT v, y FROM t) AS s(v, y) LEFT JOIN z ON z.id = s.id", + }, + ) def test_lateral_table_valued_function(self): self.validate_all( @@ -631,3 +685,38 @@ WHERE "SUSER_SNAME()", write={"spark": "CURRENT_USER()"}, ) + self.validate_all( + "SYSTEM_USER()", + write={"spark": "CURRENT_USER()"}, + ) + self.validate_all( + "SYSTEM_USER", + write={"spark": "CURRENT_USER()"}, + ) + + def test_hints(self): + self.validate_all( + "SELECT x FROM a INNER HASH JOIN b ON b.id = a.id", + write={"spark": "SELECT x FROM a INNER JOIN b ON b.id = a.id"}, + ) + self.validate_all( + "SELECT x FROM a INNER LOOP JOIN b ON b.id = a.id", + write={"spark": "SELECT x FROM a INNER JOIN b ON b.id = a.id"}, + ) + self.validate_all( + "SELECT x FROM a INNER REMOTE JOIN b ON b.id = a.id", + write={"spark": "SELECT x FROM a INNER JOIN b ON b.id = a.id"}, + ) + self.validate_all( + "SELECT x FROM a INNER MERGE JOIN b ON b.id = a.id", + write={"spark": "SELECT x FROM a INNER JOIN b ON b.id = a.id"}, + ) + self.validate_all( + "SELECT x FROM a WITH (NOLOCK)", + write={ + "spark": "SELECT x FROM a", + "tsql": "SELECT x FROM a WITH (NOLOCK)", + "": "SELECT x FROM a WITH (NOLOCK)", + }, + ) + self.validate_identity("SELECT x FROM a INNER LOOP JOIN b ON b.id = a.id") diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 54e5583..a08a7a8 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -100,6 +100,9 @@ CURRENT_DATE AT TIME ZONE zone_column CURRENT_DATE AT TIME ZONE 'UTC' AT TIME ZONE 'Asia/Tokio' ARRAY() ARRAY(1, 2) +ARRAY(time, foo) +ARRAY(foo, time) +ARRAY(LENGTH(waiter_name) > 0) ARRAY_CONTAINS(x, 1) EXTRACT(x FROM y) EXTRACT(DATE FROM y) @@ -126,12 +129,14 @@ x ILIKE '%y%' ESCAPE '\' 1 AS escape INTERVAL '1' day INTERVAL '1' MONTH -INTERVAL '1 day' INTERVAL '-1' CURRENT_DATE INTERVAL '-31' CAST(GETDATE() AS DATE) -INTERVAL 2 months INTERVAL (1 + 3) DAYS +INTERVAL '1' day * 5 +5 * INTERVAL '1' day +CASE WHEN TRUE THEN INTERVAL '15' days END CAST('45' AS INTERVAL DAYS) +CAST(x AS UUID) FILTER(a, x -> x.a.b.c.d.e.f.g) FILTER(a, x -> FOO(x.a.b.c.d.e.f.g) + x.a.b.c.d.e.f.g) TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), 1, DAY) @@ -250,6 +255,8 @@ SELECT * FROM test LIMIT 1 + 1 SELECT * FROM test LIMIT 100 OFFSET 200 SELECT * FROM test FETCH FIRST ROWS ONLY SELECT * FROM test FETCH FIRST 1 ROWS ONLY +SELECT * FROM test ORDER BY id DESC FETCH FIRST 10 ROWS WITH TIES +SELECT * FROM test ORDER BY id DESC FETCH FIRST 10 PERCENT ROWS WITH TIES SELECT * FROM test FETCH NEXT 1 ROWS ONLY SELECT (1 > 2) AS x FROM test SELECT NOT (1 > 2) FROM test @@ -554,6 +561,7 @@ CREATE TABLE asd AS SELECT asd FROM asd WITH NO DATA CREATE TABLE asd AS SELECT asd FROM asd WITH DATA CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY) CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY) +CREATE TABLE konyvszerzo (szerzo_azon INT CONSTRAINT konyvszerzo_szerzo_fk REFERENCES szerzo) CREATE TABLE IF NOT EXISTS customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (INCREMENT BY 1)) CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1 MINVALUE -1 MAXVALUE 1 NO CYCLE)) CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10)) @@ -640,6 +648,7 @@ DELETE FROM y DELETE FROM event USING sales WHERE event.eventid = sales.eventid DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid +DELETE FROM event AS event USING sales AS s WHERE event.eventid = s.eventid PREPARE statement EXECUTE statement DROP TABLE a @@ -648,6 +657,7 @@ DROP TABLE IF EXISTS a DROP TABLE IF EXISTS a.b DROP TABLE a CASCADE DROP TABLE s_hajo CASCADE CONSTRAINTS +DROP TABLE a PURGE DROP VIEW a DROP VIEW a.b DROP VIEW IF EXISTS a @@ -717,12 +727,14 @@ SELECT a /* x */ /* y */ /* z */, b /* k */ /* m */ SELECT * FROM foo /* x */, bla /* x */ SELECT 1 /* comment */ + 1 SELECT 1 /* c1 */ + 2 /* c2 */ +SELECT 1 /* c1 */ + /* c2 */ 2 /* c3 */ SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */ SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */ SELECT x FROM a.b.c /* x */, e.f.g /* x */ SELECT FOO(x /* c */) /* FOO */, b /* b */ SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM (VALUES (1 /* c4 */, "test" /* c5 */)) /* c6 */ INSERT INTO foo SELECT * FROM bar /* comment */ +/* c */ WITH x AS (SELECT 1) SELECT * FROM x SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b' SELECT x AS INTO FROM bla SELECT * INTO newevent FROM event @@ -736,6 +748,7 @@ ALTER TABLE IF EXISTS integers ADD COLUMN k INT ALTER TABLE integers ADD COLUMN l INT DEFAULT 10 ALTER TABLE measurements ADD COLUMN mtime TIMESTAMPTZ DEFAULT NOW() ALTER TABLE integers DROP COLUMN k +ALTER TABLE integers DROP PRIMARY KEY ALTER TABLE integers DROP COLUMN IF EXISTS k ALTER TABLE integers DROP COLUMN k CASCADE ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR @@ -760,6 +773,7 @@ STRUCT("bla") STRUCT(5) STRUCT("2011-05-05") STRUCT(1, t.str_col) +STRUCT SELECT CAST(NULL AS ARRAY) IS NULL AS array_is_null ALTER TABLE "schema"."tablename" ADD CONSTRAINT "CHK_Name" CHECK (NOT "IdDwh" IS NULL AND "IdDwh" <> (0)) ALTER TABLE persons ADD CONSTRAINT persons_pk PRIMARY KEY (first_name, last_name) @@ -803,3 +817,6 @@ JSON_OBJECT('x': NULL, 'y': 1 WITH UNIQUE KEYS) JSON_OBJECT('x': NULL, 'y': 1 ABSENT ON NULL WITH UNIQUE KEYS) JSON_OBJECT('x': 1 RETURNING VARCHAR(100)) JSON_OBJECT('x': 1 RETURNING VARBINARY FORMAT JSON ENCODING UTF8) +SELECT if.x +SELECT NEXT VALUE FOR db.schema.sequence_name +SELECT NEXT VALUE FOR db.schema.sequence_name OVER (ORDER BY foo), col diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql index 7582f3a..ccf2f16 100644 --- a/tests/fixtures/optimizer/canonicalize.sql +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -11,7 +11,7 @@ SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w; SELECT 1 + 3.2 AS "a" FROM "w" AS "w"; SELECT CAST("2022-01-01" AS DATE) + INTERVAL '1' day; -SELECT CAST("2022-01-01" AS DATE) + INTERVAL '1' "day" AS "_col_0"; +SELECT CAST("2022-01-01" AS DATE) + INTERVAL '1' day AS "_col_0"; -------------------------------------- -- Ensure boolean predicates diff --git a/tests/fixtures/optimizer/normalize.sql b/tests/fixtures/optimizer/normalize.sql index a84fadf..803a474 100644 --- a/tests/fixtures/optimizer/normalize.sql +++ b/tests/fixtures/optimizer/normalize.sql @@ -39,3 +39,6 @@ A OR ((((B OR C) AND (B OR D)) OR C) AND (((B OR C) AND (B OR D)) OR D)); (A AND B) OR (C OR (D AND E)); (A OR C OR D) AND (A OR C OR E) AND (B OR C OR D) AND (B OR C OR E); + +SELECT * FROM x WHERE (A AND B) OR C; +SELECT * FROM x WHERE (A OR C) AND (B OR C); diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 74e2d0a..3013bba 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -131,6 +131,14 @@ SELECT DATE_TRUNC(x.a, MONTH) AS a FROM x AS x; SELECT x FROM READ_PARQUET('path.parquet', hive_partition=1); SELECT _q_0.x AS x FROM READ_PARQUET('path.parquet', hive_partition = 1) AS _q_0; +# execute: false +select * from (values (1, 2)); +SELECT _q_0._col_0 AS _col_0, _q_0._col_1 AS _col_1 FROM (VALUES (1, 2)) AS _q_0(_col_0, _col_1); + +# execute: false +select * from (values (1, 2)) x; +SELECT x._col_0 AS _col_0, x._col_1 AS _col_1 FROM (VALUES (1, 2)) AS x(_col_0, _col_1); + -------------------------------------- -- Derived tables -------------------------------------- @@ -317,6 +325,21 @@ SELECT COALESCE(y.b, z.b) AS b, COALESCE(y.c, z.c) AS c FROM y AS y JOIN z AS z SELECT * FROM y JOIN z USING(b, c) WHERE b = 2 AND c = 3; SELECT COALESCE(y.b, z.b) AS b, COALESCE(y.c, z.c) AS c FROM y AS y JOIN z AS z ON y.b = z.b AND y.c = z.c WHERE COALESCE(y.b, z.b) = 2 AND COALESCE(y.c, z.c) = 3; +-- We can safely convert `b` to `x.b` in the following two queries, because the original queries +-- would be invalid if `b` also existed in `t`'s schema (which we don't know), due to ambiguity. + +# execute: false +SELECT b FROM x JOIN t USING(a); +SELECT x.b AS b FROM x AS x JOIN t AS t ON x.a = t.a; + +# execute: false +SELECT b FROM t JOIN x USING(a); +SELECT x.b AS b FROM t AS t JOIN x AS x ON t.a = x.a; + +# execute: false +SELECT a FROM t1 JOIN t2 USING(a); +SELECT COALESCE(t1.a, t2.a) AS a FROM t1 AS t1 JOIN t2 AS t2 ON t1.a = t2.a; + -------------------------------------- -- Hint with table reference -------------------------------------- diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index 54ec64b..a2cd859 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -572,3 +572,6 @@ x > 3; 'a' < 'b'; TRUE; + +x = 2018 OR x <> 2018; +x <> 2018 OR x = 2018; \ No newline at end of file diff --git a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql index d9a06cc..9168508 100644 --- a/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql +++ b/tests/fixtures/optimizer/tpc-ds/tpc-ds.sql @@ -2500,7 +2500,7 @@ JOIN "date_dim" AS "date_dim" ON "catalog_sales"."cs_ship_date_sk" = "date_dim"."d_date_sk" AND "date_dim"."d_date" >= '2002-3-01' AND CAST("date_dim"."d_date" AS DATE) <= ( - CAST('2002-3-01' AS DATE) + INTERVAL '60' "day" + CAST('2002-3-01' AS DATE) + INTERVAL '60' day ) JOIN "customer_address" AS "customer_address" ON "catalog_sales"."cs_ship_addr_sk" = "customer_address"."ca_address_sk" @@ -9420,7 +9420,7 @@ JOIN "date_dim" AS "date_dim_2" AND "date_dim_2"."d_year" = 2002 JOIN "date_dim" AS "date_dim_3" ON "catalog_sales"."cs_ship_date_sk" = "date_dim_3"."d_date_sk" - AND "date_dim_3"."d_date" > CONCAT("date_dim_2"."d_date", INTERVAL '5' "day") + AND "date_dim_3"."d_date" > CONCAT("date_dim_2"."d_date", INTERVAL '5' day) LEFT JOIN "promotion" AS "promotion" ON "catalog_sales"."cs_promo_sk" = "promotion"."p_promo_sk" LEFT JOIN "catalog_returns" AS "catalog_returns" @@ -12200,7 +12200,7 @@ JOIN "date_dim" AS "date_dim" ON "date_dim"."d_date" >= '2000-3-01' AND "web_sales"."ws_ship_date_sk" = "date_dim"."d_date_sk" AND CAST("date_dim"."d_date" AS DATE) <= ( - CAST('2000-3-01' AS DATE) + INTERVAL '60' "day" + CAST('2000-3-01' AS DATE) + INTERVAL '60' day ) JOIN "customer_address" AS "customer_address" ON "customer_address"."ca_state" = 'MT' @@ -12295,7 +12295,7 @@ JOIN "date_dim" AS "date_dim" ON "date_dim"."d_date" >= '2000-4-01' AND "web_sales"."ws_ship_date_sk" = "date_dim"."d_date_sk" AND CAST("date_dim"."d_date" AS DATE) <= ( - CAST('2000-4-01' AS DATE) + INTERVAL '60' "day" + CAST('2000-4-01' AS DATE) + INTERVAL '60' day ) JOIN "customer_address" AS "customer_address" ON "customer_address"."ca_state" = 'IN' diff --git a/tests/test_build.py b/tests/test_build.py index 43707b0..c4b97ce 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -18,7 +18,59 @@ from sqlglot import ( class TestBuild(unittest.TestCase): def test_build(self): + x = condition("x") + for expression, sql, *dialect in [ + (lambda: x + 1, "x + 1"), + (lambda: 1 + x, "1 + x"), + (lambda: x - 1, "x - 1"), + (lambda: 1 - x, "1 - x"), + (lambda: x * 1, "x * 1"), + (lambda: 1 * x, "1 * x"), + (lambda: x / 1, "x / 1"), + (lambda: 1 / x, "1 / x"), + (lambda: x // 1, "CAST(x / 1 AS INT)"), + (lambda: 1 // x, "CAST(1 / x AS INT)"), + (lambda: x % 1, "x % 1"), + (lambda: 1 % x, "1 % x"), + (lambda: x**1, "POWER(x, 1)"), + (lambda: 1**x, "POWER(1, x)"), + (lambda: x & 1, "x AND 1"), + (lambda: 1 & x, "1 AND x"), + (lambda: x | 1, "x OR 1"), + (lambda: 1 | x, "1 OR x"), + (lambda: x < 1, "x < 1"), + (lambda: 1 < x, "x > 1"), + (lambda: x <= 1, "x <= 1"), + (lambda: 1 <= x, "x >= 1"), + (lambda: x > 1, "x > 1"), + (lambda: 1 > x, "x < 1"), + (lambda: x >= 1, "x >= 1"), + (lambda: 1 >= x, "x <= 1"), + (lambda: x.eq(1), "x = 1"), + (lambda: x.neq(1), "x <> 1"), + (lambda: x.isin(1, "2"), "x IN (1, '2')"), + (lambda: x.isin(query="select 1"), "x IN (SELECT 1)"), + (lambda: 1 + x + 2 + 3, "1 + x + 2 + 3"), + (lambda: 1 + x * 2 + 3, "1 + (x * 2) + 3"), + (lambda: x * 1 * 2 + 3, "(x * 1 * 2) + 3"), + (lambda: 1 + (x * 2) / 3, "1 + ((x * 2) / 3)"), + (lambda: x & "y", "x AND 'y'"), + (lambda: x | "y", "x OR 'y'"), + (lambda: -x, "-x"), + (lambda: ~x, "NOT x"), + (lambda: x[1], "x[1]"), + (lambda: x[1, 2], "x[1, 2]"), + (lambda: x["y"] + 1, "x['y'] + 1"), + (lambda: x.like("y"), "x LIKE 'y'"), + (lambda: x.ilike("y"), "x ILIKE 'y'"), + (lambda: x.rlike("y"), "REGEXP_LIKE(x, 'y')"), + ( + lambda: exp.Case().when("x = 1", "x").else_("bar"), + "CASE WHEN x = 1 THEN x ELSE bar END", + ), + (lambda: exp.func("COALESCE", "x", 1), "COALESCE(x, 1)"), + (lambda: select("x"), "SELECT x"), (lambda: select("x"), "SELECT x"), (lambda: select("x", "y"), "SELECT x, y"), (lambda: select("x").from_("tbl"), "SELECT x FROM tbl"), diff --git a/tests/test_expressions.py b/tests/test_expressions.py index b09b2ab..eb0cf56 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -244,11 +244,11 @@ class TestExpressions(unittest.TestCase): def test_function_building(self): self.assertEqual(exp.func("max", 1).sql(), "MAX(1)") self.assertEqual(exp.func("max", 1, 2).sql(), "MAX(1, 2)") - self.assertEqual(exp.func("bla", 1, "foo").sql(), "BLA(1, 'foo')") + self.assertEqual(exp.func("bla", 1, "foo").sql(), "BLA(1, foo)") self.assertEqual(exp.func("COUNT", exp.Star()).sql(), "COUNT(*)") self.assertEqual(exp.func("bloo").sql(), "BLOO()") self.assertEqual( - exp.func("locate", "x", "xo", dialect="hive").sql("hive"), "LOCATE('x', 'xo')" + exp.func("locate", "'x'", "'xo'", dialect="hive").sql("hive"), "LOCATE('x', 'xo')" ) self.assertIsInstance(exp.func("instr", "x", "b", dialect="mysql"), exp.StrPosition) @@ -528,6 +528,7 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop) self.assertIsInstance(parse_one("YEAR(a)"), exp.Year) self.assertIsInstance(parse_one("HLL(a)"), exp.Hll) + self.assertIsInstance(parse_one("ARRAY(time, foo)"), exp.Array) def test_column(self): column = parse_one("a.b.c.d") diff --git a/tests/test_lineage.py b/tests/test_lineage.py index 7a48605..1d13dd3 100644 --- a/tests/test_lineage.py +++ b/tests/test_lineage.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from sqlglot.lineage import lineage @@ -9,12 +11,146 @@ class TestLineage(unittest.TestCase): def test_lineage(self) -> None: node = lineage( "a", - "SELECT a FROM y", + "SELECT a FROM z", schema={"x": {"a": "int"}}, - sources={"y": "SELECT * FROM x"}, + sources={"y": "SELECT * FROM x", "z": "SELECT a FROM y"}, ) self.assertEqual( node.source.sql(), + "SELECT z.a AS a FROM (SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */) AS z /* source: z */", + ) + self.assertEqual(node.alias, "") + + downstream = node.downstream[0] + self.assertEqual( + downstream.source.sql(), "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */", ) + self.assertEqual(downstream.alias, "z") + + downstream = downstream.downstream[0] + self.assertEqual( + downstream.source.sql(), + "SELECT x.a AS a FROM x AS x", + ) + self.assertEqual(downstream.alias, "y") self.assertGreater(len(node.to_html()._repr_html_()), 1000) + + def test_lineage_sql_with_cte(self) -> None: + node = lineage( + "a", + "WITH z AS (SELECT a FROM y) SELECT a FROM z", + schema={"x": {"a": "int"}}, + sources={"y": "SELECT * FROM x"}, + ) + self.assertEqual( + node.source.sql(), + "WITH z AS (SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */) SELECT z.a AS a FROM z", + ) + self.assertEqual(node.alias, "") + + # Node containing expanded CTE expression + downstream = node.downstream[0] + self.assertEqual( + downstream.source.sql(), + "SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y /* source: y */", + ) + self.assertEqual(downstream.alias, "") + + downstream = downstream.downstream[0] + self.assertEqual( + downstream.source.sql(), + "SELECT x.a AS a FROM x AS x", + ) + self.assertEqual(downstream.alias, "y") + + def test_lineage_source_with_cte(self) -> None: + node = lineage( + "a", + "SELECT a FROM z", + schema={"x": {"a": "int"}}, + sources={"z": "WITH y AS (SELECT * FROM x) SELECT a FROM y"}, + ) + self.assertEqual( + node.source.sql(), + "SELECT z.a AS a FROM (WITH y AS (SELECT x.a AS a FROM x AS x) SELECT y.a AS a FROM y) AS z /* source: z */", + ) + self.assertEqual(node.alias, "") + + downstream = node.downstream[0] + self.assertEqual( + downstream.source.sql(), + "WITH y AS (SELECT x.a AS a FROM x AS x) SELECT y.a AS a FROM y", + ) + self.assertEqual(downstream.alias, "z") + + downstream = downstream.downstream[0] + self.assertEqual( + downstream.source.sql(), + "SELECT x.a AS a FROM x AS x", + ) + self.assertEqual(downstream.alias, "") + + def test_lineage_source_with_star(self) -> None: + node = lineage( + "a", + "WITH y AS (SELECT * FROM x) SELECT a FROM y", + ) + self.assertEqual( + node.source.sql(), + "WITH y AS (SELECT * FROM x AS x) SELECT y.a AS a FROM y", + ) + self.assertEqual(node.alias, "") + + downstream = node.downstream[0] + self.assertEqual( + downstream.source.sql(), + "SELECT * FROM x AS x", + ) + self.assertEqual(downstream.alias, "") + + def test_lineage_external_col(self) -> None: + node = lineage( + "a", + "WITH y AS (SELECT * FROM x) SELECT a FROM y JOIN z USING (uid)", + ) + self.assertEqual( + node.source.sql(), + "WITH y AS (SELECT * FROM x AS x) SELECT a AS a FROM y JOIN z AS z ON y.uid = z.uid", + ) + self.assertEqual(node.alias, "") + + downstream = node.downstream[0] + self.assertEqual( + downstream.source.sql(), + "?", + ) + self.assertEqual(downstream.alias, "") + + def test_lineage_values(self) -> None: + node = lineage( + "a", + "SELECT a FROM y", + sources={"y": "SELECT a FROM (VALUES (1), (2)) AS t (a)"}, + ) + self.assertEqual( + node.source.sql(), + "SELECT y.a AS a FROM (SELECT t.a AS a FROM (VALUES (1), (2)) AS t(a)) AS y /* source: y */", + ) + self.assertEqual(node.alias, "") + + downstream = node.downstream[0] + self.assertEqual( + downstream.source.sql(), + "SELECT t.a AS a FROM (VALUES (1), (2)) AS t(a)", + ) + self.assertEqual(downstream.expression.sql(), "t.a AS a") + self.assertEqual(downstream.alias, "y") + + downstream = downstream.downstream[0] + self.assertEqual( + downstream.source.sql(), + "(VALUES (1), (2)) AS t(a)", + ) + self.assertEqual(downstream.expression.sql(), "a") + self.assertEqual(downstream.alias, "") diff --git a/tests/test_parser.py b/tests/test_parser.py index 07a5fd7..816471e 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -22,7 +22,7 @@ class TestParser(unittest.TestCase): { "description": "Invalid expression / Unexpected token", "line": 1, - "col": 1, + "col": 7, "start_context": "", "highlight": "SELECT", "end_context": " 1;", @@ -40,7 +40,7 @@ class TestParser(unittest.TestCase): { "description": "Invalid expression / Unexpected token", "line": 1, - "col": 1, + "col": 7, "start_context": "", "highlight": "SELECT", "end_context": " 1;", @@ -49,7 +49,7 @@ class TestParser(unittest.TestCase): { "description": "Invalid expression / Unexpected token", "line": 1, - "col": 1, + "col": 7, "start_context": "", "highlight": "SELECT", "end_context": " 1;", @@ -112,6 +112,8 @@ class TestParser(unittest.TestCase): self.assertIsInstance(lambda_expr.this.this, exp.Dot) self.assertEqual(lambda_expr.sql(), "x -> x.id = id") + self.assertIsNone(parse_one("FILTER([], x -> x)").find(exp.Column)) + def test_transactions(self): expression = parse_one("BEGIN TRANSACTION") self.assertIsNone(expression.this) @@ -222,6 +224,7 @@ class TestParser(unittest.TestCase): self.assertEqual(parse_one("SELECT @x, @@x, @1").sql(), "SELECT @x, @@x, @1") def test_var(self): + self.assertIsInstance(parse_one("INTERVAL '1' DAY").args["unit"], exp.Var) self.assertEqual(parse_one("SELECT @JOIN, @'foo'").sql(), "SELECT @JOIN, @'foo'") def test_comments(self): @@ -374,3 +377,96 @@ class TestParser(unittest.TestCase): parse_one("ALTER TABLE foo RENAME TO bar").sql(), "ALTER TABLE foo RENAME TO bar", ) + + def test_pivot_columns(self): + nothing_aliased = """ + SELECT * FROM ( + SELECT partname, price FROM part + ) PIVOT (AVG(price) FOR partname IN ('prop', 'rudder')) + """ + + everything_aliased = """ + SELECT * FROM ( + SELECT partname, price FROM part + ) PIVOT (AVG(price) AS avg_price FOR partname IN ('prop' AS prop1, 'rudder' AS rudder1)) + """ + + only_pivot_columns_aliased = """ + SELECT * FROM ( + SELECT partname, price FROM part + ) PIVOT (AVG(price) FOR partname IN ('prop' AS prop1, 'rudder' AS rudder1)) + """ + + columns_partially_aliased = """ + SELECT * FROM ( + SELECT partname, price FROM part + ) PIVOT (AVG(price) FOR partname IN ('prop' AS prop1, 'rudder')) + """ + + multiple_aggregates_aliased = """ + SELECT * FROM ( + SELECT partname, price, quality FROM part + ) PIVOT (AVG(price) AS p, MAX(quality) AS q FOR partname IN ('prop' AS prop1, 'rudder')) + """ + + multiple_aggregates_not_aliased = """ + SELECT * FROM ( + SELECT partname, price, quality FROM part + ) PIVOT (AVG(price), MAX(quality) FOR partname IN ('prop' AS prop1, 'rudder')) + """ + + multiple_aggregates_not_aliased_with_quoted_identifier = """ + SELECT * FROM ( + SELECT partname, price, quality FROM part + ) PIVOT (AVG(`PrIcE`), MAX(quality) FOR partname IN ('prop' AS prop1, 'rudder')) + """ + + query_to_column_names = { + nothing_aliased: { + "bigquery": ["prop", "rudder"], + "redshift": ["prop", "rudder"], + "snowflake": ['"prop"', '"rudder"'], + "spark": ["prop", "rudder"], + }, + everything_aliased: { + "bigquery": ["avg_price_prop1", "avg_price_rudder1"], + "redshift": ["prop1_avg_price", "rudder1_avg_price"], + "spark": ["prop1", "rudder1"], + }, + only_pivot_columns_aliased: { + "bigquery": ["prop1", "rudder1"], + "redshift": ["prop1", "rudder1"], + "spark": ["prop1", "rudder1"], + }, + columns_partially_aliased: { + "bigquery": ["prop1", "rudder"], + "redshift": ["prop1", "rudder"], + "spark": ["prop1", "rudder"], + }, + multiple_aggregates_aliased: { + "bigquery": ["p_prop1", "q_prop1", "p_rudder", "q_rudder"], + "spark": ["prop1_p", "prop1_q", "rudder_p", "rudder_q"], + }, + multiple_aggregates_not_aliased: { + "spark": [ + "`prop1_avg(price)`", + "`prop1_max(quality)`", + "`rudder_avg(price)`", + "`rudder_max(quality)`", + ], + }, + multiple_aggregates_not_aliased_with_quoted_identifier: { + "spark": [ + "`prop1_avg(PrIcE)`", + "`prop1_max(quality)`", + "`rudder_avg(PrIcE)`", + "`rudder_max(quality)`", + ], + }, + } + + for query, dialect_columns in query_to_column_names.items(): + for dialect, expected_columns in dialect_columns.items(): + expr = parse_one(query, read=dialect) + columns = expr.args["from"].expressions[0].args["pivots"][0].args["columns"] + self.assertEqual(expected_columns, [col.sql(dialect=dialect) for col in columns]) diff --git a/tests/test_schema.py b/tests/test_schema.py index dc7e5b2..92cf04a 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -163,8 +163,8 @@ class TestSchema(unittest.TestCase): self.assertEqual(schema.column_names("test"), ["x", "y"]) def test_schema_get_column_type(self): - schema = MappingSchema({"a": {"b": "varchar"}}) - self.assertEqual(schema.get_column_type("a", "b").this, exp.DataType.Type.VARCHAR) + schema = MappingSchema({"A": {"b": "varchar"}}) + self.assertEqual(schema.get_column_type("a", "B").this, exp.DataType.Type.VARCHAR) self.assertEqual( schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")).this, exp.DataType.Type.VARCHAR, @@ -213,3 +213,11 @@ class TestSchema(unittest.TestCase): # Clickhouse supports both `` and "" for identifier quotes; sqlglot uses "" when generating sql schema = MappingSchema(schema={"x": {"`y`": "INT"}}, dialect="clickhouse") self.assertEqual(schema.column_names(exp.Table(this="x")), ["y"]) + + # Check that add_table normalizes both the table and the column names to be added/updated + schema = MappingSchema() + schema.add_table("Foo", {"SomeColumn": "INT", '"SomeColumn"': "DOUBLE"}) + + table_foo = exp.Table(this="fOO") + + self.assertEqual(schema.column_names(table_foo), ["somecolumn", "SomeColumn"]) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 8481f4d..987c60b 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -14,6 +14,7 @@ class TestTokens(unittest.TestCase): ("foo", []), ("foo /*comment 1*/ /*comment 2*/", ["comment 1", "comment 2"]), ("foo\n-- comment", [" comment"]), + ("1 /*/2 */", ["/2 "]), ] for sql, comment in sql_comment: @@ -22,14 +23,17 @@ class TestTokens(unittest.TestCase): def test_token_line(self): tokens = Tokenizer().tokenize( """SELECT /* - line break - */ - 'x - y', - x""" +line break +*/ +'x + y', +x""" ) + self.assertEqual(tokens[1].line, 5) + self.assertEqual(tokens[1].col, 3) self.assertEqual(tokens[-1].line, 6) + self.assertEqual(tokens[-1].col, 1) def test_command(self): tokens = Tokenizer().tokenize("SHOW;") @@ -46,6 +50,10 @@ class TestTokens(unittest.TestCase): self.assertEqual(tokens[2].token_type, TokenType.SHOW) self.assertEqual(tokens[3].token_type, TokenType.SEMICOLON) + def test_error_msg(self): + with self.assertRaisesRegex(ValueError, "Error tokenizing 'select.*"): + Tokenizer().tokenize("select /*") + def test_jinja(self): tokenizer = Tokenizer() diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 1e85b80..24d8c30 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -106,6 +106,11 @@ class TestTransforms(unittest.TestCase): "SELECT c2, SUM(c3) OVER (PARTITION BY c2) AS r FROM t1 WHERE c3 < 4 GROUP BY c2, c3 HAVING SUM(c1) > 3 QUALIFY r IN (SELECT MIN(c1) FROM test GROUP BY c2 HAVING MIN(c1) > 3)", "SELECT c2, r FROM (SELECT c2, SUM(c3) OVER (PARTITION BY c2) AS r, c1 FROM t1 WHERE c3 < 4 GROUP BY c2, c3 HAVING SUM(c1) > 3) AS _t WHERE r IN (SELECT MIN(c1) FROM test GROUP BY c2 HAVING MIN(c1) > 3)", ) + self.validate( + eliminate_qualify, + "SELECT x FROM y QUALIFY ROW_NUMBER() OVER (PARTITION BY p)", + "SELECT x FROM (SELECT x, ROW_NUMBER() OVER (PARTITION BY p) AS _w, p FROM y) AS _t WHERE _w", + ) def test_remove_precision_parameterized_types(self): self.validate( diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 36e0aa6..d68f6f8 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -20,6 +20,9 @@ class TestTranspile(unittest.TestCase): self.assertEqual(transpile(sql, **kwargs)[0], target) def test_alias(self): + self.assertEqual(transpile("SELECT SUM(y) KEEP")[0], "SELECT SUM(y) AS KEEP") + self.assertEqual(transpile("SELECT 1 overwrite")[0], "SELECT 1 AS overwrite") + self.assertEqual(transpile("SELECT 1 is")[0], "SELECT 1 AS is") self.assertEqual(transpile("SELECT 1 current_time")[0], "SELECT 1 AS current_time") self.assertEqual( transpile("SELECT 1 current_timestamp")[0], "SELECT 1 AS current_timestamp" @@ -87,6 +90,7 @@ class TestTranspile(unittest.TestCase): self.validate("SELECT 3>=3", "SELECT 3 >= 3") def test_comments(self): + self.validate("SELECT 1 /*/2 */", "SELECT 1 /* /2 */") self.validate("SELECT */*comment*/", "SELECT * /* comment */") self.validate( "SELECT * FROM table /*comment 1*/ /*comment 2*/", @@ -200,6 +204,65 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", read="mysql", pretty=True, ) + self.validate( + """ + SELECT a FROM b + WHERE foo + -- comment 1 + AND bar + -- comment 2 + AND bla; + """, + "SELECT a FROM b WHERE foo AND /* comment 1 */ bar AND /* comment 2 */ bla", + ) + self.validate( + """ + SELECT a FROM b WHERE foo + -- comment 1 + """, + "SELECT a FROM b WHERE foo /* comment 1 */", + ) + self.validate( + """ + select a from b + where foo + -- comment 1 + and bar + -- comment 2 + and bla + """, + """SELECT + a +FROM b +WHERE + foo /* comment 1 */ AND bar AND bla /* comment 2 */""", + pretty=True, + ) + self.validate( + """ + -- test + WITH v AS ( + SELECT + 1 AS literal + ) + SELECT + * + FROM v + """, + """/* test */ +WITH v AS ( + SELECT + 1 AS literal +) +SELECT + * +FROM v""", + pretty=True, + ) + self.validate( + "(/* 1 */ 1 ) /* 2 */", + "(1) /* 1 */ /* 2 */", + ) def test_types(self): self.validate("INT 1", "CAST(1 AS INT)") @@ -288,7 +351,6 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", "ALTER TABLE integers ADD k INTEGER", "ALTER TABLE integers ADD COLUMN k INT", ) - self.validate("ALTER TABLE integers DROP k", "ALTER TABLE integers DROP COLUMN k") self.validate( "ALTER TABLE integers ALTER i SET DATA TYPE VARCHAR", "ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR", @@ -299,6 +361,11 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", ) def test_time(self): + self.validate("INTERVAL '1 day'", "INTERVAL '1' day") + self.validate("INTERVAL '1 days' * 5", "INTERVAL '1' days * 5") + self.validate("5 * INTERVAL '1 day'", "5 * INTERVAL '1' day") + self.validate("INTERVAL 1 day", "INTERVAL '1' day") + self.validate("INTERVAL 2 months", "INTERVAL '2' months") self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)") self.validate("TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)") self.validate( @@ -431,6 +498,13 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", mock_logger.warning.assert_any_call("Applying array index offset (%s)", 1) mock_logger.warning.assert_any_call("Applying array index offset (%s)", -1) + self.validate("x[x - 1]", "x[x - 1]", write="presto", identity=False) + self.validate( + "x[array_size(y) - 1]", "x[CARDINALITY(y) - 1 + 1]", write="presto", identity=False + ) + self.validate("x[3 - 1]", "x[3]", write="presto", identity=False) + self.validate("MAP(a, b)[0]", "MAP(a, b)[0]", write="presto", identity=False) + def test_identify_lambda(self): self.validate("x(y -> y)", 'X("y" -> "y")', identify=True) @@ -467,14 +541,14 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", def test_error_level(self, logger): invalid = "x + 1. (" expected_messages = [ - "Required keyword: 'expressions' missing for . Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", - "Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", + "Required keyword: 'expressions' missing for . Line 1, Col: 9.\n x + 1. \033[4m(\033[0m", + "Expecting ). Line 1, Col: 9.\n x + 1. \033[4m(\033[0m", ] expected_errors = [ { "description": "Required keyword: 'expressions' missing for ", "line": 1, - "col": 8, + "col": 9, "start_context": "x + 1. ", "highlight": "(", "end_context": "", @@ -483,7 +557,7 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", { "description": "Expecting )", "line": 1, - "col": 8, + "col": 9, "start_context": "x + 1. ", "highlight": "(", "end_context": "", @@ -507,16 +581,16 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", more_than_max_errors = "((((" expected_messages = ( - "Required keyword: 'this' missing for . Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" - "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" - "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" + "Required keyword: 'this' missing for . Line 1, Col: 5.\n (((\033[4m(\033[0m\n\n" + "Expecting ). Line 1, Col: 5.\n (((\033[4m(\033[0m\n\n" + "Expecting ). Line 1, Col: 5.\n (((\033[4m(\033[0m\n\n" "... and 2 more" ) expected_errors = [ { "description": "Required keyword: 'this' missing for ", "line": 1, - "col": 4, + "col": 5, "start_context": "(((", "highlight": "(", "end_context": "", @@ -525,7 +599,7 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", { "description": "Expecting )", "line": 1, - "col": 4, + "col": 5, "start_context": "(((", "highlight": "(", "end_context": "", -- cgit v1.2.3