sqlglot.optimizer.normalize
1from __future__ import annotations 2 3import logging 4 5from sqlglot import exp 6from sqlglot.errors import OptimizeError 7from sqlglot.helper import while_changing 8from sqlglot.optimizer.scope import find_all_in_scope 9from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort 10 11logger = logging.getLogger("sqlglot") 12 13 14def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 15 """ 16 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 17 18 Example: 19 >>> import sqlglot 20 >>> expression = sqlglot.parse_one("(x AND y) OR z") 21 >>> normalize(expression, dnf=False).sql() 22 '(x OR z) AND (y OR z)' 23 24 Args: 25 expression: expression to normalize 26 dnf: rewrite in disjunctive normal form instead. 27 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 28 Returns: 29 sqlglot.Expression: normalized expression 30 """ 31 for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))): 32 if isinstance(node, exp.Connector): 33 if normalized(node, dnf=dnf): 34 continue 35 root = node is expression 36 original = node.copy() 37 38 node.transform(rewrite_between, copy=False) 39 distance = normalization_distance(node, dnf=dnf) 40 41 if distance > max_distance: 42 logger.info( 43 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 44 ) 45 return expression 46 47 try: 48 node = node.replace( 49 while_changing(node, lambda e: distributive_law(e, dnf, max_distance)) 50 ) 51 except OptimizeError as e: 52 logger.info(e) 53 node.replace(original) 54 if root: 55 return original 56 return expression 57 58 if root: 59 expression = node 60 61 return expression 62 63 64def normalized(expression: exp.Expression, dnf: bool = False) -> bool: 65 """ 66 Checks whether a given expression is in a normal form of interest. 67 68 Example: 69 >>> from sqlglot import parse_one 70 >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) 71 True 72 >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default 73 True 74 >>> normalized(parse_one("a AND (b OR c)"), dnf=True) 75 False 76 77 Args: 78 expression: The expression to check if it's normalized. 79 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 80 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 81 """ 82 ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) 83 return not any( 84 connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) 85 ) 86 87 88def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int: 89 """ 90 The difference in the number of predicates between a given expression and its normalized form. 91 92 This is used as an estimate of the cost of the conversion which is exponential in complexity. 93 94 Example: 95 >>> import sqlglot 96 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 97 >>> normalization_distance(expression) 98 4 99 100 Args: 101 expression: The expression to compute the normalization distance for. 102 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 103 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 104 105 Returns: 106 The normalization distance. 107 """ 108 return sum(_predicate_lengths(expression, dnf)) - ( 109 sum(1 for _ in expression.find_all(exp.Connector)) + 1 110 ) 111 112 113def _predicate_lengths(expression, dnf): 114 """ 115 Returns a list of predicate lengths when expanded to normalized form. 116 117 (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C). 118 """ 119 expression = expression.unnest() 120 121 if not isinstance(expression, exp.Connector): 122 return (1,) 123 124 left, right = expression.args.values() 125 126 if isinstance(expression, exp.And if dnf else exp.Or): 127 return tuple( 128 a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf) 129 ) 130 return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) 131 132 133def distributive_law(expression, dnf, max_distance): 134 """ 135 x OR (y AND z) -> (x OR y) AND (x OR z) 136 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 137 """ 138 if normalized(expression, dnf=dnf): 139 return expression 140 141 distance = normalization_distance(expression, dnf=dnf) 142 143 if distance > max_distance: 144 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 145 146 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) 147 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 148 149 if isinstance(expression, from_exp): 150 a, b = expression.unnest_operands() 151 152 from_func = exp.and_ if from_exp == exp.And else exp.or_ 153 to_func = exp.and_ if to_exp == exp.And else exp.or_ 154 155 if isinstance(a, to_exp) and isinstance(b, to_exp): 156 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 157 return _distribute(a, b, from_func, to_func) 158 return _distribute(b, a, from_func, to_func) 159 if isinstance(a, to_exp): 160 return _distribute(b, a, from_func, to_func) 161 if isinstance(b, to_exp): 162 return _distribute(a, b, from_func, to_func) 163 164 return expression 165 166 167def _distribute(a, b, from_func, to_func): 168 if isinstance(a, exp.Connector): 169 exp.replace_children( 170 a, 171 lambda c: to_func( 172 uniq_sort(flatten(from_func(c, b.left))), 173 uniq_sort(flatten(from_func(c, b.right))), 174 copy=False, 175 ), 176 ) 177 else: 178 a = to_func( 179 uniq_sort(flatten(from_func(a, b.left))), 180 uniq_sort(flatten(from_func(a, b.right))), 181 copy=False, 182 ) 183 184 return a
logger =
<Logger sqlglot (WARNING)>
def
normalize( expression: sqlglot.expressions.Expression, dnf: bool = False, max_distance: int = 128):
15def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int = 128): 16 """ 17 Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form. 18 19 Example: 20 >>> import sqlglot 21 >>> expression = sqlglot.parse_one("(x AND y) OR z") 22 >>> normalize(expression, dnf=False).sql() 23 '(x OR z) AND (y OR z)' 24 25 Args: 26 expression: expression to normalize 27 dnf: rewrite in disjunctive normal form instead. 28 max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion 29 Returns: 30 sqlglot.Expression: normalized expression 31 """ 32 for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))): 33 if isinstance(node, exp.Connector): 34 if normalized(node, dnf=dnf): 35 continue 36 root = node is expression 37 original = node.copy() 38 39 node.transform(rewrite_between, copy=False) 40 distance = normalization_distance(node, dnf=dnf) 41 42 if distance > max_distance: 43 logger.info( 44 f"Skipping normalization because distance {distance} exceeds max {max_distance}" 45 ) 46 return expression 47 48 try: 49 node = node.replace( 50 while_changing(node, lambda e: distributive_law(e, dnf, max_distance)) 51 ) 52 except OptimizeError as e: 53 logger.info(e) 54 node.replace(original) 55 if root: 56 return original 57 return expression 58 59 if root: 60 expression = node 61 62 return expression
Rewrite sqlglot AST into conjunctive normal form or disjunctive normal form.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(x AND y) OR z") >>> normalize(expression, dnf=False).sql() '(x OR z) AND (y OR z)'
Arguments:
- expression: expression to normalize
- dnf: rewrite in disjunctive normal form instead.
- max_distance (int): the maximal estimated distance from cnf/dnf to attempt conversion
Returns:
sqlglot.Expression: normalized expression
65def normalized(expression: exp.Expression, dnf: bool = False) -> bool: 66 """ 67 Checks whether a given expression is in a normal form of interest. 68 69 Example: 70 >>> from sqlglot import parse_one 71 >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) 72 True 73 >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default 74 True 75 >>> normalized(parse_one("a AND (b OR c)"), dnf=True) 76 False 77 78 Args: 79 expression: The expression to check if it's normalized. 80 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 81 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 82 """ 83 ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) 84 return not any( 85 connector.find_ancestor(ancestor) for connector in find_all_in_scope(expression, root) 86 )
Checks whether a given expression is in a normal form of interest.
Example:
>>> from sqlglot import parse_one >>> normalized(parse_one("(a AND b) OR c OR (d AND e)"), dnf=True) True >>> normalized(parse_one("(a OR b) AND c")) # Checks CNF by default True >>> normalized(parse_one("a AND (b OR c)"), dnf=True) False
Arguments:
- expression: The expression to check if it's normalized.
- dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
89def normalization_distance(expression: exp.Expression, dnf: bool = False) -> int: 90 """ 91 The difference in the number of predicates between a given expression and its normalized form. 92 93 This is used as an estimate of the cost of the conversion which is exponential in complexity. 94 95 Example: 96 >>> import sqlglot 97 >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") 98 >>> normalization_distance(expression) 99 4 100 101 Args: 102 expression: The expression to compute the normalization distance for. 103 dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). 104 Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF). 105 106 Returns: 107 The normalization distance. 108 """ 109 return sum(_predicate_lengths(expression, dnf)) - ( 110 sum(1 for _ in expression.find_all(exp.Connector)) + 1 111 )
The difference in the number of predicates between a given expression and its normalized form.
This is used as an estimate of the cost of the conversion which is exponential in complexity.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") >>> normalization_distance(expression) 4
Arguments:
- expression: The expression to compute the normalization distance for.
- dnf: Whether to check if the expression is in Disjunctive Normal Form (DNF). Default: False, i.e. we check if it's in Conjunctive Normal Form (CNF).
Returns:
The normalization distance.
def
distributive_law(expression, dnf, max_distance):
134def distributive_law(expression, dnf, max_distance): 135 """ 136 x OR (y AND z) -> (x OR y) AND (x OR z) 137 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) 138 """ 139 if normalized(expression, dnf=dnf): 140 return expression 141 142 distance = normalization_distance(expression, dnf=dnf) 143 144 if distance > max_distance: 145 raise OptimizeError(f"Normalization distance {distance} exceeds max {max_distance}") 146 147 exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) 148 to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) 149 150 if isinstance(expression, from_exp): 151 a, b = expression.unnest_operands() 152 153 from_func = exp.and_ if from_exp == exp.And else exp.or_ 154 to_func = exp.and_ if to_exp == exp.And else exp.or_ 155 156 if isinstance(a, to_exp) and isinstance(b, to_exp): 157 if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))): 158 return _distribute(a, b, from_func, to_func) 159 return _distribute(b, a, from_func, to_func) 160 if isinstance(a, to_exp): 161 return _distribute(b, a, from_func, to_func) 162 if isinstance(b, to_exp): 163 return _distribute(a, b, from_func, to_func) 164 165 return expression
x OR (y AND z) -> (x OR y) AND (x OR z) (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)