summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/normalize.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/normalize.py')
-rw-r--r--sqlglot/optimizer/normalize.py136
1 files changed, 136 insertions, 0 deletions
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
new file mode 100644
index 0000000..2c9f89c
--- /dev/null
+++ b/sqlglot/optimizer/normalize.py
@@ -0,0 +1,136 @@
+from sqlglot import exp
+from sqlglot.helper import while_changing
+from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
+
+
+def normalize(expression, dnf=False, max_distance=128):
+ """
+ Rewrite sqlglot AST into conjunctive normal form.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("(x AND y) OR z")
+ >>> normalize(expression).sql()
+ '(x OR z) AND (y OR z)'
+
+ Args:
+ expression (sqlglot.Expression): expression to normalize
+ dnf (bool): rewrite in disjunctive normal form instead
+ max_distance (int): the maximal estimated distance from cnf to attempt conversion
+ Returns:
+ sqlglot.Expression: normalized expression
+ """
+ expression = simplify(expression)
+
+ expression = while_changing(
+ expression, lambda e: distributive_law(e, dnf, max_distance)
+ )
+ return simplify(expression)
+
+
+def normalized(expression, dnf=False):
+ ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
+
+ return not any(
+ connector.find_ancestor(ancestor) for connector in expression.find_all(root)
+ )
+
+
+def normalization_distance(expression, dnf=False):
+ """
+ The difference in the number of predicates between the current expression and the normalized form.
+
+ This is used as an estimate of the cost of the conversion which is exponential in complexity.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
+ >>> normalization_distance(expression)
+ 4
+
+ Args:
+ expression (sqlglot.Expression): expression to compute distance
+ dnf (bool): compute to dnf distance instead
+ Returns:
+ int: difference
+ """
+ return sum(_predicate_lengths(expression, dnf)) - (
+ len(list(expression.find_all(exp.Connector))) + 1
+ )
+
+
+def _predicate_lengths(expression, dnf):
+ """
+ Returns a list of predicate lengths when expanded to normalized form.
+
+ (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
+ """
+ expression = expression.unnest()
+
+ if not isinstance(expression, exp.Connector):
+ return [1]
+
+ left, right = expression.args.values()
+
+ if isinstance(expression, exp.And if dnf else exp.Or):
+ x = [
+ a + b
+ for a in _predicate_lengths(left, dnf)
+ for b in _predicate_lengths(right, dnf)
+ ]
+ return x
+ return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
+
+
+def distributive_law(expression, dnf, max_distance):
+ """
+ x OR (y AND z) -> (x OR y) AND (x OR z)
+ (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
+ """
+ if isinstance(expression.unnest(), exp.Connector):
+ if normalization_distance(expression, dnf) > max_distance:
+ return expression
+
+ to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
+
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
+
+ if isinstance(expression, from_exp):
+ a, b = expression.unnest_operands()
+
+ from_func = exp.and_ if from_exp == exp.And else exp.or_
+ to_func = exp.and_ if to_exp == exp.And else exp.or_
+
+ if isinstance(a, to_exp) and isinstance(b, to_exp):
+ if len(tuple(a.find_all(exp.Connector))) > len(
+ tuple(b.find_all(exp.Connector))
+ ):
+ return _distribute(a, b, from_func, to_func)
+ return _distribute(b, a, from_func, to_func)
+ if isinstance(a, to_exp):
+ return _distribute(b, a, from_func, to_func)
+ if isinstance(b, to_exp):
+ return _distribute(a, b, from_func, to_func)
+
+ return expression
+
+
+def _distribute(a, b, from_func, to_func):
+ if isinstance(a, exp.Connector):
+ exp.replace_children(
+ a,
+ lambda c: to_func(
+ exp.paren(from_func(c, b.left)),
+ exp.paren(from_func(c, b.right)),
+ ),
+ )
+ else:
+ a = to_func(from_func(a, b.left), from_func(a, b.right))
+
+ return _simplify(a)
+
+
+def _simplify(node):
+ node = uniq_sort(flatten(node))
+ exp.replace_children(node, _simplify)
+ return node