summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/dataframe/unit/test_dataframe.py20
-rw-r--r--tests/dataframe/unit/test_dataframe_writer.py36
-rw-r--r--tests/dataframe/unit/test_session.py17
-rw-r--r--tests/dialects/test_bigquery.py4
-rw-r--r--tests/dialects/test_dialect.py81
-rw-r--r--tests/dialects/test_drill.py53
-rw-r--r--tests/dialects/test_mysql.py10
-rw-r--r--tests/dialects/test_presto.py75
-rw-r--r--tests/dialects/test_snowflake.py11
-rw-r--r--tests/fixtures/identity.sql21
-rw-r--r--tests/fixtures/optimizer/canonicalize.sql5
-rw-r--r--tests/fixtures/optimizer/optimizer.sql4
-rw-r--r--tests/fixtures/optimizer/tpc-h/tpc-h.sql50
-rw-r--r--tests/fixtures/pretty.sql7
-rw-r--r--tests/helpers.py64
-rw-r--r--tests/test_executor.py403
-rw-r--r--tests/test_expressions.py13
-rw-r--r--tests/test_optimizer.py19
-rw-r--r--tests/test_parser.py57
-rw-r--r--tests/test_tokens.py1
-rw-r--r--tests/test_transpile.py11
21 files changed, 837 insertions, 125 deletions
diff --git a/tests/dataframe/unit/test_dataframe.py b/tests/dataframe/unit/test_dataframe.py
index e36667b..24850bc 100644
--- a/tests/dataframe/unit/test_dataframe.py
+++ b/tests/dataframe/unit/test_dataframe.py
@@ -4,6 +4,8 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframe(DataFrameSQLValidator):
+ maxDiff = None
+
def test_hash_select_expression(self):
expression = exp.select("cola").from_("table")
self.assertEqual("t17051", DataFrame._create_hash_from_expression(expression))
@@ -16,26 +18,26 @@ class TestDataframe(DataFrameSQLValidator):
def test_cache(self):
df = self.df_employee.select("fname").cache()
expected_statements = [
- "DROP VIEW IF EXISTS t11623",
- "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
+ "DROP VIEW IF EXISTS t31563",
+ "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
]
self.compare_sql(df, expected_statements)
def test_persist_default(self):
df = self.df_employee.select("fname").persist()
expected_statements = [
- "DROP VIEW IF EXISTS t11623",
- "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
+ "DROP VIEW IF EXISTS t31563",
+ "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'MEMORY_AND_DISK_SER') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
]
self.compare_sql(df, expected_statements)
def test_persist_storagelevel(self):
df = self.df_employee.select("fname").persist("DISK_ONLY_2")
expected_statements = [
- "DROP VIEW IF EXISTS t11623",
- "CACHE LAZY TABLE t11623 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS string) AS `fname` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "SELECT `t11623`.`fname` AS `fname` FROM `t11623` AS `t11623`",
+ "DROP VIEW IF EXISTS t31563",
+ "CACHE LAZY TABLE t31563 OPTIONS('storageLevel' = 'DISK_ONLY_2') AS SELECT CAST(`a1`.`fname` AS STRING) AS `fname` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "SELECT `t31563`.`fname` AS `fname` FROM `t31563` AS `t31563`",
]
self.compare_sql(df, expected_statements)
diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py
index 14b4a0a..7c646f5 100644
--- a/tests/dataframe/unit/test_dataframe_writer.py
+++ b/tests/dataframe/unit/test_dataframe_writer.py
@@ -6,39 +6,41 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataFrameWriter(DataFrameSQLValidator):
+ maxDiff = None
+
def test_insertInto_full_path(self):
df = self.df_employee.write.insertInto("catalog.db.table_name")
- expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_db_table(self):
df = self.df_employee.write.insertInto("db.table_name")
- expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_table(self):
df = self.df_employee.write.insertInto("table_name")
- expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_overwrite(self):
df = self.df_employee.write.insertInto("table_name", overwrite=True)
- expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_insertInto_byName(self):
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
df = self.df_employee.write.byName.insertInto("table_name")
- expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_cache(self):
df = self.df_employee.cache().write.insertInto("table_name")
expected_statements = [
- "DROP VIEW IF EXISTS t35612",
- "CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "INSERT INTO table_name SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
+ "DROP VIEW IF EXISTS t37164",
+ "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
]
self.compare_sql(df, expected_statements)
@@ -48,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator):
def test_saveAsTable_append(self):
df = self.df_employee.write.saveAsTable("table_name", mode="append")
- expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_overwrite(self):
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
- expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_error(self):
df = self.df_employee.write.saveAsTable("table_name", mode="error")
- expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_ignore(self):
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
- expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_standalone(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
- expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_override(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
- expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_cache(self):
df = self.df_employee.cache().write.saveAsTable("table_name")
expected_statements = [
- "DROP VIEW IF EXISTS t35612",
- "CACHE LAZY TABLE t35612 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS int) AS `employee_id`, CAST(`a1`.`fname` AS string) AS `fname`, CAST(`a1`.`lname` AS string) AS `lname`, CAST(`a1`.`age` AS int) AS `age`, CAST(`a1`.`store_id` AS int) AS `store_id` FROM (VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100)) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "CREATE TABLE table_name AS SELECT `t35612`.`employee_id` AS `employee_id`, `t35612`.`fname` AS `fname`, `t35612`.`lname` AS `lname`, `t35612`.`age` AS `age`, `t35612`.`store_id` AS `store_id` FROM `t35612` AS `t35612`",
+ "DROP VIEW IF EXISTS t37164",
+ "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
]
self.compare_sql(df, expected_statements)
diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py
index 7e8bfad..55aa547 100644
--- a/tests/dataframe/unit/test_session.py
+++ b/tests/dataframe/unit/test_session.py
@@ -11,32 +11,32 @@ from tests.dataframe.unit.dataframe_sql_validator import DataFrameSQLValidator
class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_one_row(self):
df = self.spark.createDataFrame([[1, 2]], ["cola", "colb"])
- expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2)) AS `a2`(`cola`, `colb`)"
+ expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_multiple_rows(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]], ["cola", "colb"])
- expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`cola`, `colb`)"
+ expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_no_schema(self):
df = self.spark.createDataFrame([[1, 2], [3, 4], [None, 6]])
- expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM (VALUES (1, 2), (3, 4), (NULL, 6)) AS `a2`(`_1`, `_2`)"
+ expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2` FROM VALUES (1, 2), (3, 4), (NULL, 6) AS `a2`(`_1`, `_2`)"
self.compare_sql(df, expected)
def test_cdf_row_mixed_primitives(self):
df = self.spark.createDataFrame([[1, 10.1, "test", False, None]])
- expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM (VALUES (1, 10.1, 'test', FALSE, NULL)) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
+ expected = "SELECT `a2`.`_1` AS `_1`, `a2`.`_2` AS `_2`, `a2`.`_3` AS `_3`, `a2`.`_4` AS `_4`, `a2`.`_5` AS `_5` FROM VALUES (1, 10.1, 'test', FALSE, NULL) AS `a2`(`_1`, `_2`, `_3`, `_4`, `_5`)"
self.compare_sql(df, expected)
def test_cdf_dict_rows(self):
df = self.spark.createDataFrame([{"cola": 1, "colb": "test"}, {"cola": 2, "colb": "test2"}])
- expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM (VALUES (1, 'test'), (2, 'test2')) AS `a2`(`cola`, `colb`)"
+ expected = "SELECT `a2`.`cola` AS `cola`, `a2`.`colb` AS `colb` FROM VALUES (1, 'test'), (2, 'test2') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_cdf_str_schema(self):
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
- expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
+ expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_basic(self):
@@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator):
]
)
df = self.spark.createDataFrame([[1, "test"]], schema)
- expected = "SELECT CAST(`a2`.`cola` AS int) AS `cola`, CAST(`a2`.`colb` AS string) AS `colb` FROM (VALUES (1, 'test')) AS `a2`(`cola`, `colb`)"
+ expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_nested(self):
@@ -65,7 +65,8 @@ class TestDataframeSession(DataFrameSQLValidator):
]
)
df = self.spark.createDataFrame([[{"sub_cola": 1, "sub_colb": "test"}]], schema)
- expected = "SELECT CAST(`a2`.`cola` AS struct<sub_cola:int, sub_colb:string>) AS `cola` FROM (VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`))) AS `a2`(`cola`)"
+ expected = "SELECT CAST(`a2`.`cola` AS STRUCT<`sub_cola`: INT, `sub_colb`: STRING>) AS `cola` FROM VALUES (STRUCT(1 AS `sub_cola`, 'test' AS `sub_colb`)) AS `a2`(`cola`)"
+
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index a0ebc45..cc44311 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -286,6 +286,10 @@ class TestBigQuery(Validator):
"bigquery": "SELECT * FROM (SELECT a, b, c FROM test) PIVOT(SUM(b) AS d, COUNT(*) AS e FOR c IN ('x', 'y'))",
},
)
+ self.validate_identity("BEGIN A B C D E F")
+ self.validate_identity("BEGIN TRANSACTION")
+ self.validate_identity("COMMIT TRANSACTION")
+ self.validate_identity("ROLLBACK TRANSACTION")
def test_user_defined_functions(self):
self.validate_identity(
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 1913f53..1b2f9c1 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -69,6 +69,7 @@ class TestDialect(Validator):
write={
"bigquery": "CAST(a AS STRING)",
"clickhouse": "CAST(a AS TEXT)",
+ "drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS TEXT)",
"hive": "CAST(a AS STRING)",
@@ -86,6 +87,7 @@ class TestDialect(Validator):
write={
"bigquery": "CAST(a AS BINARY(4))",
"clickhouse": "CAST(a AS BINARY(4))",
+ "drill": "CAST(a AS VARBINARY(4))",
"duckdb": "CAST(a AS BINARY(4))",
"mysql": "CAST(a AS BINARY(4))",
"hive": "CAST(a AS BINARY(4))",
@@ -146,6 +148,7 @@ class TestDialect(Validator):
"CAST(a AS STRING)",
write={
"bigquery": "CAST(a AS STRING)",
+ "drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS TEXT)",
"hive": "CAST(a AS STRING)",
@@ -162,6 +165,7 @@ class TestDialect(Validator):
"CAST(a AS VARCHAR)",
write={
"bigquery": "CAST(a AS STRING)",
+ "drill": "CAST(a AS VARCHAR)",
"duckdb": "CAST(a AS TEXT)",
"mysql": "CAST(a AS VARCHAR)",
"hive": "CAST(a AS STRING)",
@@ -178,6 +182,7 @@ class TestDialect(Validator):
"CAST(a AS VARCHAR(3))",
write={
"bigquery": "CAST(a AS STRING(3))",
+ "drill": "CAST(a AS VARCHAR(3))",
"duckdb": "CAST(a AS TEXT(3))",
"mysql": "CAST(a AS VARCHAR(3))",
"hive": "CAST(a AS VARCHAR(3))",
@@ -194,6 +199,7 @@ class TestDialect(Validator):
"CAST(a AS SMALLINT)",
write={
"bigquery": "CAST(a AS INT64)",
+ "drill": "CAST(a AS INTEGER)",
"duckdb": "CAST(a AS SMALLINT)",
"mysql": "CAST(a AS SMALLINT)",
"hive": "CAST(a AS SMALLINT)",
@@ -215,6 +221,7 @@ class TestDialect(Validator):
},
write={
"duckdb": "TRY_CAST(a AS DOUBLE)",
+ "drill": "CAST(a AS DOUBLE)",
"postgres": "CAST(a AS DOUBLE PRECISION)",
"redshift": "CAST(a AS DOUBLE PRECISION)",
},
@@ -225,6 +232,7 @@ class TestDialect(Validator):
write={
"bigquery": "CAST(a AS FLOAT64)",
"clickhouse": "CAST(a AS Float64)",
+ "drill": "CAST(a AS DOUBLE)",
"duckdb": "CAST(a AS DOUBLE)",
"mysql": "CAST(a AS DOUBLE)",
"hive": "CAST(a AS DOUBLE)",
@@ -279,6 +287,7 @@ class TestDialect(Validator):
"duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')",
+ "drill": "TO_TIMESTAMP(x, 'yyyy-MM-dd''T''HH:mm:ss')",
"redshift": "TO_TIMESTAMP(x, 'YYYY-MM-DDTHH:MI:SS')",
"spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')",
},
@@ -286,6 +295,7 @@ class TestDialect(Validator):
self.validate_all(
"STR_TO_TIME('2020-01-01', '%Y-%m-%d')",
write={
+ "drill": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
"duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"oracle": "TO_TIMESTAMP('2020-01-01', 'YYYY-MM-DD')",
@@ -298,6 +308,7 @@ class TestDialect(Validator):
self.validate_all(
"STR_TO_TIME(x, '%y')",
write={
+ "drill": "TO_TIMESTAMP(x, 'yy')",
"duckdb": "STRPTIME(x, '%y')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)",
"presto": "DATE_PARSE(x, '%y')",
@@ -319,6 +330,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_STR_TO_DATE('2020-01-01')",
write={
+ "drill": "CAST('2020-01-01' AS DATE)",
"duckdb": "CAST('2020-01-01' AS DATE)",
"hive": "TO_DATE('2020-01-01')",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
@@ -328,6 +340,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_STR_TO_TIME('2020-01-01')",
write={
+ "drill": "CAST('2020-01-01' AS TIMESTAMP)",
"duckdb": "CAST('2020-01-01' AS TIMESTAMP)",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')",
@@ -344,6 +357,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_TO_STR(x, '%Y-%m-%d')",
write={
+ "drill": "TO_CHAR(x, 'yyyy-MM-dd')",
"duckdb": "STRFTIME(x, '%Y-%m-%d')",
"hive": "DATE_FORMAT(x, 'yyyy-MM-dd')",
"oracle": "TO_CHAR(x, 'YYYY-MM-DD')",
@@ -355,6 +369,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_TO_TIME_STR(x)",
write={
+ "drill": "CAST(x AS VARCHAR)",
"duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)",
@@ -364,6 +379,7 @@ class TestDialect(Validator):
self.validate_all(
"TIME_TO_UNIX(x)",
write={
+ "drill": "UNIX_TIMESTAMP(x)",
"duckdb": "EPOCH(x)",
"hive": "UNIX_TIMESTAMP(x)",
"presto": "TO_UNIXTIME(x)",
@@ -425,6 +441,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_TO_DATE_STR(x)",
write={
+ "drill": "CAST(x AS VARCHAR)",
"duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)",
@@ -433,6 +450,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_TO_DI(x)",
write={
+ "drill": "CAST(TO_DATE(x, 'yyyyMMdd') AS INT)",
"duckdb": "CAST(STRFTIME(x, '%Y%m%d') AS INT)",
"hive": "CAST(DATE_FORMAT(x, 'yyyyMMdd') AS INT)",
"presto": "CAST(DATE_FORMAT(x, '%Y%m%d') AS INT)",
@@ -441,6 +459,7 @@ class TestDialect(Validator):
self.validate_all(
"DI_TO_DATE(x)",
write={
+ "drill": "TO_DATE(CAST(x AS VARCHAR), 'yyyyMMdd')",
"duckdb": "CAST(STRPTIME(CAST(x AS TEXT), '%Y%m%d') AS DATE)",
"hive": "TO_DATE(CAST(x AS STRING), 'yyyyMMdd')",
"presto": "CAST(DATE_PARSE(CAST(x AS VARCHAR), '%Y%m%d') AS DATE)",
@@ -463,6 +482,7 @@ class TestDialect(Validator):
},
write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
+ "drill": "DATE_ADD(x, INTERVAL '1' DAY)",
"duckdb": "x + INTERVAL 1 day",
"hive": "DATE_ADD(x, 1)",
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
@@ -477,6 +497,7 @@ class TestDialect(Validator):
"DATE_ADD(x, 1)",
write={
"bigquery": "DATE_ADD(x, INTERVAL 1 'day')",
+ "drill": "DATE_ADD(x, INTERVAL '1' DAY)",
"duckdb": "x + INTERVAL 1 DAY",
"hive": "DATE_ADD(x, 1)",
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
@@ -546,6 +567,7 @@ class TestDialect(Validator):
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
},
write={
+ "drill": "TO_DATE(x, 'yyyy-MM-dd''T''HH:mm:ss')",
"mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
"starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')",
"hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)",
@@ -556,6 +578,7 @@ class TestDialect(Validator):
self.validate_all(
"STR_TO_DATE(x, '%Y-%m-%d')",
write={
+ "drill": "CAST(x AS DATE)",
"mysql": "STR_TO_DATE(x, '%Y-%m-%d')",
"starrocks": "STR_TO_DATE(x, '%Y-%m-%d')",
"hive": "CAST(x AS DATE)",
@@ -566,6 +589,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_STR_TO_DATE(x)",
write={
+ "drill": "CAST(x AS DATE)",
"duckdb": "CAST(x AS DATE)",
"hive": "TO_DATE(x)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
@@ -575,6 +599,7 @@ class TestDialect(Validator):
self.validate_all(
"TS_OR_DS_ADD('2021-02-01', 1, 'DAY')",
write={
+ "drill": "DATE_ADD(CAST('2021-02-01' AS DATE), INTERVAL '1' DAY)",
"duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY",
"hive": "DATE_ADD('2021-02-01', 1)",
"presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))",
@@ -584,6 +609,7 @@ class TestDialect(Validator):
self.validate_all(
"DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
write={
+ "drill": "DATE_ADD(CAST('2020-01-01' AS DATE), INTERVAL '1' DAY)",
"duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY",
"hive": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)",
"presto": "DATE_ADD('day', 1, CAST('2020-01-01' AS DATE))",
@@ -593,6 +619,7 @@ class TestDialect(Validator):
self.validate_all(
"TIMESTAMP '2022-01-01'",
write={
+ "drill": "CAST('2022-01-01' AS TIMESTAMP)",
"mysql": "CAST('2022-01-01' AS TIMESTAMP)",
"starrocks": "CAST('2022-01-01' AS DATETIME)",
"hive": "CAST('2022-01-01' AS TIMESTAMP)",
@@ -614,6 +641,7 @@ class TestDialect(Validator):
dialect: f"{unit}(x)"
for dialect in (
"bigquery",
+ "drill",
"duckdb",
"mysql",
"presto",
@@ -624,6 +652,7 @@ class TestDialect(Validator):
dialect: f"{unit}(x)"
for dialect in (
"bigquery",
+ "drill",
"duckdb",
"mysql",
"presto",
@@ -649,6 +678,7 @@ class TestDialect(Validator):
write={
"bigquery": "ARRAY_LENGTH(x)",
"duckdb": "ARRAY_LENGTH(x)",
+ "drill": "REPEATED_COUNT(x)",
"presto": "CARDINALITY(x)",
"spark": "SIZE(x)",
},
@@ -736,6 +766,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)",
write={
+ "drill": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
"presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)",
"spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a",
},
@@ -743,6 +774,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)",
write={
+ "drill": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)",
"spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b",
},
@@ -775,6 +807,7 @@ class TestDialect(Validator):
},
write={
"bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b",
+ "drill": "SELECT * FROM a UNION SELECT * FROM b",
"duckdb": "SELECT * FROM a UNION SELECT * FROM b",
"presto": "SELECT * FROM a UNION SELECT * FROM b",
"spark": "SELECT * FROM a UNION SELECT * FROM b",
@@ -887,6 +920,7 @@ class TestDialect(Validator):
write={
"bigquery": "LOWER(x) LIKE '%y'",
"clickhouse": "x ILIKE '%y'",
+ "drill": "x `ILIKE` '%y'",
"duckdb": "x ILIKE '%y'",
"hive": "LOWER(x) LIKE '%y'",
"mysql": "LOWER(x) LIKE '%y'",
@@ -910,32 +944,38 @@ class TestDialect(Validator):
self.validate_all(
"POSITION(' ' in x)",
write={
+ "drill": "STRPOS(x, ' ')",
"duckdb": "STRPOS(x, ' ')",
"postgres": "STRPOS(x, ' ')",
"presto": "STRPOS(x, ' ')",
"spark": "LOCATE(' ', x)",
"clickhouse": "position(x, ' ')",
"snowflake": "POSITION(' ', x)",
+ "mysql": "LOCATE(' ', x)",
},
)
self.validate_all(
"STR_POSITION('a', x)",
write={
+ "drill": "STRPOS(x, 'a')",
"duckdb": "STRPOS(x, 'a')",
"postgres": "STRPOS(x, 'a')",
"presto": "STRPOS(x, 'a')",
"spark": "LOCATE('a', x)",
"clickhouse": "position(x, 'a')",
"snowflake": "POSITION('a', x)",
+ "mysql": "LOCATE('a', x)",
},
)
self.validate_all(
"POSITION('a', x, 3)",
write={
+ "drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
"spark": "LOCATE('a', x, 3)",
"clickhouse": "position(x, 'a', 3)",
"snowflake": "POSITION('a', x, 3)",
+ "mysql": "LOCATE('a', x, 3)",
},
)
self.validate_all(
@@ -960,6 +1000,7 @@ class TestDialect(Validator):
self.validate_all(
"IF(x > 1, 1, 0)",
write={
+ "drill": "`IF`(x > 1, 1, 0)",
"duckdb": "CASE WHEN x > 1 THEN 1 ELSE 0 END",
"presto": "IF(x > 1, 1, 0)",
"hive": "IF(x > 1, 1, 0)",
@@ -970,6 +1011,7 @@ class TestDialect(Validator):
self.validate_all(
"CASE WHEN 1 THEN x ELSE 0 END",
write={
+ "drill": "CASE WHEN 1 THEN x ELSE 0 END",
"duckdb": "CASE WHEN 1 THEN x ELSE 0 END",
"presto": "CASE WHEN 1 THEN x ELSE 0 END",
"hive": "CASE WHEN 1 THEN x ELSE 0 END",
@@ -980,6 +1022,7 @@ class TestDialect(Validator):
self.validate_all(
"x[y]",
write={
+ "drill": "x[y]",
"duckdb": "x[y]",
"presto": "x[y]",
"hive": "x[y]",
@@ -1000,6 +1043,7 @@ class TestDialect(Validator):
'true or null as "foo"',
write={
"bigquery": "TRUE OR NULL AS `foo`",
+ "drill": "TRUE OR NULL AS `foo`",
"duckdb": 'TRUE OR NULL AS "foo"',
"presto": 'TRUE OR NULL AS "foo"',
"hive": "TRUE OR NULL AS `foo`",
@@ -1020,6 +1064,7 @@ class TestDialect(Validator):
"LEVENSHTEIN(col1, col2)",
write={
"duckdb": "LEVENSHTEIN(col1, col2)",
+ "drill": "LEVENSHTEIN_DISTANCE(col1, col2)",
"presto": "LEVENSHTEIN_DISTANCE(col1, col2)",
"hive": "LEVENSHTEIN(col1, col2)",
"spark": "LEVENSHTEIN(col1, col2)",
@@ -1029,6 +1074,7 @@ class TestDialect(Validator):
"LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))",
write={
"duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
+ "drill": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
"presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))",
"hive": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
"spark": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))",
@@ -1152,6 +1198,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT a AS b FROM x GROUP BY b",
write={
+ "drill": "SELECT a AS b FROM x GROUP BY b",
"duckdb": "SELECT a AS b FROM x GROUP BY b",
"presto": "SELECT a AS b FROM x GROUP BY 1",
"hive": "SELECT a AS b FROM x GROUP BY 1",
@@ -1162,6 +1209,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT y x FROM my_table t",
write={
+ "drill": "SELECT y AS x FROM my_table AS t",
"hive": "SELECT y AS x FROM my_table AS t",
"oracle": "SELECT y AS x FROM my_table t",
"postgres": "SELECT y AS x FROM my_table AS t",
@@ -1230,3 +1278,36 @@ SELECT
},
pretty=True,
)
+
+ def test_transactions(self):
+ self.validate_all(
+ "BEGIN TRANSACTION",
+ write={
+ "bigquery": "BEGIN TRANSACTION",
+ "mysql": "BEGIN",
+ "postgres": "BEGIN",
+ "presto": "START TRANSACTION",
+ "trino": "START TRANSACTION",
+ "redshift": "BEGIN",
+ "snowflake": "BEGIN",
+ "sqlite": "BEGIN TRANSACTION",
+ },
+ )
+ self.validate_all(
+ "BEGIN",
+ read={
+ "presto": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE",
+ "trino": "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE",
+ },
+ )
+ self.validate_all(
+ "BEGIN",
+ read={
+ "presto": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ",
+ "trino": "START TRANSACTION ISOLATION LEVEL REPEATABLE READ",
+ },
+ )
+ self.validate_all(
+ "BEGIN IMMEDIATE TRANSACTION",
+ write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"},
+ )
diff --git a/tests/dialects/test_drill.py b/tests/dialects/test_drill.py
new file mode 100644
index 0000000..9819daa
--- /dev/null
+++ b/tests/dialects/test_drill.py
@@ -0,0 +1,53 @@
+from tests.dialects.test_dialect import Validator
+
+
+class TestDrill(Validator):
+ dialect = "drill"
+
+ def test_string_literals(self):
+ self.validate_all(
+ "SELECT '2021-01-01' + INTERVAL 1 MONTH",
+ write={
+ "mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH",
+ },
+ )
+
+ def test_quotes(self):
+ self.validate_all(
+ "'\\''",
+ write={
+ "duckdb": "''''",
+ "presto": "''''",
+ "hive": "'\\''",
+ "spark": "'\\''",
+ },
+ )
+ self.validate_all(
+ "'\"x\"'",
+ write={
+ "duckdb": "'\"x\"'",
+ "presto": "'\"x\"'",
+ "hive": "'\"x\"'",
+ "spark": "'\"x\"'",
+ },
+ )
+ self.validate_all(
+ "'\\\\a'",
+ read={
+ "presto": "'\\a'",
+ },
+ write={
+ "duckdb": "'\\a'",
+ "presto": "'\\a'",
+ "hive": "'\\\\a'",
+ "spark": "'\\\\a'",
+ },
+ )
+
+ def test_table_function(self):
+ self.validate_all(
+ "SELECT * FROM table( dfs.`test_data.xlsx` (type => 'excel', sheetName => 'secondSheet'))",
+ write={
+ "drill": "SELECT * FROM table(dfs.`test_data.xlsx`(type => 'excel', sheetName => 'secondSheet'))",
+ },
+ )
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index 1ba118b..af98249 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -58,6 +58,16 @@ class TestMySQL(Validator):
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")
self.validate_identity("SET NAMES utf8 COLLATE utf8_unicode_ci")
self.validate_identity("SET autocommit = ON")
+ self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL SERIALIZABLE")
+ self.validate_identity("SET TRANSACTION READ ONLY")
+ self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE")
+ self.validate_identity("SELECT SCHEMA()")
+
+ def test_canonical_functions(self):
+ self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")
+ self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')")
+ self.validate_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')")
+ self.validate_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')")
def test_escape(self):
self.validate_all(
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index 098ad2b..8179cf7 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -178,6 +178,15 @@ class TestPresto(Validator):
},
)
self.validate_all(
+ "CREATE TABLE test STORED = 'PARQUET' AS SELECT 1",
+ write={
+ "duckdb": "CREATE TABLE test AS SELECT 1",
+ "presto": "CREATE TABLE test WITH (FORMAT='PARQUET') AS SELECT 1",
+ "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1",
+ "spark": "CREATE TABLE test USING PARQUET AS SELECT 1",
+ },
+ )
+ self.validate_all(
"CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1",
write={
"duckdb": "CREATE TABLE test AS SELECT 1",
@@ -427,3 +436,69 @@ class TestPresto(Validator):
"spark": UnsupportedError,
},
)
+ self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
+ self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
+
+ def test_encode_decode(self):
+ self.validate_all(
+ "TO_UTF8(x)",
+ write={
+ "spark": "ENCODE(x, 'utf-8')",
+ },
+ )
+ self.validate_all(
+ "FROM_UTF8(x)",
+ write={
+ "spark": "DECODE(x, 'utf-8')",
+ },
+ )
+ self.validate_all(
+ "ENCODE(x, 'utf-8')",
+ write={
+ "presto": "TO_UTF8(x)",
+ },
+ )
+ self.validate_all(
+ "DECODE(x, 'utf-8')",
+ write={
+ "presto": "FROM_UTF8(x)",
+ },
+ )
+ self.validate_all(
+ "ENCODE(x, 'invalid')",
+ write={
+ "presto": UnsupportedError,
+ },
+ )
+ self.validate_all(
+ "DECODE(x, 'invalid')",
+ write={
+ "presto": UnsupportedError,
+ },
+ )
+
+ def test_hex_unhex(self):
+ self.validate_all(
+ "TO_HEX(x)",
+ write={
+ "spark": "HEX(x)",
+ },
+ )
+ self.validate_all(
+ "FROM_HEX(x)",
+ write={
+ "spark": "UNHEX(x)",
+ },
+ )
+ self.validate_all(
+ "HEX(x)",
+ write={
+ "presto": "TO_HEX(x)",
+ },
+ )
+ self.validate_all(
+ "UNHEX(x)",
+ write={
+ "presto": "FROM_HEX(x)",
+ },
+ )
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index 1846b17..0e69f4e 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -169,6 +169,17 @@ class TestSnowflake(Validator):
"snowflake": "SELECT a FROM test AS unpivot",
},
)
+ self.validate_all(
+ "trim(date_column, 'UTC')",
+ write={
+ "snowflake": "TRIM(date_column, 'UTC')",
+ "postgres": "TRIM('UTC' FROM date_column)",
+ },
+ )
+ self.validate_all(
+ "trim(date_column)",
+ write={"snowflake": "TRIM(date_column)"},
+ )
def test_null_treatment(self):
self.validate_all(
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 836ab28..75bd25d 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -122,13 +122,6 @@ x AT TIME ZONE 'UTC'
CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo'
SET x = 1
SET -v
-ADD JAR s3://bucket
-ADD JARS s3://bucket, c
-ADD FILE s3://file
-ADD FILES s3://file, s3://a
-ADD ARCHIVE s3://file
-ADD ARCHIVES s3://file, s3://a
-BEGIN IMMEDIATE TRANSACTION
COMMIT
USE db
NOT 1
@@ -278,6 +271,7 @@ SELECT CEIL(a, b) FROM test
SELECT COUNT(a) FROM test
SELECT COUNT(1) FROM test
SELECT COUNT(*) FROM test
+SELECT COUNT() FROM test
SELECT COUNT(DISTINCT a) FROM test
SELECT EXP(a) FROM test
SELECT FLOOR(a) FROM test
@@ -372,6 +366,8 @@ WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2
WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2
+WITH sub_query AS (SELECT a FROM table) (SELECT a FROM sub_query)
+WITH sub_query AS (SELECT a FROM table) ((((SELECT a FROM sub_query))))
(SELECT 1) UNION (SELECT 2)
(SELECT 1) UNION SELECT 2
SELECT 1 UNION (SELECT 2)
@@ -463,6 +459,7 @@ CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECI
CREATE TABLE z (a INT(11) DEFAULT UUID())
CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id')
CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1)
+CREATE TABLE z (a INT(11) NOT NULL DEFAULT -1)
CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT)
CREATE TABLE z (a INT, PRIMARY KEY(a))
CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1
@@ -476,6 +473,9 @@ CREATE TABLE z AS ((WITH cte AS (SELECT 1) SELECT * FROM cte))
CREATE TABLE z (a INT UNIQUE)
CREATE TABLE z (a INT AUTO_INCREMENT)
CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT)
+CREATE TABLE z (a INT REFERENCES parent(b, c))
+CREATE TABLE z (a INT PRIMARY KEY, b INT REFERENCES foo(id))
+CREATE TABLE z (a INT, FOREIGN KEY (a) REFERENCES parent(b, c))
CREATE TEMPORARY FUNCTION f
CREATE TEMPORARY FUNCTION f AS 'g'
CREATE FUNCTION f
@@ -514,17 +514,23 @@ DELETE FROM x WHERE y > 1
DELETE FROM y
DELETE FROM event USING sales WHERE event.eventid = sales.eventid
DELETE FROM event USING sales, USING bla WHERE event.eventid = sales.eventid
+DELETE FROM event USING sales AS s WHERE event.eventid = s.eventid
+PREPARE statement
+EXECUTE statement
DROP TABLE a
DROP TABLE a.b
DROP TABLE IF EXISTS a
DROP TABLE IF EXISTS a.b
+DROP TABLE a CASCADE
DROP VIEW a
DROP VIEW a.b
DROP VIEW IF EXISTS a
DROP VIEW IF EXISTS a.b
SHOW TABLES
USE db
+BEGIN
ROLLBACK
+ROLLBACK TO b
EXPLAIN SELECT * FROM x
INSERT INTO x SELECT * FROM y
INSERT INTO x (SELECT * FROM y)
@@ -581,3 +587,4 @@ SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
SELECT x FROM a.b.c /* x */, e.f.g /* x */
SELECT FOO(x /* c */) /* FOO */, b /* b */
SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
+SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b'
diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql
new file mode 100644
index 0000000..7fcdbb8
--- /dev/null
+++ b/tests/fixtures/optimizer/canonicalize.sql
@@ -0,0 +1,5 @@
+SELECT w.d + w.e AS c FROM w AS w;
+SELECT CONCAT(w.d, w.e) AS c FROM w AS w;
+
+SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w;
+SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w;
diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql
index eb7e9cb..a1e531b 100644
--- a/tests/fixtures/optimizer/optimizer.sql
+++ b/tests/fixtures/optimizer/optimizer.sql
@@ -119,7 +119,7 @@ GROUP BY
LIMIT 1;
# title: Root subquery is union
-(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1;
+(SELECT b FROM x UNION SELECT b FROM y ORDER BY b) LIMIT 1;
(
SELECT
"x"."b" AS "b"
@@ -128,6 +128,8 @@ LIMIT 1;
SELECT
"y"."b" AS "b"
FROM "y" AS "y"
+ ORDER BY
+ "b"
)
LIMIT 1;
diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql
index b91205c..8138b11 100644
--- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql
+++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql
@@ -15,7 +15,7 @@ select
from
lineitem
where
- CAST(l_shipdate AS DATE) <= date '1998-12-01' - interval '90' day
+ l_shipdate <= date '1998-12-01' - interval '90' day
group by
l_returnflag,
l_linestatus
@@ -250,8 +250,8 @@ FROM "orders" AS "orders"
LEFT JOIN "_u_0" AS "_u_0"
ON "_u_0"."l_orderkey" = "orders"."o_orderkey"
WHERE
- "orders"."o_orderdate" < CAST('1993-10-01' AS DATE)
- AND "orders"."o_orderdate" >= CAST('1993-07-01' AS DATE)
+ CAST("orders"."o_orderdate" AS DATE) < CAST('1993-10-01' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-07-01' AS DATE)
AND NOT "_u_0"."l_orderkey" IS NULL
GROUP BY
"orders"."o_orderpriority"
@@ -293,8 +293,8 @@ SELECT
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
- AND "orders"."o_orderdate" < CAST('1995-01-01' AS DATE)
- AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) < CAST('1995-01-01' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1994-01-01' AS DATE)
JOIN "region" AS "region"
ON "region"."r_name" = 'ASIA'
JOIN "nation" AS "nation"
@@ -328,8 +328,8 @@ FROM "lineitem" AS "lineitem"
WHERE
"lineitem"."l_discount" BETWEEN 0.05 AND 0.07
AND "lineitem"."l_quantity" < 24
- AND "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE)
- AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE);
+ AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
+ AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE);
--------------------------------------
-- TPC-H 7
@@ -384,13 +384,13 @@ WITH "n1" AS (
SELECT
"n1"."n_name" AS "supp_nation",
"n2"."n_name" AS "cust_nation",
- EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year",
+ EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year",
SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
)) AS "revenue"
FROM "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem"
- ON "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
+ ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
JOIN "orders" AS "orders"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
@@ -409,7 +409,7 @@ JOIN "n1" AS "n2"
GROUP BY
"n1"."n_name",
"n2"."n_name",
- EXTRACT(year FROM "lineitem"."l_shipdate")
+ EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME))
ORDER BY
"supp_nation",
"cust_nation",
@@ -456,7 +456,7 @@ group by
order by
o_year;
SELECT
- EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
SUM(
CASE
WHEN "nation_2"."n_name" = 'BRAZIL'
@@ -477,7 +477,7 @@ JOIN "customer" AS "customer"
ON "customer"."c_nationkey" = "nation"."n_nationkey"
JOIN "orders" AS "orders"
ON "orders"."o_custkey" = "customer"."c_custkey"
- AND "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "part"."p_partkey" = "lineitem"."l_partkey"
@@ -488,7 +488,7 @@ JOIN "nation" AS "nation_2"
WHERE
"part"."p_type" = 'ECONOMY ANODIZED STEEL'
GROUP BY
- EXTRACT(year FROM "orders"."o_orderdate")
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
ORDER BY
"o_year";
@@ -529,7 +529,7 @@ order by
o_year desc;
SELECT
"nation"."n_name" AS "nation",
- EXTRACT(year FROM "orders"."o_orderdate") AS "o_year",
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
SUM(
"lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
@@ -551,7 +551,7 @@ WHERE
"part"."p_name" LIKE '%green%'
GROUP BY
"nation"."n_name",
- EXTRACT(year FROM "orders"."o_orderdate")
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
ORDER BY
"nation",
"o_year" DESC;
@@ -606,8 +606,8 @@ SELECT
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
- AND "orders"."o_orderdate" < CAST('1994-01-01' AS DATE)
- AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) < CAST('1994-01-01' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1993-10-01' AS DATE)
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey" AND "lineitem"."l_returnflag" = 'R'
JOIN "nation" AS "nation"
@@ -740,8 +740,8 @@ SELECT
FROM "orders" AS "orders"
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_commitdate" < "lineitem"."l_receiptdate"
- AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE)
- AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE)
+ AND CAST("lineitem"."l_receiptdate" AS DATE) < CAST('1995-01-01' AS DATE)
+ AND CAST("lineitem"."l_receiptdate" AS DATE) >= CAST('1994-01-01' AS DATE)
AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate"
AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP')
AND "orders"."o_orderkey" = "lineitem"."l_orderkey"
@@ -832,8 +832,8 @@ FROM "lineitem" AS "lineitem"
JOIN "part" AS "part"
ON "lineitem"."l_partkey" = "part"."p_partkey"
WHERE
- "lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE)
- AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE);
+ CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-10-01' AS DATE)
+ AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-09-01' AS DATE);
--------------------------------------
-- TPC-H 15
@@ -876,8 +876,8 @@ WITH "revenue" AS (
)) AS "total_revenue"
FROM "lineitem" AS "lineitem"
WHERE
- "lineitem"."l_shipdate" < CAST('1996-04-01' AS DATE)
- AND "lineitem"."l_shipdate" >= CAST('1996-01-01' AS DATE)
+ CAST("lineitem"."l_shipdate" AS DATE) < CAST('1996-04-01' AS DATE)
+ AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1996-01-01' AS DATE)
GROUP BY
"lineitem"."l_suppkey"
)
@@ -1220,8 +1220,8 @@ WITH "_u_0" AS (
"lineitem"."l_suppkey" AS "_u_2"
FROM "lineitem" AS "lineitem"
WHERE
- "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE)
- AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE)
+ CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
+ AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE)
GROUP BY
"lineitem"."l_partkey",
"lineitem"."l_suppkey"
diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql
index 5e27b5e..067fe77 100644
--- a/tests/fixtures/pretty.sql
+++ b/tests/fixtures/pretty.sql
@@ -315,3 +315,10 @@ FROM (
WHERE
id = 1
) /* x */;
+SELECT * /* multi
+ line
+ comment */;
+SELECT
+ * /* multi
+ line
+ comment */;
diff --git a/tests/helpers.py b/tests/helpers.py
index dabaf1c..9abdaae 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -57,79 +57,79 @@ SKIP_INTEGRATION = string_to_bool(os.environ.get("SKIP_INTEGRATION", "0").lower(
TPCH_SCHEMA = {
"lineitem": {
- "l_orderkey": "uint64",
- "l_partkey": "uint64",
- "l_suppkey": "uint64",
- "l_linenumber": "uint64",
- "l_quantity": "float64",
- "l_extendedprice": "float64",
- "l_discount": "float64",
- "l_tax": "float64",
+ "l_orderkey": "bigint",
+ "l_partkey": "bigint",
+ "l_suppkey": "bigint",
+ "l_linenumber": "bigint",
+ "l_quantity": "double",
+ "l_extendedprice": "double",
+ "l_discount": "double",
+ "l_tax": "double",
"l_returnflag": "string",
"l_linestatus": "string",
- "l_shipdate": "date32",
- "l_commitdate": "date32",
- "l_receiptdate": "date32",
+ "l_shipdate": "string",
+ "l_commitdate": "string",
+ "l_receiptdate": "string",
"l_shipinstruct": "string",
"l_shipmode": "string",
"l_comment": "string",
},
"orders": {
- "o_orderkey": "uint64",
- "o_custkey": "uint64",
+ "o_orderkey": "bigint",
+ "o_custkey": "bigint",
"o_orderstatus": "string",
- "o_totalprice": "float64",
- "o_orderdate": "date32",
+ "o_totalprice": "double",
+ "o_orderdate": "string",
"o_orderpriority": "string",
"o_clerk": "string",
- "o_shippriority": "int32",
+ "o_shippriority": "int",
"o_comment": "string",
},
"customer": {
- "c_custkey": "uint64",
+ "c_custkey": "bigint",
"c_name": "string",
"c_address": "string",
- "c_nationkey": "uint64",
+ "c_nationkey": "bigint",
"c_phone": "string",
- "c_acctbal": "float64",
+ "c_acctbal": "double",
"c_mktsegment": "string",
"c_comment": "string",
},
"part": {
- "p_partkey": "uint64",
+ "p_partkey": "bigint",
"p_name": "string",
"p_mfgr": "string",
"p_brand": "string",
"p_type": "string",
- "p_size": "int32",
+ "p_size": "int",
"p_container": "string",
- "p_retailprice": "float64",
+ "p_retailprice": "double",
"p_comment": "string",
},
"supplier": {
- "s_suppkey": "uint64",
+ "s_suppkey": "bigint",
"s_name": "string",
"s_address": "string",
- "s_nationkey": "uint64",
+ "s_nationkey": "bigint",
"s_phone": "string",
- "s_acctbal": "float64",
+ "s_acctbal": "double",
"s_comment": "string",
},
"partsupp": {
- "ps_partkey": "uint64",
- "ps_suppkey": "uint64",
- "ps_availqty": "int32",
- "ps_supplycost": "float64",
+ "ps_partkey": "bigint",
+ "ps_suppkey": "bigint",
+ "ps_availqty": "int",
+ "ps_supplycost": "double",
"ps_comment": "string",
},
"nation": {
- "n_nationkey": "uint64",
+ "n_nationkey": "bigint",
"n_name": "string",
- "n_regionkey": "uint64",
+ "n_regionkey": "bigint",
"n_comment": "string",
},
"region": {
- "r_regionkey": "uint64",
+ "r_regionkey": "bigint",
"r_name": "string",
"r_comment": "string",
},
diff --git a/tests/test_executor.py b/tests/test_executor.py
index 49805b9..2c4d7cd 100644
--- a/tests/test_executor.py
+++ b/tests/test_executor.py
@@ -1,12 +1,15 @@
import unittest
+from datetime import date
import duckdb
import pandas as pd
from pandas.testing import assert_frame_equal
from sqlglot import exp, parse_one
+from sqlglot.errors import ExecuteError
from sqlglot.executor import execute
from sqlglot.executor.python import Python
+from sqlglot.executor.table import Table, ensure_tables
from tests.helpers import (
FIXTURES_DIR,
SKIP_INTEGRATION,
@@ -67,13 +70,399 @@ class TestExecutor(unittest.TestCase):
def to_csv(expression):
if isinstance(expression, exp.Table):
return parse_one(
- f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}"
+ f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.alias_or_name}"
)
return expression
- for sql, _ in self.sqls[0:3]:
- a = self.cached_execute(sql)
- sql = parse_one(sql).transform(to_csv).sql(pretty=True)
- table = execute(sql, TPCH_SCHEMA)
- b = pd.DataFrame(table.rows, columns=table.columns)
- assert_frame_equal(a, b, check_dtype=False)
+ for i, (sql, _) in enumerate(self.sqls[0:7]):
+ with self.subTest(f"tpch-h {i + 1}"):
+ a = self.cached_execute(sql)
+ sql = parse_one(sql).transform(to_csv).sql(pretty=True)
+ table = execute(sql, TPCH_SCHEMA)
+ b = pd.DataFrame(table.rows, columns=table.columns)
+ assert_frame_equal(a, b, check_dtype=False)
+
+ def test_execute_callable(self):
+ tables = {
+ "x": [
+ {"a": "a", "b": "d"},
+ {"a": "b", "b": "e"},
+ {"a": "c", "b": "f"},
+ ],
+ "y": [
+ {"b": "d", "c": "g"},
+ {"b": "e", "c": "h"},
+ {"b": "f", "c": "i"},
+ ],
+ "z": [],
+ }
+ schema = {
+ "x": {
+ "a": "VARCHAR",
+ "b": "VARCHAR",
+ },
+ "y": {
+ "b": "VARCHAR",
+ "c": "VARCHAR",
+ },
+ "z": {"d": "VARCHAR"},
+ }
+
+ for sql, cols, rows in [
+ ("SELECT * FROM x", ["a", "b"], [("a", "d"), ("b", "e"), ("c", "f")]),
+ (
+ "SELECT * FROM x JOIN y ON x.b = y.b",
+ ["a", "b", "b", "c"],
+ [("a", "d", "d", "g"), ("b", "e", "e", "h"), ("c", "f", "f", "i")],
+ ),
+ (
+ "SELECT j.c AS d FROM x AS i JOIN y AS j ON i.b = j.b",
+ ["d"],
+ [("g",), ("h",), ("i",)],
+ ),
+ (
+ "SELECT CONCAT(x.a, y.c) FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
+ ["_col_0"],
+ [("bh",)],
+ ),
+ (
+ "SELECT * FROM x JOIN y ON x.b = y.b WHERE y.b = 'e'",
+ ["a", "b", "b", "c"],
+ [("b", "e", "e", "h")],
+ ),
+ (
+ "SELECT * FROM z",
+ ["d"],
+ [],
+ ),
+ (
+ "SELECT d FROM z ORDER BY d",
+ ["d"],
+ [],
+ ),
+ (
+ "SELECT a FROM x WHERE x.a <> 'b'",
+ ["a"],
+ [("a",), ("c",)],
+ ),
+ (
+ "SELECT a AS i FROM x ORDER BY a",
+ ["i"],
+ [("a",), ("b",), ("c",)],
+ ),
+ (
+ "SELECT a AS i FROM x ORDER BY i",
+ ["i"],
+ [("a",), ("b",), ("c",)],
+ ),
+ (
+ "SELECT 100 - ORD(a) AS a, a AS i FROM x ORDER BY a",
+ ["a", "i"],
+ [(1, "c"), (2, "b"), (3, "a")],
+ ),
+ (
+ "SELECT a /* test */ FROM x LIMIT 1",
+ ["a"],
+ [("a",)],
+ ),
+ ]:
+ with self.subTest(sql):
+ result = execute(sql, schema=schema, tables=tables)
+ self.assertEqual(result.columns, tuple(cols))
+ self.assertEqual(result.rows, rows)
+
+ def test_set_operations(self):
+ tables = {
+ "x": [
+ {"a": "a"},
+ {"a": "b"},
+ {"a": "c"},
+ ],
+ "y": [
+ {"a": "b"},
+ {"a": "c"},
+ {"a": "d"},
+ ],
+ }
+ schema = {
+ "x": {
+ "a": "VARCHAR",
+ },
+ "y": {
+ "a": "VARCHAR",
+ },
+ }
+
+ for sql, cols, rows in [
+ (
+ "SELECT a FROM x UNION ALL SELECT a FROM y",
+ ["a"],
+ [("a",), ("b",), ("c",), ("b",), ("c",), ("d",)],
+ ),
+ (
+ "SELECT a FROM x UNION SELECT a FROM y",
+ ["a"],
+ [("a",), ("b",), ("c",), ("d",)],
+ ),
+ (
+ "SELECT a FROM x EXCEPT SELECT a FROM y",
+ ["a"],
+ [("a",)],
+ ),
+ (
+ "SELECT a FROM x INTERSECT SELECT a FROM y",
+ ["a"],
+ [("b",), ("c",)],
+ ),
+ (
+ """SELECT i.a
+ FROM (
+ SELECT a FROM x UNION SELECT a FROM y
+ ) AS i
+ JOIN (
+ SELECT a FROM x UNION SELECT a FROM y
+ ) AS j
+ ON i.a = j.a""",
+ ["a"],
+ [("a",), ("b",), ("c",), ("d",)],
+ ),
+ (
+ "SELECT 1 AS a UNION SELECT 2 AS a UNION SELECT 3 AS a",
+ ["a"],
+ [(1,), (2,), (3,)],
+ ),
+ ]:
+ with self.subTest(sql):
+ result = execute(sql, schema=schema, tables=tables)
+ self.assertEqual(result.columns, tuple(cols))
+ self.assertEqual(set(result.rows), set(rows))
+
+ def test_execute_catalog_db_table(self):
+ tables = {
+ "catalog": {
+ "db": {
+ "x": [
+ {"a": "a"},
+ {"a": "b"},
+ {"a": "c"},
+ ],
+ }
+ }
+ }
+ schema = {
+ "catalog": {
+ "db": {
+ "x": {
+ "a": "VARCHAR",
+ }
+ }
+ }
+ }
+ result1 = execute("SELECT * FROM x", schema=schema, tables=tables)
+ result2 = execute("SELECT * FROM catalog.db.x", schema=schema, tables=tables)
+ assert result1.columns == result2.columns
+ assert result1.rows == result2.rows
+
+ def test_execute_tables(self):
+ tables = {
+ "sushi": [
+ {"id": 1, "price": 1.0},
+ {"id": 2, "price": 2.0},
+ {"id": 3, "price": 3.0},
+ ],
+ "order_items": [
+ {"sushi_id": 1, "order_id": 1},
+ {"sushi_id": 1, "order_id": 1},
+ {"sushi_id": 2, "order_id": 1},
+ {"sushi_id": 3, "order_id": 2},
+ ],
+ "orders": [
+ {"id": 1, "user_id": 1},
+ {"id": 2, "user_id": 2},
+ ],
+ }
+
+ self.assertEqual(
+ execute(
+ """
+ SELECT
+ o.user_id,
+ SUM(s.price) AS price
+ FROM orders o
+ JOIN order_items i
+ ON o.id = i.order_id
+ JOIN sushi s
+ ON i.sushi_id = s.id
+ GROUP BY o.user_id
+ """,
+ tables=tables,
+ ).rows,
+ [
+ (1, 4.0),
+ (2, 3.0),
+ ],
+ )
+
+ self.assertEqual(
+ execute(
+ """
+ SELECT
+ o.id, x.*
+ FROM orders o
+ LEFT JOIN (
+ SELECT
+ 1 AS id, 'b' AS x
+ UNION ALL
+ SELECT
+ 3 AS id, 'c' AS x
+ ) x
+ ON o.id = x.id
+ """,
+ tables=tables,
+ ).rows,
+ [(1, 1, "b"), (2, None, None)],
+ )
+ self.assertEqual(
+ execute(
+ """
+ SELECT
+ o.id, x.*
+ FROM orders o
+ RIGHT JOIN (
+ SELECT
+ 1 AS id,
+ 'b' AS x
+ UNION ALL
+ SELECT
+ 3 AS id, 'c' AS x
+ ) x
+ ON o.id = x.id
+ """,
+ tables=tables,
+ ).rows,
+ [
+ (1, 1, "b"),
+ (None, 3, "c"),
+ ],
+ )
+
+ def test_table_depth_mismatch(self):
+ tables = {"table": []}
+ schema = {"db": {"table": {"col": "VARCHAR"}}}
+ with self.assertRaises(ExecuteError):
+ execute("SELECT * FROM table", schema=schema, tables=tables)
+
+ def test_tables(self):
+ tables = ensure_tables(
+ {
+ "catalog1": {
+ "db1": {
+ "t1": [
+ {"a": 1},
+ ],
+ "t2": [
+ {"a": 1},
+ ],
+ },
+ "db2": {
+ "t3": [
+ {"a": 1},
+ ],
+ "t4": [
+ {"a": 1},
+ ],
+ },
+ },
+ "catalog2": {
+ "db3": {
+ "t5": Table(columns=("a",), rows=[(1,)]),
+ "t6": Table(columns=("a",), rows=[(1,)]),
+ },
+ "db4": {
+ "t7": Table(columns=("a",), rows=[(1,)]),
+ "t8": Table(columns=("a",), rows=[(1,)]),
+ },
+ },
+ }
+ )
+
+ t1 = tables.find(exp.table_(table="t1", db="db1", catalog="catalog1"))
+ self.assertEqual(t1.columns, ("a",))
+ self.assertEqual(t1.rows, [(1,)])
+
+ t8 = tables.find(exp.table_(table="t8"))
+ self.assertEqual(t1.columns, t8.columns)
+ self.assertEqual(t1.rows, t8.rows)
+
+ def test_static_queries(self):
+ for sql, cols, rows in [
+ ("SELECT 1", ["_col_0"], [(1,)]),
+ ("SELECT 1 + 2 AS x", ["x"], [(3,)]),
+ ("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
+ ("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
+ ("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
+ ]:
+ result = execute(sql)
+ self.assertEqual(result.columns, tuple(cols))
+ self.assertEqual(result.rows, rows)
+
+ def test_aggregate_without_group_by(self):
+ result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]})
+ self.assertEqual(result.columns, ("_col_0",))
+ self.assertEqual(result.rows, [(3,)])
+
+ def test_scalar_functions(self):
+ for sql, expected in [
+ ("CONCAT('a', 'b')", "ab"),
+ ("CONCAT('a', NULL)", None),
+ ("CONCAT_WS('_', 'a', 'b')", "a_b"),
+ ("STR_POSITION('bar', 'foobarbar')", 4),
+ ("STR_POSITION('bar', 'foobarbar', 5)", 7),
+ ("STR_POSITION(NULL, 'foobarbar')", None),
+ ("STR_POSITION('bar', NULL)", None),
+ ("UPPER('foo')", "FOO"),
+ ("UPPER(NULL)", None),
+ ("LOWER('FOO')", "foo"),
+ ("LOWER(NULL)", None),
+ ("IFNULL('a', 'b')", "a"),
+ ("IFNULL(NULL, 'b')", "b"),
+ ("IFNULL(NULL, NULL)", None),
+ ("SUBSTRING('12345')", "12345"),
+ ("SUBSTRING('12345', 3)", "345"),
+ ("SUBSTRING('12345', 3, 0)", ""),
+ ("SUBSTRING('12345', 3, 1)", "3"),
+ ("SUBSTRING('12345', 3, 2)", "34"),
+ ("SUBSTRING('12345', 3, 3)", "345"),
+ ("SUBSTRING('12345', 3, 4)", "345"),
+ ("SUBSTRING('12345', -3)", "345"),
+ ("SUBSTRING('12345', -3, 0)", ""),
+ ("SUBSTRING('12345', -3, 1)", "3"),
+ ("SUBSTRING('12345', -3, 2)", "34"),
+ ("SUBSTRING('12345', 0)", ""),
+ ("SUBSTRING('12345', 0, 1)", ""),
+ ("SUBSTRING(NULL)", None),
+ ("SUBSTRING(NULL, 1)", None),
+ ("CAST(1 AS TEXT)", "1"),
+ ("CAST('1' AS LONG)", 1),
+ ("CAST('1.1' AS FLOAT)", 1.1),
+ ("COALESCE(NULL)", None),
+ ("COALESCE(NULL, NULL)", None),
+ ("COALESCE(NULL, 'b')", "b"),
+ ("COALESCE('a', 'b')", "a"),
+ ("1 << 1", 2),
+ ("1 >> 1", 0),
+ ("1 & 1", 1),
+ ("1 | 1", 1),
+ ("1 < 1", False),
+ ("1 <= 1", True),
+ ("1 > 1", False),
+ ("1 >= 1", True),
+ ("1 + NULL", None),
+ ("IF(true, 1, 0)", 1),
+ ("IF(false, 1, 0)", 0),
+ ("CASE WHEN 0 = 1 THEN 'foo' ELSE 'bar' END", "bar"),
+ ("CAST('2022-01-01' AS DATE) + INTERVAL '1' DAY", date(2022, 1, 2)),
+ ]:
+ with self.subTest(sql):
+ result = execute(f"SELECT {sql}")
+ self.assertEqual(result.rows, [(expected,)])
diff --git a/tests/test_expressions.py b/tests/test_expressions.py
index 63371d8..c0927ad 100644
--- a/tests/test_expressions.py
+++ b/tests/test_expressions.py
@@ -441,6 +441,9 @@ class TestExpressions(unittest.TestCase):
self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance)
self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop)
self.assertIsInstance(parse_one("YEAR(a)"), exp.Year)
+ self.assertIsInstance(parse_one("BEGIN DEFERRED TRANSACTION"), exp.Transaction)
+ self.assertIsInstance(parse_one("COMMIT"), exp.Commit)
+ self.assertIsInstance(parse_one("ROLLBACK"), exp.Rollback)
def test_column(self):
dot = parse_one("a.b.c")
@@ -479,9 +482,9 @@ class TestExpressions(unittest.TestCase):
self.assertEqual(column.text("expression"), "c")
self.assertEqual(column.text("y"), "")
self.assertEqual(parse_one("select * from x.y").find(exp.Table).text("db"), "x")
- self.assertEqual(parse_one("select *").text("this"), "")
- self.assertEqual(parse_one("1 + 1").text("this"), "1")
- self.assertEqual(parse_one("'a'").text("this"), "a")
+ self.assertEqual(parse_one("select *").name, "")
+ self.assertEqual(parse_one("1 + 1").name, "1")
+ self.assertEqual(parse_one("'a'").name, "a")
def test_alias(self):
self.assertEqual(alias("foo", "bar").sql(), "foo AS bar")
@@ -538,8 +541,8 @@ class TestExpressions(unittest.TestCase):
this=exp.Literal.string("TABLE_FORMAT"),
value=exp.to_identifier("test_format"),
),
- exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.NULL),
- exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.TRUE),
+ exp.EngineProperty(this=exp.Literal.string("ENGINE"), value=exp.null()),
+ exp.CollateProperty(this=exp.Literal.string("COLLATE"), value=exp.true()),
]
),
)
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index a1b7e70..6637a1d 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -29,6 +29,7 @@ class TestOptimizer(unittest.TestCase):
CREATE TABLE x (a INT, b INT);
CREATE TABLE y (b INT, c INT);
CREATE TABLE z (b INT, c INT);
+ CREATE TABLE w (d TEXT, e TEXT);
INSERT INTO x VALUES (1, 1);
INSERT INTO x VALUES (2, 2);
@@ -47,6 +48,8 @@ class TestOptimizer(unittest.TestCase):
INSERT INTO y VALUES (4, 4);
INSERT INTO y VALUES (5, 5);
INSERT INTO y VALUES (null, null);
+
+ INSERT INTO w VALUES ('a', 'b');
"""
)
@@ -64,6 +67,10 @@ class TestOptimizer(unittest.TestCase):
"b": "INT",
"c": "INT",
},
+ "w": {
+ "d": "TEXT",
+ "e": "TEXT",
+ },
}
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
@@ -224,6 +231,18 @@ class TestOptimizer(unittest.TestCase):
def test_eliminate_subqueries(self):
self.check_file("eliminate_subqueries", optimizer.eliminate_subqueries.eliminate_subqueries)
+ def test_canonicalize(self):
+ optimize = partial(
+ optimizer.optimize,
+ rules=[
+ optimizer.qualify_tables.qualify_tables,
+ optimizer.qualify_columns.qualify_columns,
+ annotate_types,
+ optimizer.canonicalize.canonicalize,
+ ],
+ )
+ self.check_file("canonicalize", optimize, schema=self.schema)
+
def test_tpch(self):
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 04c20b1..c747ea3 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -41,12 +41,41 @@ class TestParser(unittest.TestCase):
)
def test_command(self):
- expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1")
+ expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1", read="hive")
self.assertEqual(len(expressions), 3)
self.assertEqual(expressions[0].sql(), "SET x = 1")
self.assertEqual(expressions[1].sql(), "ADD JAR s3://a")
self.assertEqual(expressions[2].sql(), "SELECT 1")
+ def test_transactions(self):
+ expression = parse_one("BEGIN TRANSACTION")
+ self.assertIsNone(expression.this)
+ self.assertEqual(expression.args["modes"], [])
+ self.assertEqual(expression.sql(), "BEGIN")
+
+ expression = parse_one("START TRANSACTION", read="mysql")
+ self.assertIsNone(expression.this)
+ self.assertEqual(expression.args["modes"], [])
+ self.assertEqual(expression.sql(), "BEGIN")
+
+ expression = parse_one("BEGIN DEFERRED TRANSACTION")
+ self.assertEqual(expression.this, "DEFERRED")
+ self.assertEqual(expression.args["modes"], [])
+ self.assertEqual(expression.sql(), "BEGIN")
+
+ expression = parse_one(
+ "START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE", read="presto"
+ )
+ self.assertIsNone(expression.this)
+ self.assertEqual(expression.args["modes"][0], "READ WRITE")
+ self.assertEqual(expression.args["modes"][1], "ISOLATION LEVEL SERIALIZABLE")
+ self.assertEqual(expression.sql(), "BEGIN")
+
+ expression = parse_one("BEGIN", read="bigquery")
+ self.assertNotIsInstance(expression, exp.Transaction)
+ self.assertIsNone(expression.expression)
+ self.assertEqual(expression.sql(), "BEGIN")
+
def test_identify(self):
expression = parse_one(
"""
@@ -55,14 +84,14 @@ class TestParser(unittest.TestCase):
"""
)
- assert expression.expressions[0].text("this") == "a"
- assert expression.expressions[1].text("this") == "b"
- assert expression.expressions[2].text("alias") == "c"
- assert expression.expressions[3].text("alias") == "D"
- assert expression.expressions[4].text("alias") == "y|z'"
+ assert expression.expressions[0].name == "a"
+ assert expression.expressions[1].name == "b"
+ assert expression.expressions[2].alias == "c"
+ assert expression.expressions[3].alias == "D"
+ assert expression.expressions[4].alias == "y|z'"
table = expression.args["from"].expressions[0]
- assert table.args["this"].args["this"] == "z"
- assert table.args["db"].args["this"] == "y"
+ assert table.this.name == "z"
+ assert table.args["db"].name == "y"
def test_multi(self):
expressions = parse(
@@ -72,8 +101,8 @@ class TestParser(unittest.TestCase):
)
assert len(expressions) == 2
- assert expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a"
- assert expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b"
+ assert expressions[0].args["from"].expressions[0].this.name == "a"
+ assert expressions[1].args["from"].expressions[0].this.name == "b"
def test_expression(self):
ignore = Parser(error_level=ErrorLevel.IGNORE)
@@ -200,7 +229,7 @@ class TestParser(unittest.TestCase):
@patch("sqlglot.parser.logger")
def test_comment_error_n(self, logger):
parse_one(
- """CREATE TABLE x
+ """SUM
(
-- test
)""",
@@ -208,19 +237,19 @@ class TestParser(unittest.TestCase):
)
assert_logger_contains(
- "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 4, Col: 1.",
+ "Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 4, Col: 1.",
logger,
)
@patch("sqlglot.parser.logger")
def test_comment_error_r(self, logger):
parse_one(
- """CREATE TABLE x (-- test\r)""",
+ """SUM(-- test\r)""",
error_level=ErrorLevel.WARN,
)
assert_logger_contains(
- "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 2, Col: 1.",
+ "Required keyword: 'this' missing for <class 'sqlglot.expressions.Sum'>. Line 2, Col: 1.",
logger,
)
diff --git a/tests/test_tokens.py b/tests/test_tokens.py
index 943c2b0..d4772ba 100644
--- a/tests/test_tokens.py
+++ b/tests/test_tokens.py
@@ -12,6 +12,7 @@ class TestTokens(unittest.TestCase):
("--comment\nfoo --test", "comment"),
("foo --comment", "comment"),
("foo", None),
+ ("foo /*comment 1*/ /*comment 2*/", "comment 1"),
]
for sql, comment in sql_comment:
diff --git a/tests/test_transpile.py b/tests/test_transpile.py
index 942053e..1bd2527 100644
--- a/tests/test_transpile.py
+++ b/tests/test_transpile.py
@@ -20,6 +20,13 @@ class TestTranspile(unittest.TestCase):
self.assertEqual(transpile(sql, **kwargs)[0], target)
def test_alias(self):
+ self.assertEqual(transpile("SELECT 1 current_time")[0], "SELECT 1 AS current_time")
+ self.assertEqual(
+ transpile("SELECT 1 current_timestamp")[0], "SELECT 1 AS current_timestamp"
+ )
+ self.assertEqual(transpile("SELECT 1 current_date")[0], "SELECT 1 AS current_date")
+ self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime")
+
for key in ("union", "filter", "over", "from", "join"):
with self.subTest(f"alias {key}"):
self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}")
@@ -69,6 +76,10 @@ class TestTranspile(unittest.TestCase):
self.validate("SELECT 3>=3", "SELECT 3 >= 3")
def test_comments(self):
+ self.validate("SELECT */*comment*/", "SELECT * /* comment */")
+ self.validate(
+ "SELECT * FROM table /*comment 1*/ /*comment 2*/", "SELECT * FROM table /* comment 1 */"
+ )
self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo /* comment */")
self.validate("SELECT --+5\nx FROM foo", "/* +5 */ SELECT x FROM foo")
self.validate("SELECT --!5\nx FROM foo", "/* !5 */ SELECT x FROM foo")