From f2981e8e4d28233864f1ca06ecec45ab80bf9eae Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 19 Nov 2022 15:50:39 +0100 Subject: Merging upstream version 10.0.8. Signed-off-by: Daniel Baumann --- tests/dataframe/unit/test_dataframe.py | 20 +- tests/dataframe/unit/test_dataframe_writer.py | 36 +-- tests/dataframe/unit/test_session.py | 17 +- tests/dialects/test_bigquery.py | 4 + tests/dialects/test_dialect.py | 81 ++++++ tests/dialects/test_drill.py | 53 ++++ tests/dialects/test_mysql.py | 10 + tests/dialects/test_presto.py | 75 +++++ tests/dialects/test_snowflake.py | 11 + tests/fixtures/identity.sql | 21 +- tests/fixtures/optimizer/canonicalize.sql | 5 + tests/fixtures/optimizer/optimizer.sql | 4 +- tests/fixtures/optimizer/tpc-h/tpc-h.sql | 50 ++-- tests/fixtures/pretty.sql | 7 + tests/helpers.py | 64 ++-- tests/test_executor.py | 403 +++++++++++++++++++++++++- tests/test_expressions.py | 13 +- tests/test_optimizer.py | 19 ++ tests/test_parser.py | 57 +++- tests/test_tokens.py | 1 + tests/test_transpile.py | 11 + 21 files changed, 837 insertions(+), 125 deletions(-) create mode 100644 tests/dialects/test_drill.py create mode 100644 tests/fixtures/optimizer/canonicalize.sql (limited to 'tests') diff --git a/tests/dataframe/unit/test_dataframe.py b/tests/dataframe/unit/test_dataframe.py index e36667b..24850bc 100644 --- a/tests/dataframe/unit/test_dataframe.py +++ b/tests/dataframe/unit/test_dataframe.py @@ -4,6 +4,8 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator class TestDataframe(DataFrameSQLValidator): + maxDiff = None + def test_hash_select_expression(self): expression = exp.select("cola").from_("table") self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression)) @@ -16,26 +18,26 @@ class TestDataframe(DataFrameSQLValidator): def test_cache(self): df = self.df_employee.select("fname").cache() expected_statements = [ - "DROP VIEW IF EXISTS t11623", - "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", + "DROP VIEW IF EXISTS t31563", + "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`", ] self.compare_sql(df, expected_statements) def test_persist_default(self): df = self.df_employee.select("fname").persist() expected_statements = [ - "DROP VIEW IF EXISTS t11623", - "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", + "DROP VIEW IF EXISTS t31563", + "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`", ] self.compare_sql(df, expected_statements) def test_persist_storagelevel(self): df = self.df_employee.select("fname").persist("DISK_ONLY_2") expected_statements = [ - "DROP VIEW IF EXISTS t11623", - "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`", + "DROP VIEW IF EXISTS t31563", + "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`", ] self.compare_sql(df, expected_statements) diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py index 14b4a0a..7c646f5 100644 --- a/tests/dataframe/unit/test_dataframe_writer.py +++ b/tests/dataframe/unit/test_dataframe_writer.py @@ -6,39 +6,41 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator class TestDataFrameWriter(DataFrameSQLValidator): + maxDiff = None + def test_insertInto_full_path(self): df = self.df_employee.write.insertInto("catalog.db.table_name") - expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_db_table(self): df = self.df_employee.write.insertInto("db.table_name") - expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_table(self): df = self.df_employee.write.insertInto("table_name") - expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_overwrite(self): df = self.df_employee.write.insertInto("table_name", overwrite=True) - expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) @mock.patch("sqlglot.schema", MappingSchema()) def test_insertInto_byName(self): sqlglot.schema.add_table("table_name", {"employee_id": "INT"}) df = self.df_employee.write.byName.insertInto("table_name") - expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_cache(self): df = self.df_employee.cache().write.insertInto("table_name") expected_statements = [ - "DROP VIEW IF EXISTS t35612", - "CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`", + "DROP VIEW IF EXISTS t37164", + "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`", ] self.compare_sql(df, expected_statements) @@ -48,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator): def test_saveAsTable_append(self): df = self.df_employee.write.saveAsTable("table_name", mode="append") - expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_overwrite(self): df = self.df_employee.write.saveAsTable("table_name", mode="overwrite") - expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_error(self): df = self.df_employee.write.saveAsTable("table_name", mode="error") - expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_ignore(self): df = self.df_employee.write.saveAsTable("table_name", mode="ignore") - expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_mode_standalone(self): df = self.df_employee.write.mode("ignore").saveAsTable("table_name") - expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_mode_override(self): df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite") - expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_cache(self): df = self.df_employee.cache().write.saveAsTable("table_name") expected_statements = [ - "DROP VIEW IF EXISTS t35612", - "CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`", + "DROP VIEW IF EXISTS t37164", + "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`", ] self.compare_sql(df, expected_statements) diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py index 7e8bfad..55aa547 100644 --- a/tests/dataframe/unit/test_session.py +++ b/tests/dataframe/unit/test_session.py @@ -11,32 +11,32 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator class TestDataframeSession(DataFrameSQLValidator): def test_cdf_one_row(self): df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"]) - expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)" + expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2) AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_cdf_multiple_rows(self): df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"]) - expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)" + expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_cdf_no_schema(self): df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]]) - expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)" + expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`_1`, `_2`)" self.compare_sql(df, expected) def test_cdf_row_mixed_primitives(self): df = self.spark.createDataFrame([[1, 10.1, "test", False, None]]) - expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)" + expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM VALUES (1, 10.1, 'test', FALSE, NULL) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)" self.compare_sql(df, expected) def test_cdf_dict_rows(self): df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}]) - expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)" + expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 'test'), (2, 'test2') AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_cdf_str_schema(self): df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING") - expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)" + expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_typed_schema_basic(self): @@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator): ] ) df = self.spark.createDataFrame([[1, "test"]], schema) - expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)" + expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_typed_schema_nested(self): @@ -65,7 +65,8 @@ class TestDataframeSession(DataFrameSQLValidator): ] ) df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema) - expected = "SELECT CAST(`a2`.`cola` AS struct) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)" + expected = "SELECT CAST(`a2`.`cola` AS STRUCT<`sub_cola`: INT, `sub_colb`: STRING>) AS `cola` FROM VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`)) AS `a2`(`cola`)" + self.compare_sql(df, expected) @mock.patch("sqlglot.schema", MappingSchema()) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index a0ebc45..cc44311 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -286,6 +286,10 @@ class TestBigQuery(Validator): "bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))", }, ) + self.validate_identity("BEGIN A B C D E F") + self.validate_identity("BEGIN TRANSACTION") + self.validate_identity("COMMIT TRANSACTION") + self.validate_identity("ROLLBACK TRANSACTION") def test_user_defined_functions(self): self.validate_identity( diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 1913f53..1b2f9c1 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -69,6 +69,7 @@ class TestDialect(Validator): write={ "bigquery": "CAST(a AS STRING)", "clickhouse": "CAST(a AS TEXT)", + "drill": "CAST(a AS VARCHAR)", "duckdb": "CAST(a AS TEXT)", "mysql": "CAST(a AS TEXT)", "hive": "CAST(a AS STRING)", @@ -86,6 +87,7 @@ class TestDialect(Validator): write={ "bigquery": "CAST(a AS BINARY(4))", "clickhouse": "CAST(a AS BINARY(4))", + "drill": "CAST(a AS VARBINARY(4))", "duckdb": "CAST(a AS BINARY(4))", "mysql": "CAST(a AS BINARY(4))", "hive": "CAST(a AS BINARY(4))", @@ -146,6 +148,7 @@ class TestDialect(Validator): "CAST(a AS STRING)", write={ "bigquery": "CAST(a AS STRING)", + "drill": "CAST(a AS VARCHAR)", "duckdb": "CAST(a AS TEXT)", "mysql": "CAST(a AS TEXT)", "hive": "CAST(a AS STRING)", @@ -162,6 +165,7 @@ class TestDialect(Validator): "CAST(a AS VARCHAR)", write={ "bigquery": "CAST(a AS STRING)", + "drill": "CAST(a AS VARCHAR)", "duckdb": "CAST(a AS TEXT)", "mysql": "CAST(a AS VARCHAR)", "hive": "CAST(a AS STRING)", @@ -178,6 +182,7 @@ class TestDialect(Validator): "CAST(a AS VARCHAR(3))", write={ "bigquery": "CAST(a AS STRING(3))", + "drill": "CAST(a AS VARCHAR(3))", "duckdb": "CAST(a AS TEXT(3))", "mysql": "CAST(a AS VARCHAR(3))", "hive": "CAST(a AS VARCHAR(3))", @@ -194,6 +199,7 @@ class TestDialect(Validator): "CAST(a AS SMALLINT)", write={ "bigquery": "CAST(a AS INT64)", + "drill": "CAST(a AS INTEGER)", "duckdb": "CAST(a AS SMALLINT)", "mysql": "CAST(a AS SMALLINT)", "hive": "CAST(a AS SMALLINT)", @@ -215,6 +221,7 @@ class TestDialect(Validator): }, write={ "duckdb": "TRY_CAST(a AS DOUBLE)", + "drill": "CAST(a AS DOUBLE)", "postgres": "CAST(a AS DOUBLE PRECISION)", "redshift": "CAST(a AS DOUBLE PRECISION)", }, @@ -225,6 +232,7 @@ class TestDialect(Validator): write={ "bigquery": "CAST(a AS FLOAT64)", "clickhouse": "CAST(a AS Float64)", + "drill": "CAST(a AS DOUBLE)", "duckdb": "CAST(a AS DOUBLE)", "mysql": "CAST(a AS DOUBLE)", "hive": "CAST(a AS DOUBLE)", @@ -279,6 +287,7 @@ class TestDialect(Validator): "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)", "presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')", + "drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')", "redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')", "spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')", }, @@ -286,6 +295,7 @@ class TestDialect(Validator): self.validate_all( "STR_TO_TIME('2020-01-01', '%Y-%m-%d')", write={ + "drill": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')", "duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')", "hive": "CAST('2020-01-01' AS TIMESTAMP)", "oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')", @@ -298,6 +308,7 @@ class TestDialect(Validator): self.validate_all( "STR_TO_TIME(x, '%y')", write={ + "drill": "TO_TIMESTAMP(x, 'yy')", "duckdb": "STRPTIME(x, '%y')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)", "presto": "DATE_PARSE(x, '%y')", @@ -319,6 +330,7 @@ class TestDialect(Validator): self.validate_all( "TIME_STR_TO_DATE('2020-01-01')", write={ + "drill": "CAST('2020-01-01' AS DATE)", "duckdb": "CAST('2020-01-01' AS DATE)", "hive": "TO_DATE('2020-01-01')", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", @@ -328,6 +340,7 @@ class TestDialect(Validator): self.validate_all( "TIME_STR_TO_TIME('2020-01-01')", write={ + "drill": "CAST('2020-01-01' AS TIMESTAMP)", "duckdb": "CAST('2020-01-01' AS TIMESTAMP)", "hive": "CAST('2020-01-01' AS TIMESTAMP)", "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", @@ -344,6 +357,7 @@ class TestDialect(Validator): self.validate_all( "TIME_TO_STR(x, '%Y-%m-%d')", write={ + "drill": "TO_CHAR(x, 'yyyy-MM-dd')", "duckdb": "STRFTIME(x, '%Y-%m-%d')", "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", "oracle": "TO_CHAR(x, 'YYYY-MM-DD')", @@ -355,6 +369,7 @@ class TestDialect(Validator): self.validate_all( "TIME_TO_TIME_STR(x)", write={ + "drill": "CAST(x AS VARCHAR)", "duckdb": "CAST(x AS TEXT)", "hive": "CAST(x AS STRING)", "presto": "CAST(x AS VARCHAR)", @@ -364,6 +379,7 @@ class TestDialect(Validator): self.validate_all( "TIME_TO_UNIX(x)", write={ + "drill": "UNIX_TIMESTAMP(x)", "duckdb": "EPOCH(x)", "hive": "UNIX_TIMESTAMP(x)", "presto": "TO_UNIXTIME(x)", @@ -425,6 +441,7 @@ class TestDialect(Validator): self.validate_all( "DATE_TO_DATE_STR(x)", write={ + "drill": "CAST(x AS VARCHAR)", "duckdb": "CAST(x AS TEXT)", "hive": "CAST(x AS STRING)", "presto": "CAST(x AS VARCHAR)", @@ -433,6 +450,7 @@ class TestDialect(Validator): self.validate_all( "DATE_TO_DI(x)", write={ + "drill": "CAST(TO_DATE(x, 'yyyyMMdd') AS INT)", "duckdb": "CAST(STRFTIME(x, '%Y%m%d') AS INT)", "hive": "CAST(DATE_FORMAT(x, 'yyyyMMdd') AS INT)", "presto": "CAST(DATE_FORMAT(x, '%Y%m%d') AS INT)", @@ -441,6 +459,7 @@ class TestDialect(Validator): self.validate_all( "DI_TO_DATE(x)", write={ + "drill": "TO_DATE(CAST(x AS VARCHAR), 'yyyyMMdd')", "duckdb": "CAST(STRPTIME(CAST(x AS TEXT), '%Y%m%d') AS DATE)", "hive": "TO_DATE(CAST(x AS STRING), 'yyyyMMdd')", "presto": "CAST(DATE_PARSE(CAST(x AS VARCHAR), '%Y%m%d') AS DATE)", @@ -463,6 +482,7 @@ class TestDialect(Validator): }, write={ "bigquery": "DATE_ADD(x, INTERVAL 1 'day')", + "drill": "DATE_ADD(x, INTERVAL '1' DAY)", "duckdb": "x + INTERVAL 1 day", "hive": "DATE_ADD(x, 1)", "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", @@ -477,6 +497,7 @@ class TestDialect(Validator): "DATE_ADD(x, 1)", write={ "bigquery": "DATE_ADD(x, INTERVAL 1 'day')", + "drill": "DATE_ADD(x, INTERVAL '1' DAY)", "duckdb": "x + INTERVAL 1 DAY", "hive": "DATE_ADD(x, 1)", "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", @@ -546,6 +567,7 @@ class TestDialect(Validator): "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", }, write={ + "drill": "TO_DATE(x, 'yyyy-MM-dd''T''HH:mm:ss')", "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)", @@ -556,6 +578,7 @@ class TestDialect(Validator): self.validate_all( "STR_TO_DATE(x, '%Y-%m-%d')", write={ + "drill": "CAST(x AS DATE)", "mysql": "STR_TO_DATE(x, '%Y-%m-%d')", "starrocks": "STR_TO_DATE(x, '%Y-%m-%d')", "hive": "CAST(x AS DATE)", @@ -566,6 +589,7 @@ class TestDialect(Validator): self.validate_all( "DATE_STR_TO_DATE(x)", write={ + "drill": "CAST(x AS DATE)", "duckdb": "CAST(x AS DATE)", "hive": "TO_DATE(x)", "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", @@ -575,6 +599,7 @@ class TestDialect(Validator): self.validate_all( "TS_OR_DS_ADD('2021-02-01', 1, 'DAY')", write={ + "drill": "DATE_ADD(CAST('2021-02-01' AS DATE), INTERVAL '1' DAY)", "duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY", "hive": "DATE_ADD('2021-02-01', 1)", "presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))", @@ -584,6 +609,7 @@ class TestDialect(Validator): self.validate_all( "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", write={ + "drill": "DATE_ADD(CAST('2020-01-01' AS DATE), INTERVAL '1' DAY)", "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY", "hive": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", "presto": "DATE_ADD('day', 1, CAST('2020-01-01' AS DATE))", @@ -593,6 +619,7 @@ class TestDialect(Validator): self.validate_all( "TIMESTAMP '2022-01-01'", write={ + "drill": "CAST('2022-01-01' AS TIMESTAMP)", "mysql": "CAST('2022-01-01' AS TIMESTAMP)", "starrocks": "CAST('2022-01-01' AS DATETIME)", "hive": "CAST('2022-01-01' AS TIMESTAMP)", @@ -614,6 +641,7 @@ class TestDialect(Validator): dialect: f"{unit}(x)" for dialect in ( "bigquery", + "drill", "duckdb", "mysql", "presto", @@ -624,6 +652,7 @@ class TestDialect(Validator): dialect: f"{unit}(x)" for dialect in ( "bigquery", + "drill", "duckdb", "mysql", "presto", @@ -649,6 +678,7 @@ class TestDialect(Validator): write={ "bigquery": "ARRAY_LENGTH(x)", "duckdb": "ARRAY_LENGTH(x)", + "drill": "REPEATED_COUNT(x)", "presto": "CARDINALITY(x)", "spark": "SIZE(x)", }, @@ -736,6 +766,7 @@ class TestDialect(Validator): self.validate_all( "SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)", write={ + "drill": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)", "presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)", "spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", }, @@ -743,6 +774,7 @@ class TestDialect(Validator): self.validate_all( "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)", write={ + "drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)", "presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)", "spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b", }, @@ -775,6 +807,7 @@ class TestDialect(Validator): }, write={ "bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b", + "drill": "SELECT * FROM a UNION SELECT * FROM b", "duckdb": "SELECT * FROM a UNION SELECT * FROM b", "presto": "SELECT * FROM a UNION SELECT * FROM b", "spark": "SELECT * FROM a UNION SELECT * FROM b", @@ -887,6 +920,7 @@ class TestDialect(Validator): write={ "bigquery": "LOWER(x) LIKE '%y'", "clickhouse": "x ILIKE '%y'", + "drill": "x `ILIKE` '%y'", "duckdb": "x ILIKE '%y'", "hive": "LOWER(x) LIKE '%y'", "mysql": "LOWER(x) LIKE '%y'", @@ -910,32 +944,38 @@ class TestDialect(Validator): self.validate_all( "POSITION(' ' in x)", write={ + "drill": "STRPOS(x, ' ')", "duckdb": "STRPOS(x, ' ')", "postgres": "STRPOS(x, ' ')", "presto": "STRPOS(x, ' ')", "spark": "LOCATE(' ', x)", "clickhouse": "position(x, ' ')", "snowflake": "POSITION(' ', x)", + "mysql": "LOCATE(' ', x)", }, ) self.validate_all( "STR_POSITION('a', x)", write={ + "drill": "STRPOS(x, 'a')", "duckdb": "STRPOS(x, 'a')", "postgres": "STRPOS(x, 'a')", "presto": "STRPOS(x, 'a')", "spark": "LOCATE('a', x)", "clickhouse": "position(x, 'a')", "snowflake": "POSITION('a', x)", + "mysql": "LOCATE('a', x)", }, ) self.validate_all( "POSITION('a', x, 3)", write={ + "drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", "presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", "spark": "LOCATE('a', x, 3)", "clickhouse": "position(x, 'a', 3)", "snowflake": "POSITION('a', x, 3)", + "mysql": "LOCATE('a', x, 3)", }, ) self.validate_all( @@ -960,6 +1000,7 @@ class TestDialect(Validator): self.validate_all( "IF(x > 1, 1, 0)", write={ + "drill": "`IF`(x > 1, 1, 0)", "duckdb": "CASE WHEN x > 1 THEN 1 ELSE 0 END", "presto": "IF(x > 1, 1, 0)", "hive": "IF(x > 1, 1, 0)", @@ -970,6 +1011,7 @@ class TestDialect(Validator): self.validate_all( "CASE WHEN 1 THEN x ELSE 0 END", write={ + "drill": "CASE WHEN 1 THEN x ELSE 0 END", "duckdb": "CASE WHEN 1 THEN x ELSE 0 END", "presto": "CASE WHEN 1 THEN x ELSE 0 END", "hive": "CASE WHEN 1 THEN x ELSE 0 END", @@ -980,6 +1022,7 @@ class TestDialect(Validator): self.validate_all( "x[y]", write={ + "drill": "x[y]", "duckdb": "x[y]", "presto": "x[y]", "hive": "x[y]", @@ -1000,6 +1043,7 @@ class TestDialect(Validator): 'true or null as "foo"', write={ "bigquery": "TRUE OR NULL AS `foo`", + "drill": "TRUE OR NULL AS `foo`", "duckdb": 'TRUE OR NULL AS "foo"', "presto": 'TRUE OR NULL AS "foo"', "hive": "TRUE OR NULL AS `foo`", @@ -1020,6 +1064,7 @@ class TestDialect(Validator): "LEVENSHTEIN(col1, col2)", write={ "duckdb": "LEVENSHTEIN(col1, col2)", + "drill": "LEVENSHTEIN_DISTANCE(col1, col2)", "presto": "LEVENSHTEIN_DISTANCE(col1, col2)", "hive": "LEVENSHTEIN(col1, col2)", "spark": "LEVENSHTEIN(col1, col2)", @@ -1029,6 +1074,7 @@ class TestDialect(Validator): "LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))", write={ "duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", + "drill": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))", "presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))", "hive": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", "spark": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", @@ -1152,6 +1198,7 @@ class TestDialect(Validator): self.validate_all( "SELECT a AS b FROM x GROUP BY b", write={ + "drill": "SELECT a AS b FROM x GROUP BY b", "duckdb": "SELECT a AS b FROM x GROUP BY b", "presto": "SELECT a AS b FROM x GROUP BY 1", "hive": "SELECT a AS b FROM x GROUP BY 1", @@ -1162,6 +1209,7 @@ class TestDialect(Validator): self.validate_all( "SELECT y x FROM my_table t", write={ + "drill": "SELECT y AS x FROM my_table AS t", "hive": "SELECT y AS x FROM my_table AS t", "oracle": "SELECT y AS x FROM my_table t", "postgres": "SELECT y AS x FROM my_table AS t", @@ -1230,3 +1278,36 @@ SELECT }, pretty=True, ) + + def test_transactions(self): + self.validate_all( + "BEGIN TRANSACTION", + write={ + "bigquery": "BEGIN TRANSACTION", + "mysql": "BEGIN", + "postgres": "BEGIN", + "presto": "START TRANSACTION", + "trino": "START TRANSACTION", + "redshift": "BEGIN", + "snowflake": "BEGIN", + "sqlite": "BEGIN TRANSACTION", + }, + ) + self.validate_all( + "BEGIN", + read={ + "presto": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", + "trino": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", + }, + ) + self.validate_all( + "BEGIN", + read={ + "presto": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ", + "trino": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ", + }, + ) + self.validate_all( + "BEGIN IMMEDIATE TRANSACTION", + write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"}, + ) diff --git a/tests/dialects/test_drill.py b/tests/dialects/test_drill.py new file mode 100644 index 0000000..9819daa --- /dev/null +++ b/tests/dialects/test_drill.py @@ -0,0 +1,53 @@ +from tests.dialects.test_dialect import Validator + + +class TestDrill(Validator): + dialect = "drill" + + def test_string_literals(self): + self.validate_all( + "SELECT '2021-01-01' + INTERVAL 1 MONTH", + write={ + "mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH", + }, + ) + + def test_quotes(self): + self.validate_all( + "'\\''", + write={ + "duckdb": "''''", + "presto": "''''", + "hive": "'\\''", + "spark": "'\\''", + }, + ) + self.validate_all( + "'\"x\"'", + write={ + "duckdb": "'\"x\"'", + "presto": "'\"x\"'", + "hive": "'\"x\"'", + "spark": "'\"x\"'", + }, + ) + self.validate_all( + "'\\\\a'", + read={ + "presto": "'\\a'", + }, + write={ + "duckdb": "'\\a'", + "presto": "'\\a'", + "hive": "'\\\\a'", + "spark": "'\\\\a'", + }, + ) + + def test_table_function(self): + self.validate_all( + "SELECT * FROM table( dfs.`test_data.xlsx` (type => 'excel', sheetName => 'secondSheet'))", + write={ + "drill": "SELECT * FROM table(dfs.`test_data.xlsx`(type => 'excel', sheetName => 'secondSheet'))", + }, + ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 1ba118b..af98249 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -58,6 +58,16 @@ class TestMySQL(Validator): self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'") self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci") self.validate_identity("SET autocommit = ON") + self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL SERIALIZABLE") + self.validate_identity("SET TRANSACTION READ ONLY") + self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE") + self.validate_identity("SELECT SCHEMA()") + + def test_canonical_functions(self): + self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)") + self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')") + self.validate_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')") + self.validate_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')") def test_escape(self): self.validate_all( diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 098ad2b..8179cf7 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -177,6 +177,15 @@ class TestPresto(Validator): "spark": "CREATE TABLE test USING PARQUET AS SELECT 1", }, ) + self.validate_all( + "CREATE TABLE test STORED = 'PARQUET' AS SELECT 1", + write={ + "duckdb": "CREATE TABLE test AS SELECT 1", + "presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + "spark": "CREATE TABLE test USING PARQUET AS SELECT 1", + }, + ) self.validate_all( "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", write={ @@ -427,3 +436,69 @@ class TestPresto(Validator): "spark": UnsupportedError, }, ) + self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE") + self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ") + + def test_encode_decode(self): + self.validate_all( + "TO_UTF8(x)", + write={ + "spark": "ENCODE(x, 'utf-8')", + }, + ) + self.validate_all( + "FROM_UTF8(x)", + write={ + "spark": "DECODE(x, 'utf-8')", + }, + ) + self.validate_all( + "ENCODE(x, 'utf-8')", + write={ + "presto": "TO_UTF8(x)", + }, + ) + self.validate_all( + "DECODE(x, 'utf-8')", + write={ + "presto": "FROM_UTF8(x)", + }, + ) + self.validate_all( + "ENCODE(x, 'invalid')", + write={ + "presto": UnsupportedError, + }, + ) + self.validate_all( + "DECODE(x, 'invalid')", + write={ + "presto": UnsupportedError, + }, + ) + + def test_hex_unhex(self): + self.validate_all( + "TO_HEX(x)", + write={ + "spark": "HEX(x)", + }, + ) + self.validate_all( + "FROM_HEX(x)", + write={ + "spark": "UNHEX(x)", + }, + ) + self.validate_all( + "HEX(x)", + write={ + "presto": "TO_HEX(x)", + }, + ) + self.validate_all( + "UNHEX(x)", + write={ + "presto": "FROM_HEX(x)", + }, + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 1846b17..0e69f4e 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -169,6 +169,17 @@ class TestSnowflake(Validator): "snowflake": "SELECT a FROM test AS unpivot", }, ) + self.validate_all( + "trim(date_column, 'UTC')", + write={ + "snowflake": "TRIM(date_column, 'UTC')", + "postgres": "TRIM('UTC' FROM date_column)", + }, + ) + self.validate_all( + "trim(date_column)", + write={"snowflake": "TRIM(date_column)"}, + ) def test_null_treatment(self): self.validate_all( diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 836ab28..75bd25d 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -122,13 +122,6 @@ x AT TIME ZONE 'UTC' CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo' SET x = 1 SET -v -ADD JAR s3://bucket -ADD JARS s3://bucket, c -ADD FILE s3://file -ADD FILES s3://file, s3://a -ADD ARCHIVE s3://file -ADD ARCHIVES s3://file, s3://a -BEGIN IMMEDIATE TRANSACTION COMMIT USE db NOT 1 @@ -278,6 +271,7 @@ SELECT CEIL(a, b) FROM test SELECT COUNT(a) FROM test SELECT COUNT(1) FROM test SELECT COUNT(*) FROM test +SELECT COUNT() FROM test SELECT COUNT(DISTINCT a) FROM test SELECT EXP(a) FROM test SELECT FLOOR(a) FROM test @@ -372,6 +366,8 @@ WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2 WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2 WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2 WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2 +WITH sub_query AS (SELECT a FROM table) (SELECT a FROM sub_query) +WITH sub_query AS (SELECT a FROM table) ((((SELECT a FROM sub_query)))) (SELECT 1) UNION (SELECT 2) (SELECT 1) UNION SELECT 2 SELECT 1 UNION (SELECT 2) @@ -463,6 +459,7 @@ CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECI CREATE TABLE z (a INT(11) DEFAULT UUID()) CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id') CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1) +CREATE TABLE z (a INT(11) NOT NULL DEFAULT -1) CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT) CREATE TABLE z (a INT, PRIMARY KEY(a)) CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1 @@ -476,6 +473,9 @@ CREATE TABLE z AS ((WITH cte AS (SELECT 1) SELECT * FROM cte)) CREATE TABLE z (a INT UNIQUE) CREATE TABLE z (a INT AUTO_INCREMENT) CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT) +CREATE TABLE z (a INT REFERENCES parent(b, c)) +CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id)) +CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c)) CREATE TEMPORARY FUNCTION f CREATE TEMPORARY FUNCTION f AS 'g' CREATE FUNCTION f @@ -514,17 +514,23 @@ DELETE FROM x WHERE y > 1 DELETE FROM y DELETE FROM event USING sales WHERE event.eventid = sales.eventid DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid +DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid +PREPARE statement +EXECUTE statement DROP TABLE a DROP TABLE a.b DROP TABLE IF EXISTS a DROP TABLE IF EXISTS a.b +DROP TABLE a CASCADE DROP VIEW a DROP VIEW a.b DROP VIEW IF EXISTS a DROP VIEW IF EXISTS a.b SHOW TABLES USE db +BEGIN ROLLBACK +ROLLBACK TO b EXPLAIN SELECT * FROM x INSERT INTO x SELECT * FROM y INSERT INTO x (SELECT * FROM y) @@ -581,3 +587,4 @@ SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */ SELECT x FROM a.b.c /* x */, e.f.g /* x */ SELECT FOO(x /* c */) /* FOO */, b /* b */ SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */ +SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b' diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql new file mode 100644 index 0000000..7fcdbb8 --- /dev/null +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -0,0 +1,5 @@ +SELECT w.d + w.e AS c FROM w AS w; +SELECT CONCAT(w.d, w.e) AS c FROM w AS w; + +SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w; +SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index eb7e9cb..a1e531b 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -119,7 +119,7 @@ GROUP BY LIMIT 1; # title: Root subquery is union -(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1; +(SELECT b FROM x UNION SELECT b FROM y ORDER BY b) LIMIT 1; ( SELECT "x"."b" AS "b" @@ -128,6 +128,8 @@ LIMIT 1; SELECT "y"."b" AS "b" FROM "y" AS "y" + ORDER BY + "b" ) LIMIT 1; diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index b91205c..8138b11 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -15,7 +15,7 @@ select from lineitem where - CAST(l_shipdate AS DATE) <= date '1998-12-01' - interval '90' day + l_shipdate <= date '1998-12-01' - interval '90' day group by l_returnflag, l_linestatus @@ -250,8 +250,8 @@ FROM "orders" AS "orders" LEFT JOIN "_u_0" AS "_u_0" ON "_u_0"."l_orderkey" = "orders"."o_orderkey" WHERE - "orders"."o_orderdate" < CAST('1993-10-01' AS DATE) - AND "orders"."o_orderdate" >= CAST('1993-07-01' AS DATE) + CAST("orders"."o_orderdate" AS DATE) < CAST('1993-10-01' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-07-01' AS DATE) AND NOT "_u_0"."l_orderkey" IS NULL GROUP BY "orders"."o_orderpriority" @@ -293,8 +293,8 @@ SELECT FROM "customer" AS "customer" JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" - AND "orders"."o_orderdate" < CAST('1995-01-01' AS DATE) - AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) < CAST('1995-01-01' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1994-01-01' AS DATE) JOIN "region" AS "region" ON "region"."r_name" = 'ASIA' JOIN "nation" AS "nation" @@ -328,8 +328,8 @@ FROM "lineitem" AS "lineitem" WHERE "lineitem"."l_discount" BETWEEN 0.05 AND 0.07 AND "lineitem"."l_quantity" < 24 - AND "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE) - AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE); + AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE) + AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE); -------------------------------------- -- TPC-H 7 @@ -384,13 +384,13 @@ WITH "n1" AS ( SELECT "n1"."n_name" AS "supp_nation", "n2"."n_name" AS "cust_nation", - EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year", + EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year", SUM("lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" )) AS "revenue" FROM "supplier" AS "supplier" JOIN "lineitem" AS "lineitem" - ON "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) AND "supplier"."s_suppkey" = "lineitem"."l_suppkey" JOIN "orders" AS "orders" ON "orders"."o_orderkey" = "lineitem"."l_orderkey" @@ -409,7 +409,7 @@ JOIN "n1" AS "n2" GROUP BY "n1"."n_name", "n2"."n_name", - EXTRACT(year FROM "lineitem"."l_shipdate") + EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) ORDER BY "supp_nation", "cust_nation", @@ -456,7 +456,7 @@ group by order by o_year; SELECT - EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year", SUM( CASE WHEN "nation_2"."n_name" = 'BRAZIL' @@ -477,7 +477,7 @@ JOIN "customer" AS "customer" ON "customer"."c_nationkey" = "nation"."n_nationkey" JOIN "orders" AS "orders" ON "orders"."o_custkey" = "customer"."c_custkey" - AND "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) JOIN "lineitem" AS "lineitem" ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "part"."p_partkey" = "lineitem"."l_partkey" @@ -488,7 +488,7 @@ JOIN "nation" AS "nation_2" WHERE "part"."p_type" = 'ECONOMY ANODIZED STEEL' GROUP BY - EXTRACT(year FROM "orders"."o_orderdate") + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) ORDER BY "o_year"; @@ -529,7 +529,7 @@ order by o_year desc; SELECT "nation"."n_name" AS "nation", - EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year", SUM( "lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" @@ -551,7 +551,7 @@ WHERE "part"."p_name" LIKE '%green%' GROUP BY "nation"."n_name", - EXTRACT(year FROM "orders"."o_orderdate") + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) ORDER BY "nation", "o_year" DESC; @@ -606,8 +606,8 @@ SELECT FROM "customer" AS "customer" JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" - AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE) - AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) < CAST('1994-01-01' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-10-01' AS DATE) JOIN "lineitem" AS "lineitem" ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "lineitem"."l_returnflag" = 'R' JOIN "nation" AS "nation" @@ -740,8 +740,8 @@ SELECT FROM "orders" AS "orders" JOIN "lineitem" AS "lineitem" ON "lineitem"."l_commitdate" < "lineitem"."l_receiptdate" - AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE) - AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE) + AND CAST("lineitem"."l_receiptdate" AS DATE) < CAST('1995-01-01' AS DATE) + AND CAST("lineitem"."l_receiptdate" AS DATE) >= CAST('1994-01-01' AS DATE) AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate" AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP') AND "orders"."o_orderkey" = "lineitem"."l_orderkey" @@ -832,8 +832,8 @@ FROM "lineitem" AS "lineitem" JOIN "part" AS "part" ON "lineitem"."l_partkey" = "part"."p_partkey" WHERE - "lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE) - AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE); + CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-10-01' AS DATE) + AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-09-01' AS DATE); -------------------------------------- -- TPC-H 15 @@ -876,8 +876,8 @@ WITH "revenue" AS ( )) AS "total_revenue" FROM "lineitem" AS "lineitem" WHERE - "lineitem"."l_shipdate" < CAST('1996-04-01' AS DATE) - AND "lineitem"."l_shipdate" >= CAST('1996-01-01' AS DATE) + CAST("lineitem"."l_shipdate" AS DATE) < CAST('1996-04-01' AS DATE) + AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE) GROUP BY "lineitem"."l_suppkey" ) @@ -1220,8 +1220,8 @@ WITH "_u_0" AS ( "lineitem"."l_suppkey" AS "_u_2" FROM "lineitem" AS "lineitem" WHERE - "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE) - AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE) + CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE) + AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE) GROUP BY "lineitem"."l_partkey", "lineitem"."l_suppkey" diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql index 5e27b5e..067fe77 100644 --- a/tests/fixtures/pretty.sql +++ b/tests/fixtures/pretty.sql @@ -315,3 +315,10 @@ FROM ( WHERE id = 1 ) /* x */; +SELECT * /* multi + line + comment */; +SELECT + * /* multi + line + comment */; diff --git a/tests/helpers.py b/tests/helpers.py index dabaf1c..9abdaae 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -57,79 +57,79 @@ SKIP_INTEGRATION = string_to_bool(os.environ.get("SKIP_INTEGRATION", "0").lower( TPCH_SCHEMA = { "lineitem": { - "l_orderkey": "uint64", - "l_partkey": "uint64", - "l_suppkey": "uint64", - "l_linenumber": "uint64", - "l_quantity": "float64", - "l_extendedprice": "float64", - "l_discount": "float64", - "l_tax": "float64", + "l_orderkey": "bigint", + "l_partkey": "bigint", + "l_suppkey": "bigint", + "l_linenumber": "bigint", + "l_quantity": "double", + "l_extendedprice": "double", + "l_discount": "double", + "l_tax": "double", "l_returnflag": "string", "l_linestatus": "string", - "l_shipdate": "date32", - "l_commitdate": "date32", - "l_receiptdate": "date32", + "l_shipdate": "string", + "l_commitdate": "string", + "l_receiptdate": "string", "l_shipinstruct": "string", "l_shipmode": "string", "l_comment": "string", }, "orders": { - "o_orderkey": "uint64", - "o_custkey": "uint64", + "o_orderkey": "bigint", + "o_custkey": "bigint", "o_orderstatus": "string", - "o_totalprice": "float64", - "o_orderdate": "date32", + "o_totalprice": "double", + "o_orderdate": "string", "o_orderpriority": "string", "o_clerk": "string", - "o_shippriority": "int32", + "o_shippriority": "int", "o_comment": "string", }, "customer": { - "c_custkey": "uint64", + "c_custkey": "bigint", "c_name": "string", "c_address": "string", - "c_nationkey": "uint64", + "c_nationkey": "bigint", "c_phone": "string", - "c_acctbal": "float64", + "c_acctbal": "double", "c_mktsegment": "string", "c_comment": "string", }, "part": { - "p_partkey": "uint64", + "p_partkey": "bigint", "p_name": "string", "p_mfgr": "string", "p_brand": "string", "p_type": "string", - "p_size": "int32", + "p_size": "int", "p_container": "string", - "p_retailprice": "float64", + "p_retailprice": "double", "p_comment": "string", }, "supplier": { - "s_suppkey": "uint64", + "s_suppkey": "bigint", "s_name": "string", "s_address": "string", - "s_nationkey": "uint64", + "s_nationkey": "bigint", "s_phone": "string", - "s_acctbal": "float64", + "s_acctbal": "double", "s_comment": "string", }, "partsupp": { - "ps_partkey": "uint64", - "ps_suppkey": "uint64", - "ps_availqty": "int32", - "ps_supplycost": "float64", + "ps_partkey": "bigint", + "ps_suppkey": "bigint", + "ps_availqty": "int", + "ps_supplycost": "double", "ps_comment": "string", }, "nation": { - "n_nationkey": "uint64", + "n_nationkey": "bigint", "n_name": "string", - "n_regionkey": "uint64", + "n_regionkey": "bigint", "n_comment": "string", }, "region": { - "r_regionkey": "uint64", + "r_regionkey": "bigint", "r_name": "string", "r_comment": "string", }, diff --git a/tests/test_executor.py b/tests/test_executor.py index 49805b9..2c4d7cd 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -1,12 +1,15 @@ import unittest +from datetime import date import duckdb import pandas as pd from pandas.testing import assert_frame_equal from sqlglot import exp, parse_one +from sqlglot.errors import ExecuteError from sqlglot.executor import execute from sqlglot.executor.python import Python +from sqlglot.executor.table import Table, ensure_tables from tests.helpers import ( FIXTURES_DIR, SKIP_INTEGRATION, @@ -67,13 +70,399 @@ class TestExecutor(unittest.TestCase): def to_csv(expression): if isinstance(expression, exp.Table): return parse_one( - f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}" + f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}" ) return expression - for sql, _ in self.sqls[0:3]: - a = self.cached_execute(sql) - sql = parse_one(sql).transform(to_csv).sql(pretty=True) - table = execute(sql, TPCH_SCHEMA) - b = pd.DataFrame(table.rows, columns=table.columns) - assert_frame_equal(a, b, check_dtype=False) + for i, (sql, _) in enumerate(self.sqls[0:7]): + with self.subTest(f"tpch-h {i + 1}"): + a = self.cached_execute(sql) + sql = parse_one(sql).transform(to_csv).sql(pretty=True) + table = execute(sql, TPCH_SCHEMA) + b = pd.DataFrame(table.rows, columns=table.columns) + assert_frame_equal(a, b, check_dtype=False) + + def test_execute_callable(self): + tables = { + "x": [ + {"a": "a", "b": "d"}, + {"a": "b", "b": "e"}, + {"a": "c", "b": "f"}, + ], + "y": [ + {"b": "d", "c": "g"}, + {"b": "e", "c": "h"}, + {"b": "f", "c": "i"}, + ], + "z": [], + } + schema = { + "x": { + "a": "VARCHAR", + "b": "VARCHAR", + }, + "y": { + "b": "VARCHAR", + "c": "VARCHAR", + }, + "z": {"d": "VARCHAR"}, + } + + for sql, cols, rows in [ + ("SELECT * FROM x", ["a", "b"], [("a", "d"), ("b", "e"), ("c", "f")]), + ( + "SELECT * FROM x JOIN y ON x.b = y.b", + ["a", "b", "b", "c"], + [("a", "d", "d", "g"), ("b", "e", "e", "h"), ("c", "f", "f", "i")], + ), + ( + "SELECT j.c AS d FROM x AS i JOIN y AS j ON i.b = j.b", + ["d"], + [("g",), ("h",), ("i",)], + ), + ( + "SELECT CONCAT(x.a, y.c) FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'", + ["_col_0"], + [("bh",)], + ), + ( + "SELECT * FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'", + ["a", "b", "b", "c"], + [("b", "e", "e", "h")], + ), + ( + "SELECT * FROM z", + ["d"], + [], + ), + ( + "SELECT d FROM z ORDER BY d", + ["d"], + [], + ), + ( + "SELECT a FROM x WHERE x.a <> 'b'", + ["a"], + [("a",), ("c",)], + ), + ( + "SELECT a AS i FROM x ORDER BY a", + ["i"], + [("a",), ("b",), ("c",)], + ), + ( + "SELECT a AS i FROM x ORDER BY i", + ["i"], + [("a",), ("b",), ("c",)], + ), + ( + "SELECT 100 - ORD(a) AS a, a AS i FROM x ORDER BY a", + ["a", "i"], + [(1, "c"), (2, "b"), (3, "a")], + ), + ( + "SELECT a /* test */ FROM x LIMIT 1", + ["a"], + [("a",)], + ), + ]: + with self.subTest(sql): + result = execute(sql, schema=schema, tables=tables) + self.assertEqual(result.columns, tuple(cols)) + self.assertEqual(result.rows, rows) + + def test_set_operations(self): + tables = { + "x": [ + {"a": "a"}, + {"a": "b"}, + {"a": "c"}, + ], + "y": [ + {"a": "b"}, + {"a": "c"}, + {"a": "d"}, + ], + } + schema = { + "x": { + "a": "VARCHAR", + }, + "y": { + "a": "VARCHAR", + }, + } + + for sql, cols, rows in [ + ( + "SELECT a FROM x UNION ALL SELECT a FROM y", + ["a"], + [("a",), ("b",), ("c",), ("b",), ("c",), ("d",)], + ), + ( + "SELECT a FROM x UNION SELECT a FROM y", + ["a"], + [("a",), ("b",), ("c",), ("d",)], + ), + ( + "SELECT a FROM x EXCEPT SELECT a FROM y", + ["a"], + [("a",)], + ), + ( + "SELECT a FROM x INTERSECT SELECT a FROM y", + ["a"], + [("b",), ("c",)], + ), + ( + """SELECT i.a + FROM ( + SELECT a FROM x UNION SELECT a FROM y + ) AS i + JOIN ( + SELECT a FROM x UNION SELECT a FROM y + ) AS j + ON i.a = j.a""", + ["a"], + [("a",), ("b",), ("c",), ("d",)], + ), + ( + "SELECT 1 AS a UNION SELECT 2 AS a UNION SELECT 3 AS a", + ["a"], + [(1,), (2,), (3,)], + ), + ]: + with self.subTest(sql): + result = execute(sql, schema=schema, tables=tables) + self.assertEqual(result.columns, tuple(cols)) + self.assertEqual(set(result.rows), set(rows)) + + def test_execute_catalog_db_table(self): + tables = { + "catalog": { + "db": { + "x": [ + {"a": "a"}, + {"a": "b"}, + {"a": "c"}, + ], + } + } + } + schema = { + "catalog": { + "db": { + "x": { + "a": "VARCHAR", + } + } + } + } + result1 = execute("SELECT * FROM x", schema=schema, tables=tables) + result2 = execute("SELECT * FROM catalog.db.x", schema=schema, tables=tables) + assert result1.columns == result2.columns + assert result1.rows == result2.rows + + def test_execute_tables(self): + tables = { + "sushi": [ + {"id": 1, "price": 1.0}, + {"id": 2, "price": 2.0}, + {"id": 3, "price": 3.0}, + ], + "order_items": [ + {"sushi_id": 1, "order_id": 1}, + {"sushi_id": 1, "order_id": 1}, + {"sushi_id": 2, "order_id": 1}, + {"sushi_id": 3, "order_id": 2}, + ], + "orders": [ + {"id": 1, "user_id": 1}, + {"id": 2, "user_id": 2}, + ], + } + + self.assertEqual( + execute( + """ + SELECT + o.user_id, + SUM(s.price) AS price + FROM orders o + JOIN order_items i + ON o.id = i.order_id + JOIN sushi s + ON i.sushi_id = s.id + GROUP BY o.user_id + """, + tables=tables, + ).rows, + [ + (1, 4.0), + (2, 3.0), + ], + ) + + self.assertEqual( + execute( + """ + SELECT + o.id, x.* + FROM orders o + LEFT JOIN ( + SELECT + 1 AS id, 'b' AS x + UNION ALL + SELECT + 3 AS id, 'c' AS x + ) x + ON o.id = x.id + """, + tables=tables, + ).rows, + [(1, 1, "b"), (2, None, None)], + ) + self.assertEqual( + execute( + """ + SELECT + o.id, x.* + FROM orders o + RIGHT JOIN ( + SELECT + 1 AS id, + 'b' AS x + UNION ALL + SELECT + 3 AS id, 'c' AS x + ) x + ON o.id = x.id + """, + tables=tables, + ).rows, + [ + (1, 1, "b"), + (None, 3, "c"), + ], + ) + + def test_table_depth_mismatch(self): + tables = {"table": []} + schema = {"db": {"table": {"col": "VARCHAR"}}} + with self.assertRaises(ExecuteError): + execute("SELECT * FROM table", schema=schema, tables=tables) + + def test_tables(self): + tables = ensure_tables( + { + "catalog1": { + "db1": { + "t1": [ + {"a": 1}, + ], + "t2": [ + {"a": 1}, + ], + }, + "db2": { + "t3": [ + {"a": 1}, + ], + "t4": [ + {"a": 1}, + ], + }, + }, + "catalog2": { + "db3": { + "t5": Table(columns=("a",), rows=[(1,)]), + "t6": Table(columns=("a",), rows=[(1,)]), + }, + "db4": { + "t7": Table(columns=("a",), rows=[(1,)]), + "t8": Table(columns=("a",), rows=[(1,)]), + }, + }, + } + ) + + t1 = tables.find(exp.table_(table="t1", db="db1", catalog="catalog1")) + self.assertEqual(t1.columns, ("a",)) + self.assertEqual(t1.rows, [(1,)]) + + t8 = tables.find(exp.table_(table="t8")) + self.assertEqual(t1.columns, t8.columns) + self.assertEqual(t1.rows, t8.rows) + + def test_static_queries(self): + for sql, cols, rows in [ + ("SELECT 1", ["_col_0"], [(1,)]), + ("SELECT 1 + 2 AS x", ["x"], [(3,)]), + ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]), + ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]), + ("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]), + ]: + result = execute(sql) + self.assertEqual(result.columns, tuple(cols)) + self.assertEqual(result.rows, rows) + + def test_aggregate_without_group_by(self): + result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]}) + self.assertEqual(result.columns, ("_col_0",)) + self.assertEqual(result.rows, [(3,)]) + + def test_scalar_functions(self): + for sql, expected in [ + ("CONCAT('a', 'b')", "ab"), + ("CONCAT('a', NULL)", None), + ("CONCAT_WS('_', 'a', 'b')", "a_b"), + ("STR_POSITION('bar', 'foobarbar')", 4), + ("STR_POSITION('bar', 'foobarbar', 5)", 7), + ("STR_POSITION(NULL, 'foobarbar')", None), + ("STR_POSITION('bar', NULL)", None), + ("UPPER('foo')", "FOO"), + ("UPPER(NULL)", None), + ("LOWER('FOO')", "foo"), + ("LOWER(NULL)", None), + ("IFNULL('a', 'b')", "a"), + ("IFNULL(NULL, 'b')", "b"), + ("IFNULL(NULL, NULL)", None), + ("SUBSTRING('12345')", "12345"), + ("SUBSTRING('12345', 3)", "345"), + ("SUBSTRING('12345', 3, 0)", ""), + ("SUBSTRING('12345', 3, 1)", "3"), + ("SUBSTRING('12345', 3, 2)", "34"), + ("SUBSTRING('12345', 3, 3)", "345"), + ("SUBSTRING('12345', 3, 4)", "345"), + ("SUBSTRING('12345', -3)", "345"), + ("SUBSTRING('12345', -3, 0)", ""), + ("SUBSTRING('12345', -3, 1)", "3"), + ("SUBSTRING('12345', -3, 2)", "34"), + ("SUBSTRING('12345', 0)", ""), + ("SUBSTRING('12345', 0, 1)", ""), + ("SUBSTRING(NULL)", None), + ("SUBSTRING(NULL, 1)", None), + ("CAST(1 AS TEXT)", "1"), + ("CAST('1' AS LONG)", 1), + ("CAST('1.1' AS FLOAT)", 1.1), + ("COALESCE(NULL)", None), + ("COALESCE(NULL, NULL)", None), + ("COALESCE(NULL, 'b')", "b"), + ("COALESCE('a', 'b')", "a"), + ("1 << 1", 2), + ("1 >> 1", 0), + ("1 & 1", 1), + ("1 | 1", 1), + ("1 < 1", False), + ("1 <= 1", True), + ("1 > 1", False), + ("1 >= 1", True), + ("1 + NULL", None), + ("IF(true, 1, 0)", 1), + ("IF(false, 1, 0)", 0), + ("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"), + ("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)), + ]: + with self.subTest(sql): + result = execute(f"SELECT {sql}") + self.assertEqual(result.rows, [(expected,)]) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 63371d8..c0927ad 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -441,6 +441,9 @@ class TestExpressions(unittest.TestCase): self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance) self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop) self.assertIsInstance(parse_one("YEAR(a)"), exp.Year) + self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction) + self.assertIsInstance(parse_one("COMMIT"), exp.Commit) + self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback) def test_column(self): dot = parse_one("a.b.c") @@ -479,9 +482,9 @@ class TestExpressions(unittest.TestCase): self.assertEqual(column.text("expression"), "c") self.assertEqual(column.text("y"), "") self.assertEqual(parse_one("select * from x.y").find(exp.Table).text("db"), "x") - self.assertEqual(parse_one("select *").text("this"), "") - self.assertEqual(parse_one("1 + 1").text("this"), "1") - self.assertEqual(parse_one("'a'").text("this"), "a") + self.assertEqual(parse_one("select *").name, "") + self.assertEqual(parse_one("1 + 1").name, "1") + self.assertEqual(parse_one("'a'").name, "a") def test_alias(self): self.assertEqual(alias("foo", "bar").sql(), "foo AS bar") @@ -538,8 +541,8 @@ class TestExpressions(unittest.TestCase): this=exp.Literal.string("TABLE_FORMAT"), value=exp.to_identifier("test_format"), ), - exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL), - exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE), + exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.null()), + exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.true()), ] ), ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index a1b7e70..6637a1d 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -29,6 +29,7 @@ class TestOptimizer(unittest.TestCase): CREATE TABLE x (a INT, b INT); CREATE TABLE y (b INT, c INT); CREATE TABLE z (b INT, c INT); + CREATE TABLE w (d TEXT, e TEXT); INSERT INTO x VALUES (1, 1); INSERT INTO x VALUES (2, 2); @@ -47,6 +48,8 @@ class TestOptimizer(unittest.TestCase): INSERT INTO y VALUES (4, 4); INSERT INTO y VALUES (5, 5); INSERT INTO y VALUES (null, null); + + INSERT INTO w VALUES ('a', 'b'); """ ) @@ -64,6 +67,10 @@ class TestOptimizer(unittest.TestCase): "b": "INT", "c": "INT", }, + "w": { + "d": "TEXT", + "e": "TEXT", + }, } def check_file(self, file, func, pretty=False, execute=False, **kwargs): @@ -224,6 +231,18 @@ class TestOptimizer(unittest.TestCase): def test_eliminate_subqueries(self): self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries) + def test_canonicalize(self): + optimize = partial( + optimizer.optimize, + rules=[ + optimizer.qualify_tables.qualify_tables, + optimizer.qualify_columns.qualify_columns, + annotate_types, + optimizer.canonicalize.canonicalize, + ], + ) + self.check_file("canonicalize", optimize, schema=self.schema) + def test_tpch(self): self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True) diff --git a/tests/test_parser.py b/tests/test_parser.py index 04c20b1..c747ea3 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -41,12 +41,41 @@ class TestParser(unittest.TestCase): ) def test_command(self): - expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1") + expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive") self.assertEqual(len(expressions), 3) self.assertEqual(expressions[0].sql(), "SET x = 1") self.assertEqual(expressions[1].sql(), "ADD JAR s3://a") self.assertEqual(expressions[2].sql(), "SELECT 1") + def test_transactions(self): + expression = parse_one("BEGIN TRANSACTION") + self.assertIsNone(expression.this) + self.assertEqual(expression.args["modes"], []) + self.assertEqual(expression.sql(), "BEGIN") + + expression = parse_one("START TRANSACTION", read="mysql") + self.assertIsNone(expression.this) + self.assertEqual(expression.args["modes"], []) + self.assertEqual(expression.sql(), "BEGIN") + + expression = parse_one("BEGIN DEFERRED TRANSACTION") + self.assertEqual(expression.this, "DEFERRED") + self.assertEqual(expression.args["modes"], []) + self.assertEqual(expression.sql(), "BEGIN") + + expression = parse_one( + "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", read="presto" + ) + self.assertIsNone(expression.this) + self.assertEqual(expression.args["modes"][0], "READ WRITE") + self.assertEqual(expression.args["modes"][1], "ISOLATION LEVEL SERIALIZABLE") + self.assertEqual(expression.sql(), "BEGIN") + + expression = parse_one("BEGIN", read="bigquery") + self.assertNotIsInstance(expression, exp.Transaction) + self.assertIsNone(expression.expression) + self.assertEqual(expression.sql(), "BEGIN") + def test_identify(self): expression = parse_one( """ @@ -55,14 +84,14 @@ class TestParser(unittest.TestCase): """ ) - assert expression.expressions[0].text("this") == "a" - assert expression.expressions[1].text("this") == "b" - assert expression.expressions[2].text("alias") == "c" - assert expression.expressions[3].text("alias") == "D" - assert expression.expressions[4].text("alias") == "y|z'" + assert expression.expressions[0].name == "a" + assert expression.expressions[1].name == "b" + assert expression.expressions[2].alias == "c" + assert expression.expressions[3].alias == "D" + assert expression.expressions[4].alias == "y|z'" table = expression.args["from"].expressions[0] - assert table.args["this"].args["this"] == "z" - assert table.args["db"].args["this"] == "y" + assert table.this.name == "z" + assert table.args["db"].name == "y" def test_multi(self): expressions = parse( @@ -72,8 +101,8 @@ class TestParser(unittest.TestCase): ) assert len(expressions) == 2 - assert expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a" - assert expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b" + assert expressions[0].args["from"].expressions[0].this.name == "a" + assert expressions[1].args["from"].expressions[0].this.name == "b" def test_expression(self): ignore = Parser(error_level=ErrorLevel.IGNORE) @@ -200,7 +229,7 @@ class TestParser(unittest.TestCase): @patch("sqlglot.parser.logger") def test_comment_error_n(self, logger): parse_one( - """CREATE TABLE x + """SUM ( -- test )""", @@ -208,19 +237,19 @@ class TestParser(unittest.TestCase): ) assert_logger_contains( - "Required keyword: 'expressions' missing for . Line 4, Col: 1.", + "Required keyword: 'this' missing for . Line 4, Col: 1.", logger, ) @patch("sqlglot.parser.logger") def test_comment_error_r(self, logger): parse_one( - """CREATE TABLE x (-- test\r)""", + """SUM(-- test\r)""", error_level=ErrorLevel.WARN, ) assert_logger_contains( - "Required keyword: 'expressions' missing for . Line 2, Col: 1.", + "Required keyword: 'this' missing for . Line 2, Col: 1.", logger, ) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 943c2b0..d4772ba 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -12,6 +12,7 @@ class TestTokens(unittest.TestCase): ("--comment\nfoo --test", "comment"), ("foo --comment", "comment"), ("foo", None), + ("foo /*comment 1*/ /*comment 2*/", "comment 1"), ] for sql, comment in sql_comment: diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 942053e..1bd2527 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -20,6 +20,13 @@ class TestTranspile(unittest.TestCase): self.assertEqual(transpile(sql, **kwargs)[0], target) def test_alias(self): + self.assertEqual(transpile("SELECT 1 current_time")[0], "SELECT 1 AS current_time") + self.assertEqual( + transpile("SELECT 1 current_timestamp")[0], "SELECT 1 AS current_timestamp" + ) + self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date") + self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime") + for key in ("union", "filter", "over", "from", "join"): with self.subTest(f"alias {key}"): self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}") @@ -69,6 +76,10 @@ class TestTranspile(unittest.TestCase): self.validate("SELECT 3>=3", "SELECT 3 >= 3") def test_comments(self): + self.validate("SELECT */*comment*/", "SELECT * /* comment */") + self.validate( + "SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */" + ) self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */") self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo") self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo") -- cgit v1.2.3