summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-12 15:42:38 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-12 15:42:38 +0000
commitbea2635be022e272ddac349f5e396ec901fc37e5 (patch)
tree24dbe11c9d462ff55f9b3af4b4da4cd1ae02e8a3 /tests
parentReleasing debian version 10.1.3-1. (diff)
downloadsqlglot-bea2635be022e272ddac349f5e396ec901fc37e5.tar.xz
sqlglot-bea2635be022e272ddac349f5e396ec901fc37e5.zip
Merging upstream version 10.2.6.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'tests')
-rw-r--r--tests/dataframe/unit/dataframe_sql_validator.py5
-rw-r--r--tests/dataframe/unit/test_dataframe_writer.py34
-rw-r--r--tests/dataframe/unit/test_session.py4
-rw-r--r--tests/dialects/test_bigquery.py19
-rw-r--r--tests/dialects/test_dialect.py36
-rw-r--r--tests/dialects/test_hive.py4
-rw-r--r--tests/dialects/test_postgres.py4
-rw-r--r--tests/dialects/test_redshift.py26
-rw-r--r--tests/dialects/test_snowflake.py9
-rw-r--r--tests/fixtures/identity.sql1
-rw-r--r--tests/fixtures/optimizer/canonicalize.sql6
-rw-r--r--tests/fixtures/optimizer/simplify.sql180
-rw-r--r--tests/fixtures/optimizer/tpc-h/tpc-h.sql51
-rw-r--r--tests/test_executor.py21
-rw-r--r--tests/test_optimizer.py155
-rw-r--r--tests/test_schema.py18
-rw-r--r--tests/test_tokens.py47
17 files changed, 497 insertions, 123 deletions
diff --git a/tests/dataframe/unit/dataframe_sql_validator.py b/tests/dataframe/unit/dataframe_sql_validator.py
index 32ff8f2..2dcdb39 100644
--- a/tests/dataframe/unit/dataframe_sql_validator.py
+++ b/tests/dataframe/unit/dataframe_sql_validator.py
@@ -4,6 +4,7 @@ import unittest
from sqlglot.dataframe.sql import types
from sqlglot.dataframe.sql.dataframe import DataFrame
from sqlglot.dataframe.sql.session import SparkSession
+from sqlglot.helper import ensure_list
class DataFrameSQLValidator(unittest.TestCase):
@@ -33,9 +34,7 @@ class DataFrameSQLValidator(unittest.TestCase):
self, df: DataFrame, expected_statements: t.Union[str, t.List[str]], pretty=False
):
actual_sqls = df.sql(pretty=pretty)
- expected_statements = (
- [expected_statements] if isinstance(expected_statements, str) else expected_statements
- )
+ expected_statements = ensure_list(expected_statements)
self.assertEqual(len(expected_statements), len(actual_sqls))
for expected, actual in zip(expected_statements, actual_sqls):
self.assertEqual(expected, actual)
diff --git a/tests/dataframe/unit/test_dataframe_writer.py b/tests/dataframe/unit/test_dataframe_writer.py
index 7c646f5..042b915 100644
--- a/tests/dataframe/unit/test_dataframe_writer.py
+++ b/tests/dataframe/unit/test_dataframe_writer.py
@@ -10,37 +10,37 @@ class TestDataFrameWriter(DataFrameSQLValidator):
def test_insertInto_full_path(self):
df = self.df_employee.write.insertInto("catalog.db.table_name")
- expected = "INSERT INTO catalog.db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO catalog.db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_db_table(self):
df = self.df_employee.write.insertInto("db.table_name")
- expected = "INSERT INTO db.table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO db.table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_table(self):
df = self.df_employee.write.insertInto("table_name")
- expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_overwrite(self):
df = self.df_employee.write.insertInto("table_name", overwrite=True)
- expected = "INSERT OVERWRITE TABLE table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT OVERWRITE TABLE table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
@mock.patch("sqlglot.schema", MappingSchema())
def test_insertInto_byName(self):
sqlglot.schema.add_table("table_name", {"employee_id": "INT"})
df = self.df_employee.write.byName.insertInto("table_name")
- expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_insertInto_cache(self):
df = self.df_employee.cache().write.insertInto("table_name")
expected_statements = [
- "DROP VIEW IF EXISTS t37164",
- "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "INSERT INTO table_name SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
+ "DROP VIEW IF EXISTS t12441",
+ "CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "INSERT INTO table_name SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`",
]
self.compare_sql(df, expected_statements)
@@ -50,39 +50,39 @@ class TestDataFrameWriter(DataFrameSQLValidator):
def test_saveAsTable_append(self):
df = self.df_employee.write.saveAsTable("table_name", mode="append")
- expected = "INSERT INTO table_name SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "INSERT INTO table_name SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_overwrite(self):
df = self.df_employee.write.saveAsTable("table_name", mode="overwrite")
- expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_error(self):
df = self.df_employee.write.saveAsTable("table_name", mode="error")
- expected = "CREATE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_ignore(self):
df = self.df_employee.write.saveAsTable("table_name", mode="ignore")
- expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_standalone(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name")
- expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE TABLE IF NOT EXISTS table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_mode_override(self):
df = self.df_employee.write.mode("ignore").saveAsTable("table_name", mode="overwrite")
- expected = "CREATE OR REPLACE TABLE table_name AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
+ expected = "CREATE OR REPLACE TABLE table_name AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
self.compare_sql(df, expected)
def test_saveAsTable_cache(self):
df = self.df_employee.cache().write.saveAsTable("table_name")
expected_statements = [
- "DROP VIEW IF EXISTS t37164",
- "CACHE LAZY TABLE t37164 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT CAST(`a1`.`employee_id` AS INT) AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, CAST(`a1`.`age` AS INT) AS `age`, CAST(`a1`.`store_id` AS INT) AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
- "CREATE TABLE table_name AS SELECT `t37164`.`employee_id` AS `employee_id`, `t37164`.`fname` AS `fname`, `t37164`.`lname` AS `lname`, `t37164`.`age` AS `age`, `t37164`.`store_id` AS `store_id` FROM `t37164` AS `t37164`",
+ "DROP VIEW IF EXISTS t12441",
+ "CACHE LAZY TABLE t12441 OPTIONS('storageLevel' = 'MEMORY_AND_DISK') AS SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)",
+ "CREATE TABLE table_name AS SELECT `t12441`.`employee_id` AS `employee_id`, `t12441`.`fname` AS `fname`, `t12441`.`lname` AS `lname`, `t12441`.`age` AS `age`, `t12441`.`store_id` AS `store_id` FROM `t12441` AS `t12441`",
]
self.compare_sql(df, expected_statements)
diff --git a/tests/dataframe/unit/test_session.py b/tests/dataframe/unit/test_session.py
index 55aa547..5213667 100644
--- a/tests/dataframe/unit/test_session.py
+++ b/tests/dataframe/unit/test_session.py
@@ -36,7 +36,7 @@ class TestDataframeSession(DataFrameSQLValidator):
def test_cdf_str_schema(self):
df = self.spark.createDataFrame([[1, "test"]], "cola: INT, colb: STRING")
- expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
+ expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_basic(self):
@@ -47,7 +47,7 @@ class TestDataframeSession(DataFrameSQLValidator):
]
)
df = self.spark.createDataFrame([[1, "test"]], schema)
- expected = "SELECT CAST(`a2`.`cola` AS INT) AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
+ expected = "SELECT `a2`.`cola` AS `cola`, CAST(`a2`.`colb` AS STRING) AS `colb` FROM VALUES (1, 'test') AS `a2`(`cola`, `colb`)"
self.compare_sql(df, expected)
def test_typed_schema_nested(self):
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index cc44311..1d60ec6 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -7,6 +7,11 @@ class TestBigQuery(Validator):
def test_bigquery(self):
self.validate_all(
+ "REGEXP_CONTAINS('foo', '.*')",
+ read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"},
+ write={"mysql": "REGEXP_LIKE('foo', '.*')"},
+ ),
+ self.validate_all(
'"""x"""',
write={
"bigquery": "'x'",
@@ -94,6 +99,20 @@ class TestBigQuery(Validator):
"spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)",
},
)
+ self.validate_all(
+ "SELECT ARRAY(SELECT x FROM UNNEST([0, 1]) AS x)",
+ write={"bigquery": "SELECT ARRAY(SELECT x FROM UNNEST([0, 1]) AS x)"},
+ )
+ self.validate_all(
+ "SELECT ARRAY(SELECT DISTINCT x FROM UNNEST(some_numbers) AS x) AS unique_numbers",
+ write={
+ "bigquery": "SELECT ARRAY(SELECT DISTINCT x FROM UNNEST(some_numbers) AS x) AS unique_numbers"
+ },
+ )
+ self.validate_all(
+ "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)",
+ write={"bigquery": "SELECT ARRAY(SELECT * FROM foo JOIN bla ON x = y)"},
+ )
self.validate_all(
"x IS unknown",
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 6033570..ee67bf1 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -1318,3 +1318,39 @@ SELECT
"BEGIN IMMEDIATE TRANSACTION",
write={"sqlite": "BEGIN IMMEDIATE TRANSACTION"},
)
+
+ def test_merge(self):
+ self.validate_all(
+ """
+ MERGE INTO target USING source ON target.id = source.id
+ WHEN NOT MATCHED THEN INSERT (id) values (source.id)
+ """,
+ write={
+ "bigquery": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)",
+ "snowflake": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)",
+ "spark": "MERGE INTO target USING source ON target.id = source.id WHEN NOT MATCHED THEN INSERT (id) VALUES (source.id)",
+ },
+ )
+ self.validate_all(
+ """
+ MERGE INTO target USING source ON target.id = source.id
+ WHEN MATCHED AND source.is_deleted = 1 THEN DELETE
+ WHEN MATCHED THEN UPDATE SET val = source.val
+ WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)
+ """,
+ write={
+ "bigquery": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)",
+ "snowflake": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)",
+ "spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED AND source.is_deleted = 1 THEN DELETE WHEN MATCHED THEN UPDATE SET val = source.val WHEN NOT MATCHED THEN INSERT (id, val) VALUES (source.id, source.val)",
+ },
+ )
+ self.validate_all(
+ """
+ MERGE INTO target USING source ON target.id = source.id
+ WHEN MATCHED THEN UPDATE *
+ WHEN NOT MATCHED THEN INSERT *
+ """,
+ write={
+ "spark": "MERGE INTO target USING source ON target.id = source.id WHEN MATCHED THEN UPDATE * WHEN NOT MATCHED THEN INSERT *",
+ },
+ )
diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py
index 22d7bce..5ac8714 100644
--- a/tests/dialects/test_hive.py
+++ b/tests/dialects/test_hive.py
@@ -145,6 +145,10 @@ class TestHive(Validator):
},
)
+ self.validate_identity(
+ """CREATE EXTERNAL TABLE x (y INT) ROW FORMAT SERDE 'serde' ROW FORMAT DELIMITED FIELDS TERMINATED BY '1' WITH SERDEPROPERTIES ('input.regex'='')""",
+ )
+
def test_lateral_view(self):
self.validate_all(
"SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b",
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index cd6117c..962b28b 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -256,3 +256,7 @@ class TestPostgres(Validator):
"SELECT $$Dianne's horse$$",
write={"postgres": "SELECT 'Dianne''s horse'"},
)
+ self.validate_all(
+ "UPDATE MYTABLE T1 SET T1.COL = 13",
+ write={"postgres": "UPDATE MYTABLE AS T1 SET T1.COL = 13"},
+ )
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index 1943ee3..3034df5 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -56,8 +56,27 @@ class TestRedshift(Validator):
"redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1',
},
)
+ self.validate_all(
+ "DECODE(x, a, b, c, d)",
+ write={
+ "": "MATCHES(x, a, b, c, d)",
+ "oracle": "DECODE(x, a, b, c, d)",
+ "snowflake": "DECODE(x, a, b, c, d)",
+ },
+ )
+ self.validate_all(
+ "NVL(a, b, c, d)",
+ write={
+ "redshift": "COALESCE(a, b, c, d)",
+ "mysql": "COALESCE(a, b, c, d)",
+ "postgres": "COALESCE(a, b, c, d)",
+ },
+ )
def test_identity(self):
+ self.validate_identity(
+ "SELECT DECODE(COL1, 'replace_this', 'with_this', 'replace_that', 'with_that')"
+ )
self.validate_identity("CAST('bla' AS SUPER)")
self.validate_identity("CREATE TABLE real1 (realcol REAL)")
self.validate_identity("CAST('foo' AS HLLSKETCH)")
@@ -70,9 +89,9 @@ class TestRedshift(Validator):
self.validate_identity(
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
)
- self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE AUTO")
+ self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL")
self.validate_identity(
- "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid)"
+ "CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO"
)
self.validate_identity(
"COPY customer FROM 's3://mybucket/customer' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
@@ -80,3 +99,6 @@ class TestRedshift(Validator):
self.validate_identity(
"UNLOAD ('select * from venue') TO 's3://mybucket/unload/' IAM_ROLE 'arn:aws:iam::0123456789012:role/MyRedshiftRole'"
)
+ self.validate_identity(
+ "CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)"
+ )
diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py
index baca269..bca5aaa 100644
--- a/tests/dialects/test_snowflake.py
+++ b/tests/dialects/test_snowflake.py
@@ -500,3 +500,12 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') f, LATERAL F
},
pretty=True,
)
+
+ def test_minus(self):
+ self.validate_all(
+ "SELECT 1 EXCEPT SELECT 1",
+ read={
+ "oracle": "SELECT 1 MINUS SELECT 1",
+ "snowflake": "SELECT 1 MINUS SELECT 1",
+ },
+ )
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index 06ab96d..e12b673 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -75,6 +75,7 @@ ARRAY(1, 2)
ARRAY_CONTAINS(x, 1)
EXTRACT(x FROM y)
EXTRACT(DATE FROM y)
+EXTRACT(WEEK(monday) FROM created_at)
CONCAT_WS('-', 'a', 'b')
CONCAT_WS('-', 'a', 'b', 'c')
POSEXPLODE("x") AS ("a", "b")
diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql
index 7fcdbb8..8880881 100644
--- a/tests/fixtures/optimizer/canonicalize.sql
+++ b/tests/fixtures/optimizer/canonicalize.sql
@@ -3,3 +3,9 @@ SELECT CONCAT(w.d, w.e) AS c FROM w AS w;
SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w;
SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w;
+
+SELECT CAST(1 AS VARCHAR) AS a FROM w AS w;
+SELECT CAST(1 AS VARCHAR) AS a FROM w AS w;
+
+SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w;
+SELECT 1 + 3.2 AS a FROM w AS w;
diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql
index d9c7779..cf4195d 100644
--- a/tests/fixtures/optimizer/simplify.sql
+++ b/tests/fixtures/optimizer/simplify.sql
@@ -79,14 +79,16 @@ NULL;
NULL = NULL;
NULL;
+-- Can't optimize this because different engines do different things
+-- mysql converts to 0 and 1 but tsql does true and false
NULL <=> NULL;
-TRUE;
+NULL IS NOT DISTINCT FROM NULL;
a IS NOT DISTINCT FROM a;
-TRUE;
+a IS NOT DISTINCT FROM a;
NULL IS DISTINCT FROM NULL;
-FALSE;
+NULL IS DISTINCT FROM NULL;
NOT (NOT TRUE);
TRUE;
@@ -239,10 +241,10 @@ TRUE;
FALSE;
((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3);
-TRUE;
+x = x;
((NOT FALSE) AND (x = x)) AND (FALSE OR 1 <> 2);
-TRUE;
+x = x;
(('a' = 'a') AND TRUE and NOT FALSE);
TRUE;
@@ -372,3 +374,171 @@ CAST('1998-12-01' AS DATE) - INTERVAL '90' foo;
date '1998-12-01' + interval '90' foo;
CAST('1998-12-01' AS DATE) + INTERVAL '90' foo;
+
+--------------------------------------
+-- Comparisons
+--------------------------------------
+x < 0 OR x > 1;
+x < 0 OR x > 1;
+
+x < 0 OR x > 0;
+x < 0 OR x > 0;
+
+x < 1 OR x > 0;
+x < 1 OR x > 0;
+
+x < 1 OR x >= 0;
+x < 1 OR x >= 0;
+
+x <= 1 OR x > 0;
+x <= 1 OR x > 0;
+
+x <= 1 OR x >= 0;
+x <= 1 OR x >= 0;
+
+x <= 1 AND x <= 0;
+x <= 0;
+
+x <= 1 AND x > 0;
+x <= 1 AND x > 0;
+
+x <= 1 OR x > 0;
+x <= 1 OR x > 0;
+
+x <= 0 OR x < 0;
+x <= 0;
+
+x >= 0 OR x > 0;
+x >= 0;
+
+x >= 0 OR x > 1;
+x >= 0;
+
+x <= 0 OR x >= 0;
+x <= 0 OR x >= 0;
+
+x <= 0 AND x >= 0;
+x <= 0 AND x >= 0;
+
+x < 1 AND x < 2;
+x < 1;
+
+x < 1 OR x < 2;
+x < 2;
+
+x < 2 AND x < 1;
+x < 1;
+
+x < 2 OR x < 1;
+x < 2;
+
+x < 1 AND x < 1;
+x < 1;
+
+x < 1 OR x < 1;
+x < 1;
+
+x <= 1 AND x < 1;
+x < 1;
+
+x <= 1 OR x < 1;
+x <= 1;
+
+x < 1 AND x <= 1;
+x < 1;
+
+x < 1 OR x <= 1;
+x <= 1;
+
+x > 1 AND x > 2;
+x > 2;
+
+x > 1 OR x > 2;
+x > 1;
+
+x > 2 AND x > 1;
+x > 2;
+
+x > 2 OR x > 1;
+x > 1;
+
+x > 1 AND x > 1;
+x > 1;
+
+x > 1 OR x > 1;
+x > 1;
+
+x >= 1 AND x > 1;
+x > 1;
+
+x >= 1 OR x > 1;
+x >= 1;
+
+x > 1 AND x >= 1;
+x > 1;
+
+x > 1 OR x >= 1;
+x >= 1;
+
+x > 1 AND x >= 2;
+x >= 2;
+
+x > 1 OR x >= 2;
+x > 1;
+
+x > 1 AND x >= 2 AND x > 3 AND x > 0;
+x > 3;
+
+(x > 1 AND x >= 2 AND x > 3 AND x > 0) OR x > 0;
+x > 0;
+
+x > 1 AND x < 2 AND x > 3;
+FALSE;
+
+x > 1 AND x < 1;
+FALSE;
+
+x < 2 AND x > 1;
+x < 2 AND x > 1;
+
+x = 1 AND x < 1;
+FALSE;
+
+x = 1 AND x < 1.1;
+x = 1;
+
+x = 1 AND x <= 1;
+x = 1;
+
+x = 1 AND x <= 0.9;
+FALSE;
+
+x = 1 AND x > 0.9;
+x = 1;
+
+x = 1 AND x > 1;
+FALSE;
+
+x = 1 AND x >= 1;
+x = 1;
+
+x = 1 AND x >= 2;
+FALSE;
+
+x = 1 AND x <> 2;
+x = 1;
+
+x <> 1 AND x = 1;
+FALSE;
+
+x BETWEEN 0 AND 5 AND x > 3;
+x <= 5 AND x > 3;
+
+x > 3 AND 5 > x AND x BETWEEN 0 AND 10;
+x < 5 AND x > 3;
+
+x > 3 AND 5 < x AND x BETWEEN 9 AND 10;
+x <= 10 AND x >= 9;
+
+1 < x AND 3 < x;
+x > 3;
diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql
index 4893743..9c1f138 100644
--- a/tests/fixtures/optimizer/tpc-h/tpc-h.sql
+++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql
@@ -190,7 +190,7 @@ SELECT
SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
)) AS "revenue",
- CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate",
+ "orders"."o_orderdate" AS "o_orderdate",
"orders"."o_shippriority" AS "o_shippriority"
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
@@ -326,7 +326,8 @@ SELECT
SUM("lineitem"."l_extendedprice" * "lineitem"."l_discount") AS "revenue"
FROM "lineitem" AS "lineitem"
WHERE
- "lineitem"."l_discount" BETWEEN 0.05 AND 0.07
+ "lineitem"."l_discount" <= 0.07
+ AND "lineitem"."l_discount" >= 0.05
AND "lineitem"."l_quantity" < 24
AND CAST("lineitem"."l_shipdate" AS DATE) < CAST('1995-01-01' AS DATE)
AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1994-01-01' AS DATE);
@@ -344,7 +345,7 @@ from
select
n1.n_name as supp_nation,
n2.n_name as cust_nation,
- extract(year from l_shipdate) as l_year,
+ extract(year from cast(l_shipdate as date)) as l_year,
l_extendedprice * (1 - l_discount) as volume
from
supplier,
@@ -384,13 +385,14 @@ WITH "n1" AS (
SELECT
"n1"."n_name" AS "supp_nation",
"n2"."n_name" AS "cust_nation",
- EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME)) AS "l_year",
+ EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE)) AS "l_year",
SUM("lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
)) AS "revenue"
FROM "supplier" AS "supplier"
JOIN "lineitem" AS "lineitem"
- ON CAST("lineitem"."l_shipdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
+ ON CAST("lineitem"."l_shipdate" AS DATE) <= CAST('1996-12-31' AS DATE)
+ AND CAST("lineitem"."l_shipdate" AS DATE) >= CAST('1995-01-01' AS DATE)
AND "supplier"."s_suppkey" = "lineitem"."l_suppkey"
JOIN "orders" AS "orders"
ON "orders"."o_orderkey" = "lineitem"."l_orderkey"
@@ -409,7 +411,7 @@ JOIN "n1" AS "n2"
GROUP BY
"n1"."n_name",
"n2"."n_name",
- EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATETIME))
+ EXTRACT(year FROM CAST("lineitem"."l_shipdate" AS DATE))
ORDER BY
"supp_nation",
"cust_nation",
@@ -427,7 +429,7 @@ select
from
(
select
- extract(year from o_orderdate) as o_year,
+ extract(year from cast(o_orderdate as date)) as o_year,
l_extendedprice * (1 - l_discount) as volume,
n2.n_name as nation
from
@@ -456,7 +458,7 @@ group by
order by
o_year;
SELECT
- EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year",
SUM(
CASE
WHEN "nation_2"."n_name" = 'BRAZIL'
@@ -477,7 +479,8 @@ JOIN "customer" AS "customer"
ON "customer"."c_nationkey" = "nation"."n_nationkey"
JOIN "orders" AS "orders"
ON "orders"."o_custkey" = "customer"."c_custkey"
- AND CAST("orders"."o_orderdate" AS DATE) BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) <= CAST('1996-12-31' AS DATE)
+ AND CAST("orders"."o_orderdate" AS DATE) >= CAST('1995-01-01' AS DATE)
JOIN "lineitem" AS "lineitem"
ON "lineitem"."l_orderkey" = "orders"."o_orderkey"
AND "part"."p_partkey" = "lineitem"."l_partkey"
@@ -488,7 +491,7 @@ JOIN "nation" AS "nation_2"
WHERE
"part"."p_type" = 'ECONOMY ANODIZED STEEL'
GROUP BY
- EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE))
ORDER BY
"o_year";
@@ -503,7 +506,7 @@ from
(
select
n_name as nation,
- extract(year from o_orderdate) as o_year,
+ extract(year from cast(o_orderdate as date)) as o_year,
l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount
from
part,
@@ -529,7 +532,7 @@ order by
o_year desc;
SELECT
"nation"."n_name" AS "nation",
- EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME)) AS "o_year",
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE)) AS "o_year",
SUM(
"lineitem"."l_extendedprice" * (
1 - "lineitem"."l_discount"
@@ -551,7 +554,7 @@ WHERE
"part"."p_name" LIKE '%green%'
GROUP BY
"nation"."n_name",
- EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATETIME))
+ EXTRACT(year FROM CAST("orders"."o_orderdate" AS DATE))
ORDER BY
"nation",
"o_year" DESC;
@@ -1016,7 +1019,7 @@ select
o_orderkey,
o_orderdate,
o_totalprice,
- sum(l_quantity)
+ sum(l_quantity) total_quantity
from
customer,
orders,
@@ -1060,7 +1063,7 @@ SELECT
"orders"."o_orderkey" AS "o_orderkey",
"orders"."o_orderdate" AS "o_orderdate",
"orders"."o_totalprice" AS "o_totalprice",
- SUM("lineitem"."l_quantity") AS "_col_5"
+ SUM("lineitem"."l_quantity") AS "total_quantity"
FROM "customer" AS "customer"
JOIN "orders" AS "orders"
ON "customer"."c_custkey" = "orders"."o_custkey"
@@ -1129,19 +1132,22 @@ JOIN "part" AS "part"
"part"."p_brand" = 'Brand#12'
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
AND "part"."p_partkey" = "lineitem"."l_partkey"
- AND "part"."p_size" BETWEEN 1 AND 5
+ AND "part"."p_size" <= 5
+ AND "part"."p_size" >= 1
)
OR (
"part"."p_brand" = 'Brand#23'
AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK')
AND "part"."p_partkey" = "lineitem"."l_partkey"
- AND "part"."p_size" BETWEEN 1 AND 10
+ AND "part"."p_size" <= 10
+ AND "part"."p_size" >= 1
)
OR (
"part"."p_brand" = 'Brand#34'
AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG')
AND "part"."p_partkey" = "lineitem"."l_partkey"
- AND "part"."p_size" BETWEEN 1 AND 15
+ AND "part"."p_size" <= 15
+ AND "part"."p_size" >= 1
)
WHERE
(
@@ -1152,7 +1158,8 @@ WHERE
AND "part"."p_brand" = 'Brand#12'
AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG')
AND "part"."p_partkey" = "lineitem"."l_partkey"
- AND "part"."p_size" BETWEEN 1 AND 5
+ AND "part"."p_size" <= 5
+ AND "part"."p_size" >= 1
)
OR (
"lineitem"."l_quantity" <= 20
@@ -1162,7 +1169,8 @@ WHERE
AND "part"."p_brand" = 'Brand#23'
AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK')
AND "part"."p_partkey" = "lineitem"."l_partkey"
- AND "part"."p_size" BETWEEN 1 AND 10
+ AND "part"."p_size" <= 10
+ AND "part"."p_size" >= 1
)
OR (
"lineitem"."l_quantity" <= 30
@@ -1172,7 +1180,8 @@ WHERE
AND "part"."p_brand" = 'Brand#34'
AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG')
AND "part"."p_partkey" = "lineitem"."l_partkey"
- AND "part"."p_size" BETWEEN 1 AND 15
+ AND "part"."p_size" <= 15
+ AND "part"."p_size" >= 1
);
--------------------------------------
diff --git a/tests/test_executor.py b/tests/test_executor.py
index 9d452e4..4fe6399 100644
--- a/tests/test_executor.py
+++ b/tests/test_executor.py
@@ -26,12 +26,12 @@ class TestExecutor(unittest.TestCase):
def setUpClass(cls):
cls.conn = duckdb.connect()
- for table in TPCH_SCHEMA:
+ for table, columns in TPCH_SCHEMA.items():
cls.conn.execute(
f"""
CREATE VIEW {table} AS
SELECT *
- FROM READ_CSV_AUTO('{DIR}{table}.csv.gz')
+ FROM READ_CSV('{DIR}{table}.csv.gz', delim='|', header=True, columns={columns})
"""
)
@@ -74,13 +74,13 @@ class TestExecutor(unittest.TestCase):
)
return expression
- for i, (sql, _) in enumerate(self.sqls[0:16]):
+ for i, (sql, _) in enumerate(self.sqls[0:18]):
with self.subTest(f"tpch-h {i + 1}"):
a = self.cached_execute(sql)
sql = parse_one(sql).transform(to_csv).sql(pretty=True)
table = execute(sql, TPCH_SCHEMA)
b = pd.DataFrame(table.rows, columns=table.columns)
- assert_frame_equal(a, b, check_dtype=False)
+ assert_frame_equal(a, b, check_dtype=False, check_index_type=False)
def test_execute_callable(self):
tables = {
@@ -456,11 +456,16 @@ class TestExecutor(unittest.TestCase):
("SELECT CONCAT('a', 'b') AS x", ["x"], [("ab",)]),
("SELECT 1 AS x, 2 AS y", ["x", "y"], [(1, 2)]),
("SELECT 'foo' LIMIT 1", ["_col_0"], [("foo",)]),
- ("SELECT SUM(x) FROM (SELECT 1 AS x WHERE FALSE)", ["_col_0"], [(0,)]),
+ (
+ "SELECT SUM(x), COUNT(x) FROM (SELECT 1 AS x WHERE FALSE)",
+ ["_col_0", "_col_1"],
+ [(None, 0)],
+ ),
]:
- result = execute(sql)
- self.assertEqual(result.columns, tuple(cols))
- self.assertEqual(result.rows, rows)
+ with self.subTest(sql):
+ result = execute(sql)
+ self.assertEqual(result.columns, tuple(cols))
+ self.assertEqual(result.rows, rows)
def test_aggregate_without_group_by(self):
result = execute("SELECT SUM(x) FROM t", tables={"t": [{"x": 1}, {"x": 2}]})
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index ecf581d..0c5f6cd 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -333,7 +333,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
for sql, target_type in tests.items():
expression = annotate_types(parse_one(sql))
- self.assertEqual(expression.find(exp.Literal).type, target_type)
+ self.assertEqual(expression.find(exp.Literal).type.this, target_type)
def test_boolean_type_annotation(self):
tests = {
@@ -343,31 +343,33 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
for sql, target_type in tests.items():
expression = annotate_types(parse_one(sql))
- self.assertEqual(expression.find(exp.Boolean).type, target_type)
+ self.assertEqual(expression.find(exp.Boolean).type.this, target_type)
def test_cast_type_annotation(self):
expression = annotate_types(parse_one("CAST('2020-01-01' AS TIMESTAMPTZ(9))"))
+ self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ)
+ self.assertEqual(expression.this.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(expression.args["to"].type.this, exp.DataType.Type.TIMESTAMPTZ)
+ self.assertEqual(expression.args["to"].expressions[0].type.this, exp.DataType.Type.INT)
- self.assertEqual(expression.type, exp.DataType.Type.TIMESTAMPTZ)
- self.assertEqual(expression.this.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(expression.args["to"].type, exp.DataType.Type.TIMESTAMPTZ)
- self.assertEqual(expression.args["to"].expressions[0].type, exp.DataType.Type.INT)
+ expression = annotate_types(parse_one("ARRAY(1)::ARRAY<INT>"))
+ self.assertEqual(expression.type, parse_one("ARRAY<INT>", into=exp.DataType))
def test_cache_annotation(self):
expression = annotate_types(
parse_one("CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1")
)
- self.assertEqual(expression.expression.expressions[0].type, exp.DataType.Type.INT)
+ self.assertEqual(expression.expression.expressions[0].type.this, exp.DataType.Type.INT)
def test_binary_annotation(self):
expression = annotate_types(parse_one("SELECT 0.0 + (2 + 3)")).expressions[0]
- self.assertEqual(expression.type, exp.DataType.Type.DOUBLE)
- self.assertEqual(expression.left.type, exp.DataType.Type.DOUBLE)
- self.assertEqual(expression.right.type, exp.DataType.Type.INT)
- self.assertEqual(expression.right.this.type, exp.DataType.Type.INT)
- self.assertEqual(expression.right.this.left.type, exp.DataType.Type.INT)
- self.assertEqual(expression.right.this.right.type, exp.DataType.Type.INT)
+ self.assertEqual(expression.type.this, exp.DataType.Type.DOUBLE)
+ self.assertEqual(expression.left.type.this, exp.DataType.Type.DOUBLE)
+ self.assertEqual(expression.right.type.this, exp.DataType.Type.INT)
+ self.assertEqual(expression.right.this.type.this, exp.DataType.Type.INT)
+ self.assertEqual(expression.right.this.left.type.this, exp.DataType.Type.INT)
+ self.assertEqual(expression.right.this.right.type.this, exp.DataType.Type.INT)
def test_derived_tables_column_annotation(self):
schema = {"x": {"cola": "INT"}, "y": {"cola": "FLOAT"}}
@@ -387,128 +389,169 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
"""
expression = annotate_types(parse_one(sql), schema=schema)
- self.assertEqual(expression.expressions[0].type, exp.DataType.Type.FLOAT) # a.cola AS cola
+ self.assertEqual(
+ expression.expressions[0].type.this, exp.DataType.Type.FLOAT
+ ) # a.cola AS cola
addition_alias = expression.args["from"].expressions[0].this.expressions[0]
- self.assertEqual(addition_alias.type, exp.DataType.Type.FLOAT) # x.cola + y.cola AS cola
+ self.assertEqual(
+ addition_alias.type.this, exp.DataType.Type.FLOAT
+ ) # x.cola + y.cola AS cola
addition = addition_alias.this
- self.assertEqual(addition.type, exp.DataType.Type.FLOAT)
- self.assertEqual(addition.this.type, exp.DataType.Type.INT)
- self.assertEqual(addition.expression.type, exp.DataType.Type.FLOAT)
+ self.assertEqual(addition.type.this, exp.DataType.Type.FLOAT)
+ self.assertEqual(addition.this.type.this, exp.DataType.Type.INT)
+ self.assertEqual(addition.expression.type.this, exp.DataType.Type.FLOAT)
def test_cte_column_annotation(self):
- schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT"}}
+ schema = {"x": {"cola": "CHAR"}, "y": {"colb": "TEXT", "colc": "BOOLEAN"}}
sql = """
WITH tbl AS (
- SELECT x.cola + 'bla' AS cola, y.colb AS colb
+ SELECT x.cola + 'bla' AS cola, y.colb AS colb, y.colc AS colc
FROM (
SELECT x.cola AS cola
FROM x AS x
) AS x
JOIN (
- SELECT y.colb AS colb
+ SELECT y.colb AS colb, y.colc AS colc
FROM y AS y
) AS y
)
SELECT tbl.cola + tbl.colb + 'foo' AS col
FROM tbl AS tbl
+ WHERE tbl.colc = True
"""
expression = annotate_types(parse_one(sql), schema=schema)
self.assertEqual(
- expression.expressions[0].type, exp.DataType.Type.TEXT
+ expression.expressions[0].type.this, exp.DataType.Type.TEXT
) # tbl.cola + tbl.colb + 'foo' AS col
outer_addition = expression.expressions[0].this # (tbl.cola + tbl.colb) + 'foo'
- self.assertEqual(outer_addition.type, exp.DataType.Type.TEXT)
- self.assertEqual(outer_addition.left.type, exp.DataType.Type.TEXT)
- self.assertEqual(outer_addition.right.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(outer_addition.type.this, exp.DataType.Type.TEXT)
+ self.assertEqual(outer_addition.left.type.this, exp.DataType.Type.TEXT)
+ self.assertEqual(outer_addition.right.type.this, exp.DataType.Type.VARCHAR)
inner_addition = expression.expressions[0].this.left # tbl.cola + tbl.colb
- self.assertEqual(inner_addition.left.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(inner_addition.right.type, exp.DataType.Type.TEXT)
+ self.assertEqual(inner_addition.left.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(inner_addition.right.type.this, exp.DataType.Type.TEXT)
+
+ # WHERE tbl.colc = True
+ self.assertEqual(expression.args["where"].this.type.this, exp.DataType.Type.BOOLEAN)
cte_select = expression.args["with"].expressions[0].this
self.assertEqual(
- cte_select.expressions[0].type, exp.DataType.Type.VARCHAR
+ cte_select.expressions[0].type.this, exp.DataType.Type.VARCHAR
) # x.cola + 'bla' AS cola
- self.assertEqual(cte_select.expressions[1].type, exp.DataType.Type.TEXT) # y.colb AS colb
+ self.assertEqual(
+ cte_select.expressions[1].type.this, exp.DataType.Type.TEXT
+ ) # y.colb AS colb
+ self.assertEqual(
+ cte_select.expressions[2].type.this, exp.DataType.Type.BOOLEAN
+ ) # y.colc AS colc
cte_select_addition = cte_select.expressions[0].this # x.cola + 'bla'
- self.assertEqual(cte_select_addition.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(cte_select_addition.left.type, exp.DataType.Type.CHAR)
- self.assertEqual(cte_select_addition.right.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(cte_select_addition.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(cte_select_addition.left.type.this, exp.DataType.Type.CHAR)
+ self.assertEqual(cte_select_addition.right.type.this, exp.DataType.Type.VARCHAR)
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
for d, t in zip(
cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
):
- self.assertEqual(d.this.expressions[0].this.type, t)
+ self.assertEqual(d.this.expressions[0].this.type.this, t)
def test_function_annotation(self):
schema = {"x": {"cola": "VARCHAR", "colb": "CHAR"}}
sql = "SELECT x.cola || TRIM(x.colb) AS col FROM x AS x"
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
- self.assertEqual(concat_expr_alias.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.VARCHAR)
concat_expr = concat_expr_alias.this
- self.assertEqual(concat_expr.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
- self.assertEqual(concat_expr.right.type, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
- self.assertEqual(concat_expr.right.this.type, exp.DataType.Type.CHAR) # x.colb
+ self.assertEqual(concat_expr.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola
+ self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.VARCHAR) # TRIM(x.colb)
+ self.assertEqual(concat_expr.right.this.type.this, exp.DataType.Type.CHAR) # x.colb
sql = "SELECT CASE WHEN 1=1 THEN x.cola ELSE x.colb END AS col FROM x AS x"
case_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
- self.assertEqual(case_expr_alias.type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_expr_alias.type.this, exp.DataType.Type.VARCHAR)
case_expr = case_expr_alias.this
- self.assertEqual(case_expr.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(case_expr.args["default"].type, exp.DataType.Type.CHAR)
+ self.assertEqual(case_expr.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_expr.args["default"].type.this, exp.DataType.Type.CHAR)
case_ifs_expr = case_expr.args["ifs"][0]
- self.assertEqual(case_ifs_expr.type, exp.DataType.Type.VARCHAR)
- self.assertEqual(case_ifs_expr.args["true"].type, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_ifs_expr.type.this, exp.DataType.Type.VARCHAR)
+ self.assertEqual(case_ifs_expr.args["true"].type.this, exp.DataType.Type.VARCHAR)
def test_unknown_annotation(self):
schema = {"x": {"cola": "VARCHAR"}}
sql = "SELECT x.cola || SOME_ANONYMOUS_FUNC(x.cola) AS col FROM x AS x"
concat_expr_alias = annotate_types(parse_one(sql), schema=schema).expressions[0]
- self.assertEqual(concat_expr_alias.type, exp.DataType.Type.UNKNOWN)
+ self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.UNKNOWN)
concat_expr = concat_expr_alias.this
- self.assertEqual(concat_expr.type, exp.DataType.Type.UNKNOWN)
- self.assertEqual(concat_expr.left.type, exp.DataType.Type.VARCHAR) # x.cola
+ self.assertEqual(concat_expr.type.this, exp.DataType.Type.UNKNOWN)
+ self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.VARCHAR) # x.cola
self.assertEqual(
- concat_expr.right.type, exp.DataType.Type.UNKNOWN
+ concat_expr.right.type.this, exp.DataType.Type.UNKNOWN
) # SOME_ANONYMOUS_FUNC(x.cola)
self.assertEqual(
- concat_expr.right.expressions[0].type, exp.DataType.Type.VARCHAR
+ concat_expr.right.expressions[0].type.this, exp.DataType.Type.VARCHAR
) # x.cola (arg)
def test_null_annotation(self):
expression = annotate_types(parse_one("SELECT NULL + 2 AS col")).expressions[0].this
- self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
- self.assertEqual(expression.right.type, exp.DataType.Type.INT)
+ self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
+ self.assertEqual(expression.right.type.this, exp.DataType.Type.INT)
# NULL <op> UNKNOWN should yield NULL
sql = "SELECT NULL || SOME_ANONYMOUS_FUNC() AS result"
concat_expr_alias = annotate_types(parse_one(sql)).expressions[0]
- self.assertEqual(concat_expr_alias.type, exp.DataType.Type.NULL)
+ self.assertEqual(concat_expr_alias.type.this, exp.DataType.Type.NULL)
concat_expr = concat_expr_alias.this
- self.assertEqual(concat_expr.type, exp.DataType.Type.NULL)
- self.assertEqual(concat_expr.left.type, exp.DataType.Type.NULL)
- self.assertEqual(concat_expr.right.type, exp.DataType.Type.UNKNOWN)
+ self.assertEqual(concat_expr.type.this, exp.DataType.Type.NULL)
+ self.assertEqual(concat_expr.left.type.this, exp.DataType.Type.NULL)
+ self.assertEqual(concat_expr.right.type.this, exp.DataType.Type.UNKNOWN)
def test_nullable_annotation(self):
nullable = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
expression = annotate_types(parse_one("NULL AND FALSE"))
self.assertEqual(expression.type, nullable)
- self.assertEqual(expression.left.type, exp.DataType.Type.NULL)
- self.assertEqual(expression.right.type, exp.DataType.Type.BOOLEAN)
+ self.assertEqual(expression.left.type.this, exp.DataType.Type.NULL)
+ self.assertEqual(expression.right.type.this, exp.DataType.Type.BOOLEAN)
+
+ def test_predicate_annotation(self):
+ expression = annotate_types(parse_one("x BETWEEN a AND b"))
+ self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN)
+
+ expression = annotate_types(parse_one("x IN (a, b, c, d)"))
+ self.assertEqual(expression.type.this, exp.DataType.Type.BOOLEAN)
+
+ def test_aggfunc_annotation(self):
+ schema = {"x": {"cola": "SMALLINT", "colb": "FLOAT", "colc": "TEXT", "cold": "DATE"}}
+
+ tests = {
+ ("AVG", "cola"): exp.DataType.Type.DOUBLE,
+ ("SUM", "cola"): exp.DataType.Type.BIGINT,
+ ("SUM", "colb"): exp.DataType.Type.DOUBLE,
+ ("MIN", "cola"): exp.DataType.Type.SMALLINT,
+ ("MIN", "colb"): exp.DataType.Type.FLOAT,
+ ("MAX", "colc"): exp.DataType.Type.TEXT,
+ ("MAX", "cold"): exp.DataType.Type.DATE,
+ ("COUNT", "colb"): exp.DataType.Type.BIGINT,
+ ("STDDEV", "cola"): exp.DataType.Type.DOUBLE,
+ }
+
+ for (func, col), target_type in tests.items():
+ expression = annotate_types(
+ parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema
+ )
+ self.assertEqual(expression.expressions[0].type.this, target_type)
diff --git a/tests/test_schema.py b/tests/test_schema.py
index cc0e3d1..f1e12a2 100644
--- a/tests/test_schema.py
+++ b/tests/test_schema.py
@@ -151,31 +151,33 @@ class TestSchema(unittest.TestCase):
def test_schema_get_column_type(self):
schema = MappingSchema({"a": {"b": "varchar"}})
- self.assertEqual(schema.get_column_type("a", "b"), exp.DataType.Type.VARCHAR)
+ self.assertEqual(schema.get_column_type("a", "b").this, exp.DataType.Type.VARCHAR)
self.assertEqual(
- schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")),
+ schema.get_column_type(exp.Table(this="a"), exp.Column(this="b")).this,
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
- schema.get_column_type("a", exp.Column(this="b")), exp.DataType.Type.VARCHAR
+ schema.get_column_type("a", exp.Column(this="b")).this, exp.DataType.Type.VARCHAR
)
self.assertEqual(
- schema.get_column_type(exp.Table(this="a"), "b"), exp.DataType.Type.VARCHAR
+ schema.get_column_type(exp.Table(this="a"), "b").this, exp.DataType.Type.VARCHAR
)
schema = MappingSchema({"a": {"b": {"c": "varchar"}}})
self.assertEqual(
- schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")),
+ schema.get_column_type(exp.Table(this="b", db="a"), exp.Column(this="c")).this,
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
- schema.get_column_type(exp.Table(this="b", db="a"), "c"), exp.DataType.Type.VARCHAR
+ schema.get_column_type(exp.Table(this="b", db="a"), "c").this, exp.DataType.Type.VARCHAR
)
schema = MappingSchema({"a": {"b": {"c": {"d": "varchar"}}}})
self.assertEqual(
- schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")),
+ schema.get_column_type(
+ exp.Table(this="c", db="b", catalog="a"), exp.Column(this="d")
+ ).this,
exp.DataType.Type.VARCHAR,
)
self.assertEqual(
- schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d"),
+ schema.get_column_type(exp.Table(this="c", db="b", catalog="a"), "d").this,
exp.DataType.Type.VARCHAR,
)
diff --git a/tests/test_tokens.py b/tests/test_tokens.py
index 1d1b966..1376849 100644
--- a/tests/test_tokens.py
+++ b/tests/test_tokens.py
@@ -1,6 +1,6 @@
import unittest
-from sqlglot.tokens import Tokenizer
+from sqlglot.tokens import Tokenizer, TokenType
class TestTokens(unittest.TestCase):
@@ -17,3 +17,48 @@ class TestTokens(unittest.TestCase):
for sql, comment in sql_comment:
self.assertEqual(tokenizer.tokenize(sql)[0].comments, comment)
+
+ def test_jinja(self):
+ tokenizer = Tokenizer()
+
+ tokens = tokenizer.tokenize(
+ """
+ SELECT
+ {{ x }},
+ {{- x -}},
+ {% for x in y -%}
+ a {{+ b }}
+ {% endfor %};
+ """
+ )
+
+ tokens = [(token.token_type, token.text) for token in tokens]
+
+ self.assertEqual(
+ tokens,
+ [
+ (TokenType.SELECT, "SELECT"),
+ (TokenType.BLOCK_START, "{{"),
+ (TokenType.VAR, "x"),
+ (TokenType.BLOCK_END, "}}"),
+ (TokenType.COMMA, ","),
+ (TokenType.BLOCK_START, "{{-"),
+ (TokenType.VAR, "x"),
+ (TokenType.BLOCK_END, "-}}"),
+ (TokenType.COMMA, ","),
+ (TokenType.BLOCK_START, "{%"),
+ (TokenType.FOR, "for"),
+ (TokenType.VAR, "x"),
+ (TokenType.IN, "in"),
+ (TokenType.VAR, "y"),
+ (TokenType.BLOCK_END, "-%}"),
+ (TokenType.VAR, "a"),
+ (TokenType.BLOCK_START, "{{+"),
+ (TokenType.VAR, "b"),
+ (TokenType.BLOCK_END, "}}"),
+ (TokenType.BLOCK_START, "{%"),
+ (TokenType.VAR, "endfor"),
+ (TokenType.BLOCK_END, "%}"),
+ (TokenType.SEMICOLON, ";"),
+ ],
+ )