sqlglot.optimizer.normalize
1from __future__ import annotations 2 3import logging 4import typing as t 5 6from sqlglot import exp 7from sqlglot.errors import OptimizeError 8from sqlglot.helper import while_changing 9from sqlglot.optimizer.simplify import flatten, uniq_sort 10 11logger = logging.getLogger("sqlglot") 12 13 14def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 15 """ 16 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 17 18 Example: 19 >>> import sqlglot 20 >>> expression = sqlglot.parse_one("(x AND y) OR z") 21 >>> normalize(expression, dnf=False).sql() 22 '(x OR z) AND (y OR z)' 23 24 Args: 25 expression: expression to normalize 26 dnf: rewrite in disjunctive normal form instead. 27 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 28 Returns: 29 sqlglot.Expression: normalized expression 30 """ 31 cache: t.Dict[int, str] = {} 32 33 for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): 34 if isinstance(node, exp.Connector): 35 if normalized(node, dnf=dnf): 36 continue 37 38 distance = normalization_distance(node, dnf=dnf) 39 40 if distance > max_distance: 41 logger.info( 42 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 43 ) 44 return expression 45 46 root = node is expression 47 original = node.copy() 48 try: 49 node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache)) 50 except OptimizeError as e: 51 logger.info(e) 52 node.replace(original) 53 if root: 54 return original 55 return expression 56 57 if root: 58 expression = node 59 60 return expression 61 62 63def normalized(expression, dnf=False): 64 ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) 65 66 return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root)) 67 68 69def normalization_distance(expression, dnf=False): 70 """ 71 The difference in the number of predicates between the current expression and the normalized form. 72 73 This is used as an estimate of the cost of the conversion which is exponential in complexity. 74 75 Example: 76 >>> import sqlglot 77 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 78 >>> normalization_distance(expression) 79 4 80 81 Args: 82 expression (sqlglot.Expression): expression to compute distance 83 dnf (bool): compute to dnf distance instead 84 Returns: 85 int: difference 86 """ 87 return sum(_predicate_lengths(expression, dnf)) - ( 88 sum(1 for _ in expression.find_all(exp.Connector)) + 1 89 ) 90 91 92def _predicate_lengths(expression, dnf): 93 """ 94 Returns a list of predicate lengths when expanded to normalized form. 95 96 (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C). 97 """ 98 expression = expression.unnest() 99 100 if not isinstance(expression, exp.Connector): 101 return (1,) 102 103 left, right = expression.args.values() 104 105 if isinstance(expression, exp.And if dnf else exp.Or): 106 return tuple( 107 a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf) 108 ) 109 return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) 110 111 112def distributive_law(expression, dnf, max_distance, cache=None): 113 """ 114 x OR (y AND z) -> (x OR y) AND (x OR z) 115 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 116 """ 117 if normalized(expression, dnf=dnf): 118 return expression 119 120 distance = normalization_distance(expression, dnf=dnf) 121 122 if distance > max_distance: 123 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 124 125 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache)) 126 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 127 128 if isinstance(expression, from_exp): 129 a, b = expression.unnest_operands() 130 131 from_func = exp.and_ if from_exp == exp.And else exp.or_ 132 to_func = exp.and_ if to_exp == exp.And else exp.or_ 133 134 if isinstance(a, to_exp) and isinstance(b, to_exp): 135 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 136 return _distribute(a, b, from_func, to_func, cache) 137 return _distribute(b, a, from_func, to_func, cache) 138 if isinstance(a, to_exp): 139 return _distribute(b, a, from_func, to_func, cache) 140 if isinstance(b, to_exp): 141 return _distribute(a, b, from_func, to_func, cache) 142 143 return expression 144 145 146def _distribute(a, b, from_func, to_func, cache): 147 if isinstance(a, exp.Connector): 148 exp.replace_children( 149 a, 150 lambda c: to_func( 151 uniq_sort(flatten(from_func(c, b.left)), cache), 152 uniq_sort(flatten(from_func(c, b.right)), cache), 153 ), 154 ) 155 else: 156 a = to_func( 157 uniq_sort(flatten(from_func(a, b.left)), cache), 158 uniq_sort(flatten(from_func(a, b.right)), cache), 159 ) 160 161 return a
def
normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 16 """ 17 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 18 19 Example: 20 >>> import sqlglot 21 >>> expression = sqlglot.parse_one("(x AND y) OR z") 22 >>> normalize(expression, dnf=False).sql() 23 '(x OR z) AND (y OR z)' 24 25 Args: 26 expression: expression to normalize 27 dnf: rewrite in disjunctive normal form instead. 28 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 29 Returns: 30 sqlglot.Expression: normalized expression 31 """ 32 cache: t.Dict[int, str] = {} 33 34 for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): 35 if isinstance(node, exp.Connector): 36 if normalized(node, dnf=dnf): 37 continue 38 39 distance = normalization_distance(node, dnf=dnf) 40 41 if distance > max_distance: 42 logger.info( 43 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 44 ) 45 return expression 46 47 root = node is expression 48 original = node.copy() 49 try: 50 node = while_changing(node, lambda e: distributive_law(e, dnf, max_distance, cache)) 51 except OptimizeError as e: 52 logger.info(e) 53 node.replace(original) 54 if root: 55 return original 56 return expression 57 58 if root: 59 expression = node 60 61 return expression
Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(x AND y) OR z") >>> normalize(expression, dnf=False).sql() '(x OR z) AND (y OR z)'
Arguments:
- expression: expression to normalize
- dnf: rewrite in disjunctive normal form instead.
- max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
def
normalized(expression, dnf=False):
def
normalization_distance(expression, dnf=False):
70def normalization_distance(expression, dnf=False): 71 """ 72 The difference in the number of predicates between the current expression and the normalized form. 73 74 This is used as an estimate of the cost of the conversion which is exponential in complexity. 75 76 Example: 77 >>> import sqlglot 78 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 79 >>> normalization_distance(expression) 80 4 81 82 Args: 83 expression (sqlglot.Expression): expression to compute distance 84 dnf (bool): compute to dnf distance instead 85 Returns: 86 int: difference 87 """ 88 return sum(_predicate_lengths(expression, dnf)) - ( 89 sum(1 for _ in expression.find_all(exp.Connector)) + 1 90 )
The difference in the number of predicates between the current expression and the normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") >>> normalization_distance(expression) 4
Arguments:
- expression (sqlglot.Expression): expression to compute distance
- dnf (bool): compute to dnf distance instead
Returns:
int: difference
def
distributive_law(expression, dnf, max_distance, cache=None):
113def distributive_law(expression, dnf, max_distance, cache=None): 114 """ 115 x OR (y AND z) -> (x OR y) AND (x OR z) 116 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 117 """ 118 if normalized(expression, dnf=dnf): 119 return expression 120 121 distance = normalization_distance(expression, dnf=dnf) 122 123 if distance > max_distance: 124 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 125 126 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, cache)) 127 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 128 129 if isinstance(expression, from_exp): 130 a, b = expression.unnest_operands() 131 132 from_func = exp.and_ if from_exp == exp.And else exp.or_ 133 to_func = exp.and_ if to_exp == exp.And else exp.or_ 134 135 if isinstance(a, to_exp) and isinstance(b, to_exp): 136 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 137 return _distribute(a, b, from_func, to_func, cache) 138 return _distribute(b, a, from_func, to_func, cache) 139 if isinstance(a, to_exp): 140 return _distribute(b, a, from_func, to_func, cache) 141 if isinstance(b, to_exp): 142 return _distribute(a, b, from_func, to_func, cache) 143 144 return expression
x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)