summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/normalize.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/normalize.py')
-rw-r--r--sqlglot/optimizer/normalize.py35
1 files changed, 18 insertions, 17 deletions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index b013312..1db094e 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -1,12 +1,12 @@
from __future__ import annotations
import logging
-import typing as t
from sqlglot import exp
from sqlglot.errors import OptimizeError
+from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
-from sqlglot.optimizer.simplify import flatten, uniq_sort
+from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
logger = logging.getLogger("sqlglot")
@@ -28,13 +28,16 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
- cache: t.Dict[int, str] = {}
+ generate = cached_generator()
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
continue
+ root = node is expression
+ original = node.copy()
+ node.transform(rewrite_between, copy=False)
distance = normalization_distance(node, dnf=dnf)
if distance > max_distance:
@@ -43,11 +46,9 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
)
return expression
- root = node is expression
- original = node.copy()
try:
node = node.replace(
- while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
+ while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
)
except OptimizeError as e:
logger.info(e)
@@ -111,7 +112,7 @@ def _predicate_lengths(expression, dnf):
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
-def distributive_law(expression, dnf, max_distance, cache=None):
+def distributive_law(expression, dnf, max_distance, generate):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@@ -124,7 +125,7 @@ def distributive_law(expression, dnf, max_distance, cache=None):
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
- exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
if isinstance(expression, from_exp):
@@ -135,30 +136,30 @@ def distributive_law(expression, dnf, max_distance, cache=None):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
- return _distribute(a, b, from_func, to_func, cache)
- return _distribute(b, a, from_func, to_func, cache)
+ return _distribute(a, b, from_func, to_func, generate)
+ return _distribute(b, a, from_func, to_func, generate)
if isinstance(a, to_exp):
- return _distribute(b, a, from_func, to_func, cache)
+ return _distribute(b, a, from_func, to_func, generate)
if isinstance(b, to_exp):
- return _distribute(a, b, from_func, to_func, cache)
+ return _distribute(a, b, from_func, to_func, generate)
return expression
-def _distribute(a, b, from_func, to_func, cache):
+def _distribute(a, b, from_func, to_func, generate):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
- uniq_sort(flatten(from_func(c, b.left)), cache),
- uniq_sort(flatten(from_func(c, b.right)), cache),
+ uniq_sort(flatten(from_func(c, b.left)), generate),
+ uniq_sort(flatten(from_func(c, b.right)), generate),
copy=False,
),
)
else:
a = to_func(
- uniq_sort(flatten(from_func(a, b.left)), cache),
- uniq_sort(flatten(from_func(a, b.right)), cache),
+ uniq_sort(flatten(from_func(a, b.left)), generate),
+ uniq_sort(flatten(from_func(a, b.right)), generate),
copy=False,
)