sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4from collections import deque 5from decimal import Decimal 6 7from sqlglot import exp 8from sqlglot.generator import cached_generator 9from sqlglot.helper import first, while_changing 10 11# Final means that an expression should not be simplified 12FINAL = "final" 13 14 15def simplify(expression): 16 """ 17 Rewrite sqlglot AST to simplify expressions. 18 19 Example: 20 >>> import sqlglot 21 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 22 >>> simplify(expression).sql() 23 'TRUE' 24 25 Args: 26 expression (sqlglot.Expression): expression to simplify 27 Returns: 28 sqlglot.Expression: simplified expression 29 """ 30 31 generate = cached_generator() 32 33 # group by expressions cannot be simplified, for example 34 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 35 # the projection must exactly match the group by key 36 for group in expression.find_all(exp.Group): 37 select = group.parent 38 groups = set(group.expressions) 39 group.meta[FINAL] = True 40 41 for e in select.selects: 42 for node, *_ in e.walk(): 43 if node in groups: 44 e.meta[FINAL] = True 45 break 46 47 having = select.args.get("having") 48 if having: 49 for node, *_ in having.walk(): 50 if node in groups: 51 having.meta[FINAL] = True 52 break 53 54 def _simplify(expression, root=True): 55 if expression.meta.get(FINAL): 56 return expression 57 node = expression 58 node = rewrite_between(node) 59 node = uniq_sort(node, generate, root) 60 node = absorb_and_eliminate(node, root) 61 exp.replace_children(node, lambda e: _simplify(e, False)) 62 node = simplify_not(node) 63 node = flatten(node) 64 node = simplify_connectors(node, root) 65 node = remove_compliments(node, root) 66 node.parent = expression.parent 67 node = simplify_literals(node, root) 68 node = simplify_parens(node) 69 if root: 70 expression.replace(node) 71 return node 72 73 expression = while_changing(expression, _simplify) 74 remove_where_true(expression) 75 return expression 76 77 78def rewrite_between(expression: exp.Expression) -> exp.Expression: 79 """Rewrite x between y and z to x >= y AND x <= z. 80 81 This is done because comparison simplification is only done on lt/lte/gt/gte. 82 """ 83 if isinstance(expression, exp.Between): 84 return exp.and_( 85 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 86 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 87 copy=False, 88 ) 89 return expression 90 91 92def simplify_not(expression): 93 """ 94 Demorgan's Law 95 NOT (x OR y) -> NOT x AND NOT y 96 NOT (x AND y) -> NOT x OR NOT y 97 """ 98 if isinstance(expression, exp.Not): 99 if is_null(expression.this): 100 return exp.null() 101 if isinstance(expression.this, exp.Paren): 102 condition = expression.this.unnest() 103 if isinstance(condition, exp.And): 104 return exp.or_( 105 exp.not_(condition.left, copy=False), 106 exp.not_(condition.right, copy=False), 107 copy=False, 108 ) 109 if isinstance(condition, exp.Or): 110 return exp.and_( 111 exp.not_(condition.left, copy=False), 112 exp.not_(condition.right, copy=False), 113 copy=False, 114 ) 115 if is_null(condition): 116 return exp.null() 117 if always_true(expression.this): 118 return exp.false() 119 if is_false(expression.this): 120 return exp.true() 121 if isinstance(expression.this, exp.Not): 122 # double negation 123 # NOT NOT x -> x 124 return expression.this.this 125 return expression 126 127 128def flatten(expression): 129 """ 130 A AND (B AND C) -> A AND B AND C 131 A OR (B OR C) -> A OR B OR C 132 """ 133 if isinstance(expression, exp.Connector): 134 for node in expression.args.values(): 135 child = node.unnest() 136 if isinstance(child, expression.__class__): 137 node.replace(child) 138 return expression 139 140 141def simplify_connectors(expression, root=True): 142 def _simplify_connectors(expression, left, right): 143 if left == right: 144 return left 145 if isinstance(expression, exp.And): 146 if is_false(left) or is_false(right): 147 return exp.false() 148 if is_null(left) or is_null(right): 149 return exp.null() 150 if always_true(left) and always_true(right): 151 return exp.true() 152 if always_true(left): 153 return right 154 if always_true(right): 155 return left 156 return _simplify_comparison(expression, left, right) 157 elif isinstance(expression, exp.Or): 158 if always_true(left) or always_true(right): 159 return exp.true() 160 if is_false(left) and is_false(right): 161 return exp.false() 162 if ( 163 (is_null(left) and is_null(right)) 164 or (is_null(left) and is_false(right)) 165 or (is_false(left) and is_null(right)) 166 ): 167 return exp.null() 168 if is_false(left): 169 return right 170 if is_false(right): 171 return left 172 return _simplify_comparison(expression, left, right, or_=True) 173 174 if isinstance(expression, exp.Connector): 175 return _flat_simplify(expression, _simplify_connectors, root) 176 return expression 177 178 179LT_LTE = (exp.LT, exp.LTE) 180GT_GTE = (exp.GT, exp.GTE) 181 182COMPARISONS = ( 183 *LT_LTE, 184 *GT_GTE, 185 exp.EQ, 186 exp.NEQ, 187) 188 189INVERSE_COMPARISONS = { 190 exp.LT: exp.GT, 191 exp.GT: exp.LT, 192 exp.LTE: exp.GTE, 193 exp.GTE: exp.LTE, 194} 195 196 197def _simplify_comparison(expression, left, right, or_=False): 198 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 199 ll, lr = left.args.values() 200 rl, rr = right.args.values() 201 202 largs = {ll, lr} 203 rargs = {rl, rr} 204 205 matching = largs & rargs 206 columns = {m for m in matching if isinstance(m, exp.Column)} 207 208 if matching and columns: 209 try: 210 l = first(largs - columns) 211 r = first(rargs - columns) 212 except StopIteration: 213 return expression 214 215 # make sure the comparison is always of the form x > 1 instead of 1 < x 216 if left.__class__ in INVERSE_COMPARISONS and l == ll: 217 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 218 if right.__class__ in INVERSE_COMPARISONS and r == rl: 219 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 220 221 if l.is_number and r.is_number: 222 l = float(l.name) 223 r = float(r.name) 224 elif l.is_string and r.is_string: 225 l = l.name 226 r = r.name 227 else: 228 return None 229 230 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 231 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 232 return left if (av > bv if or_ else av <= bv) else right 233 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 234 return left if (av < bv if or_ else av >= bv) else right 235 236 # we can't ever shortcut to true because the column could be null 237 if not or_: 238 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 239 if av <= bv: 240 return exp.false() 241 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 242 if av >= bv: 243 return exp.false() 244 elif isinstance(a, exp.EQ): 245 if isinstance(b, exp.LT): 246 return exp.false() if av >= bv else a 247 if isinstance(b, exp.LTE): 248 return exp.false() if av > bv else a 249 if isinstance(b, exp.GT): 250 return exp.false() if av <= bv else a 251 if isinstance(b, exp.GTE): 252 return exp.false() if av < bv else a 253 if isinstance(b, exp.NEQ): 254 return exp.false() if av == bv else a 255 return None 256 257 258def remove_compliments(expression, root=True): 259 """ 260 Removing compliments. 261 262 A AND NOT A -> FALSE 263 A OR NOT A -> TRUE 264 """ 265 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 266 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 267 268 for a, b in itertools.permutations(expression.flatten(), 2): 269 if is_complement(a, b): 270 return compliment 271 return expression 272 273 274def uniq_sort(expression, generate, root=True): 275 """ 276 Uniq and sort a connector. 277 278 C AND A AND B AND B -> A AND B AND C 279 """ 280 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 281 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 282 flattened = tuple(expression.flatten()) 283 deduped = {generate(e): e for e in flattened} 284 arr = tuple(deduped.items()) 285 286 # check if the operands are already sorted, if not sort them 287 # A AND C AND B -> A AND B AND C 288 for i, (sql, e) in enumerate(arr[1:]): 289 if sql < arr[i][0]: 290 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 291 break 292 else: 293 # we didn't have to sort but maybe we need to dedup 294 if len(deduped) < len(flattened): 295 expression = result_func(*deduped.values(), copy=False) 296 297 return expression 298 299 300def absorb_and_eliminate(expression, root=True): 301 """ 302 absorption: 303 A AND (A OR B) -> A 304 A OR (A AND B) -> A 305 A AND (NOT A OR B) -> A AND B 306 A OR (NOT A AND B) -> A OR B 307 elimination: 308 (A AND B) OR (A AND NOT B) -> A 309 (A OR B) AND (A OR NOT B) -> A 310 """ 311 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 312 kind = exp.Or if isinstance(expression, exp.And) else exp.And 313 314 for a, b in itertools.permutations(expression.flatten(), 2): 315 if isinstance(a, kind): 316 aa, ab = a.unnest_operands() 317 318 # absorb 319 if is_complement(b, aa): 320 aa.replace(exp.true() if kind == exp.And else exp.false()) 321 elif is_complement(b, ab): 322 ab.replace(exp.true() if kind == exp.And else exp.false()) 323 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 324 a.replace(exp.false() if kind == exp.And else exp.true()) 325 elif isinstance(b, kind): 326 # eliminate 327 rhs = b.unnest_operands() 328 ba, bb = rhs 329 330 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 331 a.replace(aa) 332 b.replace(aa) 333 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 334 a.replace(ab) 335 b.replace(ab) 336 337 return expression 338 339 340def simplify_literals(expression, root=True): 341 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 342 return _flat_simplify(expression, _simplify_binary, root) 343 elif isinstance(expression, exp.Neg): 344 this = expression.this 345 if this.is_number: 346 value = this.name 347 if value[0] == "-": 348 return exp.Literal.number(value[1:]) 349 return exp.Literal.number(f"-{value}") 350 351 return expression 352 353 354def _simplify_binary(expression, a, b): 355 if isinstance(expression, exp.Is): 356 if isinstance(b, exp.Not): 357 c = b.this 358 not_ = True 359 else: 360 c = b 361 not_ = False 362 363 if is_null(c): 364 if isinstance(a, exp.Literal): 365 return exp.true() if not_ else exp.false() 366 if is_null(a): 367 return exp.false() if not_ else exp.true() 368 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 369 return None 370 elif is_null(a) or is_null(b): 371 return exp.null() 372 373 if a.is_number and b.is_number: 374 a = int(a.name) if a.is_int else Decimal(a.name) 375 b = int(b.name) if b.is_int else Decimal(b.name) 376 377 if isinstance(expression, exp.Add): 378 return exp.Literal.number(a + b) 379 if isinstance(expression, exp.Sub): 380 return exp.Literal.number(a - b) 381 if isinstance(expression, exp.Mul): 382 return exp.Literal.number(a * b) 383 if isinstance(expression, exp.Div): 384 # engines have differing int div behavior so intdiv is not safe 385 if isinstance(a, int) and isinstance(b, int): 386 return None 387 return exp.Literal.number(a / b) 388 389 boolean = eval_boolean(expression, a, b) 390 391 if boolean: 392 return boolean 393 elif a.is_string and b.is_string: 394 boolean = eval_boolean(expression, a.this, b.this) 395 396 if boolean: 397 return boolean 398 elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): 399 a, b = extract_date(a), extract_interval(b) 400 if a and b: 401 if isinstance(expression, exp.Add): 402 return date_literal(a + b) 403 if isinstance(expression, exp.Sub): 404 return date_literal(a - b) 405 elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): 406 a, b = extract_interval(a), extract_date(b) 407 # you cannot subtract a date from an interval 408 if a and b and isinstance(expression, exp.Add): 409 return date_literal(a + b) 410 411 return None 412 413 414def simplify_parens(expression): 415 if not isinstance(expression, exp.Paren): 416 return expression 417 418 this = expression.this 419 parent = expression.parent 420 421 if not isinstance(this, exp.Select) and ( 422 not isinstance(parent, (exp.Condition, exp.Binary)) 423 or isinstance(this, exp.Predicate) 424 or not isinstance(this, exp.Binary) 425 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 426 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 427 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 428 ): 429 return expression.this 430 return expression 431 432 433def remove_where_true(expression): 434 for where in expression.find_all(exp.Where): 435 if always_true(where.this): 436 where.parent.set("where", None) 437 for join in expression.find_all(exp.Join): 438 if ( 439 always_true(join.args.get("on")) 440 and not join.args.get("using") 441 and not join.args.get("method") 442 ): 443 join.set("on", None) 444 join.set("side", None) 445 join.set("kind", "CROSS") 446 447 448def always_true(expression): 449 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 450 expression, exp.Literal 451 ) 452 453 454def is_complement(a, b): 455 return isinstance(b, exp.Not) and b.this == a 456 457 458def is_false(a: exp.Expression) -> bool: 459 return type(a) is exp.Boolean and not a.this 460 461 462def is_null(a: exp.Expression) -> bool: 463 return type(a) is exp.Null 464 465 466def eval_boolean(expression, a, b): 467 if isinstance(expression, (exp.EQ, exp.Is)): 468 return boolean_literal(a == b) 469 if isinstance(expression, exp.NEQ): 470 return boolean_literal(a != b) 471 if isinstance(expression, exp.GT): 472 return boolean_literal(a > b) 473 if isinstance(expression, exp.GTE): 474 return boolean_literal(a >= b) 475 if isinstance(expression, exp.LT): 476 return boolean_literal(a < b) 477 if isinstance(expression, exp.LTE): 478 return boolean_literal(a <= b) 479 return None 480 481 482def extract_date(cast): 483 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 484 # so in that case we can't extract the date. 485 try: 486 if cast.args["to"].this == exp.DataType.Type.DATE: 487 return datetime.date.fromisoformat(cast.name) 488 if cast.args["to"].this == exp.DataType.Type.DATETIME: 489 return datetime.datetime.fromisoformat(cast.name) 490 except ValueError: 491 return None 492 493 494def extract_interval(interval): 495 try: 496 from dateutil.relativedelta import relativedelta # type: ignore 497 except ModuleNotFoundError: 498 return None 499 500 n = int(interval.name) 501 unit = interval.text("unit").lower() 502 503 if unit == "year": 504 return relativedelta(years=n) 505 if unit == "month": 506 return relativedelta(months=n) 507 if unit == "week": 508 return relativedelta(weeks=n) 509 if unit == "day": 510 return relativedelta(days=n) 511 return None 512 513 514def date_literal(date): 515 return exp.cast( 516 exp.Literal.string(date), 517 "DATETIME" if isinstance(date, datetime.datetime) else "DATE", 518 ) 519 520 521def boolean_literal(condition): 522 return exp.true() if condition else exp.false() 523 524 525def _flat_simplify(expression, simplifier, root=True): 526 if root or not expression.same_parent: 527 operands = [] 528 queue = deque(expression.flatten(unnest=False)) 529 size = len(queue) 530 531 while queue: 532 a = queue.popleft() 533 534 for b in queue: 535 result = simplifier(expression, a, b) 536 537 if result: 538 queue.remove(b) 539 queue.appendleft(result) 540 break 541 else: 542 operands.append(a) 543 544 if len(operands) < size: 545 return functools.reduce( 546 lambda a, b: expression.__class__(this=a, expression=b), operands 547 ) 548 return expression
FINAL =
'final'
def
simplify(expression):
16def simplify(expression): 17 """ 18 Rewrite sqlglot AST to simplify expressions. 19 20 Example: 21 >>> import sqlglot 22 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 23 >>> simplify(expression).sql() 24 'TRUE' 25 26 Args: 27 expression (sqlglot.Expression): expression to simplify 28 Returns: 29 sqlglot.Expression: simplified expression 30 """ 31 32 generate = cached_generator() 33 34 # group by expressions cannot be simplified, for example 35 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 36 # the projection must exactly match the group by key 37 for group in expression.find_all(exp.Group): 38 select = group.parent 39 groups = set(group.expressions) 40 group.meta[FINAL] = True 41 42 for e in select.selects: 43 for node, *_ in e.walk(): 44 if node in groups: 45 e.meta[FINAL] = True 46 break 47 48 having = select.args.get("having") 49 if having: 50 for node, *_ in having.walk(): 51 if node in groups: 52 having.meta[FINAL] = True 53 break 54 55 def _simplify(expression, root=True): 56 if expression.meta.get(FINAL): 57 return expression 58 node = expression 59 node = rewrite_between(node) 60 node = uniq_sort(node, generate, root) 61 node = absorb_and_eliminate(node, root) 62 exp.replace_children(node, lambda e: _simplify(e, False)) 63 node = simplify_not(node) 64 node = flatten(node) 65 node = simplify_connectors(node, root) 66 node = remove_compliments(node, root) 67 node.parent = expression.parent 68 node = simplify_literals(node, root) 69 node = simplify_parens(node) 70 if root: 71 expression.replace(node) 72 return node 73 74 expression = while_changing(expression, _simplify) 75 remove_where_true(expression) 76 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
Returns:
sqlglot.Expression: simplified expression
79def rewrite_between(expression: exp.Expression) -> exp.Expression: 80 """Rewrite x between y and z to x >= y AND x <= z. 81 82 This is done because comparison simplification is only done on lt/lte/gt/gte. 83 """ 84 if isinstance(expression, exp.Between): 85 return exp.and_( 86 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 87 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 88 copy=False, 89 ) 90 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
def
simplify_not(expression):
93def simplify_not(expression): 94 """ 95 Demorgan's Law 96 NOT (x OR y) -> NOT x AND NOT y 97 NOT (x AND y) -> NOT x OR NOT y 98 """ 99 if isinstance(expression, exp.Not): 100 if is_null(expression.this): 101 return exp.null() 102 if isinstance(expression.this, exp.Paren): 103 condition = expression.this.unnest() 104 if isinstance(condition, exp.And): 105 return exp.or_( 106 exp.not_(condition.left, copy=False), 107 exp.not_(condition.right, copy=False), 108 copy=False, 109 ) 110 if isinstance(condition, exp.Or): 111 return exp.and_( 112 exp.not_(condition.left, copy=False), 113 exp.not_(condition.right, copy=False), 114 copy=False, 115 ) 116 if is_null(condition): 117 return exp.null() 118 if always_true(expression.this): 119 return exp.false() 120 if is_false(expression.this): 121 return exp.true() 122 if isinstance(expression.this, exp.Not): 123 # double negation 124 # NOT NOT x -> x 125 return expression.this.this 126 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
def
flatten(expression):
129def flatten(expression): 130 """ 131 A AND (B AND C) -> A AND B AND C 132 A OR (B OR C) -> A OR B OR C 133 """ 134 if isinstance(expression, exp.Connector): 135 for node in expression.args.values(): 136 child = node.unnest() 137 if isinstance(child, expression.__class__): 138 node.replace(child) 139 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
def
simplify_connectors(expression, root=True):
142def simplify_connectors(expression, root=True): 143 def _simplify_connectors(expression, left, right): 144 if left == right: 145 return left 146 if isinstance(expression, exp.And): 147 if is_false(left) or is_false(right): 148 return exp.false() 149 if is_null(left) or is_null(right): 150 return exp.null() 151 if always_true(left) and always_true(right): 152 return exp.true() 153 if always_true(left): 154 return right 155 if always_true(right): 156 return left 157 return _simplify_comparison(expression, left, right) 158 elif isinstance(expression, exp.Or): 159 if always_true(left) or always_true(right): 160 return exp.true() 161 if is_false(left) and is_false(right): 162 return exp.false() 163 if ( 164 (is_null(left) and is_null(right)) 165 or (is_null(left) and is_false(right)) 166 or (is_false(left) and is_null(right)) 167 ): 168 return exp.null() 169 if is_false(left): 170 return right 171 if is_false(right): 172 return left 173 return _simplify_comparison(expression, left, right, or_=True) 174 175 if isinstance(expression, exp.Connector): 176 return _flat_simplify(expression, _simplify_connectors, root) 177 return expression
LT_LTE =
(<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE =
(<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
COMPARISONS =
(<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.NEQ'>)
INVERSE_COMPARISONS =
{<class 'sqlglot.expressions.LT'>: <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GT'>: <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>: <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.GTE'>: <class 'sqlglot.expressions.LTE'>}
def
remove_compliments(expression, root=True):
259def remove_compliments(expression, root=True): 260 """ 261 Removing compliments. 262 263 A AND NOT A -> FALSE 264 A OR NOT A -> TRUE 265 """ 266 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 267 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 268 269 for a, b in itertools.permutations(expression.flatten(), 2): 270 if is_complement(a, b): 271 return compliment 272 return expression
Removing compliments.
A AND NOT A -> FALSE A OR NOT A -> TRUE
def
uniq_sort(expression, generate, root=True):
275def uniq_sort(expression, generate, root=True): 276 """ 277 Uniq and sort a connector. 278 279 C AND A AND B AND B -> A AND B AND C 280 """ 281 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 282 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 283 flattened = tuple(expression.flatten()) 284 deduped = {generate(e): e for e in flattened} 285 arr = tuple(deduped.items()) 286 287 # check if the operands are already sorted, if not sort them 288 # A AND C AND B -> A AND B AND C 289 for i, (sql, e) in enumerate(arr[1:]): 290 if sql < arr[i][0]: 291 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 292 break 293 else: 294 # we didn't have to sort but maybe we need to dedup 295 if len(deduped) < len(flattened): 296 expression = result_func(*deduped.values(), copy=False) 297 298 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
def
absorb_and_eliminate(expression, root=True):
301def absorb_and_eliminate(expression, root=True): 302 """ 303 absorption: 304 A AND (A OR B) -> A 305 A OR (A AND B) -> A 306 A AND (NOT A OR B) -> A AND B 307 A OR (NOT A AND B) -> A OR B 308 elimination: 309 (A AND B) OR (A AND NOT B) -> A 310 (A OR B) AND (A OR NOT B) -> A 311 """ 312 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 313 kind = exp.Or if isinstance(expression, exp.And) else exp.And 314 315 for a, b in itertools.permutations(expression.flatten(), 2): 316 if isinstance(a, kind): 317 aa, ab = a.unnest_operands() 318 319 # absorb 320 if is_complement(b, aa): 321 aa.replace(exp.true() if kind == exp.And else exp.false()) 322 elif is_complement(b, ab): 323 ab.replace(exp.true() if kind == exp.And else exp.false()) 324 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 325 a.replace(exp.false() if kind == exp.And else exp.true()) 326 elif isinstance(b, kind): 327 # eliminate 328 rhs = b.unnest_operands() 329 ba, bb = rhs 330 331 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 332 a.replace(aa) 333 b.replace(aa) 334 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 335 a.replace(ab) 336 b.replace(ab) 337 338 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
def
simplify_literals(expression, root=True):
341def simplify_literals(expression, root=True): 342 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 343 return _flat_simplify(expression, _simplify_binary, root) 344 elif isinstance(expression, exp.Neg): 345 this = expression.this 346 if this.is_number: 347 value = this.name 348 if value[0] == "-": 349 return exp.Literal.number(value[1:]) 350 return exp.Literal.number(f"-{value}") 351 352 return expression
def
simplify_parens(expression):
415def simplify_parens(expression): 416 if not isinstance(expression, exp.Paren): 417 return expression 418 419 this = expression.this 420 parent = expression.parent 421 422 if not isinstance(this, exp.Select) and ( 423 not isinstance(parent, (exp.Condition, exp.Binary)) 424 or isinstance(this, exp.Predicate) 425 or not isinstance(this, exp.Binary) 426 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 427 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 428 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 429 ): 430 return expression.this 431 return expression
def
remove_where_true(expression):
434def remove_where_true(expression): 435 for where in expression.find_all(exp.Where): 436 if always_true(where.this): 437 where.parent.set("where", None) 438 for join in expression.find_all(exp.Join): 439 if ( 440 always_true(join.args.get("on")) 441 and not join.args.get("using") 442 and not join.args.get("method") 443 ): 444 join.set("on", None) 445 join.set("side", None) 446 join.set("kind", "CROSS")
def
always_true(expression):
def
is_complement(a, b):
def
eval_boolean(expression, a, b):
467def eval_boolean(expression, a, b): 468 if isinstance(expression, (exp.EQ, exp.Is)): 469 return boolean_literal(a == b) 470 if isinstance(expression, exp.NEQ): 471 return boolean_literal(a != b) 472 if isinstance(expression, exp.GT): 473 return boolean_literal(a > b) 474 if isinstance(expression, exp.GTE): 475 return boolean_literal(a >= b) 476 if isinstance(expression, exp.LT): 477 return boolean_literal(a < b) 478 if isinstance(expression, exp.LTE): 479 return boolean_literal(a <= b) 480 return None
def
extract_date(cast):
483def extract_date(cast): 484 # The "fromisoformat" conversion could fail if the cast is used on an identifier, 485 # so in that case we can't extract the date. 486 try: 487 if cast.args["to"].this == exp.DataType.Type.DATE: 488 return datetime.date.fromisoformat(cast.name) 489 if cast.args["to"].this == exp.DataType.Type.DATETIME: 490 return datetime.datetime.fromisoformat(cast.name) 491 except ValueError: 492 return None
def
extract_interval(interval):
495def extract_interval(interval): 496 try: 497 from dateutil.relativedelta import relativedelta # type: ignore 498 except ModuleNotFoundError: 499 return None 500 501 n = int(interval.name) 502 unit = interval.text("unit").lower() 503 504 if unit == "year": 505 return relativedelta(years=n) 506 if unit == "month": 507 return relativedelta(months=n) 508 if unit == "week": 509 return relativedelta(weeks=n) 510 if unit == "day": 511 return relativedelta(days=n) 512 return None
def
date_literal(date):
def
boolean_literal(condition):