summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py276
1 files changed, 276 insertions, 0 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
new file mode 100644
index 0000000..40540b3
--- /dev/null
+++ b/tests/test_optimizer.py
@@ -0,0 +1,276 @@
+import unittest
+
+from sqlglot import optimizer, parse_one, table
+from sqlglot.errors import OptimizeError
+from sqlglot.optimizer.schema import MappingSchema, ensure_schema
+from sqlglot.optimizer.scope import traverse_scope
+from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures
+
+
+class TestOptimizer(unittest.TestCase):
+ maxDiff = None
+
+ def setUp(self):
+ self.schema = {
+ "x": {
+ "a": "INT",
+ "b": "INT",
+ },
+ "y": {
+ "b": "INT",
+ "c": "INT",
+ },
+ "z": {
+ "b": "INT",
+ "c": "INT",
+ },
+ }
+
+ def check_file(self, file, func, pretty=False, **kwargs):
+ for meta, sql, expected in load_sql_fixture_pairs(f"optimizer/{file}.sql"):
+ dialect = meta.get("dialect")
+ with self.subTest(sql):
+ self.assertEqual(
+ func(parse_one(sql, read=dialect), **kwargs).sql(
+ pretty=pretty, dialect=dialect
+ ),
+ expected,
+ )
+
+ def test_optimize(self):
+ schema = {
+ "x": {"a": "INT", "b": "INT"},
+ "y": {"a": "INT", "b": "INT"},
+ "z": {"a": "INT", "c": "INT"},
+ }
+
+ self.check_file("optimizer", optimizer.optimize, pretty=True, schema=schema)
+
+ def test_isolate_table_selects(self):
+ self.check_file(
+ "isolate_table_selects",
+ optimizer.isolate_table_selects.isolate_table_selects,
+ )
+
+ def test_qualify_tables(self):
+ self.check_file(
+ "qualify_tables",
+ optimizer.qualify_tables.qualify_tables,
+ db="db",
+ catalog="c",
+ )
+
+ def test_normalize(self):
+ self.assertEqual(
+ optimizer.normalize.normalize(
+ parse_one("x AND (y OR z)"),
+ dnf=True,
+ ).sql(),
+ "(x AND y) OR (x AND z)",
+ )
+
+ self.check_file(
+ "normalize",
+ optimizer.normalize.normalize,
+ )
+
+ def test_qualify_columns(self):
+ def qualify_columns(expression, **kwargs):
+ expression = optimizer.qualify_tables.qualify_tables(expression)
+ expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
+ return expression
+
+ self.check_file("qualify_columns", qualify_columns, schema=self.schema)
+
+ def test_qualify_columns__invalid(self):
+ for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
+ with self.subTest(sql):
+ with self.assertRaises(OptimizeError):
+ optimizer.qualify_columns.qualify_columns(
+ parse_one(sql), schema=self.schema
+ )
+
+ def test_quote_identities(self):
+ self.check_file("quote_identities", optimizer.quote_identities.quote_identities)
+
+ def test_pushdown_projection(self):
+ def pushdown_projections(expression, **kwargs):
+ expression = optimizer.qualify_tables.qualify_tables(expression)
+ expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
+ expression = optimizer.pushdown_projections.pushdown_projections(expression)
+ return expression
+
+ self.check_file(
+ "pushdown_projections", pushdown_projections, schema=self.schema
+ )
+
+ def test_simplify(self):
+ self.check_file("simplify", optimizer.simplify.simplify)
+
+ def test_unnest_subqueries(self):
+ self.check_file(
+ "unnest_subqueries",
+ optimizer.unnest_subqueries.unnest_subqueries,
+ pretty=True,
+ )
+
+ def test_pushdown_predicates(self):
+ self.check_file(
+ "pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates
+ )
+
+ def test_expand_multi_table_selects(self):
+ self.check_file(
+ "expand_multi_table_selects",
+ optimizer.expand_multi_table_selects.expand_multi_table_selects,
+ )
+
+ def test_optimize_joins(self):
+ self.check_file(
+ "optimize_joins",
+ optimizer.optimize_joins.optimize_joins,
+ )
+
+ def test_eliminate_subqueries(self):
+ self.check_file(
+ "eliminate_subqueries",
+ optimizer.eliminate_subqueries.eliminate_subqueries,
+ pretty=True,
+ )
+
+ def test_tpch(self):
+ self.check_file(
+ "tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True
+ )
+
+ def test_schema(self):
+ schema = ensure_schema(
+ {
+ "x": {
+ "a": "uint64",
+ }
+ }
+ )
+ self.assertEqual(
+ schema.column_names(
+ table(
+ "x",
+ )
+ ),
+ ["a"],
+ )
+ with self.assertRaises(ValueError):
+ schema.column_names(table("x", db="db", catalog="c"))
+ with self.assertRaises(ValueError):
+ schema.column_names(table("x", db="db"))
+ with self.assertRaises(ValueError):
+ schema.column_names(table("x2"))
+
+ schema = ensure_schema(
+ {
+ "db": {
+ "x": {
+ "a": "uint64",
+ }
+ }
+ }
+ )
+ self.assertEqual(schema.column_names(table("x", db="db")), ["a"])
+ with self.assertRaises(ValueError):
+ schema.column_names(table("x", db="db", catalog="c"))
+ with self.assertRaises(ValueError):
+ schema.column_names(table("x"))
+ with self.assertRaises(ValueError):
+ schema.column_names(table("x", db="db2"))
+ with self.assertRaises(ValueError):
+ schema.column_names(table("x2", db="db"))
+
+ schema = ensure_schema(
+ {
+ "c": {
+ "db": {
+ "x": {
+ "a": "uint64",
+ }
+ }
+ }
+ }
+ )
+ self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"])
+ with self.assertRaises(ValueError):
+ schema.column_names(table("x", db="db"))
+ with self.assertRaises(ValueError):
+ schema.column_names(table("x"))
+ with self.assertRaises(ValueError):
+ schema.column_names(table("x", db="db", catalog="c2"))
+ with self.assertRaises(ValueError):
+ schema.column_names(table("x", db="db2"))
+ with self.assertRaises(ValueError):
+ schema.column_names(table("x2", db="db"))
+
+ schema = ensure_schema(
+ MappingSchema(
+ {
+ "x": {
+ "a": "uint64",
+ }
+ }
+ )
+ )
+ self.assertEqual(schema.column_names(table("x")), ["a"])
+
+ with self.assertRaises(OptimizeError):
+ ensure_schema({})
+
+ def test_file_schema(self):
+ expression = parse_one(
+ """
+ SELECT *
+ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
+ """
+ )
+ self.assertEqual(
+ """
+SELECT
+ "_q_0"."n_nationkey" AS "n_nationkey",
+ "_q_0"."n_name" AS "n_name",
+ "_q_0"."n_regionkey" AS "n_regionkey",
+ "_q_0"."n_comment" AS "n_comment"
+FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') AS "_q_0"
+""".strip(),
+ optimizer.optimize(expression).sql(pretty=True),
+ )
+
+ def test_scope(self):
+ sql = """
+ WITH q AS (
+ SELECT x.b FROM x
+ ), r AS (
+ SELECT y.b FROM y
+ )
+ SELECT
+ r.b,
+ s.b
+ FROM r
+ JOIN (
+ SELECT y.c AS b FROM y
+ ) s
+ ON s.b = r.b
+ WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b)
+ """
+ scopes = traverse_scope(parse_one(sql))
+ self.assertEqual(len(scopes), 5)
+ self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x")
+ self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y")
+ self.assertEqual(
+ scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b"
+ )
+ self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y")
+ self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql())
+
+ self.assertEqual(set(scopes[4].sources), {"q", "r", "s"})
+ self.assertEqual(len(scopes[4].columns), 6)
+ self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"})
+ self.assertEqual(scopes[4].source_columns("q"), [])
+ self.assertEqual(len(scopes[4].source_columns("r")), 2)
+ self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"})