sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4import typing as t 5from collections import deque 6from decimal import Decimal 7 8from sqlglot import exp 9from sqlglot.generator import cached_generator 10from sqlglot.helper import first, merge_ranges, while_changing 11 12# Final means that an expression should not be simplified 13FINAL = "final" 14 15 16class UnsupportedUnit(Exception): 17 pass 18 19 20def simplify(expression): 21 """ 22 Rewrite sqlglot AST to simplify expressions. 23 24 Example: 25 >>> import sqlglot 26 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 27 >>> simplify(expression).sql() 28 'TRUE' 29 30 Args: 31 expression (sqlglot.Expression): expression to simplify 32 Returns: 33 sqlglot.Expression: simplified expression 34 """ 35 36 generate = cached_generator() 37 38 # group by expressions cannot be simplified, for example 39 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 40 # the projection must exactly match the group by key 41 for group in expression.find_all(exp.Group): 42 select = group.parent 43 groups = set(group.expressions) 44 group.meta[FINAL] = True 45 46 for e in select.selects: 47 for node, *_ in e.walk(): 48 if node in groups: 49 e.meta[FINAL] = True 50 break 51 52 having = select.args.get("having") 53 if having: 54 for node, *_ in having.walk(): 55 if node in groups: 56 having.meta[FINAL] = True 57 break 58 59 def _simplify(expression, root=True): 60 if expression.meta.get(FINAL): 61 return expression 62 63 # Pre-order transformations 64 node = expression 65 node = rewrite_between(node) 66 node = uniq_sort(node, generate, root) 67 node = absorb_and_eliminate(node, root) 68 node = simplify_concat(node) 69 70 exp.replace_children(node, lambda e: _simplify(e, False)) 71 72 # Post-order transformations 73 node = simplify_not(node) 74 node = flatten(node) 75 node = simplify_connectors(node, root) 76 node = remove_compliments(node, root) 77 node = simplify_coalesce(node) 78 node.parent = expression.parent 79 node = simplify_literals(node, root) 80 node = simplify_equality(node) 81 node = simplify_parens(node) 82 node = simplify_datetrunc_predicate(node) 83 84 if root: 85 expression.replace(node) 86 87 return node 88 89 expression = while_changing(expression, _simplify) 90 remove_where_true(expression) 91 return expression 92 93 94def catch(*exceptions): 95 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 96 97 def decorator(func): 98 def wrapped(expression, *args, **kwargs): 99 try: 100 return func(expression, *args, **kwargs) 101 except exceptions: 102 return expression 103 104 return wrapped 105 106 return decorator 107 108 109def rewrite_between(expression: exp.Expression) -> exp.Expression: 110 """Rewrite x between y and z to x >= y AND x <= z. 111 112 This is done because comparison simplification is only done on lt/lte/gt/gte. 113 """ 114 if isinstance(expression, exp.Between): 115 return exp.and_( 116 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 117 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 118 copy=False, 119 ) 120 return expression 121 122 123def simplify_not(expression): 124 """ 125 Demorgan's Law 126 NOT (x OR y) -> NOT x AND NOT y 127 NOT (x AND y) -> NOT x OR NOT y 128 """ 129 if isinstance(expression, exp.Not): 130 if is_null(expression.this): 131 return exp.null() 132 if isinstance(expression.this, exp.Paren): 133 condition = expression.this.unnest() 134 if isinstance(condition, exp.And): 135 return exp.or_( 136 exp.not_(condition.left, copy=False), 137 exp.not_(condition.right, copy=False), 138 copy=False, 139 ) 140 if isinstance(condition, exp.Or): 141 return exp.and_( 142 exp.not_(condition.left, copy=False), 143 exp.not_(condition.right, copy=False), 144 copy=False, 145 ) 146 if is_null(condition): 147 return exp.null() 148 if always_true(expression.this): 149 return exp.false() 150 if is_false(expression.this): 151 return exp.true() 152 if isinstance(expression.this, exp.Not): 153 # double negation 154 # NOT NOT x -> x 155 return expression.this.this 156 return expression 157 158 159def flatten(expression): 160 """ 161 A AND (B AND C) -> A AND B AND C 162 A OR (B OR C) -> A OR B OR C 163 """ 164 if isinstance(expression, exp.Connector): 165 for node in expression.args.values(): 166 child = node.unnest() 167 if isinstance(child, expression.__class__): 168 node.replace(child) 169 return expression 170 171 172def simplify_connectors(expression, root=True): 173 def _simplify_connectors(expression, left, right): 174 if left == right: 175 return left 176 if isinstance(expression, exp.And): 177 if is_false(left) or is_false(right): 178 return exp.false() 179 if is_null(left) or is_null(right): 180 return exp.null() 181 if always_true(left) and always_true(right): 182 return exp.true() 183 if always_true(left): 184 return right 185 if always_true(right): 186 return left 187 return _simplify_comparison(expression, left, right) 188 elif isinstance(expression, exp.Or): 189 if always_true(left) or always_true(right): 190 return exp.true() 191 if is_false(left) and is_false(right): 192 return exp.false() 193 if ( 194 (is_null(left) and is_null(right)) 195 or (is_null(left) and is_false(right)) 196 or (is_false(left) and is_null(right)) 197 ): 198 return exp.null() 199 if is_false(left): 200 return right 201 if is_false(right): 202 return left 203 return _simplify_comparison(expression, left, right, or_=True) 204 205 if isinstance(expression, exp.Connector): 206 return _flat_simplify(expression, _simplify_connectors, root) 207 return expression 208 209 210LT_LTE = (exp.LT, exp.LTE) 211GT_GTE = (exp.GT, exp.GTE) 212 213COMPARISONS = ( 214 *LT_LTE, 215 *GT_GTE, 216 exp.EQ, 217 exp.NEQ, 218 exp.Is, 219) 220 221INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 222 exp.LT: exp.GT, 223 exp.GT: exp.LT, 224 exp.LTE: exp.GTE, 225 exp.GTE: exp.LTE, 226} 227 228 229def _simplify_comparison(expression, left, right, or_=False): 230 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 231 ll, lr = left.args.values() 232 rl, rr = right.args.values() 233 234 largs = {ll, lr} 235 rargs = {rl, rr} 236 237 matching = largs & rargs 238 columns = {m for m in matching if isinstance(m, exp.Column)} 239 240 if matching and columns: 241 try: 242 l = first(largs - columns) 243 r = first(rargs - columns) 244 except StopIteration: 245 return expression 246 247 # make sure the comparison is always of the form x > 1 instead of 1 < x 248 if left.__class__ in INVERSE_COMPARISONS and l == ll: 249 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 250 if right.__class__ in INVERSE_COMPARISONS and r == rl: 251 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 252 253 if l.is_number and r.is_number: 254 l = float(l.name) 255 r = float(r.name) 256 elif l.is_string and r.is_string: 257 l = l.name 258 r = r.name 259 else: 260 return None 261 262 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 263 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 264 return left if (av > bv if or_ else av <= bv) else right 265 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 266 return left if (av < bv if or_ else av >= bv) else right 267 268 # we can't ever shortcut to true because the column could be null 269 if not or_: 270 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 271 if av <= bv: 272 return exp.false() 273 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 274 if av >= bv: 275 return exp.false() 276 elif isinstance(a, exp.EQ): 277 if isinstance(b, exp.LT): 278 return exp.false() if av >= bv else a 279 if isinstance(b, exp.LTE): 280 return exp.false() if av > bv else a 281 if isinstance(b, exp.GT): 282 return exp.false() if av <= bv else a 283 if isinstance(b, exp.GTE): 284 return exp.false() if av < bv else a 285 if isinstance(b, exp.NEQ): 286 return exp.false() if av == bv else a 287 return None 288 289 290def remove_compliments(expression, root=True): 291 """ 292 Removing compliments. 293 294 A AND NOT A -> FALSE 295 A OR NOT A -> TRUE 296 """ 297 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 298 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 299 300 for a, b in itertools.permutations(expression.flatten(), 2): 301 if is_complement(a, b): 302 return compliment 303 return expression 304 305 306def uniq_sort(expression, generate, root=True): 307 """ 308 Uniq and sort a connector. 309 310 C AND A AND B AND B -> A AND B AND C 311 """ 312 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 313 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 314 flattened = tuple(expression.flatten()) 315 deduped = {generate(e): e for e in flattened} 316 arr = tuple(deduped.items()) 317 318 # check if the operands are already sorted, if not sort them 319 # A AND C AND B -> A AND B AND C 320 for i, (sql, e) in enumerate(arr[1:]): 321 if sql < arr[i][0]: 322 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 323 break 324 else: 325 # we didn't have to sort but maybe we need to dedup 326 if len(deduped) < len(flattened): 327 expression = result_func(*deduped.values(), copy=False) 328 329 return expression 330 331 332def absorb_and_eliminate(expression, root=True): 333 """ 334 absorption: 335 A AND (A OR B) -> A 336 A OR (A AND B) -> A 337 A AND (NOT A OR B) -> A AND B 338 A OR (NOT A AND B) -> A OR B 339 elimination: 340 (A AND B) OR (A AND NOT B) -> A 341 (A OR B) AND (A OR NOT B) -> A 342 """ 343 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 344 kind = exp.Or if isinstance(expression, exp.And) else exp.And 345 346 for a, b in itertools.permutations(expression.flatten(), 2): 347 if isinstance(a, kind): 348 aa, ab = a.unnest_operands() 349 350 # absorb 351 if is_complement(b, aa): 352 aa.replace(exp.true() if kind == exp.And else exp.false()) 353 elif is_complement(b, ab): 354 ab.replace(exp.true() if kind == exp.And else exp.false()) 355 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 356 a.replace(exp.false() if kind == exp.And else exp.true()) 357 elif isinstance(b, kind): 358 # eliminate 359 rhs = b.unnest_operands() 360 ba, bb = rhs 361 362 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 363 a.replace(aa) 364 b.replace(aa) 365 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 366 a.replace(ab) 367 b.replace(ab) 368 369 return expression 370 371 372INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 373 exp.DateAdd: exp.Sub, 374 exp.DateSub: exp.Add, 375 exp.DatetimeAdd: exp.Sub, 376 exp.DatetimeSub: exp.Add, 377} 378 379INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 380 **INVERSE_DATE_OPS, 381 exp.Add: exp.Sub, 382 exp.Sub: exp.Add, 383} 384 385 386def _is_number(expression: exp.Expression) -> bool: 387 return expression.is_number 388 389 390def _is_interval(expression: exp.Expression) -> bool: 391 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 392 393 394@catch(ModuleNotFoundError, UnsupportedUnit) 395def simplify_equality(expression: exp.Expression) -> exp.Expression: 396 """ 397 Use the subtraction and addition properties of equality to simplify expressions: 398 399 x + 1 = 3 becomes x = 2 400 401 There are two binary operations in the above expression: + and = 402 Here's how we reference all the operands in the code below: 403 404 l r 405 x + 1 = 3 406 a b 407 """ 408 if isinstance(expression, COMPARISONS): 409 l, r = expression.left, expression.right 410 411 if l.__class__ in INVERSE_OPS: 412 pass 413 elif r.__class__ in INVERSE_OPS: 414 l, r = r, l 415 else: 416 return expression 417 418 if r.is_number: 419 a_predicate = _is_number 420 b_predicate = _is_number 421 elif _is_date_literal(r): 422 a_predicate = _is_date_literal 423 b_predicate = _is_interval 424 else: 425 return expression 426 427 if l.__class__ in INVERSE_DATE_OPS: 428 a = l.this 429 b = l.interval() 430 else: 431 a, b = l.left, l.right 432 433 if not a_predicate(a) and b_predicate(b): 434 pass 435 elif not a_predicate(b) and b_predicate(a): 436 a, b = b, a 437 else: 438 return expression 439 440 return expression.__class__( 441 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 442 ) 443 return expression 444 445 446def simplify_literals(expression, root=True): 447 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 448 return _flat_simplify(expression, _simplify_binary, root) 449 450 if isinstance(expression, exp.Neg): 451 this = expression.this 452 if this.is_number: 453 value = this.name 454 if value[0] == "-": 455 return exp.Literal.number(value[1:]) 456 return exp.Literal.number(f"-{value}") 457 458 return expression 459 460 461def _simplify_binary(expression, a, b): 462 if isinstance(expression, exp.Is): 463 if isinstance(b, exp.Not): 464 c = b.this 465 not_ = True 466 else: 467 c = b 468 not_ = False 469 470 if is_null(c): 471 if isinstance(a, exp.Literal): 472 return exp.true() if not_ else exp.false() 473 if is_null(a): 474 return exp.false() if not_ else exp.true() 475 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 476 return None 477 elif is_null(a) or is_null(b): 478 return exp.null() 479 480 if a.is_number and b.is_number: 481 a = int(a.name) if a.is_int else Decimal(a.name) 482 b = int(b.name) if b.is_int else Decimal(b.name) 483 484 if isinstance(expression, exp.Add): 485 return exp.Literal.number(a + b) 486 if isinstance(expression, exp.Sub): 487 return exp.Literal.number(a - b) 488 if isinstance(expression, exp.Mul): 489 return exp.Literal.number(a * b) 490 if isinstance(expression, exp.Div): 491 # engines have differing int div behavior so intdiv is not safe 492 if isinstance(a, int) and isinstance(b, int): 493 return None 494 return exp.Literal.number(a / b) 495 496 boolean = eval_boolean(expression, a, b) 497 498 if boolean: 499 return boolean 500 elif a.is_string and b.is_string: 501 boolean = eval_boolean(expression, a.this, b.this) 502 503 if boolean: 504 return boolean 505 elif _is_date_literal(a) and isinstance(b, exp.Interval): 506 a, b = extract_date(a), extract_interval(b) 507 if a and b: 508 if isinstance(expression, exp.Add): 509 return date_literal(a + b) 510 if isinstance(expression, exp.Sub): 511 return date_literal(a - b) 512 elif isinstance(a, exp.Interval) and _is_date_literal(b): 513 a, b = extract_interval(a), extract_date(b) 514 # you cannot subtract a date from an interval 515 if a and b and isinstance(expression, exp.Add): 516 return date_literal(a + b) 517 518 return None 519 520 521def simplify_parens(expression): 522 if not isinstance(expression, exp.Paren): 523 return expression 524 525 this = expression.this 526 parent = expression.parent 527 528 if not isinstance(this, exp.Select) and ( 529 not isinstance(parent, (exp.Condition, exp.Binary)) 530 or isinstance(parent, exp.Paren) 531 or not isinstance(this, exp.Binary) 532 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 533 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 534 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 535 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 536 ): 537 return this 538 return expression 539 540 541CONSTANTS = ( 542 exp.Literal, 543 exp.Boolean, 544 exp.Null, 545) 546 547 548def simplify_coalesce(expression): 549 # COALESCE(x) -> x 550 if ( 551 isinstance(expression, exp.Coalesce) 552 and not expression.expressions 553 # COALESCE is also used as a Spark partitioning hint 554 and not isinstance(expression.parent, exp.Hint) 555 ): 556 return expression.this 557 558 if not isinstance(expression, COMPARISONS): 559 return expression 560 561 if isinstance(expression.left, exp.Coalesce): 562 coalesce = expression.left 563 other = expression.right 564 elif isinstance(expression.right, exp.Coalesce): 565 coalesce = expression.right 566 other = expression.left 567 else: 568 return expression 569 570 # This transformation is valid for non-constants, 571 # but it really only does anything if they are both constants. 572 if not isinstance(other, CONSTANTS): 573 return expression 574 575 # Find the first constant arg 576 for arg_index, arg in enumerate(coalesce.expressions): 577 if isinstance(arg, CONSTANTS): 578 break 579 else: 580 return expression 581 582 coalesce.set("expressions", coalesce.expressions[:arg_index]) 583 584 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 585 # since we already remove COALESCE at the top of this function. 586 coalesce = coalesce if coalesce.expressions else coalesce.this 587 588 # This expression is more complex than when we started, but it will get simplified further 589 return exp.paren( 590 exp.or_( 591 exp.and_( 592 coalesce.is_(exp.null()).not_(copy=False), 593 expression.copy(), 594 copy=False, 595 ), 596 exp.and_( 597 coalesce.is_(exp.null()), 598 type(expression)(this=arg.copy(), expression=other.copy()), 599 copy=False, 600 ), 601 copy=False, 602 ) 603 ) 604 605 606CONCATS = (exp.Concat, exp.DPipe) 607SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) 608 609 610def simplify_concat(expression): 611 """Reduces all groups that contain string literals by concatenating them.""" 612 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 613 return expression 614 615 new_args = [] 616 for is_string_group, group in itertools.groupby( 617 expression.expressions or expression.flatten(), lambda e: e.is_string 618 ): 619 if is_string_group: 620 new_args.append(exp.Literal.string("".join(string.name for string in group))) 621 else: 622 new_args.extend(group) 623 624 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 625 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 626 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args) 627 628 629DateRange = t.Tuple[datetime.date, datetime.date] 630 631 632def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: 633 """ 634 Get the date range for a DATE_TRUNC equality comparison: 635 636 Example: 637 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 638 Returns: 639 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 640 """ 641 floor = date_floor(date, unit) 642 643 if date != floor: 644 # This will always be False, except for NULL values. 645 return None 646 647 return floor, floor + interval(unit) 648 649 650def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 651 """Get the logical expression for a date range""" 652 return exp.and_( 653 left >= date_literal(drange[0]), 654 left < date_literal(drange[1]), 655 copy=False, 656 ) 657 658 659def _datetrunc_eq( 660 left: exp.Expression, date: datetime.date, unit: str 661) -> t.Optional[exp.Expression]: 662 drange = _datetrunc_range(date, unit) 663 if not drange: 664 return None 665 666 return _datetrunc_eq_expression(left, drange) 667 668 669def _datetrunc_neq( 670 left: exp.Expression, date: datetime.date, unit: str 671) -> t.Optional[exp.Expression]: 672 drange = _datetrunc_range(date, unit) 673 if not drange: 674 return None 675 676 return exp.and_( 677 left < date_literal(drange[0]), 678 left >= date_literal(drange[1]), 679 copy=False, 680 ) 681 682 683DateTruncBinaryTransform = t.Callable[ 684 [exp.Expression, datetime.date, str], t.Optional[exp.Expression] 685] 686DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 687 exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), 688 exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), 689 exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), 690 exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), 691 exp.EQ: _datetrunc_eq, 692 exp.NEQ: _datetrunc_neq, 693} 694DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 695 696 697def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 698 return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right) 699 700 701@catch(ModuleNotFoundError, UnsupportedUnit) 702def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: 703 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 704 comparison = expression.__class__ 705 706 if comparison not in DATETRUNC_COMPARISONS: 707 return expression 708 709 if isinstance(expression, exp.Binary): 710 l, r = expression.left, expression.right 711 712 if _is_datetrunc_predicate(l, r): 713 pass 714 elif _is_datetrunc_predicate(r, l): 715 comparison = INVERSE_COMPARISONS.get(comparison, comparison) 716 l, r = r, l 717 else: 718 return expression 719 720 unit = l.unit.name.lower() 721 date = extract_date(r) 722 723 if not date: 724 return expression 725 726 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression 727 elif isinstance(expression, exp.In): 728 l = expression.this 729 rs = expression.expressions 730 731 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 732 unit = l.unit.name.lower() 733 734 ranges = [] 735 for r in rs: 736 date = extract_date(r) 737 if not date: 738 return expression 739 drange = _datetrunc_range(date, unit) 740 if drange: 741 ranges.append(drange) 742 743 if not ranges: 744 return expression 745 746 ranges = merge_ranges(ranges) 747 748 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 749 750 return expression 751 752 753# CROSS joins result in an empty table if the right table is empty. 754# So we can only simplify certain types of joins to CROSS. 755# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 756JOINS = { 757 ("", ""), 758 ("", "INNER"), 759 ("RIGHT", ""), 760 ("RIGHT", "OUTER"), 761} 762 763 764def remove_where_true(expression): 765 for where in expression.find_all(exp.Where): 766 if always_true(where.this): 767 where.parent.set("where", None) 768 for join in expression.find_all(exp.Join): 769 if ( 770 always_true(join.args.get("on")) 771 and not join.args.get("using") 772 and not join.args.get("method") 773 and (join.side, join.kind) in JOINS 774 ): 775 join.set("on", None) 776 join.set("side", None) 777 join.set("kind", "CROSS") 778 779 780def always_true(expression): 781 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 782 expression, exp.Literal 783 ) 784 785 786def is_complement(a, b): 787 return isinstance(b, exp.Not) and b.this == a 788 789 790def is_false(a: exp.Expression) -> bool: 791 return type(a) is exp.Boolean and not a.this 792 793 794def is_null(a: exp.Expression) -> bool: 795 return type(a) is exp.Null 796 797 798def eval_boolean(expression, a, b): 799 if isinstance(expression, (exp.EQ, exp.Is)): 800 return boolean_literal(a == b) 801 if isinstance(expression, exp.NEQ): 802 return boolean_literal(a != b) 803 if isinstance(expression, exp.GT): 804 return boolean_literal(a > b) 805 if isinstance(expression, exp.GTE): 806 return boolean_literal(a >= b) 807 if isinstance(expression, exp.LT): 808 return boolean_literal(a < b) 809 if isinstance(expression, exp.LTE): 810 return boolean_literal(a <= b) 811 return None 812 813 814def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 815 if isinstance(value, datetime.datetime): 816 return value.date() 817 if isinstance(value, datetime.date): 818 return value 819 try: 820 return datetime.datetime.fromisoformat(value).date() 821 except ValueError: 822 return None 823 824 825def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 826 if isinstance(value, datetime.datetime): 827 return value 828 if isinstance(value, datetime.date): 829 return datetime.datetime(year=value.year, month=value.month, day=value.day) 830 try: 831 return datetime.datetime.fromisoformat(value) 832 except ValueError: 833 return None 834 835 836def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 837 if not value: 838 return None 839 if to.is_type(exp.DataType.Type.DATE): 840 return cast_as_date(value) 841 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 842 return cast_as_datetime(value) 843 return None 844 845 846def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 847 if isinstance(cast, exp.Cast): 848 to = cast.to 849 elif isinstance(cast, exp.TsOrDsToDate): 850 to = exp.DataType.build(exp.DataType.Type.DATE) 851 else: 852 return None 853 854 if isinstance(cast.this, exp.Literal): 855 value: t.Any = cast.this.name 856 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 857 value = extract_date(cast.this) 858 else: 859 return None 860 return cast_value(value, to) 861 862 863def _is_date_literal(expression: exp.Expression) -> bool: 864 return extract_date(expression) is not None 865 866 867def extract_interval(expression): 868 n = int(expression.name) 869 unit = expression.text("unit").lower() 870 871 try: 872 return interval(unit, n) 873 except (UnsupportedUnit, ModuleNotFoundError): 874 return None 875 876 877def date_literal(date): 878 return exp.cast( 879 exp.Literal.string(date), 880 exp.DataType.Type.DATETIME 881 if isinstance(date, datetime.datetime) 882 else exp.DataType.Type.DATE, 883 ) 884 885 886def interval(unit: str, n: int = 1): 887 from dateutil.relativedelta import relativedelta 888 889 if unit == "year": 890 return relativedelta(years=1 * n) 891 if unit == "quarter": 892 return relativedelta(months=3 * n) 893 if unit == "month": 894 return relativedelta(months=1 * n) 895 if unit == "week": 896 return relativedelta(weeks=1 * n) 897 if unit == "day": 898 return relativedelta(days=1 * n) 899 if unit == "hour": 900 return relativedelta(hours=1 * n) 901 if unit == "minute": 902 return relativedelta(minutes=1 * n) 903 if unit == "second": 904 return relativedelta(seconds=1 * n) 905 906 raise UnsupportedUnit(f"Unsupported unit: {unit}") 907 908 909def date_floor(d: datetime.date, unit: str) -> datetime.date: 910 if unit == "year": 911 return d.replace(month=1, day=1) 912 if unit == "quarter": 913 if d.month <= 3: 914 return d.replace(month=1, day=1) 915 elif d.month <= 6: 916 return d.replace(month=4, day=1) 917 elif d.month <= 9: 918 return d.replace(month=7, day=1) 919 else: 920 return d.replace(month=10, day=1) 921 if unit == "month": 922 return d.replace(month=d.month, day=1) 923 if unit == "week": 924 # Assuming week starts on Monday (0) and ends on Sunday (6) 925 return d - datetime.timedelta(days=d.weekday()) 926 if unit == "day": 927 return d 928 929 raise UnsupportedUnit(f"Unsupported unit: {unit}") 930 931 932def date_ceil(d: datetime.date, unit: str) -> datetime.date: 933 floor = date_floor(d, unit) 934 935 if floor == d: 936 return d 937 938 return floor + interval(unit) 939 940 941def boolean_literal(condition): 942 return exp.true() if condition else exp.false() 943 944 945def _flat_simplify(expression, simplifier, root=True): 946 if root or not expression.same_parent: 947 operands = [] 948 queue = deque(expression.flatten(unnest=False)) 949 size = len(queue) 950 951 while queue: 952 a = queue.popleft() 953 954 for b in queue: 955 result = simplifier(expression, a, b) 956 957 if result and result is not expression: 958 queue.remove(b) 959 queue.appendleft(result) 960 break 961 else: 962 operands.append(a) 963 964 if len(operands) < size: 965 return functools.reduce( 966 lambda a, b: expression.__class__(this=a, expression=b), operands 967 ) 968 return expression
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
21def simplify(expression): 22 """ 23 Rewrite sqlglot AST to simplify expressions. 24 25 Example: 26 >>> import sqlglot 27 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 28 >>> simplify(expression).sql() 29 'TRUE' 30 31 Args: 32 expression (sqlglot.Expression): expression to simplify 33 Returns: 34 sqlglot.Expression: simplified expression 35 """ 36 37 generate = cached_generator() 38 39 # group by expressions cannot be simplified, for example 40 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 41 # the projection must exactly match the group by key 42 for group in expression.find_all(exp.Group): 43 select = group.parent 44 groups = set(group.expressions) 45 group.meta[FINAL] = True 46 47 for e in select.selects: 48 for node, *_ in e.walk(): 49 if node in groups: 50 e.meta[FINAL] = True 51 break 52 53 having = select.args.get("having") 54 if having: 55 for node, *_ in having.walk(): 56 if node in groups: 57 having.meta[FINAL] = True 58 break 59 60 def _simplify(expression, root=True): 61 if expression.meta.get(FINAL): 62 return expression 63 64 # Pre-order transformations 65 node = expression 66 node = rewrite_between(node) 67 node = uniq_sort(node, generate, root) 68 node = absorb_and_eliminate(node, root) 69 node = simplify_concat(node) 70 71 exp.replace_children(node, lambda e: _simplify(e, False)) 72 73 # Post-order transformations 74 node = simplify_not(node) 75 node = flatten(node) 76 node = simplify_connectors(node, root) 77 node = remove_compliments(node, root) 78 node = simplify_coalesce(node) 79 node.parent = expression.parent 80 node = simplify_literals(node, root) 81 node = simplify_equality(node) 82 node = simplify_parens(node) 83 node = simplify_datetrunc_predicate(node) 84 85 if root: 86 expression.replace(node) 87 88 return node 89 90 expression = while_changing(expression, _simplify) 91 remove_where_true(expression) 92 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
Returns:
sqlglot.Expression: simplified expression
95def catch(*exceptions): 96 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 97 98 def decorator(func): 99 def wrapped(expression, *args, **kwargs): 100 try: 101 return func(expression, *args, **kwargs) 102 except exceptions: 103 return expression 104 105 return wrapped 106 107 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
110def rewrite_between(expression: exp.Expression) -> exp.Expression: 111 """Rewrite x between y and z to x >= y AND x <= z. 112 113 This is done because comparison simplification is only done on lt/lte/gt/gte. 114 """ 115 if isinstance(expression, exp.Between): 116 return exp.and_( 117 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 118 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 119 copy=False, 120 ) 121 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
124def simplify_not(expression): 125 """ 126 Demorgan's Law 127 NOT (x OR y) -> NOT x AND NOT y 128 NOT (x AND y) -> NOT x OR NOT y 129 """ 130 if isinstance(expression, exp.Not): 131 if is_null(expression.this): 132 return exp.null() 133 if isinstance(expression.this, exp.Paren): 134 condition = expression.this.unnest() 135 if isinstance(condition, exp.And): 136 return exp.or_( 137 exp.not_(condition.left, copy=False), 138 exp.not_(condition.right, copy=False), 139 copy=False, 140 ) 141 if isinstance(condition, exp.Or): 142 return exp.and_( 143 exp.not_(condition.left, copy=False), 144 exp.not_(condition.right, copy=False), 145 copy=False, 146 ) 147 if is_null(condition): 148 return exp.null() 149 if always_true(expression.this): 150 return exp.false() 151 if is_false(expression.this): 152 return exp.true() 153 if isinstance(expression.this, exp.Not): 154 # double negation 155 # NOT NOT x -> x 156 return expression.this.this 157 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
160def flatten(expression): 161 """ 162 A AND (B AND C) -> A AND B AND C 163 A OR (B OR C) -> A OR B OR C 164 """ 165 if isinstance(expression, exp.Connector): 166 for node in expression.args.values(): 167 child = node.unnest() 168 if isinstance(child, expression.__class__): 169 node.replace(child) 170 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
173def simplify_connectors(expression, root=True): 174 def _simplify_connectors(expression, left, right): 175 if left == right: 176 return left 177 if isinstance(expression, exp.And): 178 if is_false(left) or is_false(right): 179 return exp.false() 180 if is_null(left) or is_null(right): 181 return exp.null() 182 if always_true(left) and always_true(right): 183 return exp.true() 184 if always_true(left): 185 return right 186 if always_true(right): 187 return left 188 return _simplify_comparison(expression, left, right) 189 elif isinstance(expression, exp.Or): 190 if always_true(left) or always_true(right): 191 return exp.true() 192 if is_false(left) and is_false(right): 193 return exp.false() 194 if ( 195 (is_null(left) and is_null(right)) 196 or (is_null(left) and is_false(right)) 197 or (is_false(left) and is_null(right)) 198 ): 199 return exp.null() 200 if is_false(left): 201 return right 202 if is_false(right): 203 return left 204 return _simplify_comparison(expression, left, right, or_=True) 205 206 if isinstance(expression, exp.Connector): 207 return _flat_simplify(expression, _simplify_connectors, root) 208 return expression
291def remove_compliments(expression, root=True): 292 """ 293 Removing compliments. 294 295 A AND NOT A -> FALSE 296 A OR NOT A -> TRUE 297 """ 298 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 299 compliment = exp.false() if isinstance(expression, exp.And) else exp.true() 300 301 for a, b in itertools.permutations(expression.flatten(), 2): 302 if is_complement(a, b): 303 return compliment 304 return expression
Removing compliments.
A AND NOT A -> FALSE A OR NOT A -> TRUE
307def uniq_sort(expression, generate, root=True): 308 """ 309 Uniq and sort a connector. 310 311 C AND A AND B AND B -> A AND B AND C 312 """ 313 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 314 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 315 flattened = tuple(expression.flatten()) 316 deduped = {generate(e): e for e in flattened} 317 arr = tuple(deduped.items()) 318 319 # check if the operands are already sorted, if not sort them 320 # A AND C AND B -> A AND B AND C 321 for i, (sql, e) in enumerate(arr[1:]): 322 if sql < arr[i][0]: 323 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 324 break 325 else: 326 # we didn't have to sort but maybe we need to dedup 327 if len(deduped) < len(flattened): 328 expression = result_func(*deduped.values(), copy=False) 329 330 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
333def absorb_and_eliminate(expression, root=True): 334 """ 335 absorption: 336 A AND (A OR B) -> A 337 A OR (A AND B) -> A 338 A AND (NOT A OR B) -> A AND B 339 A OR (NOT A AND B) -> A OR B 340 elimination: 341 (A AND B) OR (A AND NOT B) -> A 342 (A OR B) AND (A OR NOT B) -> A 343 """ 344 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 345 kind = exp.Or if isinstance(expression, exp.And) else exp.And 346 347 for a, b in itertools.permutations(expression.flatten(), 2): 348 if isinstance(a, kind): 349 aa, ab = a.unnest_operands() 350 351 # absorb 352 if is_complement(b, aa): 353 aa.replace(exp.true() if kind == exp.And else exp.false()) 354 elif is_complement(b, ab): 355 ab.replace(exp.true() if kind == exp.And else exp.false()) 356 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 357 a.replace(exp.false() if kind == exp.And else exp.true()) 358 elif isinstance(b, kind): 359 # eliminate 360 rhs = b.unnest_operands() 361 ba, bb = rhs 362 363 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 364 a.replace(aa) 365 b.replace(aa) 366 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 367 a.replace(ab) 368 b.replace(ab) 369 370 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
99 def wrapped(expression, *args, **kwargs): 100 try: 101 return func(expression, *args, **kwargs) 102 except exceptions: 103 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
447def simplify_literals(expression, root=True): 448 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 449 return _flat_simplify(expression, _simplify_binary, root) 450 451 if isinstance(expression, exp.Neg): 452 this = expression.this 453 if this.is_number: 454 value = this.name 455 if value[0] == "-": 456 return exp.Literal.number(value[1:]) 457 return exp.Literal.number(f"-{value}") 458 459 return expression
522def simplify_parens(expression): 523 if not isinstance(expression, exp.Paren): 524 return expression 525 526 this = expression.this 527 parent = expression.parent 528 529 if not isinstance(this, exp.Select) and ( 530 not isinstance(parent, (exp.Condition, exp.Binary)) 531 or isinstance(parent, exp.Paren) 532 or not isinstance(this, exp.Binary) 533 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 534 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 535 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 536 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 537 ): 538 return this 539 return expression
549def simplify_coalesce(expression): 550 # COALESCE(x) -> x 551 if ( 552 isinstance(expression, exp.Coalesce) 553 and not expression.expressions 554 # COALESCE is also used as a Spark partitioning hint 555 and not isinstance(expression.parent, exp.Hint) 556 ): 557 return expression.this 558 559 if not isinstance(expression, COMPARISONS): 560 return expression 561 562 if isinstance(expression.left, exp.Coalesce): 563 coalesce = expression.left 564 other = expression.right 565 elif isinstance(expression.right, exp.Coalesce): 566 coalesce = expression.right 567 other = expression.left 568 else: 569 return expression 570 571 # This transformation is valid for non-constants, 572 # but it really only does anything if they are both constants. 573 if not isinstance(other, CONSTANTS): 574 return expression 575 576 # Find the first constant arg 577 for arg_index, arg in enumerate(coalesce.expressions): 578 if isinstance(arg, CONSTANTS): 579 break 580 else: 581 return expression 582 583 coalesce.set("expressions", coalesce.expressions[:arg_index]) 584 585 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 586 # since we already remove COALESCE at the top of this function. 587 coalesce = coalesce if coalesce.expressions else coalesce.this 588 589 # This expression is more complex than when we started, but it will get simplified further 590 return exp.paren( 591 exp.or_( 592 exp.and_( 593 coalesce.is_(exp.null()).not_(copy=False), 594 expression.copy(), 595 copy=False, 596 ), 597 exp.and_( 598 coalesce.is_(exp.null()), 599 type(expression)(this=arg.copy(), expression=other.copy()), 600 copy=False, 601 ), 602 copy=False, 603 ) 604 )
611def simplify_concat(expression): 612 """Reduces all groups that contain string literals by concatenating them.""" 613 if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): 614 return expression 615 616 new_args = [] 617 for is_string_group, group in itertools.groupby( 618 expression.expressions or expression.flatten(), lambda e: e.is_string 619 ): 620 if is_string_group: 621 new_args.append(exp.Literal.string("".join(string.name for string in group))) 622 else: 623 new_args.extend(group) 624 625 # Ensures we preserve the right concat type, i.e. whether it's "safe" or not 626 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 627 return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
Reduces all groups that contain string literals by concatenating them.
99 def wrapped(expression, *args, **kwargs): 100 try: 101 return func(expression, *args, **kwargs) 102 except exceptions: 103 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
765def remove_where_true(expression): 766 for where in expression.find_all(exp.Where): 767 if always_true(where.this): 768 where.parent.set("where", None) 769 for join in expression.find_all(exp.Join): 770 if ( 771 always_true(join.args.get("on")) 772 and not join.args.get("using") 773 and not join.args.get("method") 774 and (join.side, join.kind) in JOINS 775 ): 776 join.set("on", None) 777 join.set("side", None) 778 join.set("kind", "CROSS")
799def eval_boolean(expression, a, b): 800 if isinstance(expression, (exp.EQ, exp.Is)): 801 return boolean_literal(a == b) 802 if isinstance(expression, exp.NEQ): 803 return boolean_literal(a != b) 804 if isinstance(expression, exp.GT): 805 return boolean_literal(a > b) 806 if isinstance(expression, exp.GTE): 807 return boolean_literal(a >= b) 808 if isinstance(expression, exp.LT): 809 return boolean_literal(a < b) 810 if isinstance(expression, exp.LTE): 811 return boolean_literal(a <= b) 812 return None
826def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 827 if isinstance(value, datetime.datetime): 828 return value 829 if isinstance(value, datetime.date): 830 return datetime.datetime(year=value.year, month=value.month, day=value.day) 831 try: 832 return datetime.datetime.fromisoformat(value) 833 except ValueError: 834 return None
837def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 838 if not value: 839 return None 840 if to.is_type(exp.DataType.Type.DATE): 841 return cast_as_date(value) 842 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 843 return cast_as_datetime(value) 844 return None
847def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 848 if isinstance(cast, exp.Cast): 849 to = cast.to 850 elif isinstance(cast, exp.TsOrDsToDate): 851 to = exp.DataType.build(exp.DataType.Type.DATE) 852 else: 853 return None 854 855 if isinstance(cast.this, exp.Literal): 856 value: t.Any = cast.this.name 857 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 858 value = extract_date(cast.this) 859 else: 860 return None 861 return cast_value(value, to)
887def interval(unit: str, n: int = 1): 888 from dateutil.relativedelta import relativedelta 889 890 if unit == "year": 891 return relativedelta(years=1 * n) 892 if unit == "quarter": 893 return relativedelta(months=3 * n) 894 if unit == "month": 895 return relativedelta(months=1 * n) 896 if unit == "week": 897 return relativedelta(weeks=1 * n) 898 if unit == "day": 899 return relativedelta(days=1 * n) 900 if unit == "hour": 901 return relativedelta(hours=1 * n) 902 if unit == "minute": 903 return relativedelta(minutes=1 * n) 904 if unit == "second": 905 return relativedelta(seconds=1 * n) 906 907 raise UnsupportedUnit(f"Unsupported unit: {unit}")
910def date_floor(d: datetime.date, unit: str) -> datetime.date: 911 if unit == "year": 912 return d.replace(month=1, day=1) 913 if unit == "quarter": 914 if d.month <= 3: 915 return d.replace(month=1, day=1) 916 elif d.month <= 6: 917 return d.replace(month=4, day=1) 918 elif d.month <= 9: 919 return d.replace(month=7, day=1) 920 else: 921 return d.replace(month=10, day=1) 922 if unit == "month": 923 return d.replace(month=d.month, day=1) 924 if unit == "week": 925 # Assuming week starts on Monday (0) and ends on Sunday (6) 926 return d - datetime.timedelta(days=d.weekday()) 927 if unit == "day": 928 return d 929 930 raise UnsupportedUnit(f"Unsupported unit: {unit}")