diff options
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r-- | tests/test_optimizer.py | 276 |
1 files changed, 276 insertions, 0 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 0000000..40540b3 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,276 @@ +import unittest + +from sqlglot import optimizer, parse_one, table +from sqlglot.errors import OptimizeError +from sqlglot.optimizer.schema import MappingSchema, ensure_schema +from sqlglot.optimizer.scope import traverse_scope +from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures + + +class TestOptimizer(unittest.TestCase): + maxDiff = None + + def setUp(self): + self.schema = { + "x": { + "a": "INT", + "b": "INT", + }, + "y": { + "b": "INT", + "c": "INT", + }, + "z": { + "b": "INT", + "c": "INT", + }, + } + + def check_file(self, file, func, pretty=False, **kwargs): + for meta, sql, expected in load_sql_fixture_pairs(f"optimizer/{file}.sql"): + dialect = meta.get("dialect") + with self.subTest(sql): + self.assertEqual( + func(parse_one(sql, read=dialect), **kwargs).sql( + pretty=pretty, dialect=dialect + ), + expected, + ) + + def test_optimize(self): + schema = { + "x": {"a": "INT", "b": "INT"}, + "y": {"a": "INT", "b": "INT"}, + "z": {"a": "INT", "c": "INT"}, + } + + self.check_file("optimizer", optimizer.optimize, pretty=True, schema=schema) + + def test_isolate_table_selects(self): + self.check_file( + "isolate_table_selects", + optimizer.isolate_table_selects.isolate_table_selects, + ) + + def test_qualify_tables(self): + self.check_file( + "qualify_tables", + optimizer.qualify_tables.qualify_tables, + db="db", + catalog="c", + ) + + def test_normalize(self): + self.assertEqual( + optimizer.normalize.normalize( + parse_one("x AND (y OR z)"), + dnf=True, + ).sql(), + "(x AND y) OR (x AND z)", + ) + + self.check_file( + "normalize", + optimizer.normalize.normalize, + ) + + def test_qualify_columns(self): + def qualify_columns(expression, **kwargs): + expression = optimizer.qualify_tables.qualify_tables(expression) + expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) + return expression + + self.check_file("qualify_columns", qualify_columns, schema=self.schema) + + def test_qualify_columns__invalid(self): + for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"): + with self.subTest(sql): + with self.assertRaises(OptimizeError): + optimizer.qualify_columns.qualify_columns( + parse_one(sql), schema=self.schema + ) + + def test_quote_identities(self): + self.check_file("quote_identities", optimizer.quote_identities.quote_identities) + + def test_pushdown_projection(self): + def pushdown_projections(expression, **kwargs): + expression = optimizer.qualify_tables.qualify_tables(expression) + expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) + expression = optimizer.pushdown_projections.pushdown_projections(expression) + return expression + + self.check_file( + "pushdown_projections", pushdown_projections, schema=self.schema + ) + + def test_simplify(self): + self.check_file("simplify", optimizer.simplify.simplify) + + def test_unnest_subqueries(self): + self.check_file( + "unnest_subqueries", + optimizer.unnest_subqueries.unnest_subqueries, + pretty=True, + ) + + def test_pushdown_predicates(self): + self.check_file( + "pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates + ) + + def test_expand_multi_table_selects(self): + self.check_file( + "expand_multi_table_selects", + optimizer.expand_multi_table_selects.expand_multi_table_selects, + ) + + def test_optimize_joins(self): + self.check_file( + "optimize_joins", + optimizer.optimize_joins.optimize_joins, + ) + + def test_eliminate_subqueries(self): + self.check_file( + "eliminate_subqueries", + optimizer.eliminate_subqueries.eliminate_subqueries, + pretty=True, + ) + + def test_tpch(self): + self.check_file( + "tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True + ) + + def test_schema(self): + schema = ensure_schema( + { + "x": { + "a": "uint64", + } + } + ) + self.assertEqual( + schema.column_names( + table( + "x", + ) + ), + ["a"], + ) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db")) + with self.assertRaises(ValueError): + schema.column_names(table("x2")) + + schema = ensure_schema( + { + "db": { + "x": { + "a": "uint64", + } + } + } + ) + self.assertEqual(schema.column_names(table("x", db="db")), ["a"]) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c")) + with self.assertRaises(ValueError): + schema.column_names(table("x")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db2")) + with self.assertRaises(ValueError): + schema.column_names(table("x2", db="db")) + + schema = ensure_schema( + { + "c": { + "db": { + "x": { + "a": "uint64", + } + } + } + } + ) + self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"]) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db")) + with self.assertRaises(ValueError): + schema.column_names(table("x")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c2")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db2")) + with self.assertRaises(ValueError): + schema.column_names(table("x2", db="db")) + + schema = ensure_schema( + MappingSchema( + { + "x": { + "a": "uint64", + } + } + ) + ) + self.assertEqual(schema.column_names(table("x")), ["a"]) + + with self.assertRaises(OptimizeError): + ensure_schema({}) + + def test_file_schema(self): + expression = parse_one( + """ + SELECT * + FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') + """ + ) + self.assertEqual( + """ +SELECT + "_q_0"."n_nationkey" AS "n_nationkey", + "_q_0"."n_name" AS "n_name", + "_q_0"."n_regionkey" AS "n_regionkey", + "_q_0"."n_comment" AS "n_comment" +FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') AS "_q_0" +""".strip(), + optimizer.optimize(expression).sql(pretty=True), + ) + + def test_scope(self): + sql = """ + WITH q AS ( + SELECT x.b FROM x + ), r AS ( + SELECT y.b FROM y + ) + SELECT + r.b, + s.b + FROM r + JOIN ( + SELECT y.c AS b FROM y + ) s + ON s.b = r.b + WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b) + """ + scopes = traverse_scope(parse_one(sql)) + self.assertEqual(len(scopes), 5) + self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") + self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") + self.assertEqual( + scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b" + ) + self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) + + self.assertEqual(set(scopes[4].sources), {"q", "r", "s"}) + self.assertEqual(len(scopes[4].columns), 6) + self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"}) + self.assertEqual(scopes[4].source_columns("q"), []) + self.assertEqual(len(scopes[4].source_columns("r")), 2) + self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) |