sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import functools 5import itertools 6import typing as t 7from collections import deque 8from decimal import Decimal 9 10import sqlglot 11from sqlglot import Dialect, exp 12from sqlglot.helper import first, is_iterable, merge_ranges, while_changing 13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 14 15if t.TYPE_CHECKING: 16 from sqlglot.dialects.dialect import DialectType 17 18 DateTruncBinaryTransform = t.Callable[ 19 [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression] 20 ] 21 22# Final means that an expression should not be simplified 23FINAL = "final" 24 25 26class UnsupportedUnit(Exception): 27 pass 28 29 30def simplify( 31 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 32): 33 """ 34 Rewrite sqlglot AST to simplify expressions. 35 36 Example: 37 >>> import sqlglot 38 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 39 >>> simplify(expression).sql() 40 'TRUE' 41 42 Args: 43 expression (sqlglot.Expression): expression to simplify 44 constant_propagation: whether or not the constant propagation rule should be used 45 46 Returns: 47 sqlglot.Expression: simplified expression 48 """ 49 50 dialect = Dialect.get_or_raise(dialect) 51 52 def _simplify(expression, root=True): 53 if expression.meta.get(FINAL): 54 return expression 55 56 # group by expressions cannot be simplified, for example 57 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 58 # the projection must exactly match the group by key 59 group = expression.args.get("group") 60 61 if group and hasattr(expression, "selects"): 62 groups = set(group.expressions) 63 group.meta[FINAL] = True 64 65 for e in expression.selects: 66 for node, *_ in e.walk(): 67 if node in groups: 68 e.meta[FINAL] = True 69 break 70 71 having = expression.args.get("having") 72 if having: 73 for node, *_ in having.walk(): 74 if node in groups: 75 having.meta[FINAL] = True 76 break 77 78 # Pre-order transformations 79 node = expression 80 node = rewrite_between(node) 81 node = uniq_sort(node, root) 82 node = absorb_and_eliminate(node, root) 83 node = simplify_concat(node) 84 node = simplify_conditionals(node) 85 86 if constant_propagation: 87 node = propagate_constants(node, root) 88 89 exp.replace_children(node, lambda e: _simplify(e, False)) 90 91 # Post-order transformations 92 node = simplify_not(node) 93 node = flatten(node) 94 node = simplify_connectors(node, root) 95 node = remove_complements(node, root) 96 node = simplify_coalesce(node) 97 node.parent = expression.parent 98 node = simplify_literals(node, root) 99 node = simplify_equality(node) 100 node = simplify_parens(node) 101 node = simplify_datetrunc(node, dialect) 102 node = sort_comparison(node) 103 node = simplify_startswith(node) 104 105 if root: 106 expression.replace(node) 107 return node 108 109 expression = while_changing(expression, _simplify) 110 remove_where_true(expression) 111 return expression 112 113 114def catch(*exceptions): 115 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 116 117 def decorator(func): 118 def wrapped(expression, *args, **kwargs): 119 try: 120 return func(expression, *args, **kwargs) 121 except exceptions: 122 return expression 123 124 return wrapped 125 126 return decorator 127 128 129def rewrite_between(expression: exp.Expression) -> exp.Expression: 130 """Rewrite x between y and z to x >= y AND x <= z. 131 132 This is done because comparison simplification is only done on lt/lte/gt/gte. 133 """ 134 if isinstance(expression, exp.Between): 135 negate = isinstance(expression.parent, exp.Not) 136 137 expression = exp.and_( 138 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 139 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 140 copy=False, 141 ) 142 143 if negate: 144 expression = exp.paren(expression, copy=False) 145 146 return expression 147 148 149COMPLEMENT_COMPARISONS = { 150 exp.LT: exp.GTE, 151 exp.GT: exp.LTE, 152 exp.LTE: exp.GT, 153 exp.GTE: exp.LT, 154 exp.EQ: exp.NEQ, 155 exp.NEQ: exp.EQ, 156} 157 158 159def simplify_not(expression): 160 """ 161 Demorgan's Law 162 NOT (x OR y) -> NOT x AND NOT y 163 NOT (x AND y) -> NOT x OR NOT y 164 """ 165 if isinstance(expression, exp.Not): 166 this = expression.this 167 if is_null(this): 168 return exp.null() 169 if this.__class__ in COMPLEMENT_COMPARISONS: 170 return COMPLEMENT_COMPARISONS[this.__class__]( 171 this=this.this, expression=this.expression 172 ) 173 if isinstance(this, exp.Paren): 174 condition = this.unnest() 175 if isinstance(condition, exp.And): 176 return exp.paren( 177 exp.or_( 178 exp.not_(condition.left, copy=False), 179 exp.not_(condition.right, copy=False), 180 copy=False, 181 ) 182 ) 183 if isinstance(condition, exp.Or): 184 return exp.paren( 185 exp.and_( 186 exp.not_(condition.left, copy=False), 187 exp.not_(condition.right, copy=False), 188 copy=False, 189 ) 190 ) 191 if is_null(condition): 192 return exp.null() 193 if always_true(this): 194 return exp.false() 195 if is_false(this): 196 return exp.true() 197 if isinstance(this, exp.Not): 198 # double negation 199 # NOT NOT x -> x 200 return this.this 201 return expression 202 203 204def flatten(expression): 205 """ 206 A AND (B AND C) -> A AND B AND C 207 A OR (B OR C) -> A OR B OR C 208 """ 209 if isinstance(expression, exp.Connector): 210 for node in expression.args.values(): 211 child = node.unnest() 212 if isinstance(child, expression.__class__): 213 node.replace(child) 214 return expression 215 216 217def simplify_connectors(expression, root=True): 218 def _simplify_connectors(expression, left, right): 219 if left == right: 220 return left 221 if isinstance(expression, exp.And): 222 if is_false(left) or is_false(right): 223 return exp.false() 224 if is_null(left) or is_null(right): 225 return exp.null() 226 if always_true(left) and always_true(right): 227 return exp.true() 228 if always_true(left): 229 return right 230 if always_true(right): 231 return left 232 return _simplify_comparison(expression, left, right) 233 elif isinstance(expression, exp.Or): 234 if always_true(left) or always_true(right): 235 return exp.true() 236 if is_false(left) and is_false(right): 237 return exp.false() 238 if ( 239 (is_null(left) and is_null(right)) 240 or (is_null(left) and is_false(right)) 241 or (is_false(left) and is_null(right)) 242 ): 243 return exp.null() 244 if is_false(left): 245 return right 246 if is_false(right): 247 return left 248 return _simplify_comparison(expression, left, right, or_=True) 249 250 if isinstance(expression, exp.Connector): 251 return _flat_simplify(expression, _simplify_connectors, root) 252 return expression 253 254 255LT_LTE = (exp.LT, exp.LTE) 256GT_GTE = (exp.GT, exp.GTE) 257 258COMPARISONS = ( 259 *LT_LTE, 260 *GT_GTE, 261 exp.EQ, 262 exp.NEQ, 263 exp.Is, 264) 265 266INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 267 exp.LT: exp.GT, 268 exp.GT: exp.LT, 269 exp.LTE: exp.GTE, 270 exp.GTE: exp.LTE, 271} 272 273NONDETERMINISTIC = (exp.Rand, exp.Randn) 274 275 276def _simplify_comparison(expression, left, right, or_=False): 277 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 278 ll, lr = left.args.values() 279 rl, rr = right.args.values() 280 281 largs = {ll, lr} 282 rargs = {rl, rr} 283 284 matching = largs & rargs 285 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 286 287 if matching and columns: 288 try: 289 l = first(largs - columns) 290 r = first(rargs - columns) 291 except StopIteration: 292 return expression 293 294 if l.is_number and r.is_number: 295 l = float(l.name) 296 r = float(r.name) 297 elif l.is_string and r.is_string: 298 l = l.name 299 r = r.name 300 else: 301 l = extract_date(l) 302 if not l: 303 return None 304 r = extract_date(r) 305 if not r: 306 return None 307 308 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 309 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 310 return left if (av > bv if or_ else av <= bv) else right 311 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 312 return left if (av < bv if or_ else av >= bv) else right 313 314 # we can't ever shortcut to true because the column could be null 315 if not or_: 316 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 317 if av <= bv: 318 return exp.false() 319 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 320 if av >= bv: 321 return exp.false() 322 elif isinstance(a, exp.EQ): 323 if isinstance(b, exp.LT): 324 return exp.false() if av >= bv else a 325 if isinstance(b, exp.LTE): 326 return exp.false() if av > bv else a 327 if isinstance(b, exp.GT): 328 return exp.false() if av <= bv else a 329 if isinstance(b, exp.GTE): 330 return exp.false() if av < bv else a 331 if isinstance(b, exp.NEQ): 332 return exp.false() if av == bv else a 333 return None 334 335 336def remove_complements(expression, root=True): 337 """ 338 Removing complements. 339 340 A AND NOT A -> FALSE 341 A OR NOT A -> TRUE 342 """ 343 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 344 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 345 346 for a, b in itertools.permutations(expression.flatten(), 2): 347 if is_complement(a, b): 348 return complement 349 return expression 350 351 352def uniq_sort(expression, root=True): 353 """ 354 Uniq and sort a connector. 355 356 C AND A AND B AND B -> A AND B AND C 357 """ 358 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 359 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 360 flattened = tuple(expression.flatten()) 361 deduped = {gen(e): e for e in flattened} 362 arr = tuple(deduped.items()) 363 364 # check if the operands are already sorted, if not sort them 365 # A AND C AND B -> A AND B AND C 366 for i, (sql, e) in enumerate(arr[1:]): 367 if sql < arr[i][0]: 368 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 369 break 370 else: 371 # we didn't have to sort but maybe we need to dedup 372 if len(deduped) < len(flattened): 373 expression = result_func(*deduped.values(), copy=False) 374 375 return expression 376 377 378def absorb_and_eliminate(expression, root=True): 379 """ 380 absorption: 381 A AND (A OR B) -> A 382 A OR (A AND B) -> A 383 A AND (NOT A OR B) -> A AND B 384 A OR (NOT A AND B) -> A OR B 385 elimination: 386 (A AND B) OR (A AND NOT B) -> A 387 (A OR B) AND (A OR NOT B) -> A 388 """ 389 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 390 kind = exp.Or if isinstance(expression, exp.And) else exp.And 391 392 for a, b in itertools.permutations(expression.flatten(), 2): 393 if isinstance(a, kind): 394 aa, ab = a.unnest_operands() 395 396 # absorb 397 if is_complement(b, aa): 398 aa.replace(exp.true() if kind == exp.And else exp.false()) 399 elif is_complement(b, ab): 400 ab.replace(exp.true() if kind == exp.And else exp.false()) 401 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 402 a.replace(exp.false() if kind == exp.And else exp.true()) 403 elif isinstance(b, kind): 404 # eliminate 405 rhs = b.unnest_operands() 406 ba, bb = rhs 407 408 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 409 a.replace(aa) 410 b.replace(aa) 411 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 412 a.replace(ab) 413 b.replace(ab) 414 415 return expression 416 417 418def propagate_constants(expression, root=True): 419 """ 420 Propagate constants for conjunctions in DNF: 421 422 SELECT * FROM t WHERE a = b AND b = 5 becomes 423 SELECT * FROM t WHERE a = 5 AND b = 5 424 425 Reference: https://www.sqlite.org/optoverview.html 426 """ 427 428 if ( 429 isinstance(expression, exp.And) 430 and (root or not expression.same_parent) 431 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 432 ): 433 constant_mapping = {} 434 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 435 if isinstance(expr, exp.EQ): 436 l, r = expr.left, expr.right 437 438 # TODO: create a helper that can be used to detect nested literal expressions such 439 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 440 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 441 constant_mapping[l] = (id(l), r) 442 443 if constant_mapping: 444 for column in find_all_in_scope(expression, exp.Column): 445 parent = column.parent 446 column_id, constant = constant_mapping.get(column) or (None, None) 447 if ( 448 column_id is not None 449 and id(column) != column_id 450 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 451 ): 452 column.replace(constant.copy()) 453 454 return expression 455 456 457INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 458 exp.DateAdd: exp.Sub, 459 exp.DateSub: exp.Add, 460 exp.DatetimeAdd: exp.Sub, 461 exp.DatetimeSub: exp.Add, 462} 463 464INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 465 **INVERSE_DATE_OPS, 466 exp.Add: exp.Sub, 467 exp.Sub: exp.Add, 468} 469 470 471def _is_number(expression: exp.Expression) -> bool: 472 return expression.is_number 473 474 475def _is_interval(expression: exp.Expression) -> bool: 476 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 477 478 479@catch(ModuleNotFoundError, UnsupportedUnit) 480def simplify_equality(expression: exp.Expression) -> exp.Expression: 481 """ 482 Use the subtraction and addition properties of equality to simplify expressions: 483 484 x + 1 = 3 becomes x = 2 485 486 There are two binary operations in the above expression: + and = 487 Here's how we reference all the operands in the code below: 488 489 l r 490 x + 1 = 3 491 a b 492 """ 493 if isinstance(expression, COMPARISONS): 494 l, r = expression.left, expression.right 495 496 if l.__class__ not in INVERSE_OPS: 497 return expression 498 499 if r.is_number: 500 a_predicate = _is_number 501 b_predicate = _is_number 502 elif _is_date_literal(r): 503 a_predicate = _is_date_literal 504 b_predicate = _is_interval 505 else: 506 return expression 507 508 if l.__class__ in INVERSE_DATE_OPS: 509 l = t.cast(exp.IntervalOp, l) 510 a = l.this 511 b = l.interval() 512 else: 513 l = t.cast(exp.Binary, l) 514 a, b = l.left, l.right 515 516 if not a_predicate(a) and b_predicate(b): 517 pass 518 elif not a_predicate(b) and b_predicate(a): 519 a, b = b, a 520 else: 521 return expression 522 523 return expression.__class__( 524 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 525 ) 526 return expression 527 528 529def simplify_literals(expression, root=True): 530 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 531 return _flat_simplify(expression, _simplify_binary, root) 532 533 if isinstance(expression, exp.Neg): 534 this = expression.this 535 if this.is_number: 536 value = this.name 537 if value[0] == "-": 538 return exp.Literal.number(value[1:]) 539 return exp.Literal.number(f"-{value}") 540 541 if type(expression) in INVERSE_DATE_OPS: 542 return _simplify_binary(expression, expression.this, expression.interval()) or expression 543 544 return expression 545 546 547def _simplify_binary(expression, a, b): 548 if isinstance(expression, exp.Is): 549 if isinstance(b, exp.Not): 550 c = b.this 551 not_ = True 552 else: 553 c = b 554 not_ = False 555 556 if is_null(c): 557 if isinstance(a, exp.Literal): 558 return exp.true() if not_ else exp.false() 559 if is_null(a): 560 return exp.false() if not_ else exp.true() 561 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 562 return None 563 elif is_null(a) or is_null(b): 564 return exp.null() 565 566 if a.is_number and b.is_number: 567 num_a = int(a.name) if a.is_int else Decimal(a.name) 568 num_b = int(b.name) if b.is_int else Decimal(b.name) 569 570 if isinstance(expression, exp.Add): 571 return exp.Literal.number(num_a + num_b) 572 if isinstance(expression, exp.Mul): 573 return exp.Literal.number(num_a * num_b) 574 575 # We only simplify Sub, Div if a and b have the same parent because they're not associative 576 if isinstance(expression, exp.Sub): 577 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 578 if isinstance(expression, exp.Div): 579 # engines have differing int div behavior so intdiv is not safe 580 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 581 return None 582 return exp.Literal.number(num_a / num_b) 583 584 boolean = eval_boolean(expression, num_a, num_b) 585 586 if boolean: 587 return boolean 588 elif a.is_string and b.is_string: 589 boolean = eval_boolean(expression, a.this, b.this) 590 591 if boolean: 592 return boolean 593 elif _is_date_literal(a) and isinstance(b, exp.Interval): 594 a, b = extract_date(a), extract_interval(b) 595 if a and b: 596 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 597 return date_literal(a + b) 598 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 599 return date_literal(a - b) 600 elif isinstance(a, exp.Interval) and _is_date_literal(b): 601 a, b = extract_interval(a), extract_date(b) 602 # you cannot subtract a date from an interval 603 if a and b and isinstance(expression, exp.Add): 604 return date_literal(a + b) 605 elif _is_date_literal(a) and _is_date_literal(b): 606 if isinstance(expression, exp.Predicate): 607 a, b = extract_date(a), extract_date(b) 608 boolean = eval_boolean(expression, a, b) 609 if boolean: 610 return boolean 611 612 return None 613 614 615def simplify_parens(expression): 616 if not isinstance(expression, exp.Paren): 617 return expression 618 619 this = expression.this 620 parent = expression.parent 621 622 if not isinstance(this, exp.Select) and ( 623 not isinstance(parent, (exp.Condition, exp.Binary)) 624 or isinstance(parent, exp.Paren) 625 or not isinstance(this, exp.Binary) 626 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 627 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 628 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 629 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 630 ): 631 return this 632 return expression 633 634 635NONNULL_CONSTANTS = ( 636 exp.Literal, 637 exp.Boolean, 638) 639 640CONSTANTS = ( 641 exp.Literal, 642 exp.Boolean, 643 exp.Null, 644) 645 646 647def _is_nonnull_constant(expression: exp.Expression) -> bool: 648 return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression) 649 650 651def _is_constant(expression: exp.Expression) -> bool: 652 return isinstance(expression, CONSTANTS) or _is_date_literal(expression) 653 654 655def simplify_coalesce(expression): 656 # COALESCE(x) -> x 657 if ( 658 isinstance(expression, exp.Coalesce) 659 and (not expression.expressions or _is_nonnull_constant(expression.this)) 660 # COALESCE is also used as a Spark partitioning hint 661 and not isinstance(expression.parent, exp.Hint) 662 ): 663 return expression.this 664 665 if not isinstance(expression, COMPARISONS): 666 return expression 667 668 if isinstance(expression.left, exp.Coalesce): 669 coalesce = expression.left 670 other = expression.right 671 elif isinstance(expression.right, exp.Coalesce): 672 coalesce = expression.right 673 other = expression.left 674 else: 675 return expression 676 677 # This transformation is valid for non-constants, 678 # but it really only does anything if they are both constants. 679 if not _is_constant(other): 680 return expression 681 682 # Find the first constant arg 683 for arg_index, arg in enumerate(coalesce.expressions): 684 if _is_constant(arg): 685 break 686 else: 687 return expression 688 689 coalesce.set("expressions", coalesce.expressions[:arg_index]) 690 691 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 692 # since we already remove COALESCE at the top of this function. 693 coalesce = coalesce if coalesce.expressions else coalesce.this 694 695 # This expression is more complex than when we started, but it will get simplified further 696 return exp.paren( 697 exp.or_( 698 exp.and_( 699 coalesce.is_(exp.null()).not_(copy=False), 700 expression.copy(), 701 copy=False, 702 ), 703 exp.and_( 704 coalesce.is_(exp.null()), 705 type(expression)(this=arg.copy(), expression=other.copy()), 706 copy=False, 707 ), 708 copy=False, 709 ) 710 ) 711 712 713CONCATS = (exp.Concat, exp.DPipe) 714 715 716def simplify_concat(expression): 717 """Reduces all groups that contain string literals by concatenating them.""" 718 if not isinstance(expression, CONCATS) or ( 719 # We can't reduce a CONCAT_WS call if we don't statically know the separator 720 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 721 ): 722 return expression 723 724 if isinstance(expression, exp.ConcatWs): 725 sep_expr, *expressions = expression.expressions 726 sep = sep_expr.name 727 concat_type = exp.ConcatWs 728 args = {} 729 else: 730 expressions = expression.expressions 731 sep = "" 732 concat_type = exp.Concat 733 args = { 734 "safe": expression.args.get("safe"), 735 "coalesce": expression.args.get("coalesce"), 736 } 737 738 new_args = [] 739 for is_string_group, group in itertools.groupby( 740 expressions or expression.flatten(), lambda e: e.is_string 741 ): 742 if is_string_group: 743 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 744 else: 745 new_args.extend(group) 746 747 if len(new_args) == 1 and new_args[0].is_string: 748 return new_args[0] 749 750 if concat_type is exp.ConcatWs: 751 new_args = [sep_expr] + new_args 752 753 return concat_type(expressions=new_args, **args) 754 755 756def simplify_conditionals(expression): 757 """Simplifies expressions like IF, CASE if their condition is statically known.""" 758 if isinstance(expression, exp.Case): 759 this = expression.this 760 for case in expression.args["ifs"]: 761 cond = case.this 762 if this: 763 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 764 cond = cond.replace(this.pop().eq(cond)) 765 766 if always_true(cond): 767 return case.args["true"] 768 769 if always_false(cond): 770 case.pop() 771 if not expression.args["ifs"]: 772 return expression.args.get("default") or exp.null() 773 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 774 if always_true(expression.this): 775 return expression.args["true"] 776 if always_false(expression.this): 777 return expression.args.get("false") or exp.null() 778 779 return expression 780 781 782def simplify_startswith(expression: exp.Expression) -> exp.Expression: 783 """ 784 Reduces a prefix check to either TRUE or FALSE if both the string and the 785 prefix are statically known. 786 787 Example: 788 >>> from sqlglot import parse_one 789 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 790 'TRUE' 791 """ 792 if ( 793 isinstance(expression, exp.StartsWith) 794 and expression.this.is_string 795 and expression.expression.is_string 796 ): 797 return exp.convert(expression.name.startswith(expression.expression.name)) 798 799 return expression 800 801 802DateRange = t.Tuple[datetime.date, datetime.date] 803 804 805def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 806 """ 807 Get the date range for a DATE_TRUNC equality comparison: 808 809 Example: 810 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 811 Returns: 812 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 813 """ 814 floor = date_floor(date, unit, dialect) 815 816 if date != floor: 817 # This will always be False, except for NULL values. 818 return None 819 820 return floor, floor + interval(unit) 821 822 823def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 824 """Get the logical expression for a date range""" 825 return exp.and_( 826 left >= date_literal(drange[0]), 827 left < date_literal(drange[1]), 828 copy=False, 829 ) 830 831 832def _datetrunc_eq( 833 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 834) -> t.Optional[exp.Expression]: 835 drange = _datetrunc_range(date, unit, dialect) 836 if not drange: 837 return None 838 839 return _datetrunc_eq_expression(left, drange) 840 841 842def _datetrunc_neq( 843 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 844) -> t.Optional[exp.Expression]: 845 drange = _datetrunc_range(date, unit, dialect) 846 if not drange: 847 return None 848 849 return exp.and_( 850 left < date_literal(drange[0]), 851 left >= date_literal(drange[1]), 852 copy=False, 853 ) 854 855 856DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 857 exp.LT: lambda l, dt, u, d: l 858 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)), 859 exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)), 860 exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)), 861 exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)), 862 exp.EQ: _datetrunc_eq, 863 exp.NEQ: _datetrunc_neq, 864} 865DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 866DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 867 868 869def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 870 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 871 872 873@catch(ModuleNotFoundError, UnsupportedUnit) 874def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 875 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 876 comparison = expression.__class__ 877 878 if isinstance(expression, DATETRUNCS): 879 date = extract_date(expression.this) 880 if date and expression.unit: 881 return date_literal(date_floor(date, expression.unit.name.lower(), dialect)) 882 elif comparison not in DATETRUNC_COMPARISONS: 883 return expression 884 885 if isinstance(expression, exp.Binary): 886 l, r = expression.left, expression.right 887 888 if not _is_datetrunc_predicate(l, r): 889 return expression 890 891 l = t.cast(exp.DateTrunc, l) 892 unit = l.unit.name.lower() 893 date = extract_date(r) 894 895 if not date: 896 return expression 897 898 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression 899 elif isinstance(expression, exp.In): 900 l = expression.this 901 rs = expression.expressions 902 903 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 904 l = t.cast(exp.DateTrunc, l) 905 unit = l.unit.name.lower() 906 907 ranges = [] 908 for r in rs: 909 date = extract_date(r) 910 if not date: 911 return expression 912 drange = _datetrunc_range(date, unit, dialect) 913 if drange: 914 ranges.append(drange) 915 916 if not ranges: 917 return expression 918 919 ranges = merge_ranges(ranges) 920 921 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 922 923 return expression 924 925 926def sort_comparison(expression: exp.Expression) -> exp.Expression: 927 if expression.__class__ in COMPLEMENT_COMPARISONS: 928 l, r = expression.this, expression.expression 929 l_column = isinstance(l, exp.Column) 930 r_column = isinstance(r, exp.Column) 931 l_const = _is_constant(l) 932 r_const = _is_constant(r) 933 934 if (l_column and not r_column) or (r_const and not l_const): 935 return expression 936 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 937 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 938 this=r, expression=l 939 ) 940 return expression 941 942 943# CROSS joins result in an empty table if the right table is empty. 944# So we can only simplify certain types of joins to CROSS. 945# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 946JOINS = { 947 ("", ""), 948 ("", "INNER"), 949 ("RIGHT", ""), 950 ("RIGHT", "OUTER"), 951} 952 953 954def remove_where_true(expression): 955 for where in expression.find_all(exp.Where): 956 if always_true(where.this): 957 where.parent.set("where", None) 958 for join in expression.find_all(exp.Join): 959 if ( 960 always_true(join.args.get("on")) 961 and not join.args.get("using") 962 and not join.args.get("method") 963 and (join.side, join.kind) in JOINS 964 ): 965 join.set("on", None) 966 join.set("side", None) 967 join.set("kind", "CROSS") 968 969 970def always_true(expression): 971 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 972 expression, exp.Literal 973 ) 974 975 976def always_false(expression): 977 return is_false(expression) or is_null(expression) 978 979 980def is_complement(a, b): 981 return isinstance(b, exp.Not) and b.this == a 982 983 984def is_false(a: exp.Expression) -> bool: 985 return type(a) is exp.Boolean and not a.this 986 987 988def is_null(a: exp.Expression) -> bool: 989 return type(a) is exp.Null 990 991 992def eval_boolean(expression, a, b): 993 if isinstance(expression, (exp.EQ, exp.Is)): 994 return boolean_literal(a == b) 995 if isinstance(expression, exp.NEQ): 996 return boolean_literal(a != b) 997 if isinstance(expression, exp.GT): 998 return boolean_literal(a > b) 999 if isinstance(expression, exp.GTE): 1000 return boolean_literal(a >= b) 1001 if isinstance(expression, exp.LT): 1002 return boolean_literal(a < b) 1003 if isinstance(expression, exp.LTE): 1004 return boolean_literal(a <= b) 1005 return None 1006 1007 1008def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1009 if isinstance(value, datetime.datetime): 1010 return value.date() 1011 if isinstance(value, datetime.date): 1012 return value 1013 try: 1014 return datetime.datetime.fromisoformat(value).date() 1015 except ValueError: 1016 return None 1017 1018 1019def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1020 if isinstance(value, datetime.datetime): 1021 return value 1022 if isinstance(value, datetime.date): 1023 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1024 try: 1025 return datetime.datetime.fromisoformat(value) 1026 except ValueError: 1027 return None 1028 1029 1030def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1031 if not value: 1032 return None 1033 if to.is_type(exp.DataType.Type.DATE): 1034 return cast_as_date(value) 1035 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1036 return cast_as_datetime(value) 1037 return None 1038 1039 1040def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1041 if isinstance(cast, exp.Cast): 1042 to = cast.to 1043 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1044 to = exp.DataType.build(exp.DataType.Type.DATE) 1045 else: 1046 return None 1047 1048 if isinstance(cast.this, exp.Literal): 1049 value: t.Any = cast.this.name 1050 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1051 value = extract_date(cast.this) 1052 else: 1053 return None 1054 return cast_value(value, to) 1055 1056 1057def _is_date_literal(expression: exp.Expression) -> bool: 1058 return extract_date(expression) is not None 1059 1060 1061def extract_interval(expression): 1062 try: 1063 n = int(expression.name) 1064 unit = expression.text("unit").lower() 1065 return interval(unit, n) 1066 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1067 return None 1068 1069 1070def date_literal(date): 1071 return exp.cast( 1072 exp.Literal.string(date), 1073 ( 1074 exp.DataType.Type.DATETIME 1075 if isinstance(date, datetime.datetime) 1076 else exp.DataType.Type.DATE 1077 ), 1078 ) 1079 1080 1081def interval(unit: str, n: int = 1): 1082 from dateutil.relativedelta import relativedelta 1083 1084 if unit == "year": 1085 return relativedelta(years=1 * n) 1086 if unit == "quarter": 1087 return relativedelta(months=3 * n) 1088 if unit == "month": 1089 return relativedelta(months=1 * n) 1090 if unit == "week": 1091 return relativedelta(weeks=1 * n) 1092 if unit == "day": 1093 return relativedelta(days=1 * n) 1094 if unit == "hour": 1095 return relativedelta(hours=1 * n) 1096 if unit == "minute": 1097 return relativedelta(minutes=1 * n) 1098 if unit == "second": 1099 return relativedelta(seconds=1 * n) 1100 1101 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1102 1103 1104def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1105 if unit == "year": 1106 return d.replace(month=1, day=1) 1107 if unit == "quarter": 1108 if d.month <= 3: 1109 return d.replace(month=1, day=1) 1110 elif d.month <= 6: 1111 return d.replace(month=4, day=1) 1112 elif d.month <= 9: 1113 return d.replace(month=7, day=1) 1114 else: 1115 return d.replace(month=10, day=1) 1116 if unit == "month": 1117 return d.replace(month=d.month, day=1) 1118 if unit == "week": 1119 # Assuming week starts on Monday (0) and ends on Sunday (6) 1120 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1121 if unit == "day": 1122 return d 1123 1124 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1125 1126 1127def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1128 floor = date_floor(d, unit, dialect) 1129 1130 if floor == d: 1131 return d 1132 1133 return floor + interval(unit) 1134 1135 1136def boolean_literal(condition): 1137 return exp.true() if condition else exp.false() 1138 1139 1140def _flat_simplify(expression, simplifier, root=True): 1141 if root or not expression.same_parent: 1142 operands = [] 1143 queue = deque(expression.flatten(unnest=False)) 1144 size = len(queue) 1145 1146 while queue: 1147 a = queue.popleft() 1148 1149 for b in queue: 1150 result = simplifier(expression, a, b) 1151 1152 if result and result is not expression: 1153 queue.remove(b) 1154 queue.appendleft(result) 1155 break 1156 else: 1157 operands.append(a) 1158 1159 if len(operands) < size: 1160 return functools.reduce( 1161 lambda a, b: expression.__class__(this=a, expression=b), operands 1162 ) 1163 return expression 1164 1165 1166def gen(expression: t.Any) -> str: 1167 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1168 1169 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1170 generator is expensive so we have a bare minimum sql generator here. 1171 """ 1172 if expression is None: 1173 return "_" 1174 if is_iterable(expression): 1175 return ",".join(gen(e) for e in expression) 1176 if not isinstance(expression, exp.Expression): 1177 return str(expression) 1178 1179 etype = type(expression) 1180 if etype in GEN_MAP: 1181 return GEN_MAP[etype](expression) 1182 return f"{expression.key} {gen(expression.args.values())}" 1183 1184 1185GEN_MAP = { 1186 exp.Add: lambda e: _binary(e, "+"), 1187 exp.And: lambda e: _binary(e, "AND"), 1188 exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}", 1189 exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}", 1190 exp.Boolean: lambda e: "TRUE" if e.this else "FALSE", 1191 exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]", 1192 exp.Column: lambda e: ".".join(gen(p) for p in e.parts), 1193 exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}", 1194 exp.Div: lambda e: _binary(e, "/"), 1195 exp.Dot: lambda e: _binary(e, "."), 1196 exp.EQ: lambda e: _binary(e, "="), 1197 exp.GT: lambda e: _binary(e, ">"), 1198 exp.GTE: lambda e: _binary(e, ">="), 1199 exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name, 1200 exp.ILike: lambda e: _binary(e, "ILIKE"), 1201 exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})", 1202 exp.Is: lambda e: _binary(e, "IS"), 1203 exp.Like: lambda e: _binary(e, "LIKE"), 1204 exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name, 1205 exp.LT: lambda e: _binary(e, "<"), 1206 exp.LTE: lambda e: _binary(e, "<="), 1207 exp.Mod: lambda e: _binary(e, "%"), 1208 exp.Mul: lambda e: _binary(e, "*"), 1209 exp.Neg: lambda e: _unary(e, "-"), 1210 exp.NEQ: lambda e: _binary(e, "<>"), 1211 exp.Not: lambda e: _unary(e, "NOT"), 1212 exp.Null: lambda e: "NULL", 1213 exp.Or: lambda e: _binary(e, "OR"), 1214 exp.Paren: lambda e: f"({gen(e.this)})", 1215 exp.Sub: lambda e: _binary(e, "-"), 1216 exp.Subquery: lambda e: f"({gen(e.args.values())})", 1217 exp.Table: lambda e: gen(e.args.values()), 1218 exp.Var: lambda e: e.name, 1219} 1220 1221 1222def _binary(e: exp.Binary, op: str) -> str: 1223 return f"{gen(e.left)} {op} {gen(e.right)}" 1224 1225 1226def _unary(e: exp.Unary, op: str) -> str: 1227 return f"{op} {gen(e.this)}"
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
31def simplify( 32 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 33): 34 """ 35 Rewrite sqlglot AST to simplify expressions. 36 37 Example: 38 >>> import sqlglot 39 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 40 >>> simplify(expression).sql() 41 'TRUE' 42 43 Args: 44 expression (sqlglot.Expression): expression to simplify 45 constant_propagation: whether or not the constant propagation rule should be used 46 47 Returns: 48 sqlglot.Expression: simplified expression 49 """ 50 51 dialect = Dialect.get_or_raise(dialect) 52 53 def _simplify(expression, root=True): 54 if expression.meta.get(FINAL): 55 return expression 56 57 # group by expressions cannot be simplified, for example 58 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 59 # the projection must exactly match the group by key 60 group = expression.args.get("group") 61 62 if group and hasattr(expression, "selects"): 63 groups = set(group.expressions) 64 group.meta[FINAL] = True 65 66 for e in expression.selects: 67 for node, *_ in e.walk(): 68 if node in groups: 69 e.meta[FINAL] = True 70 break 71 72 having = expression.args.get("having") 73 if having: 74 for node, *_ in having.walk(): 75 if node in groups: 76 having.meta[FINAL] = True 77 break 78 79 # Pre-order transformations 80 node = expression 81 node = rewrite_between(node) 82 node = uniq_sort(node, root) 83 node = absorb_and_eliminate(node, root) 84 node = simplify_concat(node) 85 node = simplify_conditionals(node) 86 87 if constant_propagation: 88 node = propagate_constants(node, root) 89 90 exp.replace_children(node, lambda e: _simplify(e, False)) 91 92 # Post-order transformations 93 node = simplify_not(node) 94 node = flatten(node) 95 node = simplify_connectors(node, root) 96 node = remove_complements(node, root) 97 node = simplify_coalesce(node) 98 node.parent = expression.parent 99 node = simplify_literals(node, root) 100 node = simplify_equality(node) 101 node = simplify_parens(node) 102 node = simplify_datetrunc(node, dialect) 103 node = sort_comparison(node) 104 node = simplify_startswith(node) 105 106 if root: 107 expression.replace(node) 108 return node 109 110 expression = while_changing(expression, _simplify) 111 remove_where_true(expression) 112 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether or not the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
115def catch(*exceptions): 116 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 117 118 def decorator(func): 119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression 124 125 return wrapped 126 127 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
130def rewrite_between(expression: exp.Expression) -> exp.Expression: 131 """Rewrite x between y and z to x >= y AND x <= z. 132 133 This is done because comparison simplification is only done on lt/lte/gt/gte. 134 """ 135 if isinstance(expression, exp.Between): 136 negate = isinstance(expression.parent, exp.Not) 137 138 expression = exp.and_( 139 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 140 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 141 copy=False, 142 ) 143 144 if negate: 145 expression = exp.paren(expression, copy=False) 146 147 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
160def simplify_not(expression): 161 """ 162 Demorgan's Law 163 NOT (x OR y) -> NOT x AND NOT y 164 NOT (x AND y) -> NOT x OR NOT y 165 """ 166 if isinstance(expression, exp.Not): 167 this = expression.this 168 if is_null(this): 169 return exp.null() 170 if this.__class__ in COMPLEMENT_COMPARISONS: 171 return COMPLEMENT_COMPARISONS[this.__class__]( 172 this=this.this, expression=this.expression 173 ) 174 if isinstance(this, exp.Paren): 175 condition = this.unnest() 176 if isinstance(condition, exp.And): 177 return exp.paren( 178 exp.or_( 179 exp.not_(condition.left, copy=False), 180 exp.not_(condition.right, copy=False), 181 copy=False, 182 ) 183 ) 184 if isinstance(condition, exp.Or): 185 return exp.paren( 186 exp.and_( 187 exp.not_(condition.left, copy=False), 188 exp.not_(condition.right, copy=False), 189 copy=False, 190 ) 191 ) 192 if is_null(condition): 193 return exp.null() 194 if always_true(this): 195 return exp.false() 196 if is_false(this): 197 return exp.true() 198 if isinstance(this, exp.Not): 199 # double negation 200 # NOT NOT x -> x 201 return this.this 202 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
205def flatten(expression): 206 """ 207 A AND (B AND C) -> A AND B AND C 208 A OR (B OR C) -> A OR B OR C 209 """ 210 if isinstance(expression, exp.Connector): 211 for node in expression.args.values(): 212 child = node.unnest() 213 if isinstance(child, expression.__class__): 214 node.replace(child) 215 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
218def simplify_connectors(expression, root=True): 219 def _simplify_connectors(expression, left, right): 220 if left == right: 221 return left 222 if isinstance(expression, exp.And): 223 if is_false(left) or is_false(right): 224 return exp.false() 225 if is_null(left) or is_null(right): 226 return exp.null() 227 if always_true(left) and always_true(right): 228 return exp.true() 229 if always_true(left): 230 return right 231 if always_true(right): 232 return left 233 return _simplify_comparison(expression, left, right) 234 elif isinstance(expression, exp.Or): 235 if always_true(left) or always_true(right): 236 return exp.true() 237 if is_false(left) and is_false(right): 238 return exp.false() 239 if ( 240 (is_null(left) and is_null(right)) 241 or (is_null(left) and is_false(right)) 242 or (is_false(left) and is_null(right)) 243 ): 244 return exp.null() 245 if is_false(left): 246 return right 247 if is_false(right): 248 return left 249 return _simplify_comparison(expression, left, right, or_=True) 250 251 if isinstance(expression, exp.Connector): 252 return _flat_simplify(expression, _simplify_connectors, root) 253 return expression
337def remove_complements(expression, root=True): 338 """ 339 Removing complements. 340 341 A AND NOT A -> FALSE 342 A OR NOT A -> TRUE 343 """ 344 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 345 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 346 347 for a, b in itertools.permutations(expression.flatten(), 2): 348 if is_complement(a, b): 349 return complement 350 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
353def uniq_sort(expression, root=True): 354 """ 355 Uniq and sort a connector. 356 357 C AND A AND B AND B -> A AND B AND C 358 """ 359 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 360 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 361 flattened = tuple(expression.flatten()) 362 deduped = {gen(e): e for e in flattened} 363 arr = tuple(deduped.items()) 364 365 # check if the operands are already sorted, if not sort them 366 # A AND C AND B -> A AND B AND C 367 for i, (sql, e) in enumerate(arr[1:]): 368 if sql < arr[i][0]: 369 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 370 break 371 else: 372 # we didn't have to sort but maybe we need to dedup 373 if len(deduped) < len(flattened): 374 expression = result_func(*deduped.values(), copy=False) 375 376 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
379def absorb_and_eliminate(expression, root=True): 380 """ 381 absorption: 382 A AND (A OR B) -> A 383 A OR (A AND B) -> A 384 A AND (NOT A OR B) -> A AND B 385 A OR (NOT A AND B) -> A OR B 386 elimination: 387 (A AND B) OR (A AND NOT B) -> A 388 (A OR B) AND (A OR NOT B) -> A 389 """ 390 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 391 kind = exp.Or if isinstance(expression, exp.And) else exp.And 392 393 for a, b in itertools.permutations(expression.flatten(), 2): 394 if isinstance(a, kind): 395 aa, ab = a.unnest_operands() 396 397 # absorb 398 if is_complement(b, aa): 399 aa.replace(exp.true() if kind == exp.And else exp.false()) 400 elif is_complement(b, ab): 401 ab.replace(exp.true() if kind == exp.And else exp.false()) 402 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 403 a.replace(exp.false() if kind == exp.And else exp.true()) 404 elif isinstance(b, kind): 405 # eliminate 406 rhs = b.unnest_operands() 407 ba, bb = rhs 408 409 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 410 a.replace(aa) 411 b.replace(aa) 412 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 413 a.replace(ab) 414 b.replace(ab) 415 416 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
419def propagate_constants(expression, root=True): 420 """ 421 Propagate constants for conjunctions in DNF: 422 423 SELECT * FROM t WHERE a = b AND b = 5 becomes 424 SELECT * FROM t WHERE a = 5 AND b = 5 425 426 Reference: https://www.sqlite.org/optoverview.html 427 """ 428 429 if ( 430 isinstance(expression, exp.And) 431 and (root or not expression.same_parent) 432 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 433 ): 434 constant_mapping = {} 435 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 436 if isinstance(expr, exp.EQ): 437 l, r = expr.left, expr.right 438 439 # TODO: create a helper that can be used to detect nested literal expressions such 440 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 441 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 442 constant_mapping[l] = (id(l), r) 443 444 if constant_mapping: 445 for column in find_all_in_scope(expression, exp.Column): 446 parent = column.parent 447 column_id, constant = constant_mapping.get(column) or (None, None) 448 if ( 449 column_id is not None 450 and id(column) != column_id 451 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 452 ): 453 column.replace(constant.copy()) 454 455 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
530def simplify_literals(expression, root=True): 531 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 532 return _flat_simplify(expression, _simplify_binary, root) 533 534 if isinstance(expression, exp.Neg): 535 this = expression.this 536 if this.is_number: 537 value = this.name 538 if value[0] == "-": 539 return exp.Literal.number(value[1:]) 540 return exp.Literal.number(f"-{value}") 541 542 if type(expression) in INVERSE_DATE_OPS: 543 return _simplify_binary(expression, expression.this, expression.interval()) or expression 544 545 return expression
616def simplify_parens(expression): 617 if not isinstance(expression, exp.Paren): 618 return expression 619 620 this = expression.this 621 parent = expression.parent 622 623 if not isinstance(this, exp.Select) and ( 624 not isinstance(parent, (exp.Condition, exp.Binary)) 625 or isinstance(parent, exp.Paren) 626 or not isinstance(this, exp.Binary) 627 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 628 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 629 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 630 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 631 ): 632 return this 633 return expression
656def simplify_coalesce(expression): 657 # COALESCE(x) -> x 658 if ( 659 isinstance(expression, exp.Coalesce) 660 and (not expression.expressions or _is_nonnull_constant(expression.this)) 661 # COALESCE is also used as a Spark partitioning hint 662 and not isinstance(expression.parent, exp.Hint) 663 ): 664 return expression.this 665 666 if not isinstance(expression, COMPARISONS): 667 return expression 668 669 if isinstance(expression.left, exp.Coalesce): 670 coalesce = expression.left 671 other = expression.right 672 elif isinstance(expression.right, exp.Coalesce): 673 coalesce = expression.right 674 other = expression.left 675 else: 676 return expression 677 678 # This transformation is valid for non-constants, 679 # but it really only does anything if they are both constants. 680 if not _is_constant(other): 681 return expression 682 683 # Find the first constant arg 684 for arg_index, arg in enumerate(coalesce.expressions): 685 if _is_constant(arg): 686 break 687 else: 688 return expression 689 690 coalesce.set("expressions", coalesce.expressions[:arg_index]) 691 692 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 693 # since we already remove COALESCE at the top of this function. 694 coalesce = coalesce if coalesce.expressions else coalesce.this 695 696 # This expression is more complex than when we started, but it will get simplified further 697 return exp.paren( 698 exp.or_( 699 exp.and_( 700 coalesce.is_(exp.null()).not_(copy=False), 701 expression.copy(), 702 copy=False, 703 ), 704 exp.and_( 705 coalesce.is_(exp.null()), 706 type(expression)(this=arg.copy(), expression=other.copy()), 707 copy=False, 708 ), 709 copy=False, 710 ) 711 )
717def simplify_concat(expression): 718 """Reduces all groups that contain string literals by concatenating them.""" 719 if not isinstance(expression, CONCATS) or ( 720 # We can't reduce a CONCAT_WS call if we don't statically know the separator 721 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 722 ): 723 return expression 724 725 if isinstance(expression, exp.ConcatWs): 726 sep_expr, *expressions = expression.expressions 727 sep = sep_expr.name 728 concat_type = exp.ConcatWs 729 args = {} 730 else: 731 expressions = expression.expressions 732 sep = "" 733 concat_type = exp.Concat 734 args = { 735 "safe": expression.args.get("safe"), 736 "coalesce": expression.args.get("coalesce"), 737 } 738 739 new_args = [] 740 for is_string_group, group in itertools.groupby( 741 expressions or expression.flatten(), lambda e: e.is_string 742 ): 743 if is_string_group: 744 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 745 else: 746 new_args.extend(group) 747 748 if len(new_args) == 1 and new_args[0].is_string: 749 return new_args[0] 750 751 if concat_type is exp.ConcatWs: 752 new_args = [sep_expr] + new_args 753 754 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
757def simplify_conditionals(expression): 758 """Simplifies expressions like IF, CASE if their condition is statically known.""" 759 if isinstance(expression, exp.Case): 760 this = expression.this 761 for case in expression.args["ifs"]: 762 cond = case.this 763 if this: 764 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 765 cond = cond.replace(this.pop().eq(cond)) 766 767 if always_true(cond): 768 return case.args["true"] 769 770 if always_false(cond): 771 case.pop() 772 if not expression.args["ifs"]: 773 return expression.args.get("default") or exp.null() 774 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 775 if always_true(expression.this): 776 return expression.args["true"] 777 if always_false(expression.this): 778 return expression.args.get("false") or exp.null() 779 780 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
783def simplify_startswith(expression: exp.Expression) -> exp.Expression: 784 """ 785 Reduces a prefix check to either TRUE or FALSE if both the string and the 786 prefix are statically known. 787 788 Example: 789 >>> from sqlglot import parse_one 790 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 791 'TRUE' 792 """ 793 if ( 794 isinstance(expression, exp.StartsWith) 795 and expression.this.is_string 796 and expression.expression.is_string 797 ): 798 return exp.convert(expression.name.startswith(expression.expression.name)) 799 800 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
927def sort_comparison(expression: exp.Expression) -> exp.Expression: 928 if expression.__class__ in COMPLEMENT_COMPARISONS: 929 l, r = expression.this, expression.expression 930 l_column = isinstance(l, exp.Column) 931 r_column = isinstance(r, exp.Column) 932 l_const = _is_constant(l) 933 r_const = _is_constant(r) 934 935 if (l_column and not r_column) or (r_const and not l_const): 936 return expression 937 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 938 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 939 this=r, expression=l 940 ) 941 return expression
955def remove_where_true(expression): 956 for where in expression.find_all(exp.Where): 957 if always_true(where.this): 958 where.parent.set("where", None) 959 for join in expression.find_all(exp.Join): 960 if ( 961 always_true(join.args.get("on")) 962 and not join.args.get("using") 963 and not join.args.get("method") 964 and (join.side, join.kind) in JOINS 965 ): 966 join.set("on", None) 967 join.set("side", None) 968 join.set("kind", "CROSS")
993def eval_boolean(expression, a, b): 994 if isinstance(expression, (exp.EQ, exp.Is)): 995 return boolean_literal(a == b) 996 if isinstance(expression, exp.NEQ): 997 return boolean_literal(a != b) 998 if isinstance(expression, exp.GT): 999 return boolean_literal(a > b) 1000 if isinstance(expression, exp.GTE): 1001 return boolean_literal(a >= b) 1002 if isinstance(expression, exp.LT): 1003 return boolean_literal(a < b) 1004 if isinstance(expression, exp.LTE): 1005 return boolean_literal(a <= b) 1006 return None
1009def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1010 if isinstance(value, datetime.datetime): 1011 return value.date() 1012 if isinstance(value, datetime.date): 1013 return value 1014 try: 1015 return datetime.datetime.fromisoformat(value).date() 1016 except ValueError: 1017 return None
1020def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1021 if isinstance(value, datetime.datetime): 1022 return value 1023 if isinstance(value, datetime.date): 1024 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1025 try: 1026 return datetime.datetime.fromisoformat(value) 1027 except ValueError: 1028 return None
1031def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1032 if not value: 1033 return None 1034 if to.is_type(exp.DataType.Type.DATE): 1035 return cast_as_date(value) 1036 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1037 return cast_as_datetime(value) 1038 return None
1041def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1042 if isinstance(cast, exp.Cast): 1043 to = cast.to 1044 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1045 to = exp.DataType.build(exp.DataType.Type.DATE) 1046 else: 1047 return None 1048 1049 if isinstance(cast.this, exp.Literal): 1050 value: t.Any = cast.this.name 1051 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1052 value = extract_date(cast.this) 1053 else: 1054 return None 1055 return cast_value(value, to)
1082def interval(unit: str, n: int = 1): 1083 from dateutil.relativedelta import relativedelta 1084 1085 if unit == "year": 1086 return relativedelta(years=1 * n) 1087 if unit == "quarter": 1088 return relativedelta(months=3 * n) 1089 if unit == "month": 1090 return relativedelta(months=1 * n) 1091 if unit == "week": 1092 return relativedelta(weeks=1 * n) 1093 if unit == "day": 1094 return relativedelta(days=1 * n) 1095 if unit == "hour": 1096 return relativedelta(hours=1 * n) 1097 if unit == "minute": 1098 return relativedelta(minutes=1 * n) 1099 if unit == "second": 1100 return relativedelta(seconds=1 * n) 1101 1102 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1105def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1106 if unit == "year": 1107 return d.replace(month=1, day=1) 1108 if unit == "quarter": 1109 if d.month <= 3: 1110 return d.replace(month=1, day=1) 1111 elif d.month <= 6: 1112 return d.replace(month=4, day=1) 1113 elif d.month <= 9: 1114 return d.replace(month=7, day=1) 1115 else: 1116 return d.replace(month=10, day=1) 1117 if unit == "month": 1118 return d.replace(month=d.month, day=1) 1119 if unit == "week": 1120 # Assuming week starts on Monday (0) and ends on Sunday (6) 1121 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1122 if unit == "day": 1123 return d 1124 1125 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1167def gen(expression: t.Any) -> str: 1168 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1169 1170 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1171 generator is expensive so we have a bare minimum sql generator here. 1172 """ 1173 if expression is None: 1174 return "_" 1175 if is_iterable(expression): 1176 return ",".join(gen(e) for e in expression) 1177 if not isinstance(expression, exp.Expression): 1178 return str(expression) 1179 1180 etype = type(expression) 1181 if etype in GEN_MAP: 1182 return GEN_MAP[etype](expression) 1183 return f"{expression.key} {gen(expression.args.values())}"
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.