diff options
Diffstat (limited to 'sqlglot/optimizer/normalize.py')
-rw-r--r-- | sqlglot/optimizer/normalize.py | 35 |
1 files changed, 18 insertions, 17 deletions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index b013312..1db094e 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -1,12 +1,12 @@ from __future__ import annotations import logging -import typing as t from sqlglot import exp from sqlglot.errors import OptimizeError +from sqlglot.generator import cached_generator from sqlglot.helper import while_changing -from sqlglot.optimizer.simplify import flatten, uniq_sort +from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort logger = logging.getLogger("sqlglot") @@ -28,13 +28,16 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = Returns: sqlglot.Expression: normalized expression """ - cache: t.Dict[int, str] = {} + generate = cached_generator() for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): if isinstance(node, exp.Connector): if normalized(node, dnf=dnf): continue + root = node is expression + original = node.copy() + node.transform(rewrite_between, copy=False) distance = normalization_distance(node, dnf=dnf) if distance > max_distance: @@ -43,11 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = ) return expression - root = node is expression - original = node.copy() try: node = node.replace( - while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache)) + while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate)) ) except OptimizeError as e: logger.info(e) @@ -111,7 +112,7 @@ def _predicate_lengths(expression, dnf): return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) -def distributive_law(expression, dnf, max_distance, cache=None): +def distributive_law(expression, dnf, max_distance, generate): """ x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) @@ -124,7 +125,7 @@ def distributive_law(expression, dnf, max_distance, cache=None): if distance > max_distance: raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") - exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache)) + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate)) to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) if isinstance(expression, from_exp): @@ -135,30 +136,30 @@ def distributive_law(expression, dnf, max_distance, cache=None): if isinstance(a, to_exp) and isinstance(b, to_exp): if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): - return _distribute(a, b, from_func, to_func, cache) - return _distribute(b, a, from_func, to_func, cache) + return _distribute(a, b, from_func, to_func, generate) + return _distribute(b, a, from_func, to_func, generate) if isinstance(a, to_exp): - return _distribute(b, a, from_func, to_func, cache) + return _distribute(b, a, from_func, to_func, generate) if isinstance(b, to_exp): - return _distribute(a, b, from_func, to_func, cache) + return _distribute(a, b, from_func, to_func, generate) return expression -def _distribute(a, b, from_func, to_func, cache): +def _distribute(a, b, from_func, to_func, generate): if isinstance(a, exp.Connector): exp.replace_children( a, lambda c: to_func( - uniq_sort(flatten(from_func(c, b.left)), cache), - uniq_sort(flatten(from_func(c, b.right)), cache), + uniq_sort(flatten(from_func(c, b.left)), generate), + uniq_sort(flatten(from_func(c, b.right)), generate), copy=False, ), ) else: a = to_func( - uniq_sort(flatten(from_func(a, b.left)), cache), - uniq_sort(flatten(from_func(a, b.right)), cache), + uniq_sort(flatten(from_func(a, b.left)), generate), + uniq_sort(flatten(from_func(a, b.right)), generate), copy=False, ) |