Edit on GitHub

sqlglot.optimizer.normalize

  1from __future__ import annotations
  2
  3import logging
  4import typing as t
  5
  6from sqlglot import exp
  7from sqlglot.errors import OptimizeError
  8from sqlglot.helper import while_changing
  9from sqlglot.optimizer.simplify import flatten, uniq_sort
 10
 11logger = logging.getLogger("sqlglot")
 12
 13
 14def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
 15    """
 16    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
 17
 18    Example:
 19        >>> import sqlglot
 20        >>> expression = sqlglot.parse_one("(x AND y) OR z")
 21        >>> normalize(expression, dnf=False).sql()
 22        '(x OR z) AND (y OR z)'
 23
 24    Args:
 25        expression: expression to normalize
 26        dnf: rewrite in disjunctive normal form instead.
 27        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
 28    Returns:
 29        sqlglot.Expression: normalized expression
 30    """
 31    cache: t.Dict[int, str] = {}
 32
 33    for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
 34        if isinstance(node, exp.Connector):
 35            if normalized(node, dnf=dnf):
 36                continue
 37
 38            distance = normalization_distance(node, dnf=dnf)
 39
 40            if distance > max_distance:
 41                logger.info(
 42                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
 43                )
 44                return expression
 45
 46            root = node is expression
 47            original = node.copy()
 48            try:
 49                node = node.replace(
 50                    while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
 51                )
 52            except OptimizeError as e:
 53                logger.info(e)
 54                node.replace(original)
 55                if root:
 56                    return original
 57                return expression
 58
 59            if root:
 60                expression = node
 61
 62    return expression
 63
 64
 65def normalized(expression, dnf=False):
 66    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
 67
 68    return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
 69
 70
 71def normalization_distance(expression, dnf=False):
 72    """
 73    The difference in the number of predicates between the current expression and the normalized form.
 74
 75    This is used as an estimate of the cost of the conversion which is exponential in complexity.
 76
 77    Example:
 78        >>> import sqlglot
 79        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
 80        >>> normalization_distance(expression)
 81        4
 82
 83    Args:
 84        expression (sqlglot.Expression): expression to compute distance
 85        dnf (bool): compute to dnf distance instead
 86    Returns:
 87        int: difference
 88    """
 89    return sum(_predicate_lengths(expression, dnf)) - (
 90        sum(1 for _ in expression.find_all(exp.Connector)) + 1
 91    )
 92
 93
 94def _predicate_lengths(expression, dnf):
 95    """
 96    Returns a list of predicate lengths when expanded to normalized form.
 97
 98    (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
 99    """
100    expression = expression.unnest()
101
102    if not isinstance(expression, exp.Connector):
103        return (1,)
104
105    left, right = expression.args.values()
106
107    if isinstance(expression, exp.And if dnf else exp.Or):
108        return tuple(
109            a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)
110        )
111    return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
112
113
114def distributive_law(expression, dnf, max_distance, cache=None):
115    """
116    x OR (y AND z) -> (x OR y) AND (x OR z)
117    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
118    """
119    if normalized(expression, dnf=dnf):
120        return expression
121
122    distance = normalization_distance(expression, dnf=dnf)
123
124    if distance > max_distance:
125        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
126
127    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
128    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
129
130    if isinstance(expression, from_exp):
131        a, b = expression.unnest_operands()
132
133        from_func = exp.and_ if from_exp == exp.And else exp.or_
134        to_func = exp.and_ if to_exp == exp.And else exp.or_
135
136        if isinstance(a, to_exp) and isinstance(b, to_exp):
137            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
138                return _distribute(a, b, from_func, to_func, cache)
139            return _distribute(b, a, from_func, to_func, cache)
140        if isinstance(a, to_exp):
141            return _distribute(b, a, from_func, to_func, cache)
142        if isinstance(b, to_exp):
143            return _distribute(a, b, from_func, to_func, cache)
144
145    return expression
146
147
148def _distribute(a, b, from_func, to_func, cache):
149    if isinstance(a, exp.Connector):
150        exp.replace_children(
151            a,
152            lambda c: to_func(
153                uniq_sort(flatten(from_func(c, b.left)), cache),
154                uniq_sort(flatten(from_func(c, b.right)), cache),
155            ),
156        )
157    else:
158        a = to_func(
159            uniq_sort(flatten(from_func(a, b.left)), cache),
160            uniq_sort(flatten(from_func(a, b.right)), cache),
161        )
162
163    return a
def normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128):
16    """
17    Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
18
19    Example:
20        >>> import sqlglot
21        >>> expression = sqlglot.parse_one("(x AND y) OR z")
22        >>> normalize(expression, dnf=False).sql()
23        '(x OR z) AND (y OR z)'
24
25    Args:
26        expression: expression to normalize
27        dnf: rewrite in disjunctive normal form instead.
28        max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
29    Returns:
30        sqlglot.Expression: normalized expression
31    """
32    cache: t.Dict[int, str] = {}
33
34    for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))):
35        if isinstance(node, exp.Connector):
36            if normalized(node, dnf=dnf):
37                continue
38
39            distance = normalization_distance(node, dnf=dnf)
40
41            if distance > max_distance:
42                logger.info(
43                    f"Skipping normalization because distance {distance} exceeds max {max_distance}"
44                )
45                return expression
46
47            root = node is expression
48            original = node.copy()
49            try:
50                node = node.replace(
51                    while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache))
52                )
53            except OptimizeError as e:
54                logger.info(e)
55                node.replace(original)
56                if root:
57                    return original
58                return expression
59
60            if root:
61                expression = node
62
63    return expression

Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(x AND y) OR z")
>>> normalize(expression, dnf=False).sql()
'(x OR z) AND (y OR z)'
Arguments:
  • expression: expression to normalize
  • dnf: rewrite in disjunctive normal form instead.
  • max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:

sqlglot.Expression: normalized expression

def normalized(expression, dnf=False):
66def normalized(expression, dnf=False):
67    ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
68
69    return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root))
def normalization_distance(expression, dnf=False):
72def normalization_distance(expression, dnf=False):
73    """
74    The difference in the number of predicates between the current expression and the normalized form.
75
76    This is used as an estimate of the cost of the conversion which is exponential in complexity.
77
78    Example:
79        >>> import sqlglot
80        >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
81        >>> normalization_distance(expression)
82        4
83
84    Args:
85        expression (sqlglot.Expression): expression to compute distance
86        dnf (bool): compute to dnf distance instead
87    Returns:
88        int: difference
89    """
90    return sum(_predicate_lengths(expression, dnf)) - (
91        sum(1 for _ in expression.find_all(exp.Connector)) + 1
92    )

The difference in the number of predicates between the current expression and the normalized form.

This is used as an estimate of the cost of the conversion which is exponential in complexity.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
>>> normalization_distance(expression)
4
Arguments:
  • expression (sqlglot.Expression): expression to compute distance
  • dnf (bool): compute to dnf distance instead
Returns:

int: difference

def distributive_law(expression, dnf, max_distance, cache=None):
115def distributive_law(expression, dnf, max_distance, cache=None):
116    """
117    x OR (y AND z) -> (x OR y) AND (x OR z)
118    (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
119    """
120    if normalized(expression, dnf=dnf):
121        return expression
122
123    distance = normalization_distance(expression, dnf=dnf)
124
125    if distance > max_distance:
126        raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}")
127
128    exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache))
129    to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
130
131    if isinstance(expression, from_exp):
132        a, b = expression.unnest_operands()
133
134        from_func = exp.and_ if from_exp == exp.And else exp.or_
135        to_func = exp.and_ if to_exp == exp.And else exp.or_
136
137        if isinstance(a, to_exp) and isinstance(b, to_exp):
138            if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
139                return _distribute(a, b, from_func, to_func, cache)
140            return _distribute(b, a, from_func, to_func, cache)
141        if isinstance(a, to_exp):
142            return _distribute(b, a, from_func, to_func, cache)
143        if isinstance(b, to_exp):
144            return _distribute(a, b, from_func, to_func, cache)
145
146    return expression

x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)