sqlglot.optimizer.simplify
1import datetime 2import functools 3import itertools 4import typing as t 5from collections import deque 6from decimal import Decimal 7 8import sqlglot 9from sqlglot import exp 10from sqlglot.generator import cached_generator 11from sqlglot.helper import first, merge_ranges, while_changing 12from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 13 14# Final means that an expression should not be simplified 15FINAL = "final" 16 17 18class UnsupportedUnit(Exception): 19 pass 20 21 22def simplify(expression, constant_propagation=False): 23 """ 24 Rewrite sqlglot AST to simplify expressions. 25 26 Example: 27 >>> import sqlglot 28 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 29 >>> simplify(expression).sql() 30 'TRUE' 31 32 Args: 33 expression (sqlglot.Expression): expression to simplify 34 constant_propagation: whether or not the constant propagation rule should be used 35 36 Returns: 37 sqlglot.Expression: simplified expression 38 """ 39 40 generate = cached_generator() 41 42 # group by expressions cannot be simplified, for example 43 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 44 # the projection must exactly match the group by key 45 for group in expression.find_all(exp.Group): 46 select = group.parent 47 groups = set(group.expressions) 48 group.meta[FINAL] = True 49 50 for e in select.selects: 51 for node, *_ in e.walk(): 52 if node in groups: 53 e.meta[FINAL] = True 54 break 55 56 having = select.args.get("having") 57 if having: 58 for node, *_ in having.walk(): 59 if node in groups: 60 having.meta[FINAL] = True 61 break 62 63 def _simplify(expression, root=True): 64 if expression.meta.get(FINAL): 65 return expression 66 67 # Pre-order transformations 68 node = expression 69 node = rewrite_between(node) 70 node = uniq_sort(node, generate, root) 71 node = absorb_and_eliminate(node, root) 72 node = simplify_concat(node) 73 node = simplify_conditionals(node) 74 75 if constant_propagation: 76 node = propagate_constants(node, root) 77 78 exp.replace_children(node, lambda e: _simplify(e, False)) 79 80 # Post-order transformations 81 node = simplify_not(node) 82 node = flatten(node) 83 node = simplify_connectors(node, root) 84 node = remove_complements(node, root) 85 node = simplify_coalesce(node) 86 node.parent = expression.parent 87 node = simplify_literals(node, root) 88 node = simplify_equality(node) 89 node = simplify_parens(node) 90 node = simplify_datetrunc_predicate(node) 91 92 if root: 93 expression.replace(node) 94 95 return node 96 97 expression = while_changing(expression, _simplify) 98 remove_where_true(expression) 99 return expression 100 101 102def catch(*exceptions): 103 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 104 105 def decorator(func): 106 def wrapped(expression, *args, **kwargs): 107 try: 108 return func(expression, *args, **kwargs) 109 except exceptions: 110 return expression 111 112 return wrapped 113 114 return decorator 115 116 117def rewrite_between(expression: exp.Expression) -> exp.Expression: 118 """Rewrite x between y and z to x >= y AND x <= z. 119 120 This is done because comparison simplification is only done on lt/lte/gt/gte. 121 """ 122 if isinstance(expression, exp.Between): 123 return exp.and_( 124 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 125 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 126 copy=False, 127 ) 128 return expression 129 130 131def simplify_not(expression): 132 """ 133 Demorgan's Law 134 NOT (x OR y) -> NOT x AND NOT y 135 NOT (x AND y) -> NOT x OR NOT y 136 """ 137 if isinstance(expression, exp.Not): 138 if is_null(expression.this): 139 return exp.null() 140 if isinstance(expression.this, exp.Paren): 141 condition = expression.this.unnest() 142 if isinstance(condition, exp.And): 143 return exp.or_( 144 exp.not_(condition.left, copy=False), 145 exp.not_(condition.right, copy=False), 146 copy=False, 147 ) 148 if isinstance(condition, exp.Or): 149 return exp.and_( 150 exp.not_(condition.left, copy=False), 151 exp.not_(condition.right, copy=False), 152 copy=False, 153 ) 154 if is_null(condition): 155 return exp.null() 156 if always_true(expression.this): 157 return exp.false() 158 if is_false(expression.this): 159 return exp.true() 160 if isinstance(expression.this, exp.Not): 161 # double negation 162 # NOT NOT x -> x 163 return expression.this.this 164 return expression 165 166 167def flatten(expression): 168 """ 169 A AND (B AND C) -> A AND B AND C 170 A OR (B OR C) -> A OR B OR C 171 """ 172 if isinstance(expression, exp.Connector): 173 for node in expression.args.values(): 174 child = node.unnest() 175 if isinstance(child, expression.__class__): 176 node.replace(child) 177 return expression 178 179 180def simplify_connectors(expression, root=True): 181 def _simplify_connectors(expression, left, right): 182 if left == right: 183 return left 184 if isinstance(expression, exp.And): 185 if is_false(left) or is_false(right): 186 return exp.false() 187 if is_null(left) or is_null(right): 188 return exp.null() 189 if always_true(left) and always_true(right): 190 return exp.true() 191 if always_true(left): 192 return right 193 if always_true(right): 194 return left 195 return _simplify_comparison(expression, left, right) 196 elif isinstance(expression, exp.Or): 197 if always_true(left) or always_true(right): 198 return exp.true() 199 if is_false(left) and is_false(right): 200 return exp.false() 201 if ( 202 (is_null(left) and is_null(right)) 203 or (is_null(left) and is_false(right)) 204 or (is_false(left) and is_null(right)) 205 ): 206 return exp.null() 207 if is_false(left): 208 return right 209 if is_false(right): 210 return left 211 return _simplify_comparison(expression, left, right, or_=True) 212 213 if isinstance(expression, exp.Connector): 214 return _flat_simplify(expression, _simplify_connectors, root) 215 return expression 216 217 218LT_LTE = (exp.LT, exp.LTE) 219GT_GTE = (exp.GT, exp.GTE) 220 221COMPARISONS = ( 222 *LT_LTE, 223 *GT_GTE, 224 exp.EQ, 225 exp.NEQ, 226 exp.Is, 227) 228 229INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 230 exp.LT: exp.GT, 231 exp.GT: exp.LT, 232 exp.LTE: exp.GTE, 233 exp.GTE: exp.LTE, 234} 235 236 237def _simplify_comparison(expression, left, right, or_=False): 238 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 239 ll, lr = left.args.values() 240 rl, rr = right.args.values() 241 242 largs = {ll, lr} 243 rargs = {rl, rr} 244 245 matching = largs & rargs 246 columns = {m for m in matching if isinstance(m, exp.Column)} 247 248 if matching and columns: 249 try: 250 l = first(largs - columns) 251 r = first(rargs - columns) 252 except StopIteration: 253 return expression 254 255 # make sure the comparison is always of the form x > 1 instead of 1 < x 256 if left.__class__ in INVERSE_COMPARISONS and l == ll: 257 left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll) 258 if right.__class__ in INVERSE_COMPARISONS and r == rl: 259 right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl) 260 261 if l.is_number and r.is_number: 262 l = float(l.name) 263 r = float(r.name) 264 elif l.is_string and r.is_string: 265 l = l.name 266 r = r.name 267 else: 268 return None 269 270 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 271 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 272 return left if (av > bv if or_ else av <= bv) else right 273 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 274 return left if (av < bv if or_ else av >= bv) else right 275 276 # we can't ever shortcut to true because the column could be null 277 if not or_: 278 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 279 if av <= bv: 280 return exp.false() 281 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 282 if av >= bv: 283 return exp.false() 284 elif isinstance(a, exp.EQ): 285 if isinstance(b, exp.LT): 286 return exp.false() if av >= bv else a 287 if isinstance(b, exp.LTE): 288 return exp.false() if av > bv else a 289 if isinstance(b, exp.GT): 290 return exp.false() if av <= bv else a 291 if isinstance(b, exp.GTE): 292 return exp.false() if av < bv else a 293 if isinstance(b, exp.NEQ): 294 return exp.false() if av == bv else a 295 return None 296 297 298def remove_complements(expression, root=True): 299 """ 300 Removing complements. 301 302 A AND NOT A -> FALSE 303 A OR NOT A -> TRUE 304 """ 305 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 306 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 307 308 for a, b in itertools.permutations(expression.flatten(), 2): 309 if is_complement(a, b): 310 return complement 311 return expression 312 313 314def uniq_sort(expression, generate, root=True): 315 """ 316 Uniq and sort a connector. 317 318 C AND A AND B AND B -> A AND B AND C 319 """ 320 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 321 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 322 flattened = tuple(expression.flatten()) 323 deduped = {generate(e): e for e in flattened} 324 arr = tuple(deduped.items()) 325 326 # check if the operands are already sorted, if not sort them 327 # A AND C AND B -> A AND B AND C 328 for i, (sql, e) in enumerate(arr[1:]): 329 if sql < arr[i][0]: 330 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 331 break 332 else: 333 # we didn't have to sort but maybe we need to dedup 334 if len(deduped) < len(flattened): 335 expression = result_func(*deduped.values(), copy=False) 336 337 return expression 338 339 340def absorb_and_eliminate(expression, root=True): 341 """ 342 absorption: 343 A AND (A OR B) -> A 344 A OR (A AND B) -> A 345 A AND (NOT A OR B) -> A AND B 346 A OR (NOT A AND B) -> A OR B 347 elimination: 348 (A AND B) OR (A AND NOT B) -> A 349 (A OR B) AND (A OR NOT B) -> A 350 """ 351 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 352 kind = exp.Or if isinstance(expression, exp.And) else exp.And 353 354 for a, b in itertools.permutations(expression.flatten(), 2): 355 if isinstance(a, kind): 356 aa, ab = a.unnest_operands() 357 358 # absorb 359 if is_complement(b, aa): 360 aa.replace(exp.true() if kind == exp.And else exp.false()) 361 elif is_complement(b, ab): 362 ab.replace(exp.true() if kind == exp.And else exp.false()) 363 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 364 a.replace(exp.false() if kind == exp.And else exp.true()) 365 elif isinstance(b, kind): 366 # eliminate 367 rhs = b.unnest_operands() 368 ba, bb = rhs 369 370 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 371 a.replace(aa) 372 b.replace(aa) 373 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 374 a.replace(ab) 375 b.replace(ab) 376 377 return expression 378 379 380def propagate_constants(expression, root=True): 381 """ 382 Propagate constants for conjunctions in DNF: 383 384 SELECT * FROM t WHERE a = b AND b = 5 becomes 385 SELECT * FROM t WHERE a = 5 AND b = 5 386 387 Reference: https://www.sqlite.org/optoverview.html 388 """ 389 390 if ( 391 isinstance(expression, exp.And) 392 and (root or not expression.same_parent) 393 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 394 ): 395 constant_mapping = {} 396 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 397 if isinstance(expr, exp.EQ): 398 l, r = expr.left, expr.right 399 400 # TODO: create a helper that can be used to detect nested literal expressions such 401 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 402 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 403 pass 404 elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): 405 l, r = r, l 406 else: 407 continue 408 409 constant_mapping[l] = (id(l), r) 410 411 if constant_mapping: 412 for column in find_all_in_scope(expression, exp.Column): 413 parent = column.parent 414 column_id, constant = constant_mapping.get(column) or (None, None) 415 if ( 416 column_id is not None 417 and id(column) != column_id 418 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 419 ): 420 column.replace(constant.copy()) 421 422 return expression 423 424 425INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 426 exp.DateAdd: exp.Sub, 427 exp.DateSub: exp.Add, 428 exp.DatetimeAdd: exp.Sub, 429 exp.DatetimeSub: exp.Add, 430} 431 432INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 433 **INVERSE_DATE_OPS, 434 exp.Add: exp.Sub, 435 exp.Sub: exp.Add, 436} 437 438 439def _is_number(expression: exp.Expression) -> bool: 440 return expression.is_number 441 442 443def _is_interval(expression: exp.Expression) -> bool: 444 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 445 446 447@catch(ModuleNotFoundError, UnsupportedUnit) 448def simplify_equality(expression: exp.Expression) -> exp.Expression: 449 """ 450 Use the subtraction and addition properties of equality to simplify expressions: 451 452 x + 1 = 3 becomes x = 2 453 454 There are two binary operations in the above expression: + and = 455 Here's how we reference all the operands in the code below: 456 457 l r 458 x + 1 = 3 459 a b 460 """ 461 if isinstance(expression, COMPARISONS): 462 l, r = expression.left, expression.right 463 464 if l.__class__ in INVERSE_OPS: 465 pass 466 elif r.__class__ in INVERSE_OPS: 467 l, r = r, l 468 else: 469 return expression 470 471 if r.is_number: 472 a_predicate = _is_number 473 b_predicate = _is_number 474 elif _is_date_literal(r): 475 a_predicate = _is_date_literal 476 b_predicate = _is_interval 477 else: 478 return expression 479 480 if l.__class__ in INVERSE_DATE_OPS: 481 l = t.cast(exp.IntervalOp, l) 482 a = l.this 483 b = l.interval() 484 else: 485 l = t.cast(exp.Binary, l) 486 a, b = l.left, l.right 487 488 if not a_predicate(a) and b_predicate(b): 489 pass 490 elif not a_predicate(b) and b_predicate(a): 491 a, b = b, a 492 else: 493 return expression 494 495 return expression.__class__( 496 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 497 ) 498 return expression 499 500 501def simplify_literals(expression, root=True): 502 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 503 return _flat_simplify(expression, _simplify_binary, root) 504 505 if isinstance(expression, exp.Neg): 506 this = expression.this 507 if this.is_number: 508 value = this.name 509 if value[0] == "-": 510 return exp.Literal.number(value[1:]) 511 return exp.Literal.number(f"-{value}") 512 513 return expression 514 515 516def _simplify_binary(expression, a, b): 517 if isinstance(expression, exp.Is): 518 if isinstance(b, exp.Not): 519 c = b.this 520 not_ = True 521 else: 522 c = b 523 not_ = False 524 525 if is_null(c): 526 if isinstance(a, exp.Literal): 527 return exp.true() if not_ else exp.false() 528 if is_null(a): 529 return exp.false() if not_ else exp.true() 530 elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)): 531 return None 532 elif is_null(a) or is_null(b): 533 return exp.null() 534 535 if a.is_number and b.is_number: 536 a = int(a.name) if a.is_int else Decimal(a.name) 537 b = int(b.name) if b.is_int else Decimal(b.name) 538 539 if isinstance(expression, exp.Add): 540 return exp.Literal.number(a + b) 541 if isinstance(expression, exp.Sub): 542 return exp.Literal.number(a - b) 543 if isinstance(expression, exp.Mul): 544 return exp.Literal.number(a * b) 545 if isinstance(expression, exp.Div): 546 # engines have differing int div behavior so intdiv is not safe 547 if isinstance(a, int) and isinstance(b, int): 548 return None 549 return exp.Literal.number(a / b) 550 551 boolean = eval_boolean(expression, a, b) 552 553 if boolean: 554 return boolean 555 elif a.is_string and b.is_string: 556 boolean = eval_boolean(expression, a.this, b.this) 557 558 if boolean: 559 return boolean 560 elif _is_date_literal(a) and isinstance(b, exp.Interval): 561 a, b = extract_date(a), extract_interval(b) 562 if a and b: 563 if isinstance(expression, exp.Add): 564 return date_literal(a + b) 565 if isinstance(expression, exp.Sub): 566 return date_literal(a - b) 567 elif isinstance(a, exp.Interval) and _is_date_literal(b): 568 a, b = extract_interval(a), extract_date(b) 569 # you cannot subtract a date from an interval 570 if a and b and isinstance(expression, exp.Add): 571 return date_literal(a + b) 572 573 return None 574 575 576def simplify_parens(expression): 577 if not isinstance(expression, exp.Paren): 578 return expression 579 580 this = expression.this 581 parent = expression.parent 582 583 if not isinstance(this, exp.Select) and ( 584 not isinstance(parent, (exp.Condition, exp.Binary)) 585 or isinstance(parent, exp.Paren) 586 or not isinstance(this, exp.Binary) 587 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 588 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 589 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 590 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 591 ): 592 return this 593 return expression 594 595 596CONSTANTS = ( 597 exp.Literal, 598 exp.Boolean, 599 exp.Null, 600) 601 602 603def simplify_coalesce(expression): 604 # COALESCE(x) -> x 605 if ( 606 isinstance(expression, exp.Coalesce) 607 and not expression.expressions 608 # COALESCE is also used as a Spark partitioning hint 609 and not isinstance(expression.parent, exp.Hint) 610 ): 611 return expression.this 612 613 if not isinstance(expression, COMPARISONS): 614 return expression 615 616 if isinstance(expression.left, exp.Coalesce): 617 coalesce = expression.left 618 other = expression.right 619 elif isinstance(expression.right, exp.Coalesce): 620 coalesce = expression.right 621 other = expression.left 622 else: 623 return expression 624 625 # This transformation is valid for non-constants, 626 # but it really only does anything if they are both constants. 627 if not isinstance(other, CONSTANTS): 628 return expression 629 630 # Find the first constant arg 631 for arg_index, arg in enumerate(coalesce.expressions): 632 if isinstance(arg, CONSTANTS): 633 break 634 else: 635 return expression 636 637 coalesce.set("expressions", coalesce.expressions[:arg_index]) 638 639 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 640 # since we already remove COALESCE at the top of this function. 641 coalesce = coalesce if coalesce.expressions else coalesce.this 642 643 # This expression is more complex than when we started, but it will get simplified further 644 return exp.paren( 645 exp.or_( 646 exp.and_( 647 coalesce.is_(exp.null()).not_(copy=False), 648 expression.copy(), 649 copy=False, 650 ), 651 exp.and_( 652 coalesce.is_(exp.null()), 653 type(expression)(this=arg.copy(), expression=other.copy()), 654 copy=False, 655 ), 656 copy=False, 657 ) 658 ) 659 660 661CONCATS = (exp.Concat, exp.DPipe) 662SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) 663 664 665def simplify_concat(expression): 666 """Reduces all groups that contain string literals by concatenating them.""" 667 if not isinstance(expression, CONCATS) or ( 668 # We can't reduce a CONCAT_WS call if we don't statically know the separator 669 isinstance(expression, exp.ConcatWs) 670 and not expression.expressions[0].is_string 671 ): 672 return expression 673 674 if isinstance(expression, exp.ConcatWs): 675 sep_expr, *expressions = expression.expressions 676 sep = sep_expr.name 677 concat_type = exp.ConcatWs 678 else: 679 expressions = expression.expressions 680 sep = "" 681 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 682 683 new_args = [] 684 for is_string_group, group in itertools.groupby( 685 expressions or expression.flatten(), lambda e: e.is_string 686 ): 687 if is_string_group: 688 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 689 else: 690 new_args.extend(group) 691 692 if len(new_args) == 1 and new_args[0].is_string: 693 return new_args[0] 694 695 if concat_type is exp.ConcatWs: 696 new_args = [sep_expr] + new_args 697 698 return concat_type(expressions=new_args) 699 700 701def simplify_conditionals(expression): 702 """Simplifies expressions like IF, CASE if their condition is statically known.""" 703 if isinstance(expression, exp.Case): 704 this = expression.this 705 for case in expression.args["ifs"]: 706 cond = case.this 707 if this: 708 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 709 cond = cond.replace(this.pop().eq(cond)) 710 711 if always_true(cond): 712 return case.args["true"] 713 714 if always_false(cond): 715 case.pop() 716 if not expression.args["ifs"]: 717 return expression.args.get("default") or exp.null() 718 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 719 if always_true(expression.this): 720 return expression.args["true"] 721 if always_false(expression.this): 722 return expression.args.get("false") or exp.null() 723 724 return expression 725 726 727DateRange = t.Tuple[datetime.date, datetime.date] 728 729 730def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: 731 """ 732 Get the date range for a DATE_TRUNC equality comparison: 733 734 Example: 735 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 736 Returns: 737 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 738 """ 739 floor = date_floor(date, unit) 740 741 if date != floor: 742 # This will always be False, except for NULL values. 743 return None 744 745 return floor, floor + interval(unit) 746 747 748def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 749 """Get the logical expression for a date range""" 750 return exp.and_( 751 left >= date_literal(drange[0]), 752 left < date_literal(drange[1]), 753 copy=False, 754 ) 755 756 757def _datetrunc_eq( 758 left: exp.Expression, date: datetime.date, unit: str 759) -> t.Optional[exp.Expression]: 760 drange = _datetrunc_range(date, unit) 761 if not drange: 762 return None 763 764 return _datetrunc_eq_expression(left, drange) 765 766 767def _datetrunc_neq( 768 left: exp.Expression, date: datetime.date, unit: str 769) -> t.Optional[exp.Expression]: 770 drange = _datetrunc_range(date, unit) 771 if not drange: 772 return None 773 774 return exp.and_( 775 left < date_literal(drange[0]), 776 left >= date_literal(drange[1]), 777 copy=False, 778 ) 779 780 781DateTruncBinaryTransform = t.Callable[ 782 [exp.Expression, datetime.date, str], t.Optional[exp.Expression] 783] 784DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 785 exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), 786 exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), 787 exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), 788 exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), 789 exp.EQ: _datetrunc_eq, 790 exp.NEQ: _datetrunc_neq, 791} 792DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 793 794 795def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 796 return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right) 797 798 799@catch(ModuleNotFoundError, UnsupportedUnit) 800def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: 801 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 802 comparison = expression.__class__ 803 804 if comparison not in DATETRUNC_COMPARISONS: 805 return expression 806 807 if isinstance(expression, exp.Binary): 808 l, r = expression.left, expression.right 809 810 if _is_datetrunc_predicate(l, r): 811 pass 812 elif _is_datetrunc_predicate(r, l): 813 comparison = INVERSE_COMPARISONS.get(comparison, comparison) 814 l, r = r, l 815 else: 816 return expression 817 818 l = t.cast(exp.DateTrunc, l) 819 unit = l.unit.name.lower() 820 date = extract_date(r) 821 822 if not date: 823 return expression 824 825 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression 826 elif isinstance(expression, exp.In): 827 l = expression.this 828 rs = expression.expressions 829 830 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 831 l = t.cast(exp.DateTrunc, l) 832 unit = l.unit.name.lower() 833 834 ranges = [] 835 for r in rs: 836 date = extract_date(r) 837 if not date: 838 return expression 839 drange = _datetrunc_range(date, unit) 840 if drange: 841 ranges.append(drange) 842 843 if not ranges: 844 return expression 845 846 ranges = merge_ranges(ranges) 847 848 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 849 850 return expression 851 852 853# CROSS joins result in an empty table if the right table is empty. 854# So we can only simplify certain types of joins to CROSS. 855# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 856JOINS = { 857 ("", ""), 858 ("", "INNER"), 859 ("RIGHT", ""), 860 ("RIGHT", "OUTER"), 861} 862 863 864def remove_where_true(expression): 865 for where in expression.find_all(exp.Where): 866 if always_true(where.this): 867 where.parent.set("where", None) 868 for join in expression.find_all(exp.Join): 869 if ( 870 always_true(join.args.get("on")) 871 and not join.args.get("using") 872 and not join.args.get("method") 873 and (join.side, join.kind) in JOINS 874 ): 875 join.set("on", None) 876 join.set("side", None) 877 join.set("kind", "CROSS") 878 879 880def always_true(expression): 881 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 882 expression, exp.Literal 883 ) 884 885 886def always_false(expression): 887 return is_false(expression) or is_null(expression) 888 889 890def is_complement(a, b): 891 return isinstance(b, exp.Not) and b.this == a 892 893 894def is_false(a: exp.Expression) -> bool: 895 return type(a) is exp.Boolean and not a.this 896 897 898def is_null(a: exp.Expression) -> bool: 899 return type(a) is exp.Null 900 901 902def eval_boolean(expression, a, b): 903 if isinstance(expression, (exp.EQ, exp.Is)): 904 return boolean_literal(a == b) 905 if isinstance(expression, exp.NEQ): 906 return boolean_literal(a != b) 907 if isinstance(expression, exp.GT): 908 return boolean_literal(a > b) 909 if isinstance(expression, exp.GTE): 910 return boolean_literal(a >= b) 911 if isinstance(expression, exp.LT): 912 return boolean_literal(a < b) 913 if isinstance(expression, exp.LTE): 914 return boolean_literal(a <= b) 915 return None 916 917 918def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 919 if isinstance(value, datetime.datetime): 920 return value.date() 921 if isinstance(value, datetime.date): 922 return value 923 try: 924 return datetime.datetime.fromisoformat(value).date() 925 except ValueError: 926 return None 927 928 929def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 930 if isinstance(value, datetime.datetime): 931 return value 932 if isinstance(value, datetime.date): 933 return datetime.datetime(year=value.year, month=value.month, day=value.day) 934 try: 935 return datetime.datetime.fromisoformat(value) 936 except ValueError: 937 return None 938 939 940def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 941 if not value: 942 return None 943 if to.is_type(exp.DataType.Type.DATE): 944 return cast_as_date(value) 945 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 946 return cast_as_datetime(value) 947 return None 948 949 950def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 951 if isinstance(cast, exp.Cast): 952 to = cast.to 953 elif isinstance(cast, exp.TsOrDsToDate): 954 to = exp.DataType.build(exp.DataType.Type.DATE) 955 else: 956 return None 957 958 if isinstance(cast.this, exp.Literal): 959 value: t.Any = cast.this.name 960 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 961 value = extract_date(cast.this) 962 else: 963 return None 964 return cast_value(value, to) 965 966 967def _is_date_literal(expression: exp.Expression) -> bool: 968 return extract_date(expression) is not None 969 970 971def extract_interval(expression): 972 n = int(expression.name) 973 unit = expression.text("unit").lower() 974 975 try: 976 return interval(unit, n) 977 except (UnsupportedUnit, ModuleNotFoundError): 978 return None 979 980 981def date_literal(date): 982 return exp.cast( 983 exp.Literal.string(date), 984 exp.DataType.Type.DATETIME 985 if isinstance(date, datetime.datetime) 986 else exp.DataType.Type.DATE, 987 ) 988 989 990def interval(unit: str, n: int = 1): 991 from dateutil.relativedelta import relativedelta 992 993 if unit == "year": 994 return relativedelta(years=1 * n) 995 if unit == "quarter": 996 return relativedelta(months=3 * n) 997 if unit == "month": 998 return relativedelta(months=1 * n) 999 if unit == "week": 1000 return relativedelta(weeks=1 * n) 1001 if unit == "day": 1002 return relativedelta(days=1 * n) 1003 if unit == "hour": 1004 return relativedelta(hours=1 * n) 1005 if unit == "minute": 1006 return relativedelta(minutes=1 * n) 1007 if unit == "second": 1008 return relativedelta(seconds=1 * n) 1009 1010 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1011 1012 1013def date_floor(d: datetime.date, unit: str) -> datetime.date: 1014 if unit == "year": 1015 return d.replace(month=1, day=1) 1016 if unit == "quarter": 1017 if d.month <= 3: 1018 return d.replace(month=1, day=1) 1019 elif d.month <= 6: 1020 return d.replace(month=4, day=1) 1021 elif d.month <= 9: 1022 return d.replace(month=7, day=1) 1023 else: 1024 return d.replace(month=10, day=1) 1025 if unit == "month": 1026 return d.replace(month=d.month, day=1) 1027 if unit == "week": 1028 # Assuming week starts on Monday (0) and ends on Sunday (6) 1029 return d - datetime.timedelta(days=d.weekday()) 1030 if unit == "day": 1031 return d 1032 1033 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1034 1035 1036def date_ceil(d: datetime.date, unit: str) -> datetime.date: 1037 floor = date_floor(d, unit) 1038 1039 if floor == d: 1040 return d 1041 1042 return floor + interval(unit) 1043 1044 1045def boolean_literal(condition): 1046 return exp.true() if condition else exp.false() 1047 1048 1049def _flat_simplify(expression, simplifier, root=True): 1050 if root or not expression.same_parent: 1051 operands = [] 1052 queue = deque(expression.flatten(unnest=False)) 1053 size = len(queue) 1054 1055 while queue: 1056 a = queue.popleft() 1057 1058 for b in queue: 1059 result = simplifier(expression, a, b) 1060 1061 if result and result is not expression: 1062 queue.remove(b) 1063 queue.appendleft(result) 1064 break 1065 else: 1066 operands.append(a) 1067 1068 if len(operands) < size: 1069 return functools.reduce( 1070 lambda a, b: expression.__class__(this=a, expression=b), operands 1071 ) 1072 return expression
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
23def simplify(expression, constant_propagation=False): 24 """ 25 Rewrite sqlglot AST to simplify expressions. 26 27 Example: 28 >>> import sqlglot 29 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 30 >>> simplify(expression).sql() 31 'TRUE' 32 33 Args: 34 expression (sqlglot.Expression): expression to simplify 35 constant_propagation: whether or not the constant propagation rule should be used 36 37 Returns: 38 sqlglot.Expression: simplified expression 39 """ 40 41 generate = cached_generator() 42 43 # group by expressions cannot be simplified, for example 44 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 45 # the projection must exactly match the group by key 46 for group in expression.find_all(exp.Group): 47 select = group.parent 48 groups = set(group.expressions) 49 group.meta[FINAL] = True 50 51 for e in select.selects: 52 for node, *_ in e.walk(): 53 if node in groups: 54 e.meta[FINAL] = True 55 break 56 57 having = select.args.get("having") 58 if having: 59 for node, *_ in having.walk(): 60 if node in groups: 61 having.meta[FINAL] = True 62 break 63 64 def _simplify(expression, root=True): 65 if expression.meta.get(FINAL): 66 return expression 67 68 # Pre-order transformations 69 node = expression 70 node = rewrite_between(node) 71 node = uniq_sort(node, generate, root) 72 node = absorb_and_eliminate(node, root) 73 node = simplify_concat(node) 74 node = simplify_conditionals(node) 75 76 if constant_propagation: 77 node = propagate_constants(node, root) 78 79 exp.replace_children(node, lambda e: _simplify(e, False)) 80 81 # Post-order transformations 82 node = simplify_not(node) 83 node = flatten(node) 84 node = simplify_connectors(node, root) 85 node = remove_complements(node, root) 86 node = simplify_coalesce(node) 87 node.parent = expression.parent 88 node = simplify_literals(node, root) 89 node = simplify_equality(node) 90 node = simplify_parens(node) 91 node = simplify_datetrunc_predicate(node) 92 93 if root: 94 expression.replace(node) 95 96 return node 97 98 expression = while_changing(expression, _simplify) 99 remove_where_true(expression) 100 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether or not the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
103def catch(*exceptions): 104 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 105 106 def decorator(func): 107 def wrapped(expression, *args, **kwargs): 108 try: 109 return func(expression, *args, **kwargs) 110 except exceptions: 111 return expression 112 113 return wrapped 114 115 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
118def rewrite_between(expression: exp.Expression) -> exp.Expression: 119 """Rewrite x between y and z to x >= y AND x <= z. 120 121 This is done because comparison simplification is only done on lt/lte/gt/gte. 122 """ 123 if isinstance(expression, exp.Between): 124 return exp.and_( 125 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 126 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 127 copy=False, 128 ) 129 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
132def simplify_not(expression): 133 """ 134 Demorgan's Law 135 NOT (x OR y) -> NOT x AND NOT y 136 NOT (x AND y) -> NOT x OR NOT y 137 """ 138 if isinstance(expression, exp.Not): 139 if is_null(expression.this): 140 return exp.null() 141 if isinstance(expression.this, exp.Paren): 142 condition = expression.this.unnest() 143 if isinstance(condition, exp.And): 144 return exp.or_( 145 exp.not_(condition.left, copy=False), 146 exp.not_(condition.right, copy=False), 147 copy=False, 148 ) 149 if isinstance(condition, exp.Or): 150 return exp.and_( 151 exp.not_(condition.left, copy=False), 152 exp.not_(condition.right, copy=False), 153 copy=False, 154 ) 155 if is_null(condition): 156 return exp.null() 157 if always_true(expression.this): 158 return exp.false() 159 if is_false(expression.this): 160 return exp.true() 161 if isinstance(expression.this, exp.Not): 162 # double negation 163 # NOT NOT x -> x 164 return expression.this.this 165 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
168def flatten(expression): 169 """ 170 A AND (B AND C) -> A AND B AND C 171 A OR (B OR C) -> A OR B OR C 172 """ 173 if isinstance(expression, exp.Connector): 174 for node in expression.args.values(): 175 child = node.unnest() 176 if isinstance(child, expression.__class__): 177 node.replace(child) 178 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
181def simplify_connectors(expression, root=True): 182 def _simplify_connectors(expression, left, right): 183 if left == right: 184 return left 185 if isinstance(expression, exp.And): 186 if is_false(left) or is_false(right): 187 return exp.false() 188 if is_null(left) or is_null(right): 189 return exp.null() 190 if always_true(left) and always_true(right): 191 return exp.true() 192 if always_true(left): 193 return right 194 if always_true(right): 195 return left 196 return _simplify_comparison(expression, left, right) 197 elif isinstance(expression, exp.Or): 198 if always_true(left) or always_true(right): 199 return exp.true() 200 if is_false(left) and is_false(right): 201 return exp.false() 202 if ( 203 (is_null(left) and is_null(right)) 204 or (is_null(left) and is_false(right)) 205 or (is_false(left) and is_null(right)) 206 ): 207 return exp.null() 208 if is_false(left): 209 return right 210 if is_false(right): 211 return left 212 return _simplify_comparison(expression, left, right, or_=True) 213 214 if isinstance(expression, exp.Connector): 215 return _flat_simplify(expression, _simplify_connectors, root) 216 return expression
299def remove_complements(expression, root=True): 300 """ 301 Removing complements. 302 303 A AND NOT A -> FALSE 304 A OR NOT A -> TRUE 305 """ 306 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 307 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 308 309 for a, b in itertools.permutations(expression.flatten(), 2): 310 if is_complement(a, b): 311 return complement 312 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
315def uniq_sort(expression, generate, root=True): 316 """ 317 Uniq and sort a connector. 318 319 C AND A AND B AND B -> A AND B AND C 320 """ 321 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 322 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 323 flattened = tuple(expression.flatten()) 324 deduped = {generate(e): e for e in flattened} 325 arr = tuple(deduped.items()) 326 327 # check if the operands are already sorted, if not sort them 328 # A AND C AND B -> A AND B AND C 329 for i, (sql, e) in enumerate(arr[1:]): 330 if sql < arr[i][0]: 331 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 332 break 333 else: 334 # we didn't have to sort but maybe we need to dedup 335 if len(deduped) < len(flattened): 336 expression = result_func(*deduped.values(), copy=False) 337 338 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
341def absorb_and_eliminate(expression, root=True): 342 """ 343 absorption: 344 A AND (A OR B) -> A 345 A OR (A AND B) -> A 346 A AND (NOT A OR B) -> A AND B 347 A OR (NOT A AND B) -> A OR B 348 elimination: 349 (A AND B) OR (A AND NOT B) -> A 350 (A OR B) AND (A OR NOT B) -> A 351 """ 352 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 353 kind = exp.Or if isinstance(expression, exp.And) else exp.And 354 355 for a, b in itertools.permutations(expression.flatten(), 2): 356 if isinstance(a, kind): 357 aa, ab = a.unnest_operands() 358 359 # absorb 360 if is_complement(b, aa): 361 aa.replace(exp.true() if kind == exp.And else exp.false()) 362 elif is_complement(b, ab): 363 ab.replace(exp.true() if kind == exp.And else exp.false()) 364 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 365 a.replace(exp.false() if kind == exp.And else exp.true()) 366 elif isinstance(b, kind): 367 # eliminate 368 rhs = b.unnest_operands() 369 ba, bb = rhs 370 371 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 372 a.replace(aa) 373 b.replace(aa) 374 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 375 a.replace(ab) 376 b.replace(ab) 377 378 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
381def propagate_constants(expression, root=True): 382 """ 383 Propagate constants for conjunctions in DNF: 384 385 SELECT * FROM t WHERE a = b AND b = 5 becomes 386 SELECT * FROM t WHERE a = 5 AND b = 5 387 388 Reference: https://www.sqlite.org/optoverview.html 389 """ 390 391 if ( 392 isinstance(expression, exp.And) 393 and (root or not expression.same_parent) 394 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 395 ): 396 constant_mapping = {} 397 for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)): 398 if isinstance(expr, exp.EQ): 399 l, r = expr.left, expr.right 400 401 # TODO: create a helper that can be used to detect nested literal expressions such 402 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 403 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 404 pass 405 elif isinstance(r, exp.Column) and isinstance(l, exp.Literal): 406 l, r = r, l 407 else: 408 continue 409 410 constant_mapping[l] = (id(l), r) 411 412 if constant_mapping: 413 for column in find_all_in_scope(expression, exp.Column): 414 parent = column.parent 415 column_id, constant = constant_mapping.get(column) or (None, None) 416 if ( 417 column_id is not None 418 and id(column) != column_id 419 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 420 ): 421 column.replace(constant.copy()) 422 423 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
107 def wrapped(expression, *args, **kwargs): 108 try: 109 return func(expression, *args, **kwargs) 110 except exceptions: 111 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
502def simplify_literals(expression, root=True): 503 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 504 return _flat_simplify(expression, _simplify_binary, root) 505 506 if isinstance(expression, exp.Neg): 507 this = expression.this 508 if this.is_number: 509 value = this.name 510 if value[0] == "-": 511 return exp.Literal.number(value[1:]) 512 return exp.Literal.number(f"-{value}") 513 514 return expression
577def simplify_parens(expression): 578 if not isinstance(expression, exp.Paren): 579 return expression 580 581 this = expression.this 582 parent = expression.parent 583 584 if not isinstance(this, exp.Select) and ( 585 not isinstance(parent, (exp.Condition, exp.Binary)) 586 or isinstance(parent, exp.Paren) 587 or not isinstance(this, exp.Binary) 588 or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate)) 589 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 590 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 591 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 592 ): 593 return this 594 return expression
604def simplify_coalesce(expression): 605 # COALESCE(x) -> x 606 if ( 607 isinstance(expression, exp.Coalesce) 608 and not expression.expressions 609 # COALESCE is also used as a Spark partitioning hint 610 and not isinstance(expression.parent, exp.Hint) 611 ): 612 return expression.this 613 614 if not isinstance(expression, COMPARISONS): 615 return expression 616 617 if isinstance(expression.left, exp.Coalesce): 618 coalesce = expression.left 619 other = expression.right 620 elif isinstance(expression.right, exp.Coalesce): 621 coalesce = expression.right 622 other = expression.left 623 else: 624 return expression 625 626 # This transformation is valid for non-constants, 627 # but it really only does anything if they are both constants. 628 if not isinstance(other, CONSTANTS): 629 return expression 630 631 # Find the first constant arg 632 for arg_index, arg in enumerate(coalesce.expressions): 633 if isinstance(arg, CONSTANTS): 634 break 635 else: 636 return expression 637 638 coalesce.set("expressions", coalesce.expressions[:arg_index]) 639 640 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 641 # since we already remove COALESCE at the top of this function. 642 coalesce = coalesce if coalesce.expressions else coalesce.this 643 644 # This expression is more complex than when we started, but it will get simplified further 645 return exp.paren( 646 exp.or_( 647 exp.and_( 648 coalesce.is_(exp.null()).not_(copy=False), 649 expression.copy(), 650 copy=False, 651 ), 652 exp.and_( 653 coalesce.is_(exp.null()), 654 type(expression)(this=arg.copy(), expression=other.copy()), 655 copy=False, 656 ), 657 copy=False, 658 ) 659 )
666def simplify_concat(expression): 667 """Reduces all groups that contain string literals by concatenating them.""" 668 if not isinstance(expression, CONCATS) or ( 669 # We can't reduce a CONCAT_WS call if we don't statically know the separator 670 isinstance(expression, exp.ConcatWs) 671 and not expression.expressions[0].is_string 672 ): 673 return expression 674 675 if isinstance(expression, exp.ConcatWs): 676 sep_expr, *expressions = expression.expressions 677 sep = sep_expr.name 678 concat_type = exp.ConcatWs 679 else: 680 expressions = expression.expressions 681 sep = "" 682 concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat 683 684 new_args = [] 685 for is_string_group, group in itertools.groupby( 686 expressions or expression.flatten(), lambda e: e.is_string 687 ): 688 if is_string_group: 689 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 690 else: 691 new_args.extend(group) 692 693 if len(new_args) == 1 and new_args[0].is_string: 694 return new_args[0] 695 696 if concat_type is exp.ConcatWs: 697 new_args = [sep_expr] + new_args 698 699 return concat_type(expressions=new_args)
Reduces all groups that contain string literals by concatenating them.
702def simplify_conditionals(expression): 703 """Simplifies expressions like IF, CASE if their condition is statically known.""" 704 if isinstance(expression, exp.Case): 705 this = expression.this 706 for case in expression.args["ifs"]: 707 cond = case.this 708 if this: 709 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 710 cond = cond.replace(this.pop().eq(cond)) 711 712 if always_true(cond): 713 return case.args["true"] 714 715 if always_false(cond): 716 case.pop() 717 if not expression.args["ifs"]: 718 return expression.args.get("default") or exp.null() 719 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 720 if always_true(expression.this): 721 return expression.args["true"] 722 if always_false(expression.this): 723 return expression.args.get("false") or exp.null() 724 725 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
107 def wrapped(expression, *args, **kwargs): 108 try: 109 return func(expression, *args, **kwargs) 110 except exceptions: 111 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
865def remove_where_true(expression): 866 for where in expression.find_all(exp.Where): 867 if always_true(where.this): 868 where.parent.set("where", None) 869 for join in expression.find_all(exp.Join): 870 if ( 871 always_true(join.args.get("on")) 872 and not join.args.get("using") 873 and not join.args.get("method") 874 and (join.side, join.kind) in JOINS 875 ): 876 join.set("on", None) 877 join.set("side", None) 878 join.set("kind", "CROSS")
903def eval_boolean(expression, a, b): 904 if isinstance(expression, (exp.EQ, exp.Is)): 905 return boolean_literal(a == b) 906 if isinstance(expression, exp.NEQ): 907 return boolean_literal(a != b) 908 if isinstance(expression, exp.GT): 909 return boolean_literal(a > b) 910 if isinstance(expression, exp.GTE): 911 return boolean_literal(a >= b) 912 if isinstance(expression, exp.LT): 913 return boolean_literal(a < b) 914 if isinstance(expression, exp.LTE): 915 return boolean_literal(a <= b) 916 return None
930def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 931 if isinstance(value, datetime.datetime): 932 return value 933 if isinstance(value, datetime.date): 934 return datetime.datetime(year=value.year, month=value.month, day=value.day) 935 try: 936 return datetime.datetime.fromisoformat(value) 937 except ValueError: 938 return None
941def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 942 if not value: 943 return None 944 if to.is_type(exp.DataType.Type.DATE): 945 return cast_as_date(value) 946 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 947 return cast_as_datetime(value) 948 return None
951def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 952 if isinstance(cast, exp.Cast): 953 to = cast.to 954 elif isinstance(cast, exp.TsOrDsToDate): 955 to = exp.DataType.build(exp.DataType.Type.DATE) 956 else: 957 return None 958 959 if isinstance(cast.this, exp.Literal): 960 value: t.Any = cast.this.name 961 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 962 value = extract_date(cast.this) 963 else: 964 return None 965 return cast_value(value, to)
991def interval(unit: str, n: int = 1): 992 from dateutil.relativedelta import relativedelta 993 994 if unit == "year": 995 return relativedelta(years=1 * n) 996 if unit == "quarter": 997 return relativedelta(months=3 * n) 998 if unit == "month": 999 return relativedelta(months=1 * n) 1000 if unit == "week": 1001 return relativedelta(weeks=1 * n) 1002 if unit == "day": 1003 return relativedelta(days=1 * n) 1004 if unit == "hour": 1005 return relativedelta(hours=1 * n) 1006 if unit == "minute": 1007 return relativedelta(minutes=1 * n) 1008 if unit == "second": 1009 return relativedelta(seconds=1 * n) 1010 1011 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1014def date_floor(d: datetime.date, unit: str) -> datetime.date: 1015 if unit == "year": 1016 return d.replace(month=1, day=1) 1017 if unit == "quarter": 1018 if d.month <= 3: 1019 return d.replace(month=1, day=1) 1020 elif d.month <= 6: 1021 return d.replace(month=4, day=1) 1022 elif d.month <= 9: 1023 return d.replace(month=7, day=1) 1024 else: 1025 return d.replace(month=10, day=1) 1026 if unit == "month": 1027 return d.replace(month=d.month, day=1) 1028 if unit == "week": 1029 # Assuming week starts on Monday (0) and ends on Sunday (6) 1030 return d - datetime.timedelta(days=d.weekday()) 1031 if unit == "day": 1032 return d 1033 1034 raise UnsupportedUnit(f"Unsupported unit: {unit}")