sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4import typing as t 5from collections import deque 6from decimal import Decimal 7 8import sqlglot 9from sqlglot import exp 10from sqlglot.helper import first, is_iterable, merge_ranges, while_changing 11from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 12 13# Final means that an expression should not be simplified 14FINAL = "final" 15 16 17class UnsupportedUnit(Exception): 18 pass 19 20 21def simplify(expression, constant_propagation=False): 22 """ 23 Rewrite sqlglot AST to simplify expressions. 24 25 Example: 26 >>> import sqlglot 27 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 28 >>> simplify(expression).sql() 29 'TRUE' 30 31 Args: 32 expression (sqlglot.Expression): expression to simplify 33 constant_propagation: whether or not the constant propagation rule should be used 34 35 Returns: 36 sqlglot.Expression: simplified expression 37 """ 38 39 # group by expressions cannot be simplified, for example 40 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 41 # the projection must exactly match the group by key 42 for group in expression.find_all(exp.Group): 43 select = group.parent 44 groups = set(group.expressions) 45 group.meta[FINAL] = True 46 47 for e in select.selects: 48 for node, *_ in e.walk(): 49 if node in groups: 50 e.meta[FINAL] = True 51 break 52 53 having = select.args.get("having") 54 if having: 55 for node, *_ in having.walk(): 56 if node in groups: 57 having.meta[FINAL] = True 58 break 59 60 def _simplify(expression, root=True): 61 if expression.meta.get(FINAL): 62 return expression 63 64 # Pre-order transformations 65 node = expression 66 node = rewrite_between(node) 67 node = uniq_sort(node, root) 68 node = absorb_and_eliminate(node, root) 69 node = simplify_concat(node) 70 node = simplify_conditionals(node) 71 72 if constant_propagation: 73 node = propagate_constants(node, root) 74 75 exp.replace_children(node, lambda e: _simplify(e, False)) 76 77 # Post-order transformations 78 node = simplify_not(node) 79 node = flatten(node) 80 node = simplify_connectors(node, root) 81 node = remove_complements(node, root) 82 node = simplify_coalesce(node) 83 node.parent = expression.parent 84 node = simplify_literals(node, root) 85 node = simplify_equality(node) 86 node = simplify_parens(node) 87 node = simplify_datetrunc_predicate(node) 88 89 if root: 90 expression.replace(node) 91 92 return node 93 94 expression = while_changing(expression, _simplify) 95 remove_where_true(expression) 96 return expression 97 98 99def catch(*exceptions): 100 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 101 102 def decorator(func): 103 def wrapped(expression, *args, **kwargs): 104 try: 105 return func(expression, *args, **kwargs) 106 except exceptions: 107 return expression 108 109 return wrapped 110 111 return decorator 112 113 114def rewrite_between(expression: exp.Expression) -> exp.Expression: 115 """Rewrite x between y and z to x >= y AND x <= z. 116 117 This is done because comparison simplification is only done on lt/lte/gt/gte. 118 """ 119 if isinstance(expression, exp.Between): 120 return exp.and_( 121 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 122 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 123 copy=False, 124 ) 125 return expression 126 127 128def simplify_not(expression): 129 """ 130 Demorgan's Law 131 NOT (x OR y) -> NOT x AND NOT y 132 NOT (x AND y) -> NOT x OR NOT y 133 """ 134 if isinstance(expression, exp.Not): 135 if is_null(expression.this): 136 return exp.null() 137 if isinstance(expression.this, exp.Paren): 138 condition = expression.this.unnest() 139 if isinstance(condition, exp.And): 140 return exp.or_( 141 exp.not_(condition.left, copy=False), 142 exp.not_(condition.right, copy=False), 143 copy=False, 144 ) 145 if isinstance(condition, exp.Or): 146 return exp.and_( 147 exp.not_(condition.left, copy=False), 148 exp.not_(condition.right, copy=False), 149 copy=False, 150 ) 151 if is_null(condition): 152 return exp.null() 153 if always_true(expression.this): 154 return exp.false() 155 if is_false(expression.this): 156 return exp.true() 157 if isinstance(expression.this, exp.Not): 158 # double negation 159 # NOT NOT x -> x 160 return expression.this.this 161 return expression 162 163 164def flatten(expression): 165 """ 166 A AND (B AND C) -> A AND B AND C 167 A OR (B OR C) -> A OR B OR C 168 """ 169 if isinstance(expression, exp.Connector): 170 for node in expression.args.values(): 171 child = node.unnest() 172 if isinstance(child, expression.__class__): 173 node.replace(child) 174 return expression 175 176 177def simplify_connectors(expression, root=True): 178 def _simplify_connectors(expression, left, right): 179 if left == right: 180 return left 181 if isinstance(expression, exp.And): 182 if is_false(left) or is_false(right): 183 return exp.false() 184 if is_null(left) or is_null(right): 185 return exp.null() 186 if always_true(left) and always_true(right): 187 return exp.true() 188 if always_true(left): 189 return right 190 if always_true(right): 191 return left 192 return _simplify_comparison(expression, left, right) 193 elif isinstance(expression, exp.Or): 194 if always_true(left) or always_true(right): 195 return exp.true() 196 if is_false(left) and is_false(right): 197 return exp.false() 198 if ( 199 (is_null(left) and is_null(right)) 200 or (is_null(left) and is_false(right)) 201 or (is_false(left) and is_null(right)) 202 ): 203 return exp.null() 204 if is_false(left): 205 return right 206 if is_false(right): 207 return left 208 return _simplify_comparison(expression, left, right, or_=True) 209 210 if isinstance(expression, exp.Connector): 211 return _flat_simplify(expression, _simplify_connectors, root) 212 return expression 213 214 215LT_LTE = (exp.LT, exp.LTE) 216GT_GTE = (exp.GT, exp.GTE) 217 218COMPARISONS = ( 219 *LT_LTE, 220 *GT_GTE, 221 exp.EQ, 222 exp.NEQ, 223 exp.Is, 224) 225 226INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 227 exp.LT: exp.GT, 228 exp.GT: exp.LT, 229 exp.LTE: exp.GTE, 230 exp.GTE: exp.LTE, 231} 232 233 234def _simplify_comparison(expression, left, right, or_=False): 235 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 236 ll, lr = left.args.values() 237 rl, rr = right.args.values() 238 239 largs = {ll, lr} 240 rargs = {rl, rr} 241 242 matching = largs & rargs 243 columns = {m for m in matching if isinstance(m, exp.Column)} 244 245 if matching and columns: 246 try: 247 l = first(largs - columns) 248 r = first(rargs - columns) 249 except StopIteration: 250 return expression 251 252 # make sure the comparison is always of the form x > 1 instead of 1 < x 253 if left.__class__ in INVERSE_COMPARISONS and l == ll: 254 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 255 if right.__class__ in INVERSE_COMPARISONS and r == rl: 256 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 257 258 if l.is_number and r.is_number: 259 l = float(l.name) 260 r = float(r.name) 261 elif l.is_string and r.is_string: 262 l = l.name 263 r = r.name 264 else: 265 return None 266 267 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 268 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 269 return left if (av > bv if or_ else av <= bv) else right 270 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 271 return left if (av < bv if or_ else av >= bv) else right 272 273 # we can't ever shortcut to true because the column could be null 274 if not or_: 275 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 276 if av <= bv: 277 return exp.false() 278 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 279 if av >= bv: 280 return exp.false() 281 elif isinstance(a, exp.EQ): 282 if isinstance(b, exp.LT): 283 return exp.false() if av >= bv else a 284 if isinstance(b, exp.LTE): 285 return exp.false() if av > bv else a 286 if isinstance(b, exp.GT): 287 return exp.false() if av <= bv else a 288 if isinstance(b, exp.GTE): 289 return exp.false() if av < bv else a 290 if isinstance(b, exp.NEQ): 291 return exp.false() if av == bv else a 292 return None 293 294 295def remove_complements(expression, root=True): 296 """ 297 Removing complements. 298 299 A AND NOT A -> FALSE 300 A OR NOT A -> TRUE 301 """ 302 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 303 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 304 305 for a, b in itertools.permutations(expression.flatten(), 2): 306 if is_complement(a, b): 307 return complement 308 return expression 309 310 311def uniq_sort(expression, root=True): 312 """ 313 Uniq and sort a connector. 314 315 C AND A AND B AND B -> A AND B AND C 316 """ 317 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 318 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 319 flattened = tuple(expression.flatten()) 320 deduped = {gen(e): e for e in flattened} 321 arr = tuple(deduped.items()) 322 323 # check if the operands are already sorted, if not sort them 324 # A AND C AND B -> A AND B AND C 325 for i, (sql, e) in enumerate(arr[1:]): 326 if sql < arr[i][0]: 327 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 328 break 329 else: 330 # we didn't have to sort but maybe we need to dedup 331 if len(deduped) < len(flattened): 332 expression = result_func(*deduped.values(), copy=False) 333 334 return expression 335 336 337def absorb_and_eliminate(expression, root=True): 338 """ 339 absorption: 340 A AND (A OR B) -> A 341 A OR (A AND B) -> A 342 A AND (NOT A OR B) -> A AND B 343 A OR (NOT A AND B) -> A OR B 344 elimination: 345 (A AND B) OR (A AND NOT B) -> A 346 (A OR B) AND (A OR NOT B) -> A 347 """ 348 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 349 kind = exp.Or if isinstance(expression, exp.And) else exp.And 350 351 for a, b in itertools.permutations(expression.flatten(), 2): 352 if isinstance(a, kind): 353 aa, ab = a.unnest_operands() 354 355 # absorb 356 if is_complement(b, aa): 357 aa.replace(exp.true() if kind == exp.And else exp.false()) 358 elif is_complement(b, ab): 359 ab.replace(exp.true() if kind == exp.And else exp.false()) 360 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 361 a.replace(exp.false() if kind == exp.And else exp.true()) 362 elif isinstance(b, kind): 363 # eliminate 364 rhs = b.unnest_operands() 365 ba, bb = rhs 366 367 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 368 a.replace(aa) 369 b.replace(aa) 370 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 371 a.replace(ab) 372 b.replace(ab) 373 374 return expression 375 376 377def propagate_constants(expression, root=True): 378 """ 379 Propagate constants for conjunctions in DNF: 380 381 SELECT * FROM t WHERE a = b AND b = 5 becomes 382 SELECT * FROM t WHERE a = 5 AND b = 5 383 384 Reference: https://www.sqlite.org/optoverview.html 385 """ 386 387 if ( 388 isinstance(expression, exp.And) 389 and (root or not expression.same_parent) 390 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 391 ): 392 constant_mapping = {} 393 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 394 if isinstance(expr, exp.EQ): 395 l, r = expr.left, expr.right 396 397 # TODO: create a helper that can be used to detect nested literal expressions such 398 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 399 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 400 pass 401 elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): 402 l, r = r, l 403 else: 404 continue 405 406 constant_mapping[l] = (id(l), r) 407 408 if constant_mapping: 409 for column in find_all_in_scope(expression, exp.Column): 410 parent = column.parent 411 column_id, constant = constant_mapping.get(column) or (None, None) 412 if ( 413 column_id is not None 414 and id(column) != column_id 415 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 416 ): 417 column.replace(constant.copy()) 418 419 return expression 420 421 422INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 423 exp.DateAdd: exp.Sub, 424 exp.DateSub: exp.Add, 425 exp.DatetimeAdd: exp.Sub, 426 exp.DatetimeSub: exp.Add, 427} 428 429INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 430 **INVERSE_DATE_OPS, 431 exp.Add: exp.Sub, 432 exp.Sub: exp.Add, 433} 434 435 436def _is_number(expression: exp.Expression) -> bool: 437 return expression.is_number 438 439 440def _is_interval(expression: exp.Expression) -> bool: 441 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 442 443 444@catch(ModuleNotFoundError, UnsupportedUnit) 445def simplify_equality(expression: exp.Expression) -> exp.Expression: 446 """ 447 Use the subtraction and addition properties of equality to simplify expressions: 448 449 x + 1 = 3 becomes x = 2 450 451 There are two binary operations in the above expression: + and = 452 Here's how we reference all the operands in the code below: 453 454 l r 455 x + 1 = 3 456 a b 457 """ 458 if isinstance(expression, COMPARISONS): 459 l, r = expression.left, expression.right 460 461 if l.__class__ in INVERSE_OPS: 462 pass 463 elif r.__class__ in INVERSE_OPS: 464 l, r = r, l 465 else: 466 return expression 467 468 if r.is_number: 469 a_predicate = _is_number 470 b_predicate = _is_number 471 elif _is_date_literal(r): 472 a_predicate = _is_date_literal 473 b_predicate = _is_interval 474 else: 475 return expression 476 477 if l.__class__ in INVERSE_DATE_OPS: 478 l = t.cast(exp.IntervalOp, l) 479 a = l.this 480 b = l.interval() 481 else: 482 l = t.cast(exp.Binary, l) 483 a, b = l.left, l.right 484 485 if not a_predicate(a) and b_predicate(b): 486 pass 487 elif not a_predicate(b) and b_predicate(a): 488 a, b = b, a 489 else: 490 return expression 491 492 return expression.__class__( 493 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 494 ) 495 return expression 496 497 498def simplify_literals(expression, root=True): 499 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 500 return _flat_simplify(expression, _simplify_binary, root) 501 502 if isinstance(expression, exp.Neg): 503 this = expression.this 504 if this.is_number: 505 value = this.name 506 if value[0] == "-": 507 return exp.Literal.number(value[1:]) 508 return exp.Literal.number(f"-{value}") 509 510 return expression 511 512 513def _simplify_binary(expression, a, b): 514 if isinstance(expression, exp.Is): 515 if isinstance(b, exp.Not): 516 c = b.this 517 not_ = True 518 else: 519 c = b 520 not_ = False 521 522 if is_null(c): 523 if isinstance(a, exp.Literal): 524 return exp.true() if not_ else exp.false() 525 if is_null(a): 526 return exp.false() if not_ else exp.true() 527 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 528 return None 529 elif is_null(a) or is_null(b): 530 return exp.null() 531 532 if a.is_number and b.is_number: 533 a = int(a.name) if a.is_int else Decimal(a.name) 534 b = int(b.name) if b.is_int else Decimal(b.name) 535 536 if isinstance(expression, exp.Add): 537 return exp.Literal.number(a + b) 538 if isinstance(expression, exp.Sub): 539 return exp.Literal.number(a - b) 540 if isinstance(expression, exp.Mul): 541 return exp.Literal.number(a * b) 542 if isinstance(expression, exp.Div): 543 # engines have differing int div behavior so intdiv is not safe 544 if isinstance(a, int) and isinstance(b, int): 545 return None 546 return exp.Literal.number(a / b) 547 548 boolean = eval_boolean(expression, a, b) 549 550 if boolean: 551 return boolean 552 elif a.is_string and b.is_string: 553 boolean = eval_boolean(expression, a.this, b.this) 554 555 if boolean: 556 return boolean 557 elif _is_date_literal(a) and isinstance(b, exp.Interval): 558 a, b = extract_date(a), extract_interval(b) 559 if a and b: 560 if isinstance(expression, exp.Add): 561 return date_literal(a + b) 562 if isinstance(expression, exp.Sub): 563 return date_literal(a - b) 564 elif isinstance(a, exp.Interval) and _is_date_literal(b): 565 a, b = extract_interval(a), extract_date(b) 566 # you cannot subtract a date from an interval 567 if a and b and isinstance(expression, exp.Add): 568 return date_literal(a + b) 569 570 return None 571 572 573def simplify_parens(expression): 574 if not isinstance(expression, exp.Paren): 575 return expression 576 577 this = expression.this 578 parent = expression.parent 579 580 if not isinstance(this, exp.Select) and ( 581 not isinstance(parent, (exp.Condition, exp.Binary)) 582 or isinstance(parent, exp.Paren) 583 or not isinstance(this, exp.Binary) 584 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 585 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 586 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 587 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 588 ): 589 return this 590 return expression 591 592 593CONSTANTS = ( 594 exp.Literal, 595 exp.Boolean, 596 exp.Null, 597) 598 599 600def simplify_coalesce(expression): 601 # COALESCE(x) -> x 602 if ( 603 isinstance(expression, exp.Coalesce) 604 and not expression.expressions 605 # COALESCE is also used as a Spark partitioning hint 606 and not isinstance(expression.parent, exp.Hint) 607 ): 608 return expression.this 609 610 if not isinstance(expression, COMPARISONS): 611 return expression 612 613 if isinstance(expression.left, exp.Coalesce): 614 coalesce = expression.left 615 other = expression.right 616 elif isinstance(expression.right, exp.Coalesce): 617 coalesce = expression.right 618 other = expression.left 619 else: 620 return expression 621 622 # This transformation is valid for non-constants, 623 # but it really only does anything if they are both constants. 624 if not isinstance(other, CONSTANTS): 625 return expression 626 627 # Find the first constant arg 628 for arg_index, arg in enumerate(coalesce.expressions): 629 if isinstance(arg, CONSTANTS): 630 break 631 else: 632 return expression 633 634 coalesce.set("expressions", coalesce.expressions[:arg_index]) 635 636 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 637 # since we already remove COALESCE at the top of this function. 638 coalesce = coalesce if coalesce.expressions else coalesce.this 639 640 # This expression is more complex than when we started, but it will get simplified further 641 return exp.paren( 642 exp.or_( 643 exp.and_( 644 coalesce.is_(exp.null()).not_(copy=False), 645 expression.copy(), 646 copy=False, 647 ), 648 exp.and_( 649 coalesce.is_(exp.null()), 650 type(expression)(this=arg.copy(), expression=other.copy()), 651 copy=False, 652 ), 653 copy=False, 654 ) 655 ) 656 657 658CONCATS = (exp.Concat, exp.DPipe) 659SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) 660 661 662def simplify_concat(expression): 663 """Reduces all groups that contain string literals by concatenating them.""" 664 if not isinstance(expression, CONCATS) or ( 665 # We can't reduce a CONCAT_WS call if we don't statically know the separator 666 isinstance(expression, exp.ConcatWs) 667 and not expression.expressions[0].is_string 668 ): 669 return expression 670 671 if isinstance(expression, exp.ConcatWs): 672 sep_expr, *expressions = expression.expressions 673 sep = sep_expr.name 674 concat_type = exp.ConcatWs 675 else: 676 expressions = expression.expressions 677 sep = "" 678 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 679 680 new_args = [] 681 for is_string_group, group in itertools.groupby( 682 expressions or expression.flatten(), lambda e: e.is_string 683 ): 684 if is_string_group: 685 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 686 else: 687 new_args.extend(group) 688 689 if len(new_args) == 1 and new_args[0].is_string: 690 return new_args[0] 691 692 if concat_type is exp.ConcatWs: 693 new_args = [sep_expr] + new_args 694 695 return concat_type(expressions=new_args) 696 697 698def simplify_conditionals(expression): 699 """Simplifies expressions like IF, CASE if their condition is statically known.""" 700 if isinstance(expression, exp.Case): 701 this = expression.this 702 for case in expression.args["ifs"]: 703 cond = case.this 704 if this: 705 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 706 cond = cond.replace(this.pop().eq(cond)) 707 708 if always_true(cond): 709 return case.args["true"] 710 711 if always_false(cond): 712 case.pop() 713 if not expression.args["ifs"]: 714 return expression.args.get("default") or exp.null() 715 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 716 if always_true(expression.this): 717 return expression.args["true"] 718 if always_false(expression.this): 719 return expression.args.get("false") or exp.null() 720 721 return expression 722 723 724DateRange = t.Tuple[datetime.date, datetime.date] 725 726 727def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: 728 """ 729 Get the date range for a DATE_TRUNC equality comparison: 730 731 Example: 732 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 733 Returns: 734 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 735 """ 736 floor = date_floor(date, unit) 737 738 if date != floor: 739 # This will always be False, except for NULL values. 740 return None 741 742 return floor, floor + interval(unit) 743 744 745def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 746 """Get the logical expression for a date range""" 747 return exp.and_( 748 left >= date_literal(drange[0]), 749 left < date_literal(drange[1]), 750 copy=False, 751 ) 752 753 754def _datetrunc_eq( 755 left: exp.Expression, date: datetime.date, unit: str 756) -> t.Optional[exp.Expression]: 757 drange = _datetrunc_range(date, unit) 758 if not drange: 759 return None 760 761 return _datetrunc_eq_expression(left, drange) 762 763 764def _datetrunc_neq( 765 left: exp.Expression, date: datetime.date, unit: str 766) -> t.Optional[exp.Expression]: 767 drange = _datetrunc_range(date, unit) 768 if not drange: 769 return None 770 771 return exp.and_( 772 left < date_literal(drange[0]), 773 left >= date_literal(drange[1]), 774 copy=False, 775 ) 776 777 778DateTruncBinaryTransform = t.Callable[ 779 [exp.Expression, datetime.date, str], t.Optional[exp.Expression] 780] 781DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 782 exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), 783 exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), 784 exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), 785 exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), 786 exp.EQ: _datetrunc_eq, 787 exp.NEQ: _datetrunc_neq, 788} 789DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 790 791 792def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 793 return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right) 794 795 796@catch(ModuleNotFoundError, UnsupportedUnit) 797def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: 798 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 799 comparison = expression.__class__ 800 801 if comparison not in DATETRUNC_COMPARISONS: 802 return expression 803 804 if isinstance(expression, exp.Binary): 805 l, r = expression.left, expression.right 806 807 if _is_datetrunc_predicate(l, r): 808 pass 809 elif _is_datetrunc_predicate(r, l): 810 comparison = INVERSE_COMPARISONS.get(comparison, comparison) 811 l, r = r, l 812 else: 813 return expression 814 815 l = t.cast(exp.DateTrunc, l) 816 unit = l.unit.name.lower() 817 date = extract_date(r) 818 819 if not date: 820 return expression 821 822 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression 823 elif isinstance(expression, exp.In): 824 l = expression.this 825 rs = expression.expressions 826 827 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 828 l = t.cast(exp.DateTrunc, l) 829 unit = l.unit.name.lower() 830 831 ranges = [] 832 for r in rs: 833 date = extract_date(r) 834 if not date: 835 return expression 836 drange = _datetrunc_range(date, unit) 837 if drange: 838 ranges.append(drange) 839 840 if not ranges: 841 return expression 842 843 ranges = merge_ranges(ranges) 844 845 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 846 847 return expression 848 849 850# CROSS joins result in an empty table if the right table is empty. 851# So we can only simplify certain types of joins to CROSS. 852# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 853JOINS = { 854 ("", ""), 855 ("", "INNER"), 856 ("RIGHT", ""), 857 ("RIGHT", "OUTER"), 858} 859 860 861def remove_where_true(expression): 862 for where in expression.find_all(exp.Where): 863 if always_true(where.this): 864 where.parent.set("where", None) 865 for join in expression.find_all(exp.Join): 866 if ( 867 always_true(join.args.get("on")) 868 and not join.args.get("using") 869 and not join.args.get("method") 870 and (join.side, join.kind) in JOINS 871 ): 872 join.set("on", None) 873 join.set("side", None) 874 join.set("kind", "CROSS") 875 876 877def always_true(expression): 878 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 879 expression, exp.Literal 880 ) 881 882 883def always_false(expression): 884 return is_false(expression) or is_null(expression) 885 886 887def is_complement(a, b): 888 return isinstance(b, exp.Not) and b.this == a 889 890 891def is_false(a: exp.Expression) -> bool: 892 return type(a) is exp.Boolean and not a.this 893 894 895def is_null(a: exp.Expression) -> bool: 896 return type(a) is exp.Null 897 898 899def eval_boolean(expression, a, b): 900 if isinstance(expression, (exp.EQ, exp.Is)): 901 return boolean_literal(a == b) 902 if isinstance(expression, exp.NEQ): 903 return boolean_literal(a != b) 904 if isinstance(expression, exp.GT): 905 return boolean_literal(a > b) 906 if isinstance(expression, exp.GTE): 907 return boolean_literal(a >= b) 908 if isinstance(expression, exp.LT): 909 return boolean_literal(a < b) 910 if isinstance(expression, exp.LTE): 911 return boolean_literal(a <= b) 912 return None 913 914 915def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 916 if isinstance(value, datetime.datetime): 917 return value.date() 918 if isinstance(value, datetime.date): 919 return value 920 try: 921 return datetime.datetime.fromisoformat(value).date() 922 except ValueError: 923 return None 924 925 926def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 927 if isinstance(value, datetime.datetime): 928 return value 929 if isinstance(value, datetime.date): 930 return datetime.datetime(year=value.year, month=value.month, day=value.day) 931 try: 932 return datetime.datetime.fromisoformat(value) 933 except ValueError: 934 return None 935 936 937def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 938 if not value: 939 return None 940 if to.is_type(exp.DataType.Type.DATE): 941 return cast_as_date(value) 942 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 943 return cast_as_datetime(value) 944 return None 945 946 947def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 948 if isinstance(cast, exp.Cast): 949 to = cast.to 950 elif isinstance(cast, exp.TsOrDsToDate): 951 to = exp.DataType.build(exp.DataType.Type.DATE) 952 else: 953 return None 954 955 if isinstance(cast.this, exp.Literal): 956 value: t.Any = cast.this.name 957 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 958 value = extract_date(cast.this) 959 else: 960 return None 961 return cast_value(value, to) 962 963 964def _is_date_literal(expression: exp.Expression) -> bool: 965 return extract_date(expression) is not None 966 967 968def extract_interval(expression): 969 n = int(expression.name) 970 unit = expression.text("unit").lower() 971 972 try: 973 return interval(unit, n) 974 except (UnsupportedUnit, ModuleNotFoundError): 975 return None 976 977 978def date_literal(date): 979 return exp.cast( 980 exp.Literal.string(date), 981 exp.DataType.Type.DATETIME 982 if isinstance(date, datetime.datetime) 983 else exp.DataType.Type.DATE, 984 ) 985 986 987def interval(unit: str, n: int = 1): 988 from dateutil.relativedelta import relativedelta 989 990 if unit == "year": 991 return relativedelta(years=1 * n) 992 if unit == "quarter": 993 return relativedelta(months=3 * n) 994 if unit == "month": 995 return relativedelta(months=1 * n) 996 if unit == "week": 997 return relativedelta(weeks=1 * n) 998 if unit == "day": 999 return relativedelta(days=1 * n) 1000 if unit == "hour": 1001 return relativedelta(hours=1 * n) 1002 if unit == "minute": 1003 return relativedelta(minutes=1 * n) 1004 if unit == "second": 1005 return relativedelta(seconds=1 * n) 1006 1007 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1008 1009 1010def date_floor(d: datetime.date, unit: str) -> datetime.date: 1011 if unit == "year": 1012 return d.replace(month=1, day=1) 1013 if unit == "quarter": 1014 if d.month <= 3: 1015 return d.replace(month=1, day=1) 1016 elif d.month <= 6: 1017 return d.replace(month=4, day=1) 1018 elif d.month <= 9: 1019 return d.replace(month=7, day=1) 1020 else: 1021 return d.replace(month=10, day=1) 1022 if unit == "month": 1023 return d.replace(month=d.month, day=1) 1024 if unit == "week": 1025 # Assuming week starts on Monday (0) and ends on Sunday (6) 1026 return d - datetime.timedelta(days=d.weekday()) 1027 if unit == "day": 1028 return d 1029 1030 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1031 1032 1033def date_ceil(d: datetime.date, unit: str) -> datetime.date: 1034 floor = date_floor(d, unit) 1035 1036 if floor == d: 1037 return d 1038 1039 return floor + interval(unit) 1040 1041 1042def boolean_literal(condition): 1043 return exp.true() if condition else exp.false() 1044 1045 1046def _flat_simplify(expression, simplifier, root=True): 1047 if root or not expression.same_parent: 1048 operands = [] 1049 queue = deque(expression.flatten(unnest=False)) 1050 size = len(queue) 1051 1052 while queue: 1053 a = queue.popleft() 1054 1055 for b in queue: 1056 result = simplifier(expression, a, b) 1057 1058 if result and result is not expression: 1059 queue.remove(b) 1060 queue.appendleft(result) 1061 break 1062 else: 1063 operands.append(a) 1064 1065 if len(operands) < size: 1066 return functools.reduce( 1067 lambda a, b: expression.__class__(this=a, expression=b), operands 1068 ) 1069 return expression 1070 1071 1072def gen(expression: t.Any) -> str: 1073 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1074 1075 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1076 generator is expensive so we have a bare minimum sql generator here. 1077 """ 1078 if expression is None: 1079 return "_" 1080 if is_iterable(expression): 1081 return ",".join(gen(e) for e in expression) 1082 if not isinstance(expression, exp.Expression): 1083 return str(expression) 1084 1085 etype = type(expression) 1086 if etype in GEN_MAP: 1087 return GEN_MAP[etype](expression) 1088 return f"{expression.key} {gen(expression.args.values())}" 1089 1090 1091GEN_MAP = { 1092 exp.Add: lambda e: _binary(e, "+"), 1093 exp.And: lambda e: _binary(e, "AND"), 1094 exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}", 1095 exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}", 1096 exp.Boolean: lambda e: "TRUE" if e.this else "FALSE", 1097 exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]", 1098 exp.Column: lambda e: ".".join(gen(p) for p in e.parts), 1099 exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}", 1100 exp.Div: lambda e: _binary(e, "/"), 1101 exp.Dot: lambda e: _binary(e, "."), 1102 exp.DPipe: lambda e: _binary(e, "||"), 1103 exp.SafeDPipe: lambda e: _binary(e, "||"), 1104 exp.EQ: lambda e: _binary(e, "="), 1105 exp.GT: lambda e: _binary(e, ">"), 1106 exp.GTE: lambda e: _binary(e, ">="), 1107 exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name, 1108 exp.ILike: lambda e: _binary(e, "ILIKE"), 1109 exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})", 1110 exp.Is: lambda e: _binary(e, "IS"), 1111 exp.Like: lambda e: _binary(e, "LIKE"), 1112 exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name, 1113 exp.LT: lambda e: _binary(e, "<"), 1114 exp.LTE: lambda e: _binary(e, "<="), 1115 exp.Mod: lambda e: _binary(e, "%"), 1116 exp.Mul: lambda e: _binary(e, "*"), 1117 exp.Neg: lambda e: _unary(e, "-"), 1118 exp.NEQ: lambda e: _binary(e, "<>"), 1119 exp.Not: lambda e: _unary(e, "NOT"), 1120 exp.Null: lambda e: "NULL", 1121 exp.Or: lambda e: _binary(e, "OR"), 1122 exp.Paren: lambda e: f"({gen(e.this)})", 1123 exp.Sub: lambda e: _binary(e, "-"), 1124 exp.Subquery: lambda e: f"({gen(e.args.values())})", 1125 exp.Table: lambda e: gen(e.args.values()), 1126 exp.Var: lambda e: e.name, 1127} 1128 1129 1130def _binary(e: exp.Binary, op: str) -> str: 1131 return f"{gen(e.left)} {op} {gen(e.right)}" 1132 1133 1134def _unary(e: exp.Unary, op: str) -> str: 1135 return f"{op} {gen(e.this)}"
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
22def simplify(expression, constant_propagation=False): 23 """ 24 Rewrite sqlglot AST to simplify expressions. 25 26 Example: 27 >>> import sqlglot 28 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 29 >>> simplify(expression).sql() 30 'TRUE' 31 32 Args: 33 expression (sqlglot.Expression): expression to simplify 34 constant_propagation: whether or not the constant propagation rule should be used 35 36 Returns: 37 sqlglot.Expression: simplified expression 38 """ 39 40 # group by expressions cannot be simplified, for example 41 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 42 # the projection must exactly match the group by key 43 for group in expression.find_all(exp.Group): 44 select = group.parent 45 groups = set(group.expressions) 46 group.meta[FINAL] = True 47 48 for e in select.selects: 49 for node, *_ in e.walk(): 50 if node in groups: 51 e.meta[FINAL] = True 52 break 53 54 having = select.args.get("having") 55 if having: 56 for node, *_ in having.walk(): 57 if node in groups: 58 having.meta[FINAL] = True 59 break 60 61 def _simplify(expression, root=True): 62 if expression.meta.get(FINAL): 63 return expression 64 65 # Pre-order transformations 66 node = expression 67 node = rewrite_between(node) 68 node = uniq_sort(node, root) 69 node = absorb_and_eliminate(node, root) 70 node = simplify_concat(node) 71 node = simplify_conditionals(node) 72 73 if constant_propagation: 74 node = propagate_constants(node, root) 75 76 exp.replace_children(node, lambda e: _simplify(e, False)) 77 78 # Post-order transformations 79 node = simplify_not(node) 80 node = flatten(node) 81 node = simplify_connectors(node, root) 82 node = remove_complements(node, root) 83 node = simplify_coalesce(node) 84 node.parent = expression.parent 85 node = simplify_literals(node, root) 86 node = simplify_equality(node) 87 node = simplify_parens(node) 88 node = simplify_datetrunc_predicate(node) 89 90 if root: 91 expression.replace(node) 92 93 return node 94 95 expression = while_changing(expression, _simplify) 96 remove_where_true(expression) 97 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether or not the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
100def catch(*exceptions): 101 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 102 103 def decorator(func): 104 def wrapped(expression, *args, **kwargs): 105 try: 106 return func(expression, *args, **kwargs) 107 except exceptions: 108 return expression 109 110 return wrapped 111 112 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
115def rewrite_between(expression: exp.Expression) -> exp.Expression: 116 """Rewrite x between y and z to x >= y AND x <= z. 117 118 This is done because comparison simplification is only done on lt/lte/gt/gte. 119 """ 120 if isinstance(expression, exp.Between): 121 return exp.and_( 122 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 123 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 124 copy=False, 125 ) 126 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
129def simplify_not(expression): 130 """ 131 Demorgan's Law 132 NOT (x OR y) -> NOT x AND NOT y 133 NOT (x AND y) -> NOT x OR NOT y 134 """ 135 if isinstance(expression, exp.Not): 136 if is_null(expression.this): 137 return exp.null() 138 if isinstance(expression.this, exp.Paren): 139 condition = expression.this.unnest() 140 if isinstance(condition, exp.And): 141 return exp.or_( 142 exp.not_(condition.left, copy=False), 143 exp.not_(condition.right, copy=False), 144 copy=False, 145 ) 146 if isinstance(condition, exp.Or): 147 return exp.and_( 148 exp.not_(condition.left, copy=False), 149 exp.not_(condition.right, copy=False), 150 copy=False, 151 ) 152 if is_null(condition): 153 return exp.null() 154 if always_true(expression.this): 155 return exp.false() 156 if is_false(expression.this): 157 return exp.true() 158 if isinstance(expression.this, exp.Not): 159 # double negation 160 # NOT NOT x -> x 161 return expression.this.this 162 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
165def flatten(expression): 166 """ 167 A AND (B AND C) -> A AND B AND C 168 A OR (B OR C) -> A OR B OR C 169 """ 170 if isinstance(expression, exp.Connector): 171 for node in expression.args.values(): 172 child = node.unnest() 173 if isinstance(child, expression.__class__): 174 node.replace(child) 175 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
178def simplify_connectors(expression, root=True): 179 def _simplify_connectors(expression, left, right): 180 if left == right: 181 return left 182 if isinstance(expression, exp.And): 183 if is_false(left) or is_false(right): 184 return exp.false() 185 if is_null(left) or is_null(right): 186 return exp.null() 187 if always_true(left) and always_true(right): 188 return exp.true() 189 if always_true(left): 190 return right 191 if always_true(right): 192 return left 193 return _simplify_comparison(expression, left, right) 194 elif isinstance(expression, exp.Or): 195 if always_true(left) or always_true(right): 196 return exp.true() 197 if is_false(left) and is_false(right): 198 return exp.false() 199 if ( 200 (is_null(left) and is_null(right)) 201 or (is_null(left) and is_false(right)) 202 or (is_false(left) and is_null(right)) 203 ): 204 return exp.null() 205 if is_false(left): 206 return right 207 if is_false(right): 208 return left 209 return _simplify_comparison(expression, left, right, or_=True) 210 211 if isinstance(expression, exp.Connector): 212 return _flat_simplify(expression, _simplify_connectors, root) 213 return expression
296def remove_complements(expression, root=True): 297 """ 298 Removing complements. 299 300 A AND NOT A -> FALSE 301 A OR NOT A -> TRUE 302 """ 303 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 304 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 305 306 for a, b in itertools.permutations(expression.flatten(), 2): 307 if is_complement(a, b): 308 return complement 309 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
312def uniq_sort(expression, root=True): 313 """ 314 Uniq and sort a connector. 315 316 C AND A AND B AND B -> A AND B AND C 317 """ 318 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 319 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 320 flattened = tuple(expression.flatten()) 321 deduped = {gen(e): e for e in flattened} 322 arr = tuple(deduped.items()) 323 324 # check if the operands are already sorted, if not sort them 325 # A AND C AND B -> A AND B AND C 326 for i, (sql, e) in enumerate(arr[1:]): 327 if sql < arr[i][0]: 328 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 329 break 330 else: 331 # we didn't have to sort but maybe we need to dedup 332 if len(deduped) < len(flattened): 333 expression = result_func(*deduped.values(), copy=False) 334 335 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
338def absorb_and_eliminate(expression, root=True): 339 """ 340 absorption: 341 A AND (A OR B) -> A 342 A OR (A AND B) -> A 343 A AND (NOT A OR B) -> A AND B 344 A OR (NOT A AND B) -> A OR B 345 elimination: 346 (A AND B) OR (A AND NOT B) -> A 347 (A OR B) AND (A OR NOT B) -> A 348 """ 349 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 350 kind = exp.Or if isinstance(expression, exp.And) else exp.And 351 352 for a, b in itertools.permutations(expression.flatten(), 2): 353 if isinstance(a, kind): 354 aa, ab = a.unnest_operands() 355 356 # absorb 357 if is_complement(b, aa): 358 aa.replace(exp.true() if kind == exp.And else exp.false()) 359 elif is_complement(b, ab): 360 ab.replace(exp.true() if kind == exp.And else exp.false()) 361 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 362 a.replace(exp.false() if kind == exp.And else exp.true()) 363 elif isinstance(b, kind): 364 # eliminate 365 rhs = b.unnest_operands() 366 ba, bb = rhs 367 368 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 369 a.replace(aa) 370 b.replace(aa) 371 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 372 a.replace(ab) 373 b.replace(ab) 374 375 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
378def propagate_constants(expression, root=True): 379 """ 380 Propagate constants for conjunctions in DNF: 381 382 SELECT * FROM t WHERE a = b AND b = 5 becomes 383 SELECT * FROM t WHERE a = 5 AND b = 5 384 385 Reference: https://www.sqlite.org/optoverview.html 386 """ 387 388 if ( 389 isinstance(expression, exp.And) 390 and (root or not expression.same_parent) 391 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 392 ): 393 constant_mapping = {} 394 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 395 if isinstance(expr, exp.EQ): 396 l, r = expr.left, expr.right 397 398 # TODO: create a helper that can be used to detect nested literal expressions such 399 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 400 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 401 pass 402 elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): 403 l, r = r, l 404 else: 405 continue 406 407 constant_mapping[l] = (id(l), r) 408 409 if constant_mapping: 410 for column in find_all_in_scope(expression, exp.Column): 411 parent = column.parent 412 column_id, constant = constant_mapping.get(column) or (None, None) 413 if ( 414 column_id is not None 415 and id(column) != column_id 416 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 417 ): 418 column.replace(constant.copy()) 419 420 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
104 def wrapped(expression, *args, **kwargs): 105 try: 106 return func(expression, *args, **kwargs) 107 except exceptions: 108 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
499def simplify_literals(expression, root=True): 500 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 501 return _flat_simplify(expression, _simplify_binary, root) 502 503 if isinstance(expression, exp.Neg): 504 this = expression.this 505 if this.is_number: 506 value = this.name 507 if value[0] == "-": 508 return exp.Literal.number(value[1:]) 509 return exp.Literal.number(f"-{value}") 510 511 return expression
574def simplify_parens(expression): 575 if not isinstance(expression, exp.Paren): 576 return expression 577 578 this = expression.this 579 parent = expression.parent 580 581 if not isinstance(this, exp.Select) and ( 582 not isinstance(parent, (exp.Condition, exp.Binary)) 583 or isinstance(parent, exp.Paren) 584 or not isinstance(this, exp.Binary) 585 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 586 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 587 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 588 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 589 ): 590 return this 591 return expression
601def simplify_coalesce(expression): 602 # COALESCE(x) -> x 603 if ( 604 isinstance(expression, exp.Coalesce) 605 and not expression.expressions 606 # COALESCE is also used as a Spark partitioning hint 607 and not isinstance(expression.parent, exp.Hint) 608 ): 609 return expression.this 610 611 if not isinstance(expression, COMPARISONS): 612 return expression 613 614 if isinstance(expression.left, exp.Coalesce): 615 coalesce = expression.left 616 other = expression.right 617 elif isinstance(expression.right, exp.Coalesce): 618 coalesce = expression.right 619 other = expression.left 620 else: 621 return expression 622 623 # This transformation is valid for non-constants, 624 # but it really only does anything if they are both constants. 625 if not isinstance(other, CONSTANTS): 626 return expression 627 628 # Find the first constant arg 629 for arg_index, arg in enumerate(coalesce.expressions): 630 if isinstance(arg, CONSTANTS): 631 break 632 else: 633 return expression 634 635 coalesce.set("expressions", coalesce.expressions[:arg_index]) 636 637 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 638 # since we already remove COALESCE at the top of this function. 639 coalesce = coalesce if coalesce.expressions else coalesce.this 640 641 # This expression is more complex than when we started, but it will get simplified further 642 return exp.paren( 643 exp.or_( 644 exp.and_( 645 coalesce.is_(exp.null()).not_(copy=False), 646 expression.copy(), 647 copy=False, 648 ), 649 exp.and_( 650 coalesce.is_(exp.null()), 651 type(expression)(this=arg.copy(), expression=other.copy()), 652 copy=False, 653 ), 654 copy=False, 655 ) 656 )
663def simplify_concat(expression): 664 """Reduces all groups that contain string literals by concatenating them.""" 665 if not isinstance(expression, CONCATS) or ( 666 # We can't reduce a CONCAT_WS call if we don't statically know the separator 667 isinstance(expression, exp.ConcatWs) 668 and not expression.expressions[0].is_string 669 ): 670 return expression 671 672 if isinstance(expression, exp.ConcatWs): 673 sep_expr, *expressions = expression.expressions 674 sep = sep_expr.name 675 concat_type = exp.ConcatWs 676 else: 677 expressions = expression.expressions 678 sep = "" 679 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 680 681 new_args = [] 682 for is_string_group, group in itertools.groupby( 683 expressions or expression.flatten(), lambda e: e.is_string 684 ): 685 if is_string_group: 686 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 687 else: 688 new_args.extend(group) 689 690 if len(new_args) == 1 and new_args[0].is_string: 691 return new_args[0] 692 693 if concat_type is exp.ConcatWs: 694 new_args = [sep_expr] + new_args 695 696 return concat_type(expressions=new_args)
Reduces all groups that contain string literals by concatenating them.
699def simplify_conditionals(expression): 700 """Simplifies expressions like IF, CASE if their condition is statically known.""" 701 if isinstance(expression, exp.Case): 702 this = expression.this 703 for case in expression.args["ifs"]: 704 cond = case.this 705 if this: 706 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 707 cond = cond.replace(this.pop().eq(cond)) 708 709 if always_true(cond): 710 return case.args["true"] 711 712 if always_false(cond): 713 case.pop() 714 if not expression.args["ifs"]: 715 return expression.args.get("default") or exp.null() 716 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 717 if always_true(expression.this): 718 return expression.args["true"] 719 if always_false(expression.this): 720 return expression.args.get("false") or exp.null() 721 722 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
104 def wrapped(expression, *args, **kwargs): 105 try: 106 return func(expression, *args, **kwargs) 107 except exceptions: 108 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
862def remove_where_true(expression): 863 for where in expression.find_all(exp.Where): 864 if always_true(where.this): 865 where.parent.set("where", None) 866 for join in expression.find_all(exp.Join): 867 if ( 868 always_true(join.args.get("on")) 869 and not join.args.get("using") 870 and not join.args.get("method") 871 and (join.side, join.kind) in JOINS 872 ): 873 join.set("on", None) 874 join.set("side", None) 875 join.set("kind", "CROSS")
900def eval_boolean(expression, a, b): 901 if isinstance(expression, (exp.EQ, exp.Is)): 902 return boolean_literal(a == b) 903 if isinstance(expression, exp.NEQ): 904 return boolean_literal(a != b) 905 if isinstance(expression, exp.GT): 906 return boolean_literal(a > b) 907 if isinstance(expression, exp.GTE): 908 return boolean_literal(a >= b) 909 if isinstance(expression, exp.LT): 910 return boolean_literal(a < b) 911 if isinstance(expression, exp.LTE): 912 return boolean_literal(a <= b) 913 return None
927def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 928 if isinstance(value, datetime.datetime): 929 return value 930 if isinstance(value, datetime.date): 931 return datetime.datetime(year=value.year, month=value.month, day=value.day) 932 try: 933 return datetime.datetime.fromisoformat(value) 934 except ValueError: 935 return None
938def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 939 if not value: 940 return None 941 if to.is_type(exp.DataType.Type.DATE): 942 return cast_as_date(value) 943 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 944 return cast_as_datetime(value) 945 return None
948def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 949 if isinstance(cast, exp.Cast): 950 to = cast.to 951 elif isinstance(cast, exp.TsOrDsToDate): 952 to = exp.DataType.build(exp.DataType.Type.DATE) 953 else: 954 return None 955 956 if isinstance(cast.this, exp.Literal): 957 value: t.Any = cast.this.name 958 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 959 value = extract_date(cast.this) 960 else: 961 return None 962 return cast_value(value, to)
988def interval(unit: str, n: int = 1): 989 from dateutil.relativedelta import relativedelta 990 991 if unit == "year": 992 return relativedelta(years=1 * n) 993 if unit == "quarter": 994 return relativedelta(months=3 * n) 995 if unit == "month": 996 return relativedelta(months=1 * n) 997 if unit == "week": 998 return relativedelta(weeks=1 * n) 999 if unit == "day": 1000 return relativedelta(days=1 * n) 1001 if unit == "hour": 1002 return relativedelta(hours=1 * n) 1003 if unit == "minute": 1004 return relativedelta(minutes=1 * n) 1005 if unit == "second": 1006 return relativedelta(seconds=1 * n) 1007 1008 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1011def date_floor(d: datetime.date, unit: str) -> datetime.date: 1012 if unit == "year": 1013 return d.replace(month=1, day=1) 1014 if unit == "quarter": 1015 if d.month <= 3: 1016 return d.replace(month=1, day=1) 1017 elif d.month <= 6: 1018 return d.replace(month=4, day=1) 1019 elif d.month <= 9: 1020 return d.replace(month=7, day=1) 1021 else: 1022 return d.replace(month=10, day=1) 1023 if unit == "month": 1024 return d.replace(month=d.month, day=1) 1025 if unit == "week": 1026 # Assuming week starts on Monday (0) and ends on Sunday (6) 1027 return d - datetime.timedelta(days=d.weekday()) 1028 if unit == "day": 1029 return d 1030 1031 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1073def gen(expression: t.Any) -> str: 1074 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1075 1076 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1077 generator is expensive so we have a bare minimum sql generator here. 1078 """ 1079 if expression is None: 1080 return "_" 1081 if is_iterable(expression): 1082 return ",".join(gen(e) for e in expression) 1083 if not isinstance(expression, exp.Expression): 1084 return str(expression) 1085 1086 etype = type(expression) 1087 if etype in GEN_MAP: 1088 return GEN_MAP[etype](expression) 1089 return f"{expression.key} {gen(expression.args.values())}"
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.