summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py105
1 files changed, 78 insertions, 27 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 423cb84..2ae6da9 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -20,19 +20,20 @@ from tests.helpers import (
)
-def parse_and_optimize(func, sql, dialect, **kwargs):
- return func(parse_one(sql, read=dialect), **kwargs)
+def parse_and_optimize(func, sql, read_dialect, **kwargs):
+ return func(parse_one(sql, read=read_dialect), **kwargs)
def qualify_columns(expression, **kwargs):
- expression = optimizer.qualify_tables.qualify_tables(expression)
- expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
+ expression = optimizer.qualify.qualify(
+ expression, infer_schema=True, validate_qualify_columns=False, identify=False, **kwargs
+ )
return expression
def pushdown_projections(expression, **kwargs):
expression = optimizer.qualify_tables.qualify_tables(expression)
- expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
+ expression = optimizer.qualify_columns.qualify_columns(expression, infer_schema=True, **kwargs)
expression = optimizer.pushdown_projections.pushdown_projections(expression, **kwargs)
return expression
@@ -98,7 +99,7 @@ class TestOptimizer(unittest.TestCase):
},
}
- def check_file(self, file, func, pretty=False, execute=False, **kwargs):
+ def check_file(self, file, func, pretty=False, execute=False, set_dialect=False, **kwargs):
with ProcessPoolExecutor() as pool:
results = {}
@@ -113,6 +114,9 @@ class TestOptimizer(unittest.TestCase):
if leave_tables_isolated is not None:
func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)
+ if set_dialect and dialect:
+ func_kwargs["dialect"] = dialect
+
future = pool.submit(parse_and_optimize, func, sql, dialect, **func_kwargs)
results[future] = (
sql,
@@ -141,13 +145,24 @@ class TestOptimizer(unittest.TestCase):
assert_frame_equal(df1, df2)
def test_optimize(self):
+ self.assertEqual(optimizer.optimize("x = 1 + 1", identify=None).sql(), "x = 2")
+
schema = {
"x": {"a": "INT", "b": "INT"},
"y": {"b": "INT", "c": "INT"},
"z": {"a": "INT", "c": "INT"},
+ "u": {"f": "INT", "g": "INT", "h": "TEXT"},
}
- self.check_file("optimizer", optimizer.optimize, pretty=True, execute=True, schema=schema)
+ self.check_file(
+ "optimizer",
+ optimizer.optimize,
+ infer_schema=True,
+ pretty=True,
+ execute=True,
+ schema=schema,
+ set_dialect=True,
+ )
def test_isolate_table_selects(self):
self.check_file(
@@ -183,6 +198,15 @@ class TestOptimizer(unittest.TestCase):
self.check_file("normalize", normalize)
def test_qualify_columns(self):
+ self.assertEqual(
+ optimizer.qualify_columns.qualify_columns(
+ parse_one("select y from x"),
+ schema={},
+ infer_schema=False,
+ ).sql(),
+ "SELECT y AS y FROM x",
+ )
+
self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema)
def test_qualify_columns__with_invisible(self):
@@ -198,8 +222,12 @@ class TestOptimizer(unittest.TestCase):
)
optimizer.qualify_columns.validate_qualify_columns(expression)
- def test_lower_identities(self):
- self.check_file("lower_identities", optimizer.lower_identities.lower_identities)
+ def test_normalize_identifiers(self):
+ self.check_file(
+ "normalize_identifiers",
+ optimizer.normalize_identifiers.normalize_identifiers,
+ set_dialect=True,
+ )
def test_pushdown_projection(self):
self.check_file("pushdown_projections", pushdown_projections, schema=self.schema)
@@ -221,24 +249,20 @@ class TestOptimizer(unittest.TestCase):
def test_pushdown_predicates(self):
self.check_file("pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates)
- def test_expand_laterals(self):
+ def test_expand_alias_refs(self):
# check order of lateral expansion with no schema
self.assertEqual(
- optimizer.optimize("SELECT a + 1 AS d, d + 1 AS e FROM x " "").sql(),
- 'SELECT "x"."a" + 1 AS "d", "x"."a" + 2 AS "e" FROM "x" AS "x"',
- )
-
- self.check_file(
- "expand_laterals",
- optimizer.expand_laterals.expand_laterals,
- pretty=True,
- execute=True,
+ optimizer.optimize("SELECT a + 1 AS d, d + 1 AS e FROM x WHERE e > 1 GROUP BY e").sql(),
+ 'SELECT "x"."a" + 1 AS "d", "x"."a" + 2 AS "e" FROM "x" AS "x" WHERE "x"."a" + 2 > 1 GROUP BY "x"."a" + 2',
)
- def test_expand_multi_table_selects(self):
- self.check_file(
- "expand_multi_table_selects",
- optimizer.expand_multi_table_selects.expand_multi_table_selects,
+ self.assertEqual(
+ optimizer.qualify_columns.qualify_columns(
+ parse_one("SELECT CAST(x AS INT) AS y FROM z AS z"),
+ schema={"l": {"c": "int"}},
+ infer_schema=False,
+ ).sql(),
+ "SELECT CAST(x AS INT) AS y FROM z AS z",
)
def test_optimize_joins(self):
@@ -280,8 +304,8 @@ class TestOptimizer(unittest.TestCase):
optimize = partial(
optimizer.optimize,
rules=[
- optimizer.qualify_tables.qualify_tables,
- optimizer.qualify_columns.qualify_columns,
+ optimizer.qualify.qualify,
+ optimizer.qualify_columns.quote_identifiers,
annotate_types,
optimizer.canonicalize.canonicalize,
],
@@ -396,7 +420,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
self.assertEqual(expression.type.this, exp.DataType.Type.TIMESTAMPTZ)
self.assertEqual(expression.this.type.this, exp.DataType.Type.VARCHAR)
self.assertEqual(expression.args["to"].type.this, exp.DataType.Type.TIMESTAMPTZ)
- self.assertEqual(expression.args["to"].expressions[0].type.this, exp.DataType.Type.INT)
+ self.assertEqual(expression.args["to"].expressions[0].this.type.this, exp.DataType.Type.INT)
expression = annotate_types(parse_one("ARRAY(1)::ARRAY<INT>"))
self.assertEqual(expression.type, parse_one("ARRAY<INT>", into=exp.DataType))
@@ -450,7 +474,7 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
expression.expressions[0].type.this, exp.DataType.Type.FLOAT
) # a.cola AS cola
- addition_alias = expression.args["from"].expressions[0].this.expressions[0]
+ addition_alias = expression.args["from"].this.this.expressions[0]
self.assertEqual(
addition_alias.type.this, exp.DataType.Type.FLOAT
) # x.cola + y.cola AS cola
@@ -663,3 +687,30 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
optimizer.optimize(parse_one("SELECT * FROM a"), schema=schema),
parse_one('SELECT "a"."b c" AS "b c", "a"."d e" AS "d e" FROM "a" AS "a"'),
)
+
+ def test_quotes(self):
+ schema = {
+ "example": {
+ '"source"': {
+ "id": "text",
+ '"name"': "text",
+ '"payload"': "text",
+ }
+ }
+ }
+
+ expected = parse_one(
+ """
+ SELECT
+ "source"."ID" AS "ID",
+ "source"."name" AS "name",
+ "source"."payload" AS "payload"
+ FROM "EXAMPLE"."source" AS "source"
+ """,
+ read="snowflake",
+ ).sql(pretty=True, dialect="snowflake")
+
+ for func in (optimizer.qualify.qualify, optimizer.optimize):
+ source_query = parse_one('SELECT * FROM example."source"', read="snowflake")
+ transformed = func(source_query, dialect="snowflake", schema=schema)
+ self.assertEqual(transformed.sql(pretty=True, dialect="snowflake"), expected)