summaryrefslogtreecommitdiffstats
path: root/tests/test_optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_optimizer.py')
-rw-r--r--tests/test_optimizer.py97
1 files changed, 63 insertions, 34 deletions
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index e10d05e..597fa6f 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -1,4 +1,5 @@
import unittest
+from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
import duckdb
@@ -11,6 +12,7 @@ from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
from sqlglot.schema import MappingSchema
from tests.helpers import (
+ TPCDS_SCHEMA,
TPCH_SCHEMA,
load_sql_fixture_pairs,
load_sql_fixtures,
@@ -18,6 +20,28 @@ from tests.helpers import (
)
+def parse_and_optimize(func, sql, dialect, **kwargs):
+ return func(parse_one(sql, read=dialect), **kwargs)
+
+
+def qualify_columns(expression, **kwargs):
+ expression = optimizer.qualify_tables.qualify_tables(expression)
+ expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
+ return expression
+
+
+def pushdown_projections(expression, **kwargs):
+ expression = optimizer.qualify_tables.qualify_tables(expression)
+ expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
+ expression = optimizer.pushdown_projections.pushdown_projections(expression, **kwargs)
+ return expression
+
+
+def normalize(expression, **kwargs):
+ expression = optimizer.normalize.normalize(expression, dnf=False)
+ return optimizer.simplify.simplify(expression)
+
+
class TestOptimizer(unittest.TestCase):
maxDiff = None
@@ -74,29 +98,35 @@ class TestOptimizer(unittest.TestCase):
}
def check_file(self, file, func, pretty=False, execute=False, **kwargs):
- for i, (meta, sql, expected) in enumerate(
- load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1
- ):
- title = meta.get("title") or f"{i}, {sql}"
- dialect = meta.get("dialect")
- leave_tables_isolated = meta.get("leave_tables_isolated")
+ with ProcessPoolExecutor() as pool:
+ results = {}
+
+ for i, (meta, sql, expected) in enumerate(
+ load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1
+ ):
+ title = meta.get("title") or f"{i}, {sql}"
+ dialect = meta.get("dialect")
+ execute = execute if meta.get("execute") is None else False
+ leave_tables_isolated = meta.get("leave_tables_isolated")
- func_kwargs = {**kwargs}
- if leave_tables_isolated is not None:
- func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)
+ func_kwargs = {**kwargs}
+ if leave_tables_isolated is not None:
+ func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated)
+
+ future = pool.submit(parse_and_optimize, func, sql, dialect, **func_kwargs)
+ results[future] = (sql, title, expected, dialect, execute)
+
+ for future in as_completed(results):
+ optimized = future.result()
+ sql, title, expected, dialect, execute = results[future]
with self.subTest(title):
- optimized = func(parse_one(sql, read=dialect), **func_kwargs)
self.assertEqual(
expected,
optimized.sql(pretty=pretty, dialect=dialect),
)
- should_execute = meta.get("execute")
- if should_execute is None:
- should_execute = execute
-
- if string_to_bool(should_execute):
+ if string_to_bool(execute):
with self.subTest(f"(execute) {title}"):
df1 = self.conn.execute(
sqlglot.transpile(sql, read=dialect, write="duckdb")[0]
@@ -137,25 +167,19 @@ class TestOptimizer(unittest.TestCase):
"(x AND y) OR (x AND z)",
)
- self.check_file(
- "normalize",
- optimizer.normalize.normalize,
+ self.assertEqual(
+ optimizer.normalize.normalize(
+ parse_one("x AND (y OR z)"),
+ ).sql(),
+ "x AND (y OR z)",
)
- def test_qualify_columns(self):
- def qualify_columns(expression, **kwargs):
- expression = optimizer.qualify_tables.qualify_tables(expression)
- expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
- return expression
+ self.check_file("normalize", normalize)
+ def test_qualify_columns(self):
self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema)
def test_qualify_columns__with_invisible(self):
- def qualify_columns(expression, **kwargs):
- expression = optimizer.qualify_tables.qualify_tables(expression)
- expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
- return expression
-
schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}})
self.check_file("qualify_columns__with_invisible", qualify_columns, schema=schema)
@@ -172,17 +196,15 @@ class TestOptimizer(unittest.TestCase):
self.check_file("lower_identities", optimizer.lower_identities.lower_identities)
def test_pushdown_projection(self):
- def pushdown_projections(expression, **kwargs):
- expression = optimizer.qualify_tables.qualify_tables(expression)
- expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs)
- expression = optimizer.pushdown_projections.pushdown_projections(expression, **kwargs)
- return expression
-
self.check_file("pushdown_projections", pushdown_projections, schema=self.schema)
def test_simplify(self):
self.check_file("simplify", optimizer.simplify.simplify)
+ expression = parse_one("TRUE AND TRUE AND TRUE")
+ self.assertEqual(exp.true(), optimizer.simplify.simplify(expression))
+ self.assertEqual(exp.true(), optimizer.simplify.simplify(expression.this))
+
def test_unnest_subqueries(self):
self.check_file(
"unnest_subqueries",
@@ -257,6 +279,9 @@ class TestOptimizer(unittest.TestCase):
def test_tpch(self):
self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True)
+ def test_tpcds(self):
+ self.check_file("tpc-ds/tpc-ds", optimizer.optimize, schema=TPCDS_SCHEMA, pretty=True)
+
def test_file_schema(self):
expression = parse_one(
"""
@@ -578,6 +603,10 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|')
)
self.assertEqual(expression.expressions[0].type.this, target_type)
+ def test_concat_annotation(self):
+ expression = annotate_types(parse_one("CONCAT('A', 'B')"))
+ self.assertEqual(expression.type.this, exp.DataType.Type.VARCHAR)
+
def test_recursive_cte(self):
query = parse_one(
"""