summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/normalize.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/optimizer/normalize.py27
1 files changed, 12 insertions, 15 deletions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index 8d82b2d..6df36af 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -4,7 +4,6 @@ import logging
from sqlglot import exp
from sqlglot.errors import OptimizeError
-from sqlglot.generator import cached_generator
from sqlglot.helper import while_changing
from sqlglot.optimizer.scope import find_all_in_scope
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
@@ -29,8 +28,6 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
Returns:
sqlglot.Expression: normalized expression
"""
- generate = cached_generator()
-
for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
if isinstance(node, exp.Connector):
if normalized(node, dnf=dnf):
@@ -49,7 +46,7 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
try:
node = node.replace(
- while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
+ while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
)
except OptimizeError as e:
logger.info(e)
@@ -133,7 +130,7 @@ def _predicate_lengths(expression, dnf):
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
-def distributive_law(expression, dnf, max_distance, generate):
+def distributive_law(expression, dnf, max_distance):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@@ -146,7 +143,7 @@ def distributive_law(expression, dnf, max_distance, generate):
if distance > max_distance:
raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
- exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
if isinstance(expression, from_exp):
@@ -157,30 +154,30 @@ def distributive_law(expression, dnf, max_distance, generate):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
- return _distribute(a, b, from_func, to_func, generate)
- return _distribute(b, a, from_func, to_func, generate)
+ return _distribute(a, b, from_func, to_func)
+ return _distribute(b, a, from_func, to_func)
if isinstance(a, to_exp):
- return _distribute(b, a, from_func, to_func, generate)
+ return _distribute(b, a, from_func, to_func)
if isinstance(b, to_exp):
- return _distribute(a, b, from_func, to_func, generate)
+ return _distribute(a, b, from_func, to_func)
return expression
-def _distribute(a, b, from_func, to_func, generate):
+def _distribute(a, b, from_func, to_func):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
- uniq_sort(flatten(from_func(c, b.left)), generate),
- uniq_sort(flatten(from_func(c, b.right)), generate),
+ uniq_sort(flatten(from_func(c, b.left))),
+ uniq_sort(flatten(from_func(c, b.right))),
copy=False,
),
)
else:
a = to_func(
- uniq_sort(flatten(from_func(a, b.left)), generate),
- uniq_sort(flatten(from_func(a, b.right)), generate),
+ uniq_sort(flatten(from_func(a, b.left))),
+ uniq_sort(flatten(from_func(a, b.right))),
copy=False,
)