summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py52
1 files changed, 41 insertions, 11 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 41a5015..81b9731 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -29,7 +29,11 @@ def parse_and_optimize(func, sql, read_dialect, **kwargs):
def qualify_columns(expression, **kwargs):
expression = optimizer.qualify.qualify(
- expression, infer_schema=True, validate_qualify_columns=False, identify=False, **kwargs
+ expression,
+ infer_schema=True,
+ validate_qualify_columns=False,
+ identify=False,
+ **kwargs,
)
return expression
@@ -111,7 +115,14 @@ class TestOptimizer(unittest.TestCase):
}
def check_file(
- self, file, func, pretty=False, execute=False, set_dialect=False, only=None, **kwargs
+ self,
+ file,
+ func,
+ pretty=False,
+ execute=False,
+ set_dialect=False,
+ only=None,
+ **kwargs,
):
with ProcessPoolExecutor() as pool:
results = {}
@@ -331,7 +342,11 @@ class TestOptimizer(unittest.TestCase):
)
self.check_file(
- "qualify_columns", qualify_columns, execute=True, schema=self.schema, set_dialect=True
+ "qualify_columns",
+ qualify_columns,
+ execute=True,
+ schema=self.schema,
+ set_dialect=True,
)
self.check_file(
"qualify_columns_ddl", qualify_columns, schema=self.schema, set_dialect=True
@@ -343,7 +358,8 @@ class TestOptimizer(unittest.TestCase):
def test_pushdown_cte_alias_columns(self):
self.check_file(
- "pushdown_cte_alias_columns", optimizer.qualify_columns.pushdown_cte_alias_columns
+ "pushdown_cte_alias_columns",
+ optimizer.qualify_columns.pushdown_cte_alias_columns,
)
def test_qualify_columns__invalid(self):
@@ -405,7 +421,8 @@ class TestOptimizer(unittest.TestCase):
self.assertEqual(optimizer.simplify.gen(query), optimizer.simplify.gen(query.copy()))
anon_unquoted_identifier = exp.Anonymous(
- this=exp.to_identifier("anonymous"), expressions=[exp.column("x"), exp.column("y")]
+ this=exp.to_identifier("anonymous"),
+ expressions=[exp.column("x"), exp.column("y")],
)
self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS(x,y)")
@@ -416,7 +433,10 @@ class TestOptimizer(unittest.TestCase):
anon_invalid = exp.Anonymous(this=5)
optimizer.simplify.gen(anon_invalid)
- self.assertIn("Anonymous.this expects a str or an Identifier, got 'int'.", str(e.exception))
+ self.assertIn(
+ "Anonymous.this expects a str or an Identifier, got 'int'.",
+ str(e.exception),
+ )
sql = parse_one(
"""
@@ -906,7 +926,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
# Check that x.cola AS cola and y.colb AS colb have types CHAR and TEXT, respectively
for d, t in zip(
- cte_select.find_all(exp.Subquery), [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT]
+ cte_select.find_all(exp.Subquery),
+ [exp.DataType.Type.CHAR, exp.DataType.Type.TEXT],
):
self.assertEqual(d.this.expressions[0].this.type.this, t)
@@ -1020,7 +1041,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
for (func, col), target_type in tests.items():
expression = annotate_types(
- parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema
+ parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"),
+ schema=schema,
)
self.assertEqual(expression.expressions[0].type.this, target_type)
@@ -1035,7 +1057,13 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(exp.DataType.Type.INT, expression.selects[1].type.this)
def test_nested_type_annotation(self):
- schema = {"order": {"customer_id": "bigint", "item_id": "bigint", "item_price": "numeric"}}
+ schema = {
+ "order": {
+ "customer_id": "bigint",
+ "item_id": "bigint",
+ "item_price": "numeric",
+ }
+ }
sql = """
SELECT ARRAY_AGG(DISTINCT order.item_id) FILTER (WHERE order.item_price > 10) AS items,
FROM order AS order
@@ -1057,7 +1085,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.selects[0].type.sql(dialect="bigquery"), "STRUCT<`f` STRING>")
self.assertEqual(
- expression.selects[1].type.sql(dialect="bigquery"), "ARRAY<STRUCT<`f` STRING>>"
+ expression.selects[1].type.sql(dialect="bigquery"),
+ "ARRAY<STRUCT<`f` STRING>>",
)
expression = annotate_types(
@@ -1206,7 +1235,8 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(
optimizer.optimize(
- parse_one("SELECT * FROM a"), schema=MappingSchema(schema, dialect="bigquery")
+ parse_one("SELECT * FROM a"),
+ schema=MappingSchema(schema, dialect="bigquery"),
),
parse_one('SELECT "a"."a" AS "a", "a"."b" AS "b" FROM "a" AS "a"'),
)