Edit on GitHub

sqlglot.optimizer.normalize

  1from __future__ import annotations
  2
  3import logging
  4import typing as t
  5
  6from sqlglot import exp
  7from sqlglot.errors import OptimizeError
  8from sqlglot.helper import while_changing
  9from sqlglot.optimizer.simplify import flatten, uniq_sort
 10
 11logger = logging.getLogger("sqlglot")
 12
 13
 14def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
 15    """
 16    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
 17
 18    Example:
 19        >>> import sqlglot
 20        >>> expression = sqlglot.parse_one("(x AND y) OR z")
 21        >>> normalize(expression, dnf=False).sql()
 22        '(x OR z) AND (y OR z)'
 23
 24    Args:
 25        expression: expression to normalize
 26        dnf: rewrite in disjunctive normal form instead.
 27        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
 28    Returns:
 29        sqlglot.Expression: normalized expression
 30    """
 31    cache: t.Dict[int, str] = {}
 32
 33    for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
 34        if isinstance(node, exp.Connector):
 35            if normalized(node, dnf=dnf):
 36                continue
 37
 38            distance = normalization_distance(node, dnf=dnf)
 39
 40            if distance > max_distance:
 41                logger.info(
 42                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
 43                )
 44                return expression
 45
 46            root = node is expression
 47            original = node.copy()
 48            try:
 49                node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
 50            except OptimizeError as e:
 51                logger.info(e)
 52                node.replace(original)
 53                if root:
 54                    return original
 55                return expression
 56
 57            if root:
 58                expression = node
 59
 60    return expression
 61
 62
 63def normalized(expression, dnf=False):
 64    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
 65
 66    return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
 67
 68
 69def normalization_distance(expression, dnf=False):
 70    """
 71    The difference in the number of predicates between the current expression and the normalized form.
 72
 73    This is used as an estimate of the cost of the conversion which is exponential in complexity.
 74
 75    Example:
 76        >>> import sqlglot
 77        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
 78        >>> normalization_distance(expression)
 79        4
 80
 81    Args:
 82        expression (sqlglot.Expression): expression to compute distance
 83        dnf (bool): compute to dnf distance instead
 84    Returns:
 85        int: difference
 86    """
 87    return sum(_predicate_lengths(expression, dnf)) - (
 88        sum(1 for _ in expression.find_all(exp.Connector)) + 1
 89    )
 90
 91
 92def _predicate_lengths(expression, dnf):
 93    """
 94    Returns a list of predicate lengths when expanded to normalized form.
 95
 96    (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
 97    """
 98    expression = expression.unnest()
 99
100    if not isinstance(expression, exp.Connector):
101        return (1,)
102
103    left, right = expression.args.values()
104
105    if isinstance(expression, exp.And if dnf else exp.Or):
106        return tuple(
107            a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
108        )
109    return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
110
111
112def distributive_law(expression, dnf, max_distance, cache=None):
113    """
114    x OR (y AND z) -> (x OR y) AND (x OR z)
115    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
116    """
117    if normalized(expression, dnf=dnf):
118        return expression
119
120    distance = normalization_distance(expression, dnf=dnf)
121
122    if distance > max_distance:
123        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
124
125    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
126    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
127
128    if isinstance(expression, from_exp):
129        a, b = expression.unnest_operands()
130
131        from_func = exp.and_ if from_exp == exp.And else exp.or_
132        to_func = exp.and_ if to_exp == exp.And else exp.or_
133
134        if isinstance(a, to_exp) and isinstance(b, to_exp):
135            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
136                return _distribute(a, b, from_func, to_func, cache)
137            return _distribute(b, a, from_func, to_func, cache)
138        if isinstance(a, to_exp):
139            return _distribute(b, a, from_func, to_func, cache)
140        if isinstance(b, to_exp):
141            return _distribute(a, b, from_func, to_func, cache)
142
143    return expression
144
145
146def _distribute(a, b, from_func, to_func, cache):
147    if isinstance(a, exp.Connector):
148        exp.replace_children(
149            a,
150            lambda c: to_func(
151                uniq_sort(flatten(from_func(c, b.left)), cache),
152                uniq_sort(flatten(from_func(c, b.right)), cache),
153            ),
154        )
155    else:
156        a = to_func(
157            uniq_sort(flatten(from_func(a, b.left)), cache),
158            uniq_sort(flatten(from_func(a, b.right)), cache),
159        )
160
161    return a
def normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
16    """
17    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
18
19    Example:
20        >>> import sqlglot
21        >>> expression = sqlglot.parse_one("(x AND y) OR z")
22        >>> normalize(expression, dnf=False).sql()
23        '(x OR z) AND (y OR z)'
24
25    Args:
26        expression: expression to normalize
27        dnf: rewrite in disjunctive normal form instead.
28        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
29    Returns:
30        sqlglot.Expression: normalized expression
31    """
32    cache: t.Dict[int, str] = {}
33
34    for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
35        if isinstance(node, exp.Connector):
36            if normalized(node, dnf=dnf):
37                continue
38
39            distance = normalization_distance(node, dnf=dnf)
40
41            if distance > max_distance:
42                logger.info(
43                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
44                )
45                return expression
46
47            root = node is expression
48            original = node.copy()
49            try:
50                node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
51            except OptimizeError as e:
52                logger.info(e)
53                node.replace(original)
54                if root:
55                    return original
56                return expression
57
58            if root:
59                expression = node
60
61    return expression

Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
>>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Arguments:
  • expression: expression to normalize
  • dnf: rewrite in disjunctive normal form instead.
  • max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:

sqlglot.Expression: normalized expression

def normalized(expression, dnf=False):
64def normalized(expression, dnf=False):
65    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
66
67    return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
def normalization_distance(expression, dnf=False):
70def normalization_distance(expression, dnf=False):
71    """
72    The difference in the number of predicates between the current expression and the normalized form.
73
74    This is used as an estimate of the cost of the conversion which is exponential in complexity.
75
76    Example:
77        >>> import sqlglot
78        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
79        >>> normalization_distance(expression)
80        4
81
82    Args:
83        expression (sqlglot.Expression): expression to compute distance
84        dnf (bool): compute to dnf distance instead
85    Returns:
86        int: difference
87    """
88    return sum(_predicate_lengths(expression, dnf)) - (
89        sum(1 for _ in expression.find_all(exp.Connector)) + 1
90    )

The difference in the number of predicates between the current expression and the normalized form.

This is used as an estimate of the cost of the conversion which is exponential in complexity.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
>>> normalization_distance(expression)
4
Arguments:
  • expression (sqlglot.Expression): expression to compute distance
  • dnf (bool): compute to dnf distance instead
Returns:

int: difference

def distributive_law(expression, dnf, max_distance, cache=None):
113def distributive_law(expression, dnf, max_distance, cache=None):
114    """
115    x OR (y AND z) -> (x OR y) AND (x OR z)
116    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
117    """
118    if normalized(expression, dnf=dnf):
119        return expression
120
121    distance = normalization_distance(expression, dnf=dnf)
122
123    if distance > max_distance:
124        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
125
126    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
127    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
128
129    if isinstance(expression, from_exp):
130        a, b = expression.unnest_operands()
131
132        from_func = exp.and_ if from_exp == exp.And else exp.or_
133        to_func = exp.and_ if to_exp == exp.And else exp.or_
134
135        if isinstance(a, to_exp) and isinstance(b, to_exp):
136            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
137                return _distribute(a, b, from_func, to_func, cache)
138            return _distribute(b, a, from_func, to_func, cache)
139        if isinstance(a, to_exp):
140            return _distribute(b, a, from_func, to_func, cache)
141        if isinstance(b, to_exp):
142            return _distribute(a, b, from_func, to_func, cache)
143
144    return expression

x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)