From b38d717d5933fdae3fe85c87df7aee9a251fb58e Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Mon, 3 Apr 2023 09:31:54 +0200 Subject: Merging upstream version 11.4.5. Signed-off-by: Daniel Baumann --- tests/test_optimizer.py | 97 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 63 insertions(+), 34 deletions(-) (limited to 'tests/test_optimizer.py') diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index e10d05e..597fa6f 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -1,4 +1,5 @@ import unittest +from concurrent.futures import ProcessPoolExecutor, as_completed from functools import partial import duckdb @@ -11,6 +12,7 @@ from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope from sqlglot.schema import MappingSchema from tests.helpers import ( + TPCDS_SCHEMA, TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures, @@ -18,6 +20,28 @@ from tests.helpers import ( ) +def parse_and_optimize(func, sql, dialect, **kwargs): + return func(parse_one(sql, read=dialect), **kwargs) + + +def qualify_columns(expression, **kwargs): + expression = optimizer.qualify_tables.qualify_tables(expression) + expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) + return expression + + +def pushdown_projections(expression, **kwargs): + expression = optimizer.qualify_tables.qualify_tables(expression) + expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) + expression = optimizer.pushdown_projections.pushdown_projections(expression, **kwargs) + return expression + + +def normalize(expression, **kwargs): + expression = optimizer.normalize.normalize(expression, dnf=False) + return optimizer.simplify.simplify(expression) + + class TestOptimizer(unittest.TestCase): maxDiff = None @@ -74,29 +98,35 @@ class TestOptimizer(unittest.TestCase): } def check_file(self, file, func, pretty=False, execute=False, **kwargs): - for i, (meta, sql, expected) in enumerate( - load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1 - ): - title = meta.get("title") or f"{i}, {sql}" - dialect = meta.get("dialect") - leave_tables_isolated = meta.get("leave_tables_isolated") + with ProcessPoolExecutor() as pool: + results = {} + + for i, (meta, sql, expected) in enumerate( + load_sql_fixture_pairs(f"optimizer/{file}.sql"), start=1 + ): + title = meta.get("title") or f"{i}, {sql}" + dialect = meta.get("dialect") + execute = execute if meta.get("execute") is None else False + leave_tables_isolated = meta.get("leave_tables_isolated") - func_kwargs = {**kwargs} - if leave_tables_isolated is not None: - func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated) + func_kwargs = {**kwargs} + if leave_tables_isolated is not None: + func_kwargs["leave_tables_isolated"] = string_to_bool(leave_tables_isolated) + + future = pool.submit(parse_and_optimize, func, sql, dialect, **func_kwargs) + results[future] = (sql, title, expected, dialect, execute) + + for future in as_completed(results): + optimized = future.result() + sql, title, expected, dialect, execute = results[future] with self.subTest(title): - optimized = func(parse_one(sql, read=dialect), **func_kwargs) self.assertEqual( expected, optimized.sql(pretty=pretty, dialect=dialect), ) - should_execute = meta.get("execute") - if should_execute is None: - should_execute = execute - - if string_to_bool(should_execute): + if string_to_bool(execute): with self.subTest(f"(execute) {title}"): df1 = self.conn.execute( sqlglot.transpile(sql, read=dialect, write="duckdb")[0] @@ -137,25 +167,19 @@ class TestOptimizer(unittest.TestCase): "(x AND y) OR (x AND z)", ) - self.check_file( - "normalize", - optimizer.normalize.normalize, + self.assertEqual( + optimizer.normalize.normalize( + parse_one("x AND (y OR z)"), + ).sql(), + "x AND (y OR z)", ) - def test_qualify_columns(self): - def qualify_columns(expression, **kwargs): - expression = optimizer.qualify_tables.qualify_tables(expression) - expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) - return expression + self.check_file("normalize", normalize) + def test_qualify_columns(self): self.check_file("qualify_columns", qualify_columns, execute=True, schema=self.schema) def test_qualify_columns__with_invisible(self): - def qualify_columns(expression, **kwargs): - expression = optimizer.qualify_tables.qualify_tables(expression) - expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) - return expression - schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}}) self.check_file("qualify_columns__with_invisible", qualify_columns, schema=schema) @@ -172,17 +196,15 @@ class TestOptimizer(unittest.TestCase): self.check_file("lower_identities", optimizer.lower_identities.lower_identities) def test_pushdown_projection(self): - def pushdown_projections(expression, **kwargs): - expression = optimizer.qualify_tables.qualify_tables(expression) - expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) - expression = optimizer.pushdown_projections.pushdown_projections(expression, **kwargs) - return expression - self.check_file("pushdown_projections", pushdown_projections, schema=self.schema) def test_simplify(self): self.check_file("simplify", optimizer.simplify.simplify) + expression = parse_one("TRUE AND TRUE AND TRUE") + self.assertEqual(exp.true(), optimizer.simplify.simplify(expression)) + self.assertEqual(exp.true(), optimizer.simplify.simplify(expression.this)) + def test_unnest_subqueries(self): self.check_file( "unnest_subqueries", @@ -257,6 +279,9 @@ class TestOptimizer(unittest.TestCase): def test_tpch(self): self.check_file("tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True) + def test_tpcds(self): + self.check_file("tpc-ds/tpc-ds", optimizer.optimize, schema=TPCDS_SCHEMA, pretty=True) + def test_file_schema(self): expression = parse_one( """ @@ -578,6 +603,10 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') ) self.assertEqual(expression.expressions[0].type.this, target_type) + def test_concat_annotation(self): + expression = annotate_types(parse_one("CONCAT('A', 'B')")) + self.assertEqual(expression.type.this, exp.DataType.Type.VARCHAR) + def test_recursive_cte(self): query = parse_one( """ -- cgit v1.2.3