diff options
Diffstat (limited to '')
-rw-r--r-- | tests/dataframe/unit/dataframe_sql_validator.py | 5 | ||||
-rw-r--r-- | tests/dataframe/unit/test_dataframe_writer.py | 34 | ||||
-rw-r--r-- | tests/dataframe/unit/test_session.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_bigquery.py | 19 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 36 | ||||
-rw-r--r-- | tests/dialects/test_hive.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_postgres.py | 4 | ||||
-rw-r--r-- | tests/dialects/test_redshift.py | 26 | ||||
-rw-r--r-- | tests/dialects/test_snowflake.py | 9 | ||||
-rw-r--r-- | tests/fixtures/identity.sql | 1 | ||||
-rw-r--r-- | tests/fixtures/optimizer/canonicalize.sql | 6 | ||||
-rw-r--r-- | tests/fixtures/optimizer/simplify.sql | 180 | ||||
-rw-r--r-- | tests/fixtures/optimizer/tpc-h/tpc-h.sql | 51 | ||||
-rw-r--r-- | tests/test_executor.py | 21 | ||||
-rw-r--r-- | tests/test_optimizer.py | 155 | ||||
-rw-r--r-- | tests/test_schema.py | 18 | ||||
-rw-r--r-- | tests/test_tokens.py | 47 |
17 files changed, 497 insertions, 123 deletions
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py index 32ff8f2..2dcdb39 100644 --- a/tests/dataframe/unit/dataframe_sql_validator.py +++ b/tests/dataframe/unit/dataframe_sql_validator.py @@ -4,6 +4,7 @@ import unittest from sqlglot.dataframe.sql import types from sqlglot.dataframe.sql.dataframe import DataFrame from sqlglot.dataframe.sql.session import SparkSession +from sqlglot.helper import ensure_list class DataFrameSQLValidator(unittest.TestCase): @@ -33,9 +34,7 @@ class DataFrameSQLValidator(unittest.TestCase): self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False ): actual_sqls = df.sql(pretty=pretty) - expected_statements = ( - [expected_statements] if isinstance(expected_statements, str) else expected_statements - ) + expected_statements = ensure_list(expected_statements) self.assertEqual(len(expected_statements), len(actual_sqls)) for expected, actual in zip(expected_statements, actual_sqls): self.assertEqual(expected, actual) diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py index 7c646f5..042b915 100644 --- a/tests/dataframe/unit/test_dataframe_writer.py +++ b/tests/dataframe/unit/test_dataframe_writer.py @@ -10,37 +10,37 @@ class TestDataFrameWriter(DataFrameSQLValidator): def test_insertInto_full_path(self): df = self.df_employee.write.insertInto("catalog.db.table_name") - expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO catalog.db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_db_table(self): df = self.df_employee.write.insertInto("db.table_name") - expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_table(self): df = self.df_employee.write.insertInto("table_name") - expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_overwrite(self): df = self.df_employee.write.insertInto("table_name", overwrite=True) - expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT OVERWRITE TABLE table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) @mock.patch("sqlglot.schema", MappingSchema()) def test_insertInto_byName(self): sqlglot.schema.add_table("table_name", {"employee_id": "INT"}) df = self.df_employee.write.byName.insertInto("table_name") - expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_insertInto_cache(self): df = self.df_employee.cache().write.insertInto("table_name") expected_statements = [ - "DROP VIEW IF EXISTS t37164", - "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`", + "DROP VIEW IF EXISTS t12441", + "CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "INSERT INTO table_name SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`", ] self.compare_sql(df, expected_statements) @@ -50,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator): def test_saveAsTable_append(self): df = self.df_employee.write.saveAsTable("table_name", mode="append") - expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_overwrite(self): df = self.df_employee.write.saveAsTable("table_name", mode="overwrite") - expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_error(self): df = self.df_employee.write.saveAsTable("table_name", mode="error") - expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_ignore(self): df = self.df_employee.write.saveAsTable("table_name", mode="ignore") - expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_mode_standalone(self): df = self.df_employee.write.mode("ignore").saveAsTable("table_name") - expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_mode_override(self): df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite") - expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" + expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)" self.compare_sql(df, expected) def test_saveAsTable_cache(self): df = self.df_employee.cache().write.saveAsTable("table_name") expected_statements = [ - "DROP VIEW IF EXISTS t37164", - "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", - "CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`", + "DROP VIEW IF EXISTS t12441", + "CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)", + "CREATE TABLE table_name AS SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`", ] self.compare_sql(df, expected_statements) diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py index 55aa547..5213667 100644 --- a/tests/dataframe/unit/test_session.py +++ b/tests/dataframe/unit/test_session.py @@ -36,7 +36,7 @@ class TestDataframeSession(DataFrameSQLValidator): def test_cdf_str_schema(self): df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING") - expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" + expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_typed_schema_basic(self): @@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator): ] ) df = self.spark.createDataFrame([[1, "test"]], schema) - expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" + expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)" self.compare_sql(df, expected) def test_typed_schema_nested(self): diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index cc44311..1d60ec6 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -7,6 +7,11 @@ class TestBigQuery(Validator): def test_bigquery(self): self.validate_all( + "REGEXP_CONTAINS('foo', '.*')", + read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"}, + write={"mysql": "REGEXP_LIKE('foo', '.*')"}, + ), + self.validate_all( '"""x"""', write={ "bigquery": "'x'", @@ -94,6 +99,20 @@ class TestBigQuery(Validator): "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)", }, ) + self.validate_all( + "SELECT ARRAY(SELECT x FROM UNNEST([0, 1]) AS x)", + write={"bigquery": "SELECT ARRAY(SELECT x FROM UNNEST([0, 1]) AS x)"}, + ) + self.validate_all( + "SELECT ARRAY(SELECT DISTINCT x FROM UNNEST(some_numbers) AS x) AS unique_numbers", + write={ + "bigquery": "SELECT ARRAY(SELECT DISTINCT x FROM UNNEST(some_numbers) AS x) AS unique_numbers" + }, + ) + self.validate_all( + "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)", + write={"bigquery": "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)"}, + ) self.validate_all( "x IS unknown", diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 6033570..ee67bf1 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -1318,3 +1318,39 @@ SELECT "BEGIN IMMEDIATE TRANSACTION", write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"}, ) + + def test_merge(self): + self.validate_all( + """ + MERGE INTO target USING source ON target.id = source.id + WHEN NOT MATCHED THEN INSERT (id) values (source.id) + """, + write={ + "bigquery": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)", + "snowflake": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)", + "spark": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)", + }, + ) + self.validate_all( + """ + MERGE INTO target USING source ON target.id = source.id + WHEN MATCHED AND source.is_deleted = 1 THEN DELETE + WHEN MATCHED THEN UPDATE SET val = source.val + WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val) + """, + write={ + "bigquery": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)", + "snowflake": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)", + "spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)", + }, + ) + self.validate_all( + """ + MERGE INTO target USING source ON target.id = source.id + WHEN MATCHED THEN UPDATE * + WHEN NOT MATCHED THEN INSERT * + """, + write={ + "spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED THEN UPDATE * WHEN NOT MATCHED THEN INSERT *", + }, + ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 22d7bce..5ac8714 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -145,6 +145,10 @@ class TestHive(Validator): }, ) + self.validate_identity( + """CREATE EXTERNAL TABLE x (y INT) ROW FORMAT SERDE 'serde' ROW FORMAT DELIMITED FIELDS TERMINATED BY '1' WITH SERDEPROPERTIES ('input.regex'='')""", + ) + def test_lateral_view(self): self.validate_all( "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b", diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index cd6117c..962b28b 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -256,3 +256,7 @@ class TestPostgres(Validator): "SELECT $$Dianne's horse$$", write={"postgres": "SELECT 'Dianne''s horse'"}, ) + self.validate_all( + "UPDATE MYTABLE T1 SET T1.COL = 13", + write={"postgres": "UPDATE MYTABLE AS T1 SET T1.COL = 13"}, + ) diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 1943ee3..3034df5 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -56,8 +56,27 @@ class TestRedshift(Validator): "redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', }, ) + self.validate_all( + "DECODE(x, a, b, c, d)", + write={ + "": "MATCHES(x, a, b, c, d)", + "oracle": "DECODE(x, a, b, c, d)", + "snowflake": "DECODE(x, a, b, c, d)", + }, + ) + self.validate_all( + "NVL(a, b, c, d)", + write={ + "redshift": "COALESCE(a, b, c, d)", + "mysql": "COALESCE(a, b, c, d)", + "postgres": "COALESCE(a, b, c, d)", + }, + ) def test_identity(self): + self.validate_identity( + "SELECT DECODE(COL1, 'replace_this', 'with_this', 'replace_that', 'with_that')" + ) self.validate_identity("CAST('bla' AS SUPER)") self.validate_identity("CREATE TABLE real1 (realcol REAL)") self.validate_identity("CAST('foo' AS HLLSKETCH)") @@ -70,9 +89,9 @@ class TestRedshift(Validator): self.validate_identity( "SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'" ) - self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE AUTO") + self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL") self.validate_identity( - "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid)" + "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO" ) self.validate_identity( "COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'" @@ -80,3 +99,6 @@ class TestRedshift(Validator): self.validate_identity( "UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'" ) + self.validate_identity( + "CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)" + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index baca269..bca5aaa 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -500,3 +500,12 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL F }, pretty=True, ) + + def test_minus(self): + self.validate_all( + "SELECT 1 EXCEPT SELECT 1", + read={ + "oracle": "SELECT 1 MINUS SELECT 1", + "snowflake": "SELECT 1 MINUS SELECT 1", + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index 06ab96d..e12b673 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -75,6 +75,7 @@ ARRAY(1, 2) ARRAY_CONTAINS(x, 1) EXTRACT(x FROM y) EXTRACT(DATE FROM y) +EXTRACT(WEEK(monday) FROM created_at) CONCAT_WS('-', 'a', 'b') CONCAT_WS('-', 'a', 'b', 'c') POSEXPLODE("x") AS ("a", "b") diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql index 7fcdbb8..8880881 100644 --- a/tests/fixtures/optimizer/canonicalize.sql +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -3,3 +3,9 @@ SELECT CONCAT(w.d, w.e) AS c FROM w AS w; SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w; SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w; + +SELECT CAST(1 AS VARCHAR) AS a FROM w AS w; +SELECT CAST(1 AS VARCHAR) AS a FROM w AS w; + +SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w; +SELECT 1 + 3.2 AS a FROM w AS w; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index d9c7779..cf4195d 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -79,14 +79,16 @@ NULL; NULL = NULL; NULL; +-- Can't optimize this because different engines do different things +-- mysql converts to 0 and 1 but tsql does true and false NULL <=> NULL; -TRUE; +NULL IS NOT DISTINCT FROM NULL; a IS NOT DISTINCT FROM a; -TRUE; +a IS NOT DISTINCT FROM a; NULL IS DISTINCT FROM NULL; -FALSE; +NULL IS DISTINCT FROM NULL; NOT (NOT TRUE); TRUE; @@ -239,10 +241,10 @@ TRUE; FALSE; ((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3); -TRUE; +x = x; ((NOT FALSE) AND (x = x)) AND (FALSE OR 1 <> 2); -TRUE; +x = x; (('a' = 'a') AND TRUE and NOT FALSE); TRUE; @@ -372,3 +374,171 @@ CAST('1998-12-01' AS DATE) - INTERVAL '90' foo; date '1998-12-01' + interval '90' foo; CAST('1998-12-01' AS DATE) + INTERVAL '90' foo; + +-------------------------------------- +-- Comparisons +-------------------------------------- +x < 0 OR x > 1; +x < 0 OR x > 1; + +x < 0 OR x > 0; +x < 0 OR x > 0; + +x < 1 OR x > 0; +x < 1 OR x > 0; + +x < 1 OR x >= 0; +x < 1 OR x >= 0; + +x <= 1 OR x > 0; +x <= 1 OR x > 0; + +x <= 1 OR x >= 0; +x <= 1 OR x >= 0; + +x <= 1 AND x <= 0; +x <= 0; + +x <= 1 AND x > 0; +x <= 1 AND x > 0; + +x <= 1 OR x > 0; +x <= 1 OR x > 0; + +x <= 0 OR x < 0; +x <= 0; + +x >= 0 OR x > 0; +x >= 0; + +x >= 0 OR x > 1; +x >= 0; + +x <= 0 OR x >= 0; +x <= 0 OR x >= 0; + +x <= 0 AND x >= 0; +x <= 0 AND x >= 0; + +x < 1 AND x < 2; +x < 1; + +x < 1 OR x < 2; +x < 2; + +x < 2 AND x < 1; +x < 1; + +x < 2 OR x < 1; +x < 2; + +x < 1 AND x < 1; +x < 1; + +x < 1 OR x < 1; +x < 1; + +x <= 1 AND x < 1; +x < 1; + +x <= 1 OR x < 1; +x <= 1; + +x < 1 AND x <= 1; +x < 1; + +x < 1 OR x <= 1; +x <= 1; + +x > 1 AND x > 2; +x > 2; + +x > 1 OR x > 2; +x > 1; + +x > 2 AND x > 1; +x > 2; + +x > 2 OR x > 1; +x > 1; + +x > 1 AND x > 1; +x > 1; + +x > 1 OR x > 1; +x > 1; + +x >= 1 AND x > 1; +x > 1; + +x >= 1 OR x > 1; +x >= 1; + +x > 1 AND x >= 1; +x > 1; + +x > 1 OR x >= 1; +x >= 1; + +x > 1 AND x >= 2; +x >= 2; + +x > 1 OR x >= 2; +x > 1; + +x > 1 AND x >= 2 AND x > 3 AND x > 0; +x > 3; + +(x > 1 AND x >= 2 AND x > 3 AND x > 0) OR x > 0; +x > 0; + +x > 1 AND x < 2 AND x > 3; +FALSE; + +x > 1 AND x < 1; +FALSE; + +x < 2 AND x > 1; +x < 2 AND x > 1; + +x = 1 AND x < 1; +FALSE; + +x = 1 AND x < 1.1; +x = 1; + +x = 1 AND x <= 1; +x = 1; + +x = 1 AND x <= 0.9; +FALSE; + +x = 1 AND x > 0.9; +x = 1; + +x = 1 AND x > 1; +FALSE; + +x = 1 AND x >= 1; +x = 1; + +x = 1 AND x >= 2; +FALSE; + +x = 1 AND x <> 2; +x = 1; + +x <> 1 AND x = 1; +FALSE; + +x BETWEEN 0 AND 5 AND x > 3; +x <= 5 AND x > 3; + +x > 3 AND 5 > x AND x BETWEEN 0 AND 10; +x < 5 AND x > 3; + +x > 3 AND 5 < x AND x BETWEEN 9 AND 10; +x <= 10 AND x >= 9; + +1 < x AND 3 < x; +x > 3; diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql index 4893743..9c1f138 100644 --- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -190,7 +190,7 @@ SELECT SUM("lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" )) AS "revenue", - CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate", + "orders"."o_orderdate" AS "o_orderdate", "orders"."o_shippriority" AS "o_shippriority" FROM "customer" AS "customer" JOIN "orders" AS "orders" @@ -326,7 +326,8 @@ SELECT SUM("lineitem"."l_extendedprice" * "lineitem"."l_discount") AS "revenue" FROM "lineitem" AS "lineitem" WHERE - "lineitem"."l_discount" BETWEEN 0.05 AND 0.07 + "lineitem"."l_discount" <= 0.07 + AND "lineitem"."l_discount" >= 0.05 AND "lineitem"."l_quantity" < 24 AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE) AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE); @@ -344,7 +345,7 @@ from select n1.n_name as supp_nation, n2.n_name as cust_nation, - extract(year from l_shipdate) as l_year, + extract(year from cast(l_shipdate as date)) as l_year, l_extendedprice * (1 - l_discount) as volume from supplier, @@ -384,13 +385,14 @@ WITH "n1" AS ( SELECT "n1"."n_name" AS "supp_nation", "n2"."n_name" AS "cust_nation", - EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year", + EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE)) AS "l_year", SUM("lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" )) AS "revenue" FROM "supplier" AS "supplier" JOIN "lineitem" AS "lineitem" - ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + ON CAST("lineitem"."l_shipdate" AS DATE) <= CAST('1996-12-31' AS DATE) + AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-01-01' AS DATE) AND "supplier"."s_suppkey" = "lineitem"."l_suppkey" JOIN "orders" AS "orders" ON "orders"."o_orderkey" = "lineitem"."l_orderkey" @@ -409,7 +411,7 @@ JOIN "n1" AS "n2" GROUP BY "n1"."n_name", "n2"."n_name", - EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) + EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE)) ORDER BY "supp_nation", "cust_nation", @@ -427,7 +429,7 @@ select from ( select - extract(year from o_orderdate) as o_year, + extract(year from cast(o_orderdate as date)) as o_year, l_extendedprice * (1 - l_discount) as volume, n2.n_name as nation from @@ -456,7 +458,7 @@ group by order by o_year; SELECT - EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year", + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year", SUM( CASE WHEN "nation_2"."n_name" = 'BRAZIL' @@ -477,7 +479,8 @@ JOIN "customer" AS "customer" ON "customer"."c_nationkey" = "nation"."n_nationkey" JOIN "orders" AS "orders" ON "orders"."o_custkey" = "customer"."c_custkey" - AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) <= CAST('1996-12-31' AS DATE) + AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1995-01-01' AS DATE) JOIN "lineitem" AS "lineitem" ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "part"."p_partkey" = "lineitem"."l_partkey" @@ -488,7 +491,7 @@ JOIN "nation" AS "nation_2" WHERE "part"."p_type" = 'ECONOMY ANODIZED STEEL' GROUP BY - EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) ORDER BY "o_year"; @@ -503,7 +506,7 @@ from ( select n_name as nation, - extract(year from o_orderdate) as o_year, + extract(year from cast(o_orderdate as date)) as o_year, l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount from part, @@ -529,7 +532,7 @@ order by o_year desc; SELECT "nation"."n_name" AS "nation", - EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year", + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year", SUM( "lineitem"."l_extendedprice" * ( 1 - "lineitem"."l_discount" @@ -551,7 +554,7 @@ WHERE "part"."p_name" LIKE '%green%' GROUP BY "nation"."n_name", - EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) + EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) ORDER BY "nation", "o_year" DESC; @@ -1016,7 +1019,7 @@ select o_orderkey, o_orderdate, o_totalprice, - sum(l_quantity) + sum(l_quantity) total_quantity from customer, orders, @@ -1060,7 +1063,7 @@ SELECT "orders"."o_orderkey" AS "o_orderkey", "orders"."o_orderdate" AS "o_orderdate", "orders"."o_totalprice" AS "o_totalprice", - SUM("lineitem"."l_quantity") AS "_col_5" + SUM("lineitem"."l_quantity") AS "total_quantity" FROM "customer" AS "customer" JOIN "orders" AS "orders" ON "customer"."c_custkey" = "orders"."o_custkey" @@ -1129,19 +1132,22 @@ JOIN "part" AS "part" "part"."p_brand" = 'Brand#12' AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') AND "part"."p_partkey" = "lineitem"."l_partkey" - AND "part"."p_size" BETWEEN 1 AND 5 + AND "part"."p_size" <= 5 + AND "part"."p_size" >= 1 ) OR ( "part"."p_brand" = 'Brand#23' AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') AND "part"."p_partkey" = "lineitem"."l_partkey" - AND "part"."p_size" BETWEEN 1 AND 10 + AND "part"."p_size" <= 10 + AND "part"."p_size" >= 1 ) OR ( "part"."p_brand" = 'Brand#34' AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') AND "part"."p_partkey" = "lineitem"."l_partkey" - AND "part"."p_size" BETWEEN 1 AND 15 + AND "part"."p_size" <= 15 + AND "part"."p_size" >= 1 ) WHERE ( @@ -1152,7 +1158,8 @@ WHERE AND "part"."p_brand" = 'Brand#12' AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') AND "part"."p_partkey" = "lineitem"."l_partkey" - AND "part"."p_size" BETWEEN 1 AND 5 + AND "part"."p_size" <= 5 + AND "part"."p_size" >= 1 ) OR ( "lineitem"."l_quantity" <= 20 @@ -1162,7 +1169,8 @@ WHERE AND "part"."p_brand" = 'Brand#23' AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') AND "part"."p_partkey" = "lineitem"."l_partkey" - AND "part"."p_size" BETWEEN 1 AND 10 + AND "part"."p_size" <= 10 + AND "part"."p_size" >= 1 ) OR ( "lineitem"."l_quantity" <= 30 @@ -1172,7 +1180,8 @@ WHERE AND "part"."p_brand" = 'Brand#34' AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') AND "part"."p_partkey" = "lineitem"."l_partkey" - AND "part"."p_size" BETWEEN 1 AND 15 + AND "part"."p_size" <= 15 + AND "part"."p_size" >= 1 ); -------------------------------------- diff --git a/tests/test_executor.py b/tests/test_executor.py index 9d452e4..4fe6399 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -26,12 +26,12 @@ class TestExecutor(unittest.TestCase): def setUpClass(cls): cls.conn = duckdb.connect() - for table in TPCH_SCHEMA: + for table, columns in TPCH_SCHEMA.items(): cls.conn.execute( f""" CREATE VIEW {table} AS SELECT * - FROM READ_CSV_AUTO('{DIR}{table}.csv.gz') + FROM READ_CSV('{DIR}{table}.csv.gz', delim='|', header=True, columns={columns}) """ ) @@ -74,13 +74,13 @@ class TestExecutor(unittest.TestCase): ) return expression - for i, (sql, _) in enumerate(self.sqls[0:16]): + for i, (sql, _) in enumerate(self.sqls[0:18]): with self.subTest(f"tpch-h {i + 1}"): a = self.cached_execute(sql) sql = parse_one(sql).transform(to_csv).sql(pretty=True) table = execute(sql, TPCH_SCHEMA) b = pd.DataFrame(table.rows, columns=table.columns) - assert_frame_equal(a, b, check_dtype=False) + assert_frame_equal(a, b, check_dtype=False, check_index_type=False) def test_execute_callable(self): tables = { @@ -456,11 +456,16 @@ class TestExecutor(unittest.TestCase): ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]), ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]), ("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]), - ("SELECT SUM(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0"], [(0,)]), + ( + "SELECT SUM(x), COUNT(x) FROM (SELECT 1 AS x WHERE FALSE)", + ["_col_0", "_col_1"], + [(None, 0)], + ), ]: - result = execute(sql) - self.assertEqual(result.columns, tuple(cols)) - self.assertEqual(result.rows, rows) + with self.subTest(sql): + result = execute(sql) + self.assertEqual(result.columns, tuple(cols)) + self.assertEqual(result.rows, rows) def test_aggregate_without_group_by(self): result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]}) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index ecf581d..0c5f6cd 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -333,7 +333,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') for sql, target_type in tests.items(): expression = annotate_types(parse_one(sql)) - self.assertEqual(expression.find(exp.Literal).type, target_type) + self.assertEqual(expression.find(exp.Literal).type.this, target_type) def test_boolean_type_annotation(self): tests = { @@ -343,31 +343,33 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') for sql, target_type in tests.items(): expression = annotate_types(parse_one(sql)) - self.assertEqual(expression.find(exp.Boolean).type, target_type) + self.assertEqual(expression.find(exp.Boolean).type.this, target_type) def test_cast_type_annotation(self): expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))")) + self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ) + self.assertEqual(expression.this.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(expression.args["to"].type.this, exp.DataType.Type.TIMESTAMPTZ) + self.assertEqual(expression.args["to"].expressions[0].type.this, exp.DataType.Type.INT) - self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ) - self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR) - self.assertEqual(expression.args["to"].type, exp.DataType.Type.TIMESTAMPTZ) - self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT) + expression = annotate_types(parse_one("ARRAY(1)::ARRAY<INT>")) + self.assertEqual(expression.type, parse_one("ARRAY<INT>", into=exp.DataType)) def test_cache_annotation(self): expression = annotate_types( parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1") ) - self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT) + self.assertEqual(expression.expression.expressions[0].type.this, exp.DataType.Type.INT) def test_binary_annotation(self): expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0] - self.assertEqual(expression.type, exp.DataType.Type.DOUBLE) - self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE) - self.assertEqual(expression.right.type, exp.DataType.Type.INT) - self.assertEqual(expression.right.this.type, exp.DataType.Type.INT) - self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT) - self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT) + self.assertEqual(expression.type.this, exp.DataType.Type.DOUBLE) + self.assertEqual(expression.left.type.this, exp.DataType.Type.DOUBLE) + self.assertEqual(expression.right.type.this, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.type.this, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT) + self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT) def test_derived_tables_column_annotation(self): schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}} @@ -387,128 +389,169 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') """ expression = annotate_types(parse_one(sql), schema=schema) - self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola + self.assertEqual( + expression.expressions[0].type.this, exp.DataType.Type.FLOAT + ) # a.cola AS cola addition_alias = expression.args["from"].expressions[0].this.expressions[0] - self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola + self.assertEqual( + addition_alias.type.this, exp.DataType.Type.FLOAT + ) # x.cola + y.cola AS cola addition = addition_alias.this - self.assertEqual(addition.type, exp.DataType.Type.FLOAT) - self.assertEqual(addition.this.type, exp.DataType.Type.INT) - self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT) + self.assertEqual(addition.type.this, exp.DataType.Type.FLOAT) + self.assertEqual(addition.this.type.this, exp.DataType.Type.INT) + self.assertEqual(addition.expression.type.this, exp.DataType.Type.FLOAT) def test_cte_column_annotation(self): - schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}} + schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT", "colc": "BOOLEAN"}} sql = """ WITH tbl AS ( - SELECT x.cola + 'bla' AS cola, y.colb AS colb + SELECT x.cola + 'bla' AS cola, y.colb AS colb, y.colc AS colc FROM ( SELECT x.cola AS cola FROM x AS x ) AS x JOIN ( - SELECT y.colb AS colb + SELECT y.colb AS colb, y.colc AS colc FROM y AS y ) AS y ) SELECT tbl.cola + tbl.colb + 'foo' AS col FROM tbl AS tbl + WHERE tbl.colc = True """ expression = annotate_types(parse_one(sql), schema=schema) self.assertEqual( - expression.expressions[0].type, exp.DataType.Type.TEXT + expression.expressions[0].type.this, exp.DataType.Type.TEXT ) # tbl.cola + tbl.colb + 'foo' AS col outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo' - self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT) - self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT) - self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR) + self.assertEqual(outer_addition.type.this, exp.DataType.Type.TEXT) + self.assertEqual(outer_addition.left.type.this, exp.DataType.Type.TEXT) + self.assertEqual(outer_addition.right.type.this, exp.DataType.Type.VARCHAR) inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb - self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR) - self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT) + self.assertEqual(inner_addition.left.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(inner_addition.right.type.this, exp.DataType.Type.TEXT) + + # WHERE tbl.colc = True + self.assertEqual(expression.args["where"].this.type.this, exp.DataType.Type.BOOLEAN) cte_select = expression.args["with"].expressions[0].this self.assertEqual( - cte_select.expressions[0].type, exp.DataType.Type.VARCHAR + cte_select.expressions[0].type.this, exp.DataType.Type.VARCHAR ) # x.cola + 'bla' AS cola - self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb + self.assertEqual( + cte_select.expressions[1].type.this, exp.DataType.Type.TEXT + ) # y.colb AS colb + self.assertEqual( + cte_select.expressions[2].type.this, exp.DataType.Type.BOOLEAN + ) # y.colc AS colc cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla' - self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR) - self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR) - self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR) + self.assertEqual(cte_select_addition.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(cte_select_addition.left.type.this, exp.DataType.Type.CHAR) + self.assertEqual(cte_select_addition.right.type.this, exp.DataType.Type.VARCHAR) # Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively for d, t in zip( cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT] ): - self.assertEqual(d.this.expressions[0].this.type, t) + self.assertEqual(d.this.expressions[0].this.type.this, t) def test_function_annotation(self): schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}} sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x" concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] - self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR) + self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.VARCHAR) concat_expr = concat_expr_alias.this - self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR) - self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola - self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb) - self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb + self.assertEqual(concat_expr.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola + self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.VARCHAR) # TRIM(x.colb) + self.assertEqual(concat_expr.right.this.type.this, exp.DataType.Type.CHAR) # x.colb sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x" case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] - self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR) + self.assertEqual(case_expr_alias.type.this, exp.DataType.Type.VARCHAR) case_expr = case_expr_alias.this - self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR) - self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR) + self.assertEqual(case_expr.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(case_expr.args["default"].type.this, exp.DataType.Type.CHAR) case_ifs_expr = case_expr.args["ifs"][0] - self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR) - self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR) + self.assertEqual(case_ifs_expr.type.this, exp.DataType.Type.VARCHAR) + self.assertEqual(case_ifs_expr.args["true"].type.this, exp.DataType.Type.VARCHAR) def test_unknown_annotation(self): schema = {"x": {"cola": "VARCHAR"}} sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x" concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0] - self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN) + self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.UNKNOWN) concat_expr = concat_expr_alias.this - self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN) - self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola + self.assertEqual(concat_expr.type.this, exp.DataType.Type.UNKNOWN) + self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola self.assertEqual( - concat_expr.right.type, exp.DataType.Type.UNKNOWN + concat_expr.right.type.this, exp.DataType.Type.UNKNOWN ) # SOME_ANONYMOUS_FUNC(x.cola) self.assertEqual( - concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR + concat_expr.right.expressions[0].type.this, exp.DataType.Type.VARCHAR ) # x.cola (arg) def test_null_annotation(self): expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this - self.assertEqual(expression.left.type, exp.DataType.Type.NULL) - self.assertEqual(expression.right.type, exp.DataType.Type.INT) + self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL) + self.assertEqual(expression.right.type.this, exp.DataType.Type.INT) # NULL <op> UNKNOWN should yield NULL sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result" concat_expr_alias = annotate_types(parse_one(sql)).expressions[0] - self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL) + self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.NULL) concat_expr = concat_expr_alias.this - self.assertEqual(concat_expr.type, exp.DataType.Type.NULL) - self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL) - self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN) + self.assertEqual(concat_expr.type.this, exp.DataType.Type.NULL) + self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.NULL) + self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.UNKNOWN) def test_nullable_annotation(self): nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) expression = annotate_types(parse_one("NULL AND FALSE")) self.assertEqual(expression.type, nullable) - self.assertEqual(expression.left.type, exp.DataType.Type.NULL) - self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN) + self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL) + self.assertEqual(expression.right.type.this, exp.DataType.Type.BOOLEAN) + + def test_predicate_annotation(self): + expression = annotate_types(parse_one("x BETWEEN a AND b")) + self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN) + + expression = annotate_types(parse_one("x IN (a, b, c, d)")) + self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN) + + def test_aggfunc_annotation(self): + schema = {"x": {"cola": "SMALLINT", "colb": "FLOAT", "colc": "TEXT", "cold": "DATE"}} + + tests = { + ("AVG", "cola"): exp.DataType.Type.DOUBLE, + ("SUM", "cola"): exp.DataType.Type.BIGINT, + ("SUM", "colb"): exp.DataType.Type.DOUBLE, + ("MIN", "cola"): exp.DataType.Type.SMALLINT, + ("MIN", "colb"): exp.DataType.Type.FLOAT, + ("MAX", "colc"): exp.DataType.Type.TEXT, + ("MAX", "cold"): exp.DataType.Type.DATE, + ("COUNT", "colb"): exp.DataType.Type.BIGINT, + ("STDDEV", "cola"): exp.DataType.Type.DOUBLE, + } + + for (func, col), target_type in tests.items(): + expression = annotate_types( + parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema + ) + self.assertEqual(expression.expressions[0].type.this, target_type) diff --git a/tests/test_schema.py b/tests/test_schema.py index cc0e3d1..f1e12a2 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -151,31 +151,33 @@ class TestSchema(unittest.TestCase): def test_schema_get_column_type(self): schema = MappingSchema({"a": {"b": "varchar"}}) - self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR) + self.assertEqual(schema.get_column_type("a", "b").this, exp.DataType.Type.VARCHAR) self.assertEqual( - schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")), + schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")).this, exp.DataType.Type.VARCHAR, ) self.assertEqual( - schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR + schema.get_column_type("a", exp.Column(this="b")).this, exp.DataType.Type.VARCHAR ) self.assertEqual( - schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR + schema.get_column_type(exp.Table(this="a"), "b").this, exp.DataType.Type.VARCHAR ) schema = MappingSchema({"a": {"b": {"c": "varchar"}}}) self.assertEqual( - schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")), + schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")).this, exp.DataType.Type.VARCHAR, ) self.assertEqual( - schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR + schema.get_column_type(exp.Table(this="b", db="a"), "c").this, exp.DataType.Type.VARCHAR ) schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}}) self.assertEqual( - schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")), + schema.get_column_type( + exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d") + ).this, exp.DataType.Type.VARCHAR, ) self.assertEqual( - schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"), + schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d").this, exp.DataType.Type.VARCHAR, ) diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 1d1b966..1376849 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -1,6 +1,6 @@ import unittest -from sqlglot.tokens import Tokenizer +from sqlglot.tokens import Tokenizer, TokenType class TestTokens(unittest.TestCase): @@ -17,3 +17,48 @@ class TestTokens(unittest.TestCase): for sql, comment in sql_comment: self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment) + + def test_jinja(self): + tokenizer = Tokenizer() + + tokens = tokenizer.tokenize( + """ + SELECT + {{ x }}, + {{- x -}}, + {% for x in y -%} + a {{+ b }} + {% endfor %}; + """ + ) + + tokens = [(token.token_type, token.text) for token in tokens] + + self.assertEqual( + tokens, + [ + (TokenType.SELECT, "SELECT"), + (TokenType.BLOCK_START, "{{"), + (TokenType.VAR, "x"), + (TokenType.BLOCK_END, "}}"), + (TokenType.COMMA, ","), + (TokenType.BLOCK_START, "{{-"), + (TokenType.VAR, "x"), + (TokenType.BLOCK_END, "-}}"), + (TokenType.COMMA, ","), + (TokenType.BLOCK_START, "{%"), + (TokenType.FOR, "for"), + (TokenType.VAR, "x"), + (TokenType.IN, "in"), + (TokenType.VAR, "y"), + (TokenType.BLOCK_END, "-%}"), + (TokenType.VAR, "a"), + (TokenType.BLOCK_START, "{{+"), + (TokenType.VAR, "b"), + (TokenType.BLOCK_END, "}}"), + (TokenType.BLOCK_START, "{%"), + (TokenType.VAR, "endfor"), + (TokenType.BLOCK_END, "%}"), + (TokenType.SEMICOLON, ";"), + ], + ) |