diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/normalize.py | 27 |
1 files changed, 12 insertions, 15 deletions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index 8d82b2d..6df36af 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -4,7 +4,6 @@ import logging from sqlglot import exp from sqlglot.errors import OptimizeError -from sqlglot.generator import cached_generator from sqlglot.helper import while_changing from sqlglot.optimizer.scope import find_all_in_scope from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort @@ -29,8 +28,6 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = Returns: sqlglot.Expression: normalized expression """ - generate = cached_generator() - for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): if isinstance(node, exp.Connector): if normalized(node, dnf=dnf): @@ -49,7 +46,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = try: node = node.replace( - while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate)) + while_changing(node, lambda e: distributive_law(e, dnf, max_distance)) ) except OptimizeError as e: logger.info(e) @@ -133,7 +130,7 @@ def _predicate_lengths(expression, dnf): return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) -def distributive_law(expression, dnf, max_distance, generate): +def distributive_law(expression, dnf, max_distance): """ x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) @@ -146,7 +143,7 @@ def distributive_law(expression, dnf, max_distance, generate): if distance > max_distance: raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") - exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate)) + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) if isinstance(expression, from_exp): @@ -157,30 +154,30 @@ def distributive_law(expression, dnf, max_distance, generate): if isinstance(a, to_exp) and isinstance(b, to_exp): if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): - return _distribute(a, b, from_func, to_func, generate) - return _distribute(b, a, from_func, to_func, generate) + return _distribute(a, b, from_func, to_func) + return _distribute(b, a, from_func, to_func) if isinstance(a, to_exp): - return _distribute(b, a, from_func, to_func, generate) + return _distribute(b, a, from_func, to_func) if isinstance(b, to_exp): - return _distribute(a, b, from_func, to_func, generate) + return _distribute(a, b, from_func, to_func) return expression -def _distribute(a, b, from_func, to_func, generate): +def _distribute(a, b, from_func, to_func): if isinstance(a, exp.Connector): exp.replace_children( a, lambda c: to_func( - uniq_sort(flatten(from_func(c, b.left)), generate), - uniq_sort(flatten(from_func(c, b.right)), generate), + uniq_sort(flatten(from_func(c, b.left))), + uniq_sort(flatten(from_func(c, b.right))), copy=False, ), ) else: a = to_func( - uniq_sort(flatten(from_func(a, b.left)), generate), - uniq_sort(flatten(from_func(a, b.right)), generate), + uniq_sort(flatten(from_func(a, b.left))), + uniq_sort(flatten(from_func(a, b.right))), copy=False, ) |