diff options
Diffstat (limited to 'sqlglot/optimizer/normalize.py')
-rw-r--r-- | sqlglot/optimizer/normalize.py | 104 |
1 files changed, 69 insertions, 35 deletions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index f16f519..f2df230 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -1,29 +1,63 @@ +from __future__ import annotations + +import logging +import typing as t + from sqlglot import exp +from sqlglot.errors import OptimizeError from sqlglot.helper import while_changing -from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort +from sqlglot.optimizer.simplify import flatten, uniq_sort + +logger = logging.getLogger("sqlglot") -def normalize(expression, dnf=False, max_distance=128): +def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): """ - Rewrite sqlglot AST into conjunctive normal form. + Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. Example: >>> import sqlglot >>> expression = sqlglot.parse_one("(x AND y) OR z") - >>> normalize(expression).sql() + >>> normalize(expression, dnf=False).sql() '(x OR z) AND (y OR z)' Args: - expression (sqlglot.Expression): expression to normalize - dnf (bool): rewrite in disjunctive normal form instead - max_distance (int): the maximal estimated distance from cnf to attempt conversion + expression: expression to normalize + dnf: rewrite in disjunctive normal form instead. + max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion Returns: sqlglot.Expression: normalized expression """ - expression = simplify(expression) + cache: t.Dict[int, str] = {} + + for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): + if isinstance(node, exp.Connector): + if normalized(node, dnf=dnf): + continue + + distance = normalization_distance(node, dnf=dnf) + + if distance > max_distance: + logger.info( + f"Skipping normalization because distance {distance} exceeds max {max_distance}" + ) + return expression + + root = node is expression + original = node.copy() + try: + node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache)) + except OptimizeError as e: + logger.info(e) + node.replace(original) + if root: + return original + return expression + + if root: + expression = node - expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance)) - return simplify(expression) + return expression def normalized(expression, dnf=False): @@ -51,7 +85,7 @@ def normalization_distance(expression, dnf=False): int: difference """ return sum(_predicate_lengths(expression, dnf)) - ( - len(list(expression.find_all(exp.Connector))) + 1 + sum(1 for _ in expression.find_all(exp.Connector)) + 1 ) @@ -64,29 +98,32 @@ def _predicate_lengths(expression, dnf): expression = expression.unnest() if not isinstance(expression, exp.Connector): - return [1] + return (1,) left, right = expression.args.values() if isinstance(expression, exp.And if dnf else exp.Or): - return [ + return tuple( a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf) - ] + ) return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) -def distributive_law(expression, dnf, max_distance): +def distributive_law(expression, dnf, max_distance, cache=None): """ x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) """ - if isinstance(expression.unnest(), exp.Connector): - if normalization_distance(expression, dnf) > max_distance: - return expression + if normalized(expression, dnf=dnf): + return expression - to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) + distance = normalization_distance(expression, dnf=dnf) - exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) + if distance > max_distance: + raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") + + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache)) + to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) if isinstance(expression, from_exp): a, b = expression.unnest_operands() @@ -96,32 +133,29 @@ def distributive_law(expression, dnf, max_distance): if isinstance(a, to_exp) and isinstance(b, to_exp): if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): - return _distribute(a, b, from_func, to_func) - return _distribute(b, a, from_func, to_func) + return _distribute(a, b, from_func, to_func, cache) + return _distribute(b, a, from_func, to_func, cache) if isinstance(a, to_exp): - return _distribute(b, a, from_func, to_func) + return _distribute(b, a, from_func, to_func, cache) if isinstance(b, to_exp): - return _distribute(a, b, from_func, to_func) + return _distribute(a, b, from_func, to_func, cache) return expression -def _distribute(a, b, from_func, to_func): +def _distribute(a, b, from_func, to_func, cache): if isinstance(a, exp.Connector): exp.replace_children( a, lambda c: to_func( - exp.paren(from_func(c, b.left)), - exp.paren(from_func(c, b.right)), + uniq_sort(flatten(from_func(c, b.left)), cache), + uniq_sort(flatten(from_func(c, b.right)), cache), ), ) else: - a = to_func(from_func(a, b.left), from_func(a, b.right)) - - return _simplify(a) - + a = to_func( + uniq_sort(flatten(from_func(a, b.left)), cache), + uniq_sort(flatten(from_func(a, b.right)), cache), + ) -def _simplify(node): - node = uniq_sort(flatten(node)) - exp.replace_children(node, _simplify) - return node + return a |