Edit on GitHub

sqlglot.optimizer.normalize

  1from __future__ import annotations
  2
  3import logging
  4
  5from sqlglot import exp
  6from sqlglot.errors import OptimizeError
  7from sqlglot.generator import cached_generator
  8from sqlglot.helper import while_changing
  9from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
 10
 11logger = logging.getLogger("sqlglot")
 12
 13
 14def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
 15    """
 16    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
 17
 18    Example:
 19        >>> import sqlglot
 20        >>> expression = sqlglot.parse_one("(x AND y) OR z")
 21        >>> normalize(expression, dnf=False).sql()
 22        '(x OR z) AND (y OR z)'
 23
 24    Args:
 25        expression: expression to normalize
 26        dnf: rewrite in disjunctive normal form instead.
 27        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
 28    Returns:
 29        sqlglot.Expression: normalized expression
 30    """
 31    generate = cached_generator()
 32
 33    for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
 34        if isinstance(node, exp.Connector):
 35            if normalized(node, dnf=dnf):
 36                continue
 37            root = node is expression
 38            original = node.copy()
 39
 40            node.transform(rewrite_between, copy=False)
 41            distance = normalization_distance(node, dnf=dnf)
 42
 43            if distance > max_distance:
 44                logger.info(
 45                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
 46                )
 47                return expression
 48
 49            try:
 50                node = node.replace(
 51                    while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
 52                )
 53            except OptimizeError as e:
 54                logger.info(e)
 55                node.replace(original)
 56                if root:
 57                    return original
 58                return expression
 59
 60            if root:
 61                expression = node
 62
 63    return expression
 64
 65
 66def normalized(expression, dnf=False):
 67    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
 68
 69    return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
 70
 71
 72def normalization_distance(expression, dnf=False):
 73    """
 74    The difference in the number of predicates between the current expression and the normalized form.
 75
 76    This is used as an estimate of the cost of the conversion which is exponential in complexity.
 77
 78    Example:
 79        >>> import sqlglot
 80        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
 81        >>> normalization_distance(expression)
 82        4
 83
 84    Args:
 85        expression (sqlglot.Expression): expression to compute distance
 86        dnf (bool): compute to dnf distance instead
 87    Returns:
 88        int: difference
 89    """
 90    return sum(_predicate_lengths(expression, dnf)) - (
 91        sum(1 for _ in expression.find_all(exp.Connector)) + 1
 92    )
 93
 94
 95def _predicate_lengths(expression, dnf):
 96    """
 97    Returns a list of predicate lengths when expanded to normalized form.
 98
 99    (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
100    """
101    expression = expression.unnest()
102
103    if not isinstance(expression, exp.Connector):
104        return (1,)
105
106    left, right = expression.args.values()
107
108    if isinstance(expression, exp.And if dnf else exp.Or):
109        return tuple(
110            a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
111        )
112    return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
113
114
115def distributive_law(expression, dnf, max_distance, generate):
116    """
117    x OR (y AND z) -> (x OR y) AND (x OR z)
118    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
119    """
120    if normalized(expression, dnf=dnf):
121        return expression
122
123    distance = normalization_distance(expression, dnf=dnf)
124
125    if distance > max_distance:
126        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
127
128    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
129    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
130
131    if isinstance(expression, from_exp):
132        a, b = expression.unnest_operands()
133
134        from_func = exp.and_ if from_exp == exp.And else exp.or_
135        to_func = exp.and_ if to_exp == exp.And else exp.or_
136
137        if isinstance(a, to_exp) and isinstance(b, to_exp):
138            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
139                return _distribute(a, b, from_func, to_func, generate)
140            return _distribute(b, a, from_func, to_func, generate)
141        if isinstance(a, to_exp):
142            return _distribute(b, a, from_func, to_func, generate)
143        if isinstance(b, to_exp):
144            return _distribute(a, b, from_func, to_func, generate)
145
146    return expression
147
148
149def _distribute(a, b, from_func, to_func, generate):
150    if isinstance(a, exp.Connector):
151        exp.replace_children(
152            a,
153            lambda c: to_func(
154                uniq_sort(flatten(from_func(c, b.left)), generate),
155                uniq_sort(flatten(from_func(c, b.right)), generate),
156                copy=False,
157            ),
158        )
159    else:
160        a = to_func(
161            uniq_sort(flatten(from_func(a, b.left)), generate),
162            uniq_sort(flatten(from_func(a, b.right)), generate),
163            copy=False,
164        )
165
166    return a
logger = <Logger sqlglot (WARNING)>
def normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
16    """
17    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
18
19    Example:
20        >>> import sqlglot
21        >>> expression = sqlglot.parse_one("(x AND y) OR z")
22        >>> normalize(expression, dnf=False).sql()
23        '(x OR z) AND (y OR z)'
24
25    Args:
26        expression: expression to normalize
27        dnf: rewrite in disjunctive normal form instead.
28        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
29    Returns:
30        sqlglot.Expression: normalized expression
31    """
32    generate = cached_generator()
33
34    for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
35        if isinstance(node, exp.Connector):
36            if normalized(node, dnf=dnf):
37                continue
38            root = node is expression
39            original = node.copy()
40
41            node.transform(rewrite_between, copy=False)
42            distance = normalization_distance(node, dnf=dnf)
43
44            if distance > max_distance:
45                logger.info(
46                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
47                )
48                return expression
49
50            try:
51                node = node.replace(
52                    while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate))
53                )
54            except OptimizeError as e:
55                logger.info(e)
56                node.replace(original)
57                if root:
58                    return original
59                return expression
60
61            if root:
62                expression = node
63
64    return expression

Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
>>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Arguments:
  • expression: expression to normalize
  • dnf: rewrite in disjunctive normal form instead.
  • max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:

sqlglot.Expression: normalized expression

def normalized(expression, dnf=False):
67def normalized(expression, dnf=False):
68    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
69
70    return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
def normalization_distance(expression, dnf=False):
73def normalization_distance(expression, dnf=False):
74    """
75    The difference in the number of predicates between the current expression and the normalized form.
76
77    This is used as an estimate of the cost of the conversion which is exponential in complexity.
78
79    Example:
80        >>> import sqlglot
81        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
82        >>> normalization_distance(expression)
83        4
84
85    Args:
86        expression (sqlglot.Expression): expression to compute distance
87        dnf (bool): compute to dnf distance instead
88    Returns:
89        int: difference
90    """
91    return sum(_predicate_lengths(expression, dnf)) - (
92        sum(1 for _ in expression.find_all(exp.Connector)) + 1
93    )

The difference in the number of predicates between the current expression and the normalized form.

This is used as an estimate of the cost of the conversion which is exponential in complexity.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
>>> normalization_distance(expression)
4
Arguments:
  • expression (sqlglot.Expression): expression to compute distance
  • dnf (bool): compute to dnf distance instead
Returns:

int: difference

def distributive_law(expression, dnf, max_distance, generate):
116def distributive_law(expression, dnf, max_distance, generate):
117    """
118    x OR (y AND z) -> (x OR y) AND (x OR z)
119    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
120    """
121    if normalized(expression, dnf=dnf):
122        return expression
123
124    distance = normalization_distance(expression, dnf=dnf)
125
126    if distance > max_distance:
127        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
128
129    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate))
130    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
131
132    if isinstance(expression, from_exp):
133        a, b = expression.unnest_operands()
134
135        from_func = exp.and_ if from_exp == exp.And else exp.or_
136        to_func = exp.and_ if to_exp == exp.And else exp.or_
137
138        if isinstance(a, to_exp) and isinstance(b, to_exp):
139            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
140                return _distribute(a, b, from_func, to_func, generate)
141            return _distribute(b, a, from_func, to_func, generate)
142        if isinstance(a, to_exp):
143            return _distribute(b, a, from_func, to_func, generate)
144        if isinstance(b, to_exp):
145            return _distribute(a, b, from_func, to_func, generate)
146
147    return expression

x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)