sqlglot.optimizer.normalize
1from __future__ import annotations 2 3import logging 4 5from sqlglot import exp 6from sqlglot.errors import OptimizeError 7from sqlglot.generator import cached_generator 8from sqlglot.helper import while_changing 9from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort 10 11logger = logging.getLogger("sqlglot") 12 13 14def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 15 """ 16 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 17 18 Example: 19 >>> import sqlglot 20 >>> expression = sqlglot.parse_one("(x AND y) OR z") 21 >>> normalize(expression, dnf=False).sql() 22 '(x OR z) AND (y OR z)' 23 24 Args: 25 expression: expression to normalize 26 dnf: rewrite in disjunctive normal form instead. 27 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 28 Returns: 29 sqlglot.Expression: normalized expression 30 """ 31 generate = cached_generator() 32 33 for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): 34 if isinstance(node, exp.Connector): 35 if normalized(node, dnf=dnf): 36 continue 37 root = node is expression 38 original = node.copy() 39 40 node.transform(rewrite_between, copy=False) 41 distance = normalization_distance(node, dnf=dnf) 42 43 if distance > max_distance: 44 logger.info( 45 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 46 ) 47 return expression 48 49 try: 50 node = node.replace( 51 while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate)) 52 ) 53 except OptimizeError as e: 54 logger.info(e) 55 node.replace(original) 56 if root: 57 return original 58 return expression 59 60 if root: 61 expression = node 62 63 return expression 64 65 66def normalized(expression, dnf=False): 67 ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) 68 69 return not any(connector.find_ancestor(ancestor) for connector in expression.find_all(root)) 70 71 72def normalization_distance(expression, dnf=False): 73 """ 74 The difference in the number of predicates between the current expression and the normalized form. 75 76 This is used as an estimate of the cost of the conversion which is exponential in complexity. 77 78 Example: 79 >>> import sqlglot 80 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 81 >>> normalization_distance(expression) 82 4 83 84 Args: 85 expression (sqlglot.Expression): expression to compute distance 86 dnf (bool): compute to dnf distance instead 87 Returns: 88 int: difference 89 """ 90 return sum(_predicate_lengths(expression, dnf)) - ( 91 sum(1 for _ in expression.find_all(exp.Connector)) + 1 92 ) 93 94 95def _predicate_lengths(expression, dnf): 96 """ 97 Returns a list of predicate lengths when expanded to normalized form. 98 99 (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C). 100 """ 101 expression = expression.unnest() 102 103 if not isinstance(expression, exp.Connector): 104 return (1,) 105 106 left, right = expression.args.values() 107 108 if isinstance(expression, exp.And if dnf else exp.Or): 109 return tuple( 110 a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf) 111 ) 112 return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) 113 114 115def distributive_law(expression, dnf, max_distance, generate): 116 """ 117 x OR (y AND z) -> (x OR y) AND (x OR z) 118 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 119 """ 120 if normalized(expression, dnf=dnf): 121 return expression 122 123 distance = normalization_distance(expression, dnf=dnf) 124 125 if distance > max_distance: 126 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 127 128 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate)) 129 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 130 131 if isinstance(expression, from_exp): 132 a, b = expression.unnest_operands() 133 134 from_func = exp.and_ if from_exp == exp.And else exp.or_ 135 to_func = exp.and_ if to_exp == exp.And else exp.or_ 136 137 if isinstance(a, to_exp) and isinstance(b, to_exp): 138 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 139 return _distribute(a, b, from_func, to_func, generate) 140 return _distribute(b, a, from_func, to_func, generate) 141 if isinstance(a, to_exp): 142 return _distribute(b, a, from_func, to_func, generate) 143 if isinstance(b, to_exp): 144 return _distribute(a, b, from_func, to_func, generate) 145 146 return expression 147 148 149def _distribute(a, b, from_func, to_func, generate): 150 if isinstance(a, exp.Connector): 151 exp.replace_children( 152 a, 153 lambda c: to_func( 154 uniq_sort(flatten(from_func(c, b.left)), generate), 155 uniq_sort(flatten(from_func(c, b.right)), generate), 156 copy=False, 157 ), 158 ) 159 else: 160 a = to_func( 161 uniq_sort(flatten(from_func(a, b.left)), generate), 162 uniq_sort(flatten(from_func(a, b.right)), generate), 163 copy=False, 164 ) 165 166 return a
def
normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 16 """ 17 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 18 19 Example: 20 >>> import sqlglot 21 >>> expression = sqlglot.parse_one("(x AND y) OR z") 22 >>> normalize(expression, dnf=False).sql() 23 '(x OR z) AND (y OR z)' 24 25 Args: 26 expression: expression to normalize 27 dnf: rewrite in disjunctive normal form instead. 28 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 29 Returns: 30 sqlglot.Expression: normalized expression 31 """ 32 generate = cached_generator() 33 34 for node, *_ in tuple(expression.walk(prune=lambda e, *_: isinstance(e, exp.Connector))): 35 if isinstance(node, exp.Connector): 36 if normalized(node, dnf=dnf): 37 continue 38 root = node is expression 39 original = node.copy() 40 41 node.transform(rewrite_between, copy=False) 42 distance = normalization_distance(node, dnf=dnf) 43 44 if distance > max_distance: 45 logger.info( 46 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 47 ) 48 return expression 49 50 try: 51 node = node.replace( 52 while_changing(node, lambda e: distributive_law(e, dnf, max_distance, generate)) 53 ) 54 except OptimizeError as e: 55 logger.info(e) 56 node.replace(original) 57 if root: 58 return original 59 return expression 60 61 if root: 62 expression = node 63 64 return expression
Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(x AND y) OR z") >>> normalize(expression, dnf=False).sql() '(x OR z) AND (y OR z)'
Arguments:
- expression: expression to normalize
- dnf: rewrite in disjunctive normal form instead.
- max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
def
normalized(expression, dnf=False):
def
normalization_distance(expression, dnf=False):
73def normalization_distance(expression, dnf=False): 74 """ 75 The difference in the number of predicates between the current expression and the normalized form. 76 77 This is used as an estimate of the cost of the conversion which is exponential in complexity. 78 79 Example: 80 >>> import sqlglot 81 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 82 >>> normalization_distance(expression) 83 4 84 85 Args: 86 expression (sqlglot.Expression): expression to compute distance 87 dnf (bool): compute to dnf distance instead 88 Returns: 89 int: difference 90 """ 91 return sum(_predicate_lengths(expression, dnf)) - ( 92 sum(1 for _ in expression.find_all(exp.Connector)) + 1 93 )
The difference in the number of predicates between the current expression and the normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") >>> normalization_distance(expression) 4
Arguments:
- expression (sqlglot.Expression): expression to compute distance
- dnf (bool): compute to dnf distance instead
Returns:
int: difference
def
distributive_law(expression, dnf, max_distance, generate):
116def distributive_law(expression, dnf, max_distance, generate): 117 """ 118 x OR (y AND z) -> (x OR y) AND (x OR z) 119 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 120 """ 121 if normalized(expression, dnf=dnf): 122 return expression 123 124 distance = normalization_distance(expression, dnf=dnf) 125 126 if distance > max_distance: 127 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 128 129 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance, generate)) 130 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 131 132 if isinstance(expression, from_exp): 133 a, b = expression.unnest_operands() 134 135 from_func = exp.and_ if from_exp == exp.And else exp.or_ 136 to_func = exp.and_ if to_exp == exp.And else exp.or_ 137 138 if isinstance(a, to_exp) and isinstance(b, to_exp): 139 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 140 return _distribute(a, b, from_func, to_func, generate) 141 return _distribute(b, a, from_func, to_func, generate) 142 if isinstance(a, to_exp): 143 return _distribute(b, a, from_func, to_func, generate) 144 if isinstance(b, to_exp): 145 return _distribute(a, b, from_func, to_func, generate) 146 147 return expression
x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)