summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/normalize.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/normalize.py')
-rw-r--r--sqlglot/optimizer/normalize.py104
1 files changed, 69 insertions, 35 deletions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index f16f519..f2df230 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -1,29 +1,63 @@
+from __future__ import annotations
+
+import logging
+import typing as t
+
from sqlglot import exp
+from sqlglot.errors import OptimizeError
from sqlglot.helper import while_changing
-from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
+from sqlglot.optimizer.simplify import flatten, uniq_sort
+
+logger = logging.getLogger("sqlglot")
-def normalize(expression, dnf=False, max_distance=128):
+def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
"""
- Rewrite sqlglot AST into conjunctive normal form.
+ Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
- >>> normalize(expression).sql()
+ >>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Args:
- expression (sqlglot.Expression): expression to normalize
- dnf (bool): rewrite in disjunctive normal form instead
- max_distance (int): the maximal estimated distance from cnf to attempt conversion
+ expression: expression to normalize
+ dnf: rewrite in disjunctive normal form instead.
+ max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
"""
- expression = simplify(expression)
+ cache: t.Dict[int, str] = {}
+
+ for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
+ if isinstance(node, exp.Connector):
+ if normalized(node, dnf=dnf):
+ continue
+
+ distance = normalization_distance(node, dnf=dnf)
+
+ if distance > max_distance:
+ logger.info(
+ f"Skipping normalization because distance {distance} exceeds max {max_distance}"
+ )
+ return expression
+
+ root = node is expression
+ original = node.copy()
+ try:
+ node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
+ except OptimizeError as e:
+ logger.info(e)
+ node.replace(original)
+ if root:
+ return original
+ return expression
+
+ if root:
+ expression = node
- expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
- return simplify(expression)
+ return expression
def normalized(expression, dnf=False):
@@ -51,7 +85,7 @@ def normalization_distance(expression, dnf=False):
int: difference
"""
return sum(_predicate_lengths(expression, dnf)) - (
- len(list(expression.find_all(exp.Connector))) + 1
+ sum(1 for _ in expression.find_all(exp.Connector)) + 1
)
@@ -64,29 +98,32 @@ def _predicate_lengths(expression, dnf):
expression = expression.unnest()
if not isinstance(expression, exp.Connector):
- return [1]
+ return (1,)
left, right = expression.args.values()
if isinstance(expression, exp.And if dnf else exp.Or):
- return [
+ return tuple(
a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
- ]
+ )
return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
-def distributive_law(expression, dnf, max_distance):
+def distributive_law(expression, dnf, max_distance, cache=None):
"""
x OR (y AND z) -> (x OR y) AND (x OR z)
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
"""
- if isinstance(expression.unnest(), exp.Connector):
- if normalization_distance(expression, dnf) > max_distance:
- return expression
+ if normalized(expression, dnf=dnf):
+ return expression
- to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
+ distance = normalization_distance(expression, dnf=dnf)
- exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
+ if distance > max_distance:
+ raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
+
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
+ to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
if isinstance(expression, from_exp):
a, b = expression.unnest_operands()
@@ -96,32 +133,29 @@ def distributive_law(expression, dnf, max_distance):
if isinstance(a, to_exp) and isinstance(b, to_exp):
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
- return _distribute(a, b, from_func, to_func)
- return _distribute(b, a, from_func, to_func)
+ return _distribute(a, b, from_func, to_func, cache)
+ return _distribute(b, a, from_func, to_func, cache)
if isinstance(a, to_exp):
- return _distribute(b, a, from_func, to_func)
+ return _distribute(b, a, from_func, to_func, cache)
if isinstance(b, to_exp):
- return _distribute(a, b, from_func, to_func)
+ return _distribute(a, b, from_func, to_func, cache)
return expression
-def _distribute(a, b, from_func, to_func):
+def _distribute(a, b, from_func, to_func, cache):
if isinstance(a, exp.Connector):
exp.replace_children(
a,
lambda c: to_func(
- exp.paren(from_func(c, b.left)),
- exp.paren(from_func(c, b.right)),
+ uniq_sort(flatten(from_func(c, b.left)), cache),
+ uniq_sort(flatten(from_func(c, b.right)), cache),
),
)
else:
- a = to_func(from_func(a, b.left), from_func(a, b.right))
-
- return _simplify(a)
-
+ a = to_func(
+ uniq_sort(flatten(from_func(a, b.left)), cache),
+ uniq_sort(flatten(from_func(a, b.right)), cache),
+ )
-def _simplify(node):
- node = uniq_sort(flatten(node))
- exp.replace_children(node, _simplify)
- return node
+ return a