diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-02 23:59:40 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-06-02 23:59:46 +0000 |
commit | 20739a12c39121a9e7ad3c9a2469ec5a6876199d (patch) | |
tree | c000de91c59fd29b2d9beecf9f93b84e69727f37 /sqlglot/optimizer/normalize.py | |
parent | Releasing debian version 12.2.0-1. (diff) | |
download | sqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.tar.xz sqlglot-20739a12c39121a9e7ad3c9a2469ec5a6876199d.zip |
Merging upstream version 15.0.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/normalize.py')
-rw-r--r-- | sqlglot/optimizer/normalize.py | 35 |
1 files changed, 18 insertions, 17 deletions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index b013312..1db094e 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -1,12 +1,12 @@ from __future__ import annotations import logging -import typing as t from sqlglot import exp from sqlglot.errors import OptimizeError +from sqlglot.generator import cached_generator from sqlglot.helper import while_changing -from sqlglot.optimizer.simplify import flatten, uniq_sort +from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort logger = logging.getLogger("sqlglot") @@ -28,13 +28,16 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = Returns: sqlglot.Expression: normalized expression """ - cache: t.Dict[int, str] = {} + generate = cached_generator() for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): if isinstance(node, exp.Connector): if normalized(node, dnf=dnf): continue + root = node is expression + original = node.copy() + node.transform(rewrite_between, copy=False) distance = normalization_distance(node, dnf=dnf) if distance > max_distance: @@ -43,11 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = ) return expression - root = node is expression - original = node.copy() try: node = node.replace( - while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache)) + while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate)) ) except OptimizeError as e: logger.info(e) @@ -111,7 +112,7 @@ def _predicate_lengths(expression, dnf): return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) -def distributive_law(expression, dnf, max_distance, cache=None): +def distributive_law(expression, dnf, max_distance, generate): """ x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) @@ -124,7 +125,7 @@ def distributive_law(expression, dnf, max_distance, cache=None): if distance > max_distance: raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") - exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache)) + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate)) to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) if isinstance(expression, from_exp): @@ -135,30 +136,30 @@ def distributive_law(expression, dnf, max_distance, cache=None): if isinstance(a, to_exp) and isinstance(b, to_exp): if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): - return _distribute(a, b, from_func, to_func, cache) - return _distribute(b, a, from_func, to_func, cache) + return _distribute(a, b, from_func, to_func, generate) + return _distribute(b, a, from_func, to_func, generate) if isinstance(a, to_exp): - return _distribute(b, a, from_func, to_func, cache) + return _distribute(b, a, from_func, to_func, generate) if isinstance(b, to_exp): - return _distribute(a, b, from_func, to_func, cache) + return _distribute(a, b, from_func, to_func, generate) return expression -def _distribute(a, b, from_func, to_func, cache): +def _distribute(a, b, from_func, to_func, generate): if isinstance(a, exp.Connector): exp.replace_children( a, lambda c: to_func( - uniq_sort(flatten(from_func(c, b.left)), cache), - uniq_sort(flatten(from_func(c, b.right)), cache), + uniq_sort(flatten(from_func(c, b.left)), generate), + uniq_sort(flatten(from_func(c, b.right)), generate), copy=False, ), ) else: a = to_func( - uniq_sort(flatten(from_func(a, b.left)), cache), - uniq_sort(flatten(from_func(a, b.right)), cache), + uniq_sort(flatten(from_func(a, b.left)), generate), + uniq_sort(flatten(from_func(a, b.right)), generate), copy=False, ) |