Edit on GitHub

sqlglot.optimizer.normalize

  1from sqlglot import exp
  2from sqlglot.helper import while_changing
  3from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
  4
  5
  6def normalize(expression, dnf=False, max_distance=128):
  7    """
  8    Rewrite sqlglot AST into conjunctive normal form.
  9
 10    Example:
 11        >>> import sqlglot
 12        >>> expression = sqlglot.parse_one("(x AND y) OR z")
 13        >>> normalize(expression).sql()
 14        '(x OR z) AND (y OR z)'
 15
 16    Args:
 17        expression (sqlglot.Expression): expression to normalize
 18        dnf (bool): rewrite in disjunctive normal form instead
 19        max_distance (int): the maximal estimated distance from cnf to attempt conversion
 20    Returns:
 21        sqlglot.Expression: normalized expression
 22    """
 23    expression = simplify(expression)
 24
 25    expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
 26    return simplify(expression)
 27
 28
 29def normalized(expression, dnf=False):
 30    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
 31
 32    return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
 33
 34
 35def normalization_distance(expression, dnf=False):
 36    """
 37    The difference in the number of predicates between the current expression and the normalized form.
 38
 39    This is used as an estimate of the cost of the conversion which is exponential in complexity.
 40
 41    Example:
 42        >>> import sqlglot
 43        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
 44        >>> normalization_distance(expression)
 45        4
 46
 47    Args:
 48        expression (sqlglot.Expression): expression to compute distance
 49        dnf (bool): compute to dnf distance instead
 50    Returns:
 51        int: difference
 52    """
 53    return sum(_predicate_lengths(expression, dnf)) - (
 54        len(list(expression.find_all(exp.Connector))) + 1
 55    )
 56
 57
 58def _predicate_lengths(expression, dnf):
 59    """
 60    Returns a list of predicate lengths when expanded to normalized form.
 61
 62    (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
 63    """
 64    expression = expression.unnest()
 65
 66    if not isinstance(expression, exp.Connector):
 67        return [1]
 68
 69    left, right = expression.args.values()
 70
 71    if isinstance(expression, exp.And if dnf else exp.Or):
 72        return [
 73            a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
 74        ]
 75    return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
 76
 77
 78def distributive_law(expression, dnf, max_distance):
 79    """
 80    x OR (y AND z) -> (x OR y) AND (x OR z)
 81    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
 82    """
 83    if isinstance(expression.unnest(), exp.Connector):
 84        if normalization_distance(expression, dnf) > max_distance:
 85            return expression
 86
 87    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
 88
 89    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
 90
 91    if isinstance(expression, from_exp):
 92        a, b = expression.unnest_operands()
 93
 94        from_func = exp.and_ if from_exp == exp.And else exp.or_
 95        to_func = exp.and_ if to_exp == exp.And else exp.or_
 96
 97        if isinstance(a, to_exp) and isinstance(b, to_exp):
 98            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
 99                return _distribute(a, b, from_func, to_func)
100            return _distribute(b, a, from_func, to_func)
101        if isinstance(a, to_exp):
102            return _distribute(b, a, from_func, to_func)
103        if isinstance(b, to_exp):
104            return _distribute(a, b, from_func, to_func)
105
106    return expression
107
108
109def _distribute(a, b, from_func, to_func):
110    if isinstance(a, exp.Connector):
111        exp.replace_children(
112            a,
113            lambda c: to_func(
114                exp.paren(from_func(c, b.left)),
115                exp.paren(from_func(c, b.right)),
116            ),
117        )
118    else:
119        a = to_func(from_func(a, b.left), from_func(a, b.right))
120
121    return _simplify(a)
122
123
124def _simplify(node):
125    node = uniq_sort(flatten(node))
126    exp.replace_children(node, _simplify)
127    return node
def normalize(expression, dnf=False, max_distance=128):
 7def normalize(expression, dnf=False, max_distance=128):
 8    """
 9    Rewrite sqlglot AST into conjunctive normal form.
10
11    Example:
12        >>> import sqlglot
13        >>> expression = sqlglot.parse_one("(x AND y) OR z")
14        >>> normalize(expression).sql()
15        '(x OR z) AND (y OR z)'
16
17    Args:
18        expression (sqlglot.Expression): expression to normalize
19        dnf (bool): rewrite in disjunctive normal form instead
20        max_distance (int): the maximal estimated distance from cnf to attempt conversion
21    Returns:
22        sqlglot.Expression: normalized expression
23    """
24    expression = simplify(expression)
25
26    expression = while_changing(expression, lambda e: distributive_law(e, dnf, max_distance))
27    return simplify(expression)

Rewrite sqlglot AST into conjunctive normal form.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
>>> normalize(expression).sql()
'(x OR z) AND (y OR z)'
Arguments:
  • expression (sqlglot.Expression): expression to normalize
  • dnf (bool): rewrite in disjunctive normal form instead
  • max_distance (int): the maximal estimated distance from cnf to attempt conversion
Returns:

sqlglot.Expression: normalized expression

def normalized(expression, dnf=False):
30def normalized(expression, dnf=False):
31    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
32
33    return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
def normalization_distance(expression, dnf=False):
36def normalization_distance(expression, dnf=False):
37    """
38    The difference in the number of predicates between the current expression and the normalized form.
39
40    This is used as an estimate of the cost of the conversion which is exponential in complexity.
41
42    Example:
43        >>> import sqlglot
44        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
45        >>> normalization_distance(expression)
46        4
47
48    Args:
49        expression (sqlglot.Expression): expression to compute distance
50        dnf (bool): compute to dnf distance instead
51    Returns:
52        int: difference
53    """
54    return sum(_predicate_lengths(expression, dnf)) - (
55        len(list(expression.find_all(exp.Connector))) + 1
56    )

The difference in the number of predicates between the current expression and the normalized form.

This is used as an estimate of the cost of the conversion which is exponential in complexity.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
>>> normalization_distance(expression)
4
Arguments:
  • expression (sqlglot.Expression): expression to compute distance
  • dnf (bool): compute to dnf distance instead
Returns:

int: difference

def distributive_law(expression, dnf, max_distance):
 79def distributive_law(expression, dnf, max_distance):
 80    """
 81    x OR (y AND z) -> (x OR y) AND (x OR z)
 82    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
 83    """
 84    if isinstance(expression.unnest(), exp.Connector):
 85        if normalization_distance(expression, dnf) > max_distance:
 86            return expression
 87
 88    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
 89
 90    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
 91
 92    if isinstance(expression, from_exp):
 93        a, b = expression.unnest_operands()
 94
 95        from_func = exp.and_ if from_exp == exp.And else exp.or_
 96        to_func = exp.and_ if to_exp == exp.And else exp.or_
 97
 98        if isinstance(a, to_exp) and isinstance(b, to_exp):
 99            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
100                return _distribute(a, b, from_func, to_func)
101            return _distribute(b, a, from_func, to_func)
102        if isinstance(a, to_exp):
103            return _distribute(b, a, from_func, to_func)
104        if isinstance(b, to_exp):
105            return _distribute(a, b, from_func, to_func)
106
107    return expression

x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)