sqlglot.optimizer.normalize
1from __future__ import annotations 2 3import logging 4import typing as t 5 6from sqlglot import exp 7from sqlglot.errors import OptimizeError 8from sqlglot.helper import while_changing 9from sqlglot.optimizer.simplify import flatten, uniq_sort 10 11logger = logging.getLogger("sqlglot") 12 13 14def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 15 """ 16 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 17 18 Example: 19 >>> import sqlglot 20 >>> expression = sqlglot.parse_one("(x AND y) OR z") 21 >>> normalize(expression, dnf=False).sql() 22 '(x OR z) AND (y OR z)' 23 24 Args: 25 expression: expression to normalize 26 dnf: rewrite in disjunctive normal form instead. 27 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 28 Returns: 29 sqlglot.Expression: normalized expression 30 """ 31 cache: t.Dict[int, str] = {} 32 33 for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): 34 if isinstance(node, exp.Connector): 35 if normalized(node, dnf=dnf): 36 continue 37 38 distance = normalization_distance(node, dnf=dnf) 39 40 if distance > max_distance: 41 logger.info( 42 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 43 ) 44 return expression 45 46 root = node is expression 47 original = node.copy() 48 try: 49 node = node.replace( 50 while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache)) 51 ) 52 except OptimizeError as e: 53 logger.info(e) 54 node.replace(original) 55 if root: 56 return original 57 return expression 58 59 if root: 60 expression = node 61 62 return expression 63 64 65def normalized(expression, dnf=False): 66 ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) 67 68 return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root)) 69 70 71def normalization_distance(expression, dnf=False): 72 """ 73 The difference in the number of predicates between the current expression and the normalized form. 74 75 This is used as an estimate of the cost of the conversion which is exponential in complexity. 76 77 Example: 78 >>> import sqlglot 79 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 80 >>> normalization_distance(expression) 81 4 82 83 Args: 84 expression (sqlglot.Expression): expression to compute distance 85 dnf (bool): compute to dnf distance instead 86 Returns: 87 int: difference 88 """ 89 return sum(_predicate_lengths(expression, dnf)) - ( 90 sum(1 for _ in expression.find_all(exp.Connector)) + 1 91 ) 92 93 94def _predicate_lengths(expression, dnf): 95 """ 96 Returns a list of predicate lengths when expanded to normalized form. 97 98 (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C). 99 """ 100 expression = expression.unnest() 101 102 if not isinstance(expression, exp.Connector): 103 return (1,) 104 105 left, right = expression.args.values() 106 107 if isinstance(expression, exp.And if dnf else exp.Or): 108 return tuple( 109 a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf) 110 ) 111 return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) 112 113 114def distributive_law(expression, dnf, max_distance, cache=None): 115 """ 116 x OR (y AND z) -> (x OR y) AND (x OR z) 117 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 118 """ 119 if normalized(expression, dnf=dnf): 120 return expression 121 122 distance = normalization_distance(expression, dnf=dnf) 123 124 if distance > max_distance: 125 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 126 127 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache)) 128 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 129 130 if isinstance(expression, from_exp): 131 a, b = expression.unnest_operands() 132 133 from_func = exp.and_ if from_exp == exp.And else exp.or_ 134 to_func = exp.and_ if to_exp == exp.And else exp.or_ 135 136 if isinstance(a, to_exp) and isinstance(b, to_exp): 137 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 138 return _distribute(a, b, from_func, to_func, cache) 139 return _distribute(b, a, from_func, to_func, cache) 140 if isinstance(a, to_exp): 141 return _distribute(b, a, from_func, to_func, cache) 142 if isinstance(b, to_exp): 143 return _distribute(a, b, from_func, to_func, cache) 144 145 return expression 146 147 148def _distribute(a, b, from_func, to_func, cache): 149 if isinstance(a, exp.Connector): 150 exp.replace_children( 151 a, 152 lambda c: to_func( 153 uniq_sort(flatten(from_func(c, b.left)), cache), 154 uniq_sort(flatten(from_func(c, b.right)), cache), 155 ), 156 ) 157 else: 158 a = to_func( 159 uniq_sort(flatten(from_func(a, b.left)), cache), 160 uniq_sort(flatten(from_func(a, b.right)), cache), 161 ) 162 163 return a
def
normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 16 """ 17 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 18 19 Example: 20 >>> import sqlglot 21 >>> expression = sqlglot.parse_one("(x AND y) OR z") 22 >>> normalize(expression, dnf=False).sql() 23 '(x OR z) AND (y OR z)' 24 25 Args: 26 expression: expression to normalize 27 dnf: rewrite in disjunctive normal form instead. 28 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 29 Returns: 30 sqlglot.Expression: normalized expression 31 """ 32 cache: t.Dict[int, str] = {} 33 34 for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): 35 if isinstance(node, exp.Connector): 36 if normalized(node, dnf=dnf): 37 continue 38 39 distance = normalization_distance(node, dnf=dnf) 40 41 if distance > max_distance: 42 logger.info( 43 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 44 ) 45 return expression 46 47 root = node is expression 48 original = node.copy() 49 try: 50 node = node.replace( 51 while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache)) 52 ) 53 except OptimizeError as e: 54 logger.info(e) 55 node.replace(original) 56 if root: 57 return original 58 return expression 59 60 if root: 61 expression = node 62 63 return expression
Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(x AND y) OR z") >>> normalize(expression, dnf=False).sql() '(x OR z) AND (y OR z)'
Arguments:
- expression: expression to normalize
- dnf: rewrite in disjunctive normal form instead.
- max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
def
normalized(expression, dnf=False):
def
normalization_distance(expression, dnf=False):
72def normalization_distance(expression, dnf=False): 73 """ 74 The difference in the number of predicates between the current expression and the normalized form. 75 76 This is used as an estimate of the cost of the conversion which is exponential in complexity. 77 78 Example: 79 >>> import sqlglot 80 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 81 >>> normalization_distance(expression) 82 4 83 84 Args: 85 expression (sqlglot.Expression): expression to compute distance 86 dnf (bool): compute to dnf distance instead 87 Returns: 88 int: difference 89 """ 90 return sum(_predicate_lengths(expression, dnf)) - ( 91 sum(1 for _ in expression.find_all(exp.Connector)) + 1 92 )
The difference in the number of predicates between the current expression and the normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") >>> normalization_distance(expression) 4
Arguments:
- expression (sqlglot.Expression): expression to compute distance
- dnf (bool): compute to dnf distance instead
Returns:
int: difference
def
distributive_law(expression, dnf, max_distance, cache=None):
115def distributive_law(expression, dnf, max_distance, cache=None): 116 """ 117 x OR (y AND z) -> (x OR y) AND (x OR z) 118 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 119 """ 120 if normalized(expression, dnf=dnf): 121 return expression 122 123 distance = normalization_distance(expression, dnf=dnf) 124 125 if distance > max_distance: 126 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 127 128 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache)) 129 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 130 131 if isinstance(expression, from_exp): 132 a, b = expression.unnest_operands() 133 134 from_func = exp.and_ if from_exp == exp.And else exp.or_ 135 to_func = exp.and_ if to_exp == exp.And else exp.or_ 136 137 if isinstance(a, to_exp) and isinstance(b, to_exp): 138 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 139 return _distribute(a, b, from_func, to_func, cache) 140 return _distribute(b, a, from_func, to_func, cache) 141 if isinstance(a, to_exp): 142 return _distribute(b, a, from_func, to_func, cache) 143 if isinstance(b, to_exp): 144 return _distribute(a, b, from_func, to_func, cache) 145 146 return expression
x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)