sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import functools 5import itertools 6import typing as t 7from collections import deque 8from decimal import Decimal 9from functools import reduce 10 11import sqlglot 12from sqlglot import Dialect, exp 13from sqlglot.helper import first, merge_ranges, while_changing 14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 15 16if t.TYPE_CHECKING: 17 from sqlglot.dialects.dialect import DialectType 18 19 DateTruncBinaryTransform = t.Callable[ 20 [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression] 21 ] 22 23# Final means that an expression should not be simplified 24FINAL = "final" 25 26# Value ranges for byte-sized signed/unsigned integers 27TINYINT_MIN = -128 28TINYINT_MAX = 127 29UTINYINT_MIN = 0 30UTINYINT_MAX = 255 31 32 33class UnsupportedUnit(Exception): 34 pass 35 36 37def simplify( 38 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 39): 40 """ 41 Rewrite sqlglot AST to simplify expressions. 42 43 Example: 44 >>> import sqlglot 45 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 46 >>> simplify(expression).sql() 47 'TRUE' 48 49 Args: 50 expression (sqlglot.Expression): expression to simplify 51 constant_propagation: whether the constant propagation rule should be used 52 53 Returns: 54 sqlglot.Expression: simplified expression 55 """ 56 57 dialect = Dialect.get_or_raise(dialect) 58 59 def _simplify(expression, root=True): 60 if expression.meta.get(FINAL): 61 return expression 62 63 # group by expressions cannot be simplified, for example 64 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 65 # the projection must exactly match the group by key 66 group = expression.args.get("group") 67 68 if group and hasattr(expression, "selects"): 69 groups = set(group.expressions) 70 group.meta[FINAL] = True 71 72 for e in expression.selects: 73 for node in e.walk(): 74 if node in groups: 75 e.meta[FINAL] = True 76 break 77 78 having = expression.args.get("having") 79 if having: 80 for node in having.walk(): 81 if node in groups: 82 having.meta[FINAL] = True 83 break 84 85 # Pre-order transformations 86 node = expression 87 node = rewrite_between(node) 88 node = uniq_sort(node, root) 89 node = absorb_and_eliminate(node, root) 90 node = simplify_concat(node) 91 node = simplify_conditionals(node) 92 93 if constant_propagation: 94 node = propagate_constants(node, root) 95 96 exp.replace_children(node, lambda e: _simplify(e, False)) 97 98 # Post-order transformations 99 node = simplify_not(node) 100 node = flatten(node) 101 node = simplify_connectors(node, root) 102 node = remove_complements(node, root) 103 node = simplify_coalesce(node) 104 node.parent = expression.parent 105 node = simplify_literals(node, root) 106 node = simplify_equality(node) 107 node = simplify_parens(node) 108 node = simplify_datetrunc(node, dialect) 109 node = sort_comparison(node) 110 node = simplify_startswith(node) 111 112 if root: 113 expression.replace(node) 114 return node 115 116 expression = while_changing(expression, _simplify) 117 remove_where_true(expression) 118 return expression 119 120 121def catch(*exceptions): 122 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 123 124 def decorator(func): 125 def wrapped(expression, *args, **kwargs): 126 try: 127 return func(expression, *args, **kwargs) 128 except exceptions: 129 return expression 130 131 return wrapped 132 133 return decorator 134 135 136def rewrite_between(expression: exp.Expression) -> exp.Expression: 137 """Rewrite x between y and z to x >= y AND x <= z. 138 139 This is done because comparison simplification is only done on lt/lte/gt/gte. 140 """ 141 if isinstance(expression, exp.Between): 142 negate = isinstance(expression.parent, exp.Not) 143 144 expression = exp.and_( 145 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 146 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 147 copy=False, 148 ) 149 150 if negate: 151 expression = exp.paren(expression, copy=False) 152 153 return expression 154 155 156COMPLEMENT_COMPARISONS = { 157 exp.LT: exp.GTE, 158 exp.GT: exp.LTE, 159 exp.LTE: exp.GT, 160 exp.GTE: exp.LT, 161 exp.EQ: exp.NEQ, 162 exp.NEQ: exp.EQ, 163} 164 165 166def simplify_not(expression): 167 """ 168 Demorgan's Law 169 NOT (x OR y) -> NOT x AND NOT y 170 NOT (x AND y) -> NOT x OR NOT y 171 """ 172 if isinstance(expression, exp.Not): 173 this = expression.this 174 if is_null(this): 175 return exp.null() 176 if this.__class__ in COMPLEMENT_COMPARISONS: 177 return COMPLEMENT_COMPARISONS[this.__class__]( 178 this=this.this, expression=this.expression 179 ) 180 if isinstance(this, exp.Paren): 181 condition = this.unnest() 182 if isinstance(condition, exp.And): 183 return exp.paren( 184 exp.or_( 185 exp.not_(condition.left, copy=False), 186 exp.not_(condition.right, copy=False), 187 copy=False, 188 ) 189 ) 190 if isinstance(condition, exp.Or): 191 return exp.paren( 192 exp.and_( 193 exp.not_(condition.left, copy=False), 194 exp.not_(condition.right, copy=False), 195 copy=False, 196 ) 197 ) 198 if is_null(condition): 199 return exp.null() 200 if always_true(this): 201 return exp.false() 202 if is_false(this): 203 return exp.true() 204 if isinstance(this, exp.Not): 205 # double negation 206 # NOT NOT x -> x 207 return this.this 208 return expression 209 210 211def flatten(expression): 212 """ 213 A AND (B AND C) -> A AND B AND C 214 A OR (B OR C) -> A OR B OR C 215 """ 216 if isinstance(expression, exp.Connector): 217 for node in expression.args.values(): 218 child = node.unnest() 219 if isinstance(child, expression.__class__): 220 node.replace(child) 221 return expression 222 223 224def simplify_connectors(expression, root=True): 225 def _simplify_connectors(expression, left, right): 226 if left == right: 227 if isinstance(expression, exp.Xor): 228 return exp.false() 229 return left 230 if isinstance(expression, exp.And): 231 if is_false(left) or is_false(right): 232 return exp.false() 233 if is_null(left) or is_null(right): 234 return exp.null() 235 if always_true(left) and always_true(right): 236 return exp.true() 237 if always_true(left): 238 return right 239 if always_true(right): 240 return left 241 return _simplify_comparison(expression, left, right) 242 elif isinstance(expression, exp.Or): 243 if always_true(left) or always_true(right): 244 return exp.true() 245 if is_false(left) and is_false(right): 246 return exp.false() 247 if ( 248 (is_null(left) and is_null(right)) 249 or (is_null(left) and is_false(right)) 250 or (is_false(left) and is_null(right)) 251 ): 252 return exp.null() 253 if is_false(left): 254 return right 255 if is_false(right): 256 return left 257 return _simplify_comparison(expression, left, right, or_=True) 258 259 if isinstance(expression, exp.Connector): 260 return _flat_simplify(expression, _simplify_connectors, root) 261 return expression 262 263 264LT_LTE = (exp.LT, exp.LTE) 265GT_GTE = (exp.GT, exp.GTE) 266 267COMPARISONS = ( 268 *LT_LTE, 269 *GT_GTE, 270 exp.EQ, 271 exp.NEQ, 272 exp.Is, 273) 274 275INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 276 exp.LT: exp.GT, 277 exp.GT: exp.LT, 278 exp.LTE: exp.GTE, 279 exp.GTE: exp.LTE, 280} 281 282NONDETERMINISTIC = (exp.Rand, exp.Randn) 283 284 285def _simplify_comparison(expression, left, right, or_=False): 286 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 287 ll, lr = left.args.values() 288 rl, rr = right.args.values() 289 290 largs = {ll, lr} 291 rargs = {rl, rr} 292 293 matching = largs & rargs 294 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 295 296 if matching and columns: 297 try: 298 l = first(largs - columns) 299 r = first(rargs - columns) 300 except StopIteration: 301 return expression 302 303 if l.is_number and r.is_number: 304 l = float(l.name) 305 r = float(r.name) 306 elif l.is_string and r.is_string: 307 l = l.name 308 r = r.name 309 else: 310 l = extract_date(l) 311 if not l: 312 return None 313 r = extract_date(r) 314 if not r: 315 return None 316 # python won't compare date and datetime, but many engines will upcast 317 l, r = cast_as_datetime(l), cast_as_datetime(r) 318 319 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 320 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 321 return left if (av > bv if or_ else av <= bv) else right 322 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 323 return left if (av < bv if or_ else av >= bv) else right 324 325 # we can't ever shortcut to true because the column could be null 326 if not or_: 327 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 328 if av <= bv: 329 return exp.false() 330 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 331 if av >= bv: 332 return exp.false() 333 elif isinstance(a, exp.EQ): 334 if isinstance(b, exp.LT): 335 return exp.false() if av >= bv else a 336 if isinstance(b, exp.LTE): 337 return exp.false() if av > bv else a 338 if isinstance(b, exp.GT): 339 return exp.false() if av <= bv else a 340 if isinstance(b, exp.GTE): 341 return exp.false() if av < bv else a 342 if isinstance(b, exp.NEQ): 343 return exp.false() if av == bv else a 344 return None 345 346 347def remove_complements(expression, root=True): 348 """ 349 Removing complements. 350 351 A AND NOT A -> FALSE 352 A OR NOT A -> TRUE 353 """ 354 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 355 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 356 357 for a, b in itertools.permutations(expression.flatten(), 2): 358 if is_complement(a, b): 359 return complement 360 return expression 361 362 363def uniq_sort(expression, root=True): 364 """ 365 Uniq and sort a connector. 366 367 C AND A AND B AND B -> A AND B AND C 368 """ 369 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 370 flattened = tuple(expression.flatten()) 371 372 if isinstance(expression, exp.Xor): 373 result_func = exp.xor 374 # Do not deduplicate XOR as A XOR A != A if A == True 375 deduped = None 376 arr = tuple((gen(e), e) for e in flattened) 377 else: 378 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 379 deduped = {gen(e): e for e in flattened} 380 arr = tuple(deduped.items()) 381 382 # check if the operands are already sorted, if not sort them 383 # A AND C AND B -> A AND B AND C 384 for i, (sql, e) in enumerate(arr[1:]): 385 if sql < arr[i][0]: 386 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 387 break 388 else: 389 # we didn't have to sort but maybe we need to dedup 390 if deduped and len(deduped) < len(flattened): 391 expression = result_func(*deduped.values(), copy=False) 392 393 return expression 394 395 396def absorb_and_eliminate(expression, root=True): 397 """ 398 absorption: 399 A AND (A OR B) -> A 400 A OR (A AND B) -> A 401 A AND (NOT A OR B) -> A AND B 402 A OR (NOT A AND B) -> A OR B 403 elimination: 404 (A AND B) OR (A AND NOT B) -> A 405 (A OR B) AND (A OR NOT B) -> A 406 """ 407 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 408 kind = exp.Or if isinstance(expression, exp.And) else exp.And 409 410 for a, b in itertools.permutations(expression.flatten(), 2): 411 if isinstance(a, kind): 412 aa, ab = a.unnest_operands() 413 414 # absorb 415 if is_complement(b, aa): 416 aa.replace(exp.true() if kind == exp.And else exp.false()) 417 elif is_complement(b, ab): 418 ab.replace(exp.true() if kind == exp.And else exp.false()) 419 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 420 a.replace(exp.false() if kind == exp.And else exp.true()) 421 elif isinstance(b, kind): 422 # eliminate 423 rhs = b.unnest_operands() 424 ba, bb = rhs 425 426 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 427 a.replace(aa) 428 b.replace(aa) 429 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 430 a.replace(ab) 431 b.replace(ab) 432 433 return expression 434 435 436def propagate_constants(expression, root=True): 437 """ 438 Propagate constants for conjunctions in DNF: 439 440 SELECT * FROM t WHERE a = b AND b = 5 becomes 441 SELECT * FROM t WHERE a = 5 AND b = 5 442 443 Reference: https://www.sqlite.org/optoverview.html 444 """ 445 446 if ( 447 isinstance(expression, exp.And) 448 and (root or not expression.same_parent) 449 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 450 ): 451 constant_mapping = {} 452 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 453 if isinstance(expr, exp.EQ): 454 l, r = expr.left, expr.right 455 456 # TODO: create a helper that can be used to detect nested literal expressions such 457 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 458 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 459 constant_mapping[l] = (id(l), r) 460 461 if constant_mapping: 462 for column in find_all_in_scope(expression, exp.Column): 463 parent = column.parent 464 column_id, constant = constant_mapping.get(column) or (None, None) 465 if ( 466 column_id is not None 467 and id(column) != column_id 468 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 469 ): 470 column.replace(constant.copy()) 471 472 return expression 473 474 475INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 476 exp.DateAdd: exp.Sub, 477 exp.DateSub: exp.Add, 478 exp.DatetimeAdd: exp.Sub, 479 exp.DatetimeSub: exp.Add, 480} 481 482INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 483 **INVERSE_DATE_OPS, 484 exp.Add: exp.Sub, 485 exp.Sub: exp.Add, 486} 487 488 489def _is_number(expression: exp.Expression) -> bool: 490 return expression.is_number 491 492 493def _is_interval(expression: exp.Expression) -> bool: 494 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 495 496 497@catch(ModuleNotFoundError, UnsupportedUnit) 498def simplify_equality(expression: exp.Expression) -> exp.Expression: 499 """ 500 Use the subtraction and addition properties of equality to simplify expressions: 501 502 x + 1 = 3 becomes x = 2 503 504 There are two binary operations in the above expression: + and = 505 Here's how we reference all the operands in the code below: 506 507 l r 508 x + 1 = 3 509 a b 510 """ 511 if isinstance(expression, COMPARISONS): 512 l, r = expression.left, expression.right 513 514 if l.__class__ not in INVERSE_OPS: 515 return expression 516 517 if r.is_number: 518 a_predicate = _is_number 519 b_predicate = _is_number 520 elif _is_date_literal(r): 521 a_predicate = _is_date_literal 522 b_predicate = _is_interval 523 else: 524 return expression 525 526 if l.__class__ in INVERSE_DATE_OPS: 527 l = t.cast(exp.IntervalOp, l) 528 a = l.this 529 b = l.interval() 530 else: 531 l = t.cast(exp.Binary, l) 532 a, b = l.left, l.right 533 534 if not a_predicate(a) and b_predicate(b): 535 pass 536 elif not a_predicate(b) and b_predicate(a): 537 a, b = b, a 538 else: 539 return expression 540 541 return expression.__class__( 542 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 543 ) 544 return expression 545 546 547def simplify_literals(expression, root=True): 548 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 549 return _flat_simplify(expression, _simplify_binary, root) 550 551 if isinstance(expression, exp.Neg): 552 this = expression.this 553 if this.is_number: 554 value = this.name 555 if value[0] == "-": 556 return exp.Literal.number(value[1:]) 557 return exp.Literal.number(f"-{value}") 558 559 if type(expression) in INVERSE_DATE_OPS: 560 return _simplify_binary(expression, expression.this, expression.interval()) or expression 561 562 return expression 563 564 565NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 566 567 568def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression: 569 if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast): 570 this = _simplify_integer_cast(expr.this) 571 else: 572 this = expr.this 573 574 if isinstance(expr, exp.Cast) and this.is_int: 575 num = int(this.name) 576 577 # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any 578 # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is 579 # engine-dependent 580 if ( 581 TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES 582 ) or ( 583 UTINYINT_MIN <= num <= UTINYINT_MAX 584 and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES 585 ): 586 return this 587 588 return expr 589 590 591def _simplify_binary(expression, a, b): 592 if isinstance(expression, COMPARISONS): 593 a = _simplify_integer_cast(a) 594 b = _simplify_integer_cast(b) 595 596 if isinstance(expression, exp.Is): 597 if isinstance(b, exp.Not): 598 c = b.this 599 not_ = True 600 else: 601 c = b 602 not_ = False 603 604 if is_null(c): 605 if isinstance(a, exp.Literal): 606 return exp.true() if not_ else exp.false() 607 if is_null(a): 608 return exp.false() if not_ else exp.true() 609 elif isinstance(expression, NULL_OK): 610 return None 611 elif is_null(a) or is_null(b): 612 return exp.null() 613 614 if a.is_number and b.is_number: 615 num_a = int(a.name) if a.is_int else Decimal(a.name) 616 num_b = int(b.name) if b.is_int else Decimal(b.name) 617 618 if isinstance(expression, exp.Add): 619 return exp.Literal.number(num_a + num_b) 620 if isinstance(expression, exp.Mul): 621 return exp.Literal.number(num_a * num_b) 622 623 # We only simplify Sub, Div if a and b have the same parent because they're not associative 624 if isinstance(expression, exp.Sub): 625 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 626 if isinstance(expression, exp.Div): 627 # engines have differing int div behavior so intdiv is not safe 628 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 629 return None 630 return exp.Literal.number(num_a / num_b) 631 632 boolean = eval_boolean(expression, num_a, num_b) 633 634 if boolean: 635 return boolean 636 elif a.is_string and b.is_string: 637 boolean = eval_boolean(expression, a.this, b.this) 638 639 if boolean: 640 return boolean 641 elif _is_date_literal(a) and isinstance(b, exp.Interval): 642 date, b = extract_date(a), extract_interval(b) 643 if date and b: 644 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 645 return date_literal(date + b, extract_type(a)) 646 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 647 return date_literal(date - b, extract_type(a)) 648 elif isinstance(a, exp.Interval) and _is_date_literal(b): 649 a, date = extract_interval(a), extract_date(b) 650 # you cannot subtract a date from an interval 651 if a and b and isinstance(expression, exp.Add): 652 return date_literal(a + date, extract_type(b)) 653 elif _is_date_literal(a) and _is_date_literal(b): 654 if isinstance(expression, exp.Predicate): 655 a, b = extract_date(a), extract_date(b) 656 boolean = eval_boolean(expression, a, b) 657 if boolean: 658 return boolean 659 660 return None 661 662 663def simplify_parens(expression): 664 if not isinstance(expression, exp.Paren): 665 return expression 666 667 this = expression.this 668 parent = expression.parent 669 parent_is_predicate = isinstance(parent, exp.Predicate) 670 671 if ( 672 not isinstance(this, exp.Select) 673 and not isinstance(parent, exp.SubqueryPredicate) 674 and ( 675 not isinstance(parent, (exp.Condition, exp.Binary)) 676 or isinstance(parent, exp.Paren) 677 or ( 678 not isinstance(this, exp.Binary) 679 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 680 ) 681 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 682 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 683 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 684 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 685 ) 686 ): 687 return this 688 return expression 689 690 691def _is_nonnull_constant(expression: exp.Expression) -> bool: 692 return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression) 693 694 695def _is_constant(expression: exp.Expression) -> bool: 696 return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression) 697 698 699def simplify_coalesce(expression): 700 # COALESCE(x) -> x 701 if ( 702 isinstance(expression, exp.Coalesce) 703 and (not expression.expressions or _is_nonnull_constant(expression.this)) 704 # COALESCE is also used as a Spark partitioning hint 705 and not isinstance(expression.parent, exp.Hint) 706 ): 707 return expression.this 708 709 if not isinstance(expression, COMPARISONS): 710 return expression 711 712 if isinstance(expression.left, exp.Coalesce): 713 coalesce = expression.left 714 other = expression.right 715 elif isinstance(expression.right, exp.Coalesce): 716 coalesce = expression.right 717 other = expression.left 718 else: 719 return expression 720 721 # This transformation is valid for non-constants, 722 # but it really only does anything if they are both constants. 723 if not _is_constant(other): 724 return expression 725 726 # Find the first constant arg 727 for arg_index, arg in enumerate(coalesce.expressions): 728 if _is_constant(arg): 729 break 730 else: 731 return expression 732 733 coalesce.set("expressions", coalesce.expressions[:arg_index]) 734 735 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 736 # since we already remove COALESCE at the top of this function. 737 coalesce = coalesce if coalesce.expressions else coalesce.this 738 739 # This expression is more complex than when we started, but it will get simplified further 740 return exp.paren( 741 exp.or_( 742 exp.and_( 743 coalesce.is_(exp.null()).not_(copy=False), 744 expression.copy(), 745 copy=False, 746 ), 747 exp.and_( 748 coalesce.is_(exp.null()), 749 type(expression)(this=arg.copy(), expression=other.copy()), 750 copy=False, 751 ), 752 copy=False, 753 ) 754 ) 755 756 757CONCATS = (exp.Concat, exp.DPipe) 758 759 760def simplify_concat(expression): 761 """Reduces all groups that contain string literals by concatenating them.""" 762 if not isinstance(expression, CONCATS) or ( 763 # We can't reduce a CONCAT_WS call if we don't statically know the separator 764 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 765 ): 766 return expression 767 768 if isinstance(expression, exp.ConcatWs): 769 sep_expr, *expressions = expression.expressions 770 sep = sep_expr.name 771 concat_type = exp.ConcatWs 772 args = {} 773 else: 774 expressions = expression.expressions 775 sep = "" 776 concat_type = exp.Concat 777 args = { 778 "safe": expression.args.get("safe"), 779 "coalesce": expression.args.get("coalesce"), 780 } 781 782 new_args = [] 783 for is_string_group, group in itertools.groupby( 784 expressions or expression.flatten(), lambda e: e.is_string 785 ): 786 if is_string_group: 787 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 788 else: 789 new_args.extend(group) 790 791 if len(new_args) == 1 and new_args[0].is_string: 792 return new_args[0] 793 794 if concat_type is exp.ConcatWs: 795 new_args = [sep_expr] + new_args 796 elif isinstance(expression, exp.DPipe): 797 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 798 799 return concat_type(expressions=new_args, **args) 800 801 802def simplify_conditionals(expression): 803 """Simplifies expressions like IF, CASE if their condition is statically known.""" 804 if isinstance(expression, exp.Case): 805 this = expression.this 806 for case in expression.args["ifs"]: 807 cond = case.this 808 if this: 809 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 810 cond = cond.replace(this.pop().eq(cond)) 811 812 if always_true(cond): 813 return case.args["true"] 814 815 if always_false(cond): 816 case.pop() 817 if not expression.args["ifs"]: 818 return expression.args.get("default") or exp.null() 819 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 820 if always_true(expression.this): 821 return expression.args["true"] 822 if always_false(expression.this): 823 return expression.args.get("false") or exp.null() 824 825 return expression 826 827 828def simplify_startswith(expression: exp.Expression) -> exp.Expression: 829 """ 830 Reduces a prefix check to either TRUE or FALSE if both the string and the 831 prefix are statically known. 832 833 Example: 834 >>> from sqlglot import parse_one 835 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 836 'TRUE' 837 """ 838 if ( 839 isinstance(expression, exp.StartsWith) 840 and expression.this.is_string 841 and expression.expression.is_string 842 ): 843 return exp.convert(expression.name.startswith(expression.expression.name)) 844 845 return expression 846 847 848DateRange = t.Tuple[datetime.date, datetime.date] 849 850 851def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 852 """ 853 Get the date range for a DATE_TRUNC equality comparison: 854 855 Example: 856 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 857 Returns: 858 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 859 """ 860 floor = date_floor(date, unit, dialect) 861 862 if date != floor: 863 # This will always be False, except for NULL values. 864 return None 865 866 return floor, floor + interval(unit) 867 868 869def _datetrunc_eq_expression( 870 left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType] 871) -> exp.Expression: 872 """Get the logical expression for a date range""" 873 return exp.and_( 874 left >= date_literal(drange[0], target_type), 875 left < date_literal(drange[1], target_type), 876 copy=False, 877 ) 878 879 880def _datetrunc_eq( 881 left: exp.Expression, 882 date: datetime.date, 883 unit: str, 884 dialect: Dialect, 885 target_type: t.Optional[exp.DataType], 886) -> t.Optional[exp.Expression]: 887 drange = _datetrunc_range(date, unit, dialect) 888 if not drange: 889 return None 890 891 return _datetrunc_eq_expression(left, drange, target_type) 892 893 894def _datetrunc_neq( 895 left: exp.Expression, 896 date: datetime.date, 897 unit: str, 898 dialect: Dialect, 899 target_type: t.Optional[exp.DataType], 900) -> t.Optional[exp.Expression]: 901 drange = _datetrunc_range(date, unit, dialect) 902 if not drange: 903 return None 904 905 return exp.and_( 906 left < date_literal(drange[0], target_type), 907 left >= date_literal(drange[1], target_type), 908 copy=False, 909 ) 910 911 912DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 913 exp.LT: lambda l, dt, u, d, t: l 914 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t), 915 exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t), 916 exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t), 917 exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t), 918 exp.EQ: _datetrunc_eq, 919 exp.NEQ: _datetrunc_neq, 920} 921DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 922DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 923 924 925def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 926 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 927 928 929@catch(ModuleNotFoundError, UnsupportedUnit) 930def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 931 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 932 comparison = expression.__class__ 933 934 if isinstance(expression, DATETRUNCS): 935 this = expression.this 936 trunc_type = extract_type(this) 937 date = extract_date(this) 938 if date and expression.unit: 939 return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type) 940 elif comparison not in DATETRUNC_COMPARISONS: 941 return expression 942 943 if isinstance(expression, exp.Binary): 944 l, r = expression.left, expression.right 945 946 if not _is_datetrunc_predicate(l, r): 947 return expression 948 949 l = t.cast(exp.DateTrunc, l) 950 trunc_arg = l.this 951 unit = l.unit.name.lower() 952 date = extract_date(r) 953 954 if not date: 955 return expression 956 957 return ( 958 DATETRUNC_BINARY_COMPARISONS[comparison]( 959 trunc_arg, date, unit, dialect, extract_type(trunc_arg, r) 960 ) 961 or expression 962 ) 963 964 if isinstance(expression, exp.In): 965 l = expression.this 966 rs = expression.expressions 967 968 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 969 l = t.cast(exp.DateTrunc, l) 970 unit = l.unit.name.lower() 971 972 ranges = [] 973 for r in rs: 974 date = extract_date(r) 975 if not date: 976 return expression 977 drange = _datetrunc_range(date, unit, dialect) 978 if drange: 979 ranges.append(drange) 980 981 if not ranges: 982 return expression 983 984 ranges = merge_ranges(ranges) 985 target_type = extract_type(l, *rs) 986 987 return exp.or_( 988 *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False 989 ) 990 991 return expression 992 993 994def sort_comparison(expression: exp.Expression) -> exp.Expression: 995 if expression.__class__ in COMPLEMENT_COMPARISONS: 996 l, r = expression.this, expression.expression 997 l_column = isinstance(l, exp.Column) 998 r_column = isinstance(r, exp.Column) 999 l_const = _is_constant(l) 1000 r_const = _is_constant(r) 1001 1002 if (l_column and not r_column) or (r_const and not l_const): 1003 return expression 1004 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1005 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1006 this=r, expression=l 1007 ) 1008 return expression 1009 1010 1011# CROSS joins result in an empty table if the right table is empty. 1012# So we can only simplify certain types of joins to CROSS. 1013# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 1014JOINS = { 1015 ("", ""), 1016 ("", "INNER"), 1017 ("RIGHT", ""), 1018 ("RIGHT", "OUTER"), 1019} 1020 1021 1022def remove_where_true(expression): 1023 for where in expression.find_all(exp.Where): 1024 if always_true(where.this): 1025 where.pop() 1026 for join in expression.find_all(exp.Join): 1027 if ( 1028 always_true(join.args.get("on")) 1029 and not join.args.get("using") 1030 and not join.args.get("method") 1031 and (join.side, join.kind) in JOINS 1032 ): 1033 join.args["on"].pop() 1034 join.set("side", None) 1035 join.set("kind", "CROSS") 1036 1037 1038def always_true(expression): 1039 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 1040 expression, exp.Literal 1041 ) 1042 1043 1044def always_false(expression): 1045 return is_false(expression) or is_null(expression) 1046 1047 1048def is_complement(a, b): 1049 return isinstance(b, exp.Not) and b.this == a 1050 1051 1052def is_false(a: exp.Expression) -> bool: 1053 return type(a) is exp.Boolean and not a.this 1054 1055 1056def is_null(a: exp.Expression) -> bool: 1057 return type(a) is exp.Null 1058 1059 1060def eval_boolean(expression, a, b): 1061 if isinstance(expression, (exp.EQ, exp.Is)): 1062 return boolean_literal(a == b) 1063 if isinstance(expression, exp.NEQ): 1064 return boolean_literal(a != b) 1065 if isinstance(expression, exp.GT): 1066 return boolean_literal(a > b) 1067 if isinstance(expression, exp.GTE): 1068 return boolean_literal(a >= b) 1069 if isinstance(expression, exp.LT): 1070 return boolean_literal(a < b) 1071 if isinstance(expression, exp.LTE): 1072 return boolean_literal(a <= b) 1073 return None 1074 1075 1076def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1077 if isinstance(value, datetime.datetime): 1078 return value.date() 1079 if isinstance(value, datetime.date): 1080 return value 1081 try: 1082 return datetime.datetime.fromisoformat(value).date() 1083 except ValueError: 1084 return None 1085 1086 1087def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1088 if isinstance(value, datetime.datetime): 1089 return value 1090 if isinstance(value, datetime.date): 1091 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1092 try: 1093 return datetime.datetime.fromisoformat(value) 1094 except ValueError: 1095 return None 1096 1097 1098def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1099 if not value: 1100 return None 1101 if to.is_type(exp.DataType.Type.DATE): 1102 return cast_as_date(value) 1103 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1104 return cast_as_datetime(value) 1105 return None 1106 1107 1108def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1109 if isinstance(cast, exp.Cast): 1110 to = cast.to 1111 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1112 to = exp.DataType.build(exp.DataType.Type.DATE) 1113 else: 1114 return None 1115 1116 if isinstance(cast.this, exp.Literal): 1117 value: t.Any = cast.this.name 1118 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1119 value = extract_date(cast.this) 1120 else: 1121 return None 1122 return cast_value(value, to) 1123 1124 1125def _is_date_literal(expression: exp.Expression) -> bool: 1126 return extract_date(expression) is not None 1127 1128 1129def extract_interval(expression): 1130 try: 1131 n = int(expression.name) 1132 unit = expression.text("unit").lower() 1133 return interval(unit, n) 1134 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1135 return None 1136 1137 1138def extract_type(*expressions): 1139 target_type = None 1140 for expression in expressions: 1141 target_type = expression.to if isinstance(expression, exp.Cast) else expression.type 1142 if target_type: 1143 break 1144 1145 return target_type 1146 1147 1148def date_literal(date, target_type=None): 1149 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1150 target_type = ( 1151 exp.DataType.Type.DATETIME 1152 if isinstance(date, datetime.datetime) 1153 else exp.DataType.Type.DATE 1154 ) 1155 1156 return exp.cast(exp.Literal.string(date), target_type) 1157 1158 1159def interval(unit: str, n: int = 1): 1160 from dateutil.relativedelta import relativedelta 1161 1162 if unit == "year": 1163 return relativedelta(years=1 * n) 1164 if unit == "quarter": 1165 return relativedelta(months=3 * n) 1166 if unit == "month": 1167 return relativedelta(months=1 * n) 1168 if unit == "week": 1169 return relativedelta(weeks=1 * n) 1170 if unit == "day": 1171 return relativedelta(days=1 * n) 1172 if unit == "hour": 1173 return relativedelta(hours=1 * n) 1174 if unit == "minute": 1175 return relativedelta(minutes=1 * n) 1176 if unit == "second": 1177 return relativedelta(seconds=1 * n) 1178 1179 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1180 1181 1182def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1183 if unit == "year": 1184 return d.replace(month=1, day=1) 1185 if unit == "quarter": 1186 if d.month <= 3: 1187 return d.replace(month=1, day=1) 1188 elif d.month <= 6: 1189 return d.replace(month=4, day=1) 1190 elif d.month <= 9: 1191 return d.replace(month=7, day=1) 1192 else: 1193 return d.replace(month=10, day=1) 1194 if unit == "month": 1195 return d.replace(month=d.month, day=1) 1196 if unit == "week": 1197 # Assuming week starts on Monday (0) and ends on Sunday (6) 1198 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1199 if unit == "day": 1200 return d 1201 1202 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1203 1204 1205def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1206 floor = date_floor(d, unit, dialect) 1207 1208 if floor == d: 1209 return d 1210 1211 return floor + interval(unit) 1212 1213 1214def boolean_literal(condition): 1215 return exp.true() if condition else exp.false() 1216 1217 1218def _flat_simplify(expression, simplifier, root=True): 1219 if root or not expression.same_parent: 1220 operands = [] 1221 queue = deque(expression.flatten(unnest=False)) 1222 size = len(queue) 1223 1224 while queue: 1225 a = queue.popleft() 1226 1227 for b in queue: 1228 result = simplifier(expression, a, b) 1229 1230 if result and result is not expression: 1231 queue.remove(b) 1232 queue.appendleft(result) 1233 break 1234 else: 1235 operands.append(a) 1236 1237 if len(operands) < size: 1238 return functools.reduce( 1239 lambda a, b: expression.__class__(this=a, expression=b), operands 1240 ) 1241 return expression 1242 1243 1244def gen(expression: t.Any) -> str: 1245 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1246 1247 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1248 generator is expensive so we have a bare minimum sql generator here. 1249 """ 1250 return Gen().gen(expression) 1251 1252 1253class Gen: 1254 def __init__(self): 1255 self.stack = [] 1256 self.sqls = [] 1257 1258 def gen(self, expression: exp.Expression) -> str: 1259 self.stack = [expression] 1260 self.sqls.clear() 1261 1262 while self.stack: 1263 node = self.stack.pop() 1264 1265 if isinstance(node, exp.Expression): 1266 exp_handler_name = f"{node.key}_sql" 1267 1268 if hasattr(self, exp_handler_name): 1269 getattr(self, exp_handler_name)(node) 1270 elif isinstance(node, exp.Func): 1271 self._function(node) 1272 else: 1273 key = node.key.upper() 1274 self.stack.append(f"{key} " if self._args(node) else key) 1275 elif type(node) is list: 1276 for n in reversed(node): 1277 if n is not None: 1278 self.stack.extend((n, ",")) 1279 if node: 1280 self.stack.pop() 1281 else: 1282 if node is not None: 1283 self.sqls.append(str(node)) 1284 1285 return "".join(self.sqls) 1286 1287 def add_sql(self, e: exp.Add) -> None: 1288 self._binary(e, " + ") 1289 1290 def alias_sql(self, e: exp.Alias) -> None: 1291 self.stack.extend( 1292 ( 1293 e.args.get("alias"), 1294 " AS ", 1295 e.args.get("this"), 1296 ) 1297 ) 1298 1299 def and_sql(self, e: exp.And) -> None: 1300 self._binary(e, " AND ") 1301 1302 def anonymous_sql(self, e: exp.Anonymous) -> None: 1303 this = e.this 1304 if isinstance(this, str): 1305 name = this.upper() 1306 elif isinstance(this, exp.Identifier): 1307 name = this.this 1308 name = f'"{name}"' if this.quoted else name.upper() 1309 else: 1310 raise ValueError( 1311 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1312 ) 1313 1314 self.stack.extend( 1315 ( 1316 ")", 1317 e.expressions, 1318 "(", 1319 name, 1320 ) 1321 ) 1322 1323 def between_sql(self, e: exp.Between) -> None: 1324 self.stack.extend( 1325 ( 1326 e.args.get("high"), 1327 " AND ", 1328 e.args.get("low"), 1329 " BETWEEN ", 1330 e.this, 1331 ) 1332 ) 1333 1334 def boolean_sql(self, e: exp.Boolean) -> None: 1335 self.stack.append("TRUE" if e.this else "FALSE") 1336 1337 def bracket_sql(self, e: exp.Bracket) -> None: 1338 self.stack.extend( 1339 ( 1340 "]", 1341 e.expressions, 1342 "[", 1343 e.this, 1344 ) 1345 ) 1346 1347 def column_sql(self, e: exp.Column) -> None: 1348 for p in reversed(e.parts): 1349 self.stack.extend((p, ".")) 1350 self.stack.pop() 1351 1352 def datatype_sql(self, e: exp.DataType) -> None: 1353 self._args(e, 1) 1354 self.stack.append(f"{e.this.name} ") 1355 1356 def div_sql(self, e: exp.Div) -> None: 1357 self._binary(e, " / ") 1358 1359 def dot_sql(self, e: exp.Dot) -> None: 1360 self._binary(e, ".") 1361 1362 def eq_sql(self, e: exp.EQ) -> None: 1363 self._binary(e, " = ") 1364 1365 def from_sql(self, e: exp.From) -> None: 1366 self.stack.extend((e.this, "FROM ")) 1367 1368 def gt_sql(self, e: exp.GT) -> None: 1369 self._binary(e, " > ") 1370 1371 def gte_sql(self, e: exp.GTE) -> None: 1372 self._binary(e, " >= ") 1373 1374 def identifier_sql(self, e: exp.Identifier) -> None: 1375 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1376 1377 def ilike_sql(self, e: exp.ILike) -> None: 1378 self._binary(e, " ILIKE ") 1379 1380 def in_sql(self, e: exp.In) -> None: 1381 self.stack.append(")") 1382 self._args(e, 1) 1383 self.stack.extend( 1384 ( 1385 "(", 1386 " IN ", 1387 e.this, 1388 ) 1389 ) 1390 1391 def intdiv_sql(self, e: exp.IntDiv) -> None: 1392 self._binary(e, " DIV ") 1393 1394 def is_sql(self, e: exp.Is) -> None: 1395 self._binary(e, " IS ") 1396 1397 def like_sql(self, e: exp.Like) -> None: 1398 self._binary(e, " Like ") 1399 1400 def literal_sql(self, e: exp.Literal) -> None: 1401 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1402 1403 def lt_sql(self, e: exp.LT) -> None: 1404 self._binary(e, " < ") 1405 1406 def lte_sql(self, e: exp.LTE) -> None: 1407 self._binary(e, " <= ") 1408 1409 def mod_sql(self, e: exp.Mod) -> None: 1410 self._binary(e, " % ") 1411 1412 def mul_sql(self, e: exp.Mul) -> None: 1413 self._binary(e, " * ") 1414 1415 def neg_sql(self, e: exp.Neg) -> None: 1416 self._unary(e, "-") 1417 1418 def neq_sql(self, e: exp.NEQ) -> None: 1419 self._binary(e, " <> ") 1420 1421 def not_sql(self, e: exp.Not) -> None: 1422 self._unary(e, "NOT ") 1423 1424 def null_sql(self, e: exp.Null) -> None: 1425 self.stack.append("NULL") 1426 1427 def or_sql(self, e: exp.Or) -> None: 1428 self._binary(e, " OR ") 1429 1430 def paren_sql(self, e: exp.Paren) -> None: 1431 self.stack.extend( 1432 ( 1433 ")", 1434 e.this, 1435 "(", 1436 ) 1437 ) 1438 1439 def sub_sql(self, e: exp.Sub) -> None: 1440 self._binary(e, " - ") 1441 1442 def subquery_sql(self, e: exp.Subquery) -> None: 1443 self._args(e, 2) 1444 alias = e.args.get("alias") 1445 if alias: 1446 self.stack.append(alias) 1447 self.stack.extend((")", e.this, "(")) 1448 1449 def table_sql(self, e: exp.Table) -> None: 1450 self._args(e, 4) 1451 alias = e.args.get("alias") 1452 if alias: 1453 self.stack.append(alias) 1454 for p in reversed(e.parts): 1455 self.stack.extend((p, ".")) 1456 self.stack.pop() 1457 1458 def tablealias_sql(self, e: exp.TableAlias) -> None: 1459 columns = e.columns 1460 1461 if columns: 1462 self.stack.extend((")", columns, "(")) 1463 1464 self.stack.extend((e.this, " AS ")) 1465 1466 def var_sql(self, e: exp.Var) -> None: 1467 self.stack.append(e.this) 1468 1469 def _binary(self, e: exp.Binary, op: str) -> None: 1470 self.stack.extend((e.expression, op, e.this)) 1471 1472 def _unary(self, e: exp.Unary, op: str) -> None: 1473 self.stack.extend((e.this, op)) 1474 1475 def _function(self, e: exp.Func) -> None: 1476 self.stack.extend( 1477 ( 1478 ")", 1479 list(e.args.values()), 1480 "(", 1481 e.sql_name(), 1482 ) 1483 ) 1484 1485 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1486 kvs = [] 1487 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1488 1489 for k in arg_types or arg_types: 1490 v = node.args.get(k) 1491 1492 if v is not None: 1493 kvs.append([f":{k}", v]) 1494 if kvs: 1495 self.stack.append(kvs) 1496 return True 1497 return False
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
38def simplify( 39 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 40): 41 """ 42 Rewrite sqlglot AST to simplify expressions. 43 44 Example: 45 >>> import sqlglot 46 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 47 >>> simplify(expression).sql() 48 'TRUE' 49 50 Args: 51 expression (sqlglot.Expression): expression to simplify 52 constant_propagation: whether the constant propagation rule should be used 53 54 Returns: 55 sqlglot.Expression: simplified expression 56 """ 57 58 dialect = Dialect.get_or_raise(dialect) 59 60 def _simplify(expression, root=True): 61 if expression.meta.get(FINAL): 62 return expression 63 64 # group by expressions cannot be simplified, for example 65 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 66 # the projection must exactly match the group by key 67 group = expression.args.get("group") 68 69 if group and hasattr(expression, "selects"): 70 groups = set(group.expressions) 71 group.meta[FINAL] = True 72 73 for e in expression.selects: 74 for node in e.walk(): 75 if node in groups: 76 e.meta[FINAL] = True 77 break 78 79 having = expression.args.get("having") 80 if having: 81 for node in having.walk(): 82 if node in groups: 83 having.meta[FINAL] = True 84 break 85 86 # Pre-order transformations 87 node = expression 88 node = rewrite_between(node) 89 node = uniq_sort(node, root) 90 node = absorb_and_eliminate(node, root) 91 node = simplify_concat(node) 92 node = simplify_conditionals(node) 93 94 if constant_propagation: 95 node = propagate_constants(node, root) 96 97 exp.replace_children(node, lambda e: _simplify(e, False)) 98 99 # Post-order transformations 100 node = simplify_not(node) 101 node = flatten(node) 102 node = simplify_connectors(node, root) 103 node = remove_complements(node, root) 104 node = simplify_coalesce(node) 105 node.parent = expression.parent 106 node = simplify_literals(node, root) 107 node = simplify_equality(node) 108 node = simplify_parens(node) 109 node = simplify_datetrunc(node, dialect) 110 node = sort_comparison(node) 111 node = simplify_startswith(node) 112 113 if root: 114 expression.replace(node) 115 return node 116 117 expression = while_changing(expression, _simplify) 118 remove_where_true(expression) 119 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
122def catch(*exceptions): 123 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 124 125 def decorator(func): 126 def wrapped(expression, *args, **kwargs): 127 try: 128 return func(expression, *args, **kwargs) 129 except exceptions: 130 return expression 131 132 return wrapped 133 134 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
137def rewrite_between(expression: exp.Expression) -> exp.Expression: 138 """Rewrite x between y and z to x >= y AND x <= z. 139 140 This is done because comparison simplification is only done on lt/lte/gt/gte. 141 """ 142 if isinstance(expression, exp.Between): 143 negate = isinstance(expression.parent, exp.Not) 144 145 expression = exp.and_( 146 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 147 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 148 copy=False, 149 ) 150 151 if negate: 152 expression = exp.paren(expression, copy=False) 153 154 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
167def simplify_not(expression): 168 """ 169 Demorgan's Law 170 NOT (x OR y) -> NOT x AND NOT y 171 NOT (x AND y) -> NOT x OR NOT y 172 """ 173 if isinstance(expression, exp.Not): 174 this = expression.this 175 if is_null(this): 176 return exp.null() 177 if this.__class__ in COMPLEMENT_COMPARISONS: 178 return COMPLEMENT_COMPARISONS[this.__class__]( 179 this=this.this, expression=this.expression 180 ) 181 if isinstance(this, exp.Paren): 182 condition = this.unnest() 183 if isinstance(condition, exp.And): 184 return exp.paren( 185 exp.or_( 186 exp.not_(condition.left, copy=False), 187 exp.not_(condition.right, copy=False), 188 copy=False, 189 ) 190 ) 191 if isinstance(condition, exp.Or): 192 return exp.paren( 193 exp.and_( 194 exp.not_(condition.left, copy=False), 195 exp.not_(condition.right, copy=False), 196 copy=False, 197 ) 198 ) 199 if is_null(condition): 200 return exp.null() 201 if always_true(this): 202 return exp.false() 203 if is_false(this): 204 return exp.true() 205 if isinstance(this, exp.Not): 206 # double negation 207 # NOT NOT x -> x 208 return this.this 209 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
212def flatten(expression): 213 """ 214 A AND (B AND C) -> A AND B AND C 215 A OR (B OR C) -> A OR B OR C 216 """ 217 if isinstance(expression, exp.Connector): 218 for node in expression.args.values(): 219 child = node.unnest() 220 if isinstance(child, expression.__class__): 221 node.replace(child) 222 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
225def simplify_connectors(expression, root=True): 226 def _simplify_connectors(expression, left, right): 227 if left == right: 228 if isinstance(expression, exp.Xor): 229 return exp.false() 230 return left 231 if isinstance(expression, exp.And): 232 if is_false(left) or is_false(right): 233 return exp.false() 234 if is_null(left) or is_null(right): 235 return exp.null() 236 if always_true(left) and always_true(right): 237 return exp.true() 238 if always_true(left): 239 return right 240 if always_true(right): 241 return left 242 return _simplify_comparison(expression, left, right) 243 elif isinstance(expression, exp.Or): 244 if always_true(left) or always_true(right): 245 return exp.true() 246 if is_false(left) and is_false(right): 247 return exp.false() 248 if ( 249 (is_null(left) and is_null(right)) 250 or (is_null(left) and is_false(right)) 251 or (is_false(left) and is_null(right)) 252 ): 253 return exp.null() 254 if is_false(left): 255 return right 256 if is_false(right): 257 return left 258 return _simplify_comparison(expression, left, right, or_=True) 259 260 if isinstance(expression, exp.Connector): 261 return _flat_simplify(expression, _simplify_connectors, root) 262 return expression
348def remove_complements(expression, root=True): 349 """ 350 Removing complements. 351 352 A AND NOT A -> FALSE 353 A OR NOT A -> TRUE 354 """ 355 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 356 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 357 358 for a, b in itertools.permutations(expression.flatten(), 2): 359 if is_complement(a, b): 360 return complement 361 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
364def uniq_sort(expression, root=True): 365 """ 366 Uniq and sort a connector. 367 368 C AND A AND B AND B -> A AND B AND C 369 """ 370 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 371 flattened = tuple(expression.flatten()) 372 373 if isinstance(expression, exp.Xor): 374 result_func = exp.xor 375 # Do not deduplicate XOR as A XOR A != A if A == True 376 deduped = None 377 arr = tuple((gen(e), e) for e in flattened) 378 else: 379 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 380 deduped = {gen(e): e for e in flattened} 381 arr = tuple(deduped.items()) 382 383 # check if the operands are already sorted, if not sort them 384 # A AND C AND B -> A AND B AND C 385 for i, (sql, e) in enumerate(arr[1:]): 386 if sql < arr[i][0]: 387 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 388 break 389 else: 390 # we didn't have to sort but maybe we need to dedup 391 if deduped and len(deduped) < len(flattened): 392 expression = result_func(*deduped.values(), copy=False) 393 394 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
397def absorb_and_eliminate(expression, root=True): 398 """ 399 absorption: 400 A AND (A OR B) -> A 401 A OR (A AND B) -> A 402 A AND (NOT A OR B) -> A AND B 403 A OR (NOT A AND B) -> A OR B 404 elimination: 405 (A AND B) OR (A AND NOT B) -> A 406 (A OR B) AND (A OR NOT B) -> A 407 """ 408 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 409 kind = exp.Or if isinstance(expression, exp.And) else exp.And 410 411 for a, b in itertools.permutations(expression.flatten(), 2): 412 if isinstance(a, kind): 413 aa, ab = a.unnest_operands() 414 415 # absorb 416 if is_complement(b, aa): 417 aa.replace(exp.true() if kind == exp.And else exp.false()) 418 elif is_complement(b, ab): 419 ab.replace(exp.true() if kind == exp.And else exp.false()) 420 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 421 a.replace(exp.false() if kind == exp.And else exp.true()) 422 elif isinstance(b, kind): 423 # eliminate 424 rhs = b.unnest_operands() 425 ba, bb = rhs 426 427 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 428 a.replace(aa) 429 b.replace(aa) 430 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 431 a.replace(ab) 432 b.replace(ab) 433 434 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
437def propagate_constants(expression, root=True): 438 """ 439 Propagate constants for conjunctions in DNF: 440 441 SELECT * FROM t WHERE a = b AND b = 5 becomes 442 SELECT * FROM t WHERE a = 5 AND b = 5 443 444 Reference: https://www.sqlite.org/optoverview.html 445 """ 446 447 if ( 448 isinstance(expression, exp.And) 449 and (root or not expression.same_parent) 450 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 451 ): 452 constant_mapping = {} 453 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 454 if isinstance(expr, exp.EQ): 455 l, r = expr.left, expr.right 456 457 # TODO: create a helper that can be used to detect nested literal expressions such 458 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 459 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 460 constant_mapping[l] = (id(l), r) 461 462 if constant_mapping: 463 for column in find_all_in_scope(expression, exp.Column): 464 parent = column.parent 465 column_id, constant = constant_mapping.get(column) or (None, None) 466 if ( 467 column_id is not None 468 and id(column) != column_id 469 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 470 ): 471 column.replace(constant.copy()) 472 473 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
126 def wrapped(expression, *args, **kwargs): 127 try: 128 return func(expression, *args, **kwargs) 129 except exceptions: 130 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
548def simplify_literals(expression, root=True): 549 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 550 return _flat_simplify(expression, _simplify_binary, root) 551 552 if isinstance(expression, exp.Neg): 553 this = expression.this 554 if this.is_number: 555 value = this.name 556 if value[0] == "-": 557 return exp.Literal.number(value[1:]) 558 return exp.Literal.number(f"-{value}") 559 560 if type(expression) in INVERSE_DATE_OPS: 561 return _simplify_binary(expression, expression.this, expression.interval()) or expression 562 563 return expression
664def simplify_parens(expression): 665 if not isinstance(expression, exp.Paren): 666 return expression 667 668 this = expression.this 669 parent = expression.parent 670 parent_is_predicate = isinstance(parent, exp.Predicate) 671 672 if ( 673 not isinstance(this, exp.Select) 674 and not isinstance(parent, exp.SubqueryPredicate) 675 and ( 676 not isinstance(parent, (exp.Condition, exp.Binary)) 677 or isinstance(parent, exp.Paren) 678 or ( 679 not isinstance(this, exp.Binary) 680 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 681 ) 682 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 683 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 684 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 685 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 686 ) 687 ): 688 return this 689 return expression
700def simplify_coalesce(expression): 701 # COALESCE(x) -> x 702 if ( 703 isinstance(expression, exp.Coalesce) 704 and (not expression.expressions or _is_nonnull_constant(expression.this)) 705 # COALESCE is also used as a Spark partitioning hint 706 and not isinstance(expression.parent, exp.Hint) 707 ): 708 return expression.this 709 710 if not isinstance(expression, COMPARISONS): 711 return expression 712 713 if isinstance(expression.left, exp.Coalesce): 714 coalesce = expression.left 715 other = expression.right 716 elif isinstance(expression.right, exp.Coalesce): 717 coalesce = expression.right 718 other = expression.left 719 else: 720 return expression 721 722 # This transformation is valid for non-constants, 723 # but it really only does anything if they are both constants. 724 if not _is_constant(other): 725 return expression 726 727 # Find the first constant arg 728 for arg_index, arg in enumerate(coalesce.expressions): 729 if _is_constant(arg): 730 break 731 else: 732 return expression 733 734 coalesce.set("expressions", coalesce.expressions[:arg_index]) 735 736 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 737 # since we already remove COALESCE at the top of this function. 738 coalesce = coalesce if coalesce.expressions else coalesce.this 739 740 # This expression is more complex than when we started, but it will get simplified further 741 return exp.paren( 742 exp.or_( 743 exp.and_( 744 coalesce.is_(exp.null()).not_(copy=False), 745 expression.copy(), 746 copy=False, 747 ), 748 exp.and_( 749 coalesce.is_(exp.null()), 750 type(expression)(this=arg.copy(), expression=other.copy()), 751 copy=False, 752 ), 753 copy=False, 754 ) 755 )
761def simplify_concat(expression): 762 """Reduces all groups that contain string literals by concatenating them.""" 763 if not isinstance(expression, CONCATS) or ( 764 # We can't reduce a CONCAT_WS call if we don't statically know the separator 765 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 766 ): 767 return expression 768 769 if isinstance(expression, exp.ConcatWs): 770 sep_expr, *expressions = expression.expressions 771 sep = sep_expr.name 772 concat_type = exp.ConcatWs 773 args = {} 774 else: 775 expressions = expression.expressions 776 sep = "" 777 concat_type = exp.Concat 778 args = { 779 "safe": expression.args.get("safe"), 780 "coalesce": expression.args.get("coalesce"), 781 } 782 783 new_args = [] 784 for is_string_group, group in itertools.groupby( 785 expressions or expression.flatten(), lambda e: e.is_string 786 ): 787 if is_string_group: 788 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 789 else: 790 new_args.extend(group) 791 792 if len(new_args) == 1 and new_args[0].is_string: 793 return new_args[0] 794 795 if concat_type is exp.ConcatWs: 796 new_args = [sep_expr] + new_args 797 elif isinstance(expression, exp.DPipe): 798 return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args) 799 800 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
803def simplify_conditionals(expression): 804 """Simplifies expressions like IF, CASE if their condition is statically known.""" 805 if isinstance(expression, exp.Case): 806 this = expression.this 807 for case in expression.args["ifs"]: 808 cond = case.this 809 if this: 810 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 811 cond = cond.replace(this.pop().eq(cond)) 812 813 if always_true(cond): 814 return case.args["true"] 815 816 if always_false(cond): 817 case.pop() 818 if not expression.args["ifs"]: 819 return expression.args.get("default") or exp.null() 820 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 821 if always_true(expression.this): 822 return expression.args["true"] 823 if always_false(expression.this): 824 return expression.args.get("false") or exp.null() 825 826 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
829def simplify_startswith(expression: exp.Expression) -> exp.Expression: 830 """ 831 Reduces a prefix check to either TRUE or FALSE if both the string and the 832 prefix are statically known. 833 834 Example: 835 >>> from sqlglot import parse_one 836 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 837 'TRUE' 838 """ 839 if ( 840 isinstance(expression, exp.StartsWith) 841 and expression.this.is_string 842 and expression.expression.is_string 843 ): 844 return exp.convert(expression.name.startswith(expression.expression.name)) 845 846 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
126 def wrapped(expression, *args, **kwargs): 127 try: 128 return func(expression, *args, **kwargs) 129 except exceptions: 130 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
995def sort_comparison(expression: exp.Expression) -> exp.Expression: 996 if expression.__class__ in COMPLEMENT_COMPARISONS: 997 l, r = expression.this, expression.expression 998 l_column = isinstance(l, exp.Column) 999 r_column = isinstance(r, exp.Column) 1000 l_const = _is_constant(l) 1001 r_const = _is_constant(r) 1002 1003 if (l_column and not r_column) or (r_const and not l_const): 1004 return expression 1005 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 1006 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 1007 this=r, expression=l 1008 ) 1009 return expression
1023def remove_where_true(expression): 1024 for where in expression.find_all(exp.Where): 1025 if always_true(where.this): 1026 where.pop() 1027 for join in expression.find_all(exp.Join): 1028 if ( 1029 always_true(join.args.get("on")) 1030 and not join.args.get("using") 1031 and not join.args.get("method") 1032 and (join.side, join.kind) in JOINS 1033 ): 1034 join.args["on"].pop() 1035 join.set("side", None) 1036 join.set("kind", "CROSS")
1061def eval_boolean(expression, a, b): 1062 if isinstance(expression, (exp.EQ, exp.Is)): 1063 return boolean_literal(a == b) 1064 if isinstance(expression, exp.NEQ): 1065 return boolean_literal(a != b) 1066 if isinstance(expression, exp.GT): 1067 return boolean_literal(a > b) 1068 if isinstance(expression, exp.GTE): 1069 return boolean_literal(a >= b) 1070 if isinstance(expression, exp.LT): 1071 return boolean_literal(a < b) 1072 if isinstance(expression, exp.LTE): 1073 return boolean_literal(a <= b) 1074 return None
1077def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1078 if isinstance(value, datetime.datetime): 1079 return value.date() 1080 if isinstance(value, datetime.date): 1081 return value 1082 try: 1083 return datetime.datetime.fromisoformat(value).date() 1084 except ValueError: 1085 return None
1088def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1089 if isinstance(value, datetime.datetime): 1090 return value 1091 if isinstance(value, datetime.date): 1092 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1093 try: 1094 return datetime.datetime.fromisoformat(value) 1095 except ValueError: 1096 return None
1099def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1100 if not value: 1101 return None 1102 if to.is_type(exp.DataType.Type.DATE): 1103 return cast_as_date(value) 1104 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1105 return cast_as_datetime(value) 1106 return None
1109def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1110 if isinstance(cast, exp.Cast): 1111 to = cast.to 1112 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1113 to = exp.DataType.build(exp.DataType.Type.DATE) 1114 else: 1115 return None 1116 1117 if isinstance(cast.this, exp.Literal): 1118 value: t.Any = cast.this.name 1119 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1120 value = extract_date(cast.this) 1121 else: 1122 return None 1123 return cast_value(value, to)
1149def date_literal(date, target_type=None): 1150 if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES): 1151 target_type = ( 1152 exp.DataType.Type.DATETIME 1153 if isinstance(date, datetime.datetime) 1154 else exp.DataType.Type.DATE 1155 ) 1156 1157 return exp.cast(exp.Literal.string(date), target_type)
1160def interval(unit: str, n: int = 1): 1161 from dateutil.relativedelta import relativedelta 1162 1163 if unit == "year": 1164 return relativedelta(years=1 * n) 1165 if unit == "quarter": 1166 return relativedelta(months=3 * n) 1167 if unit == "month": 1168 return relativedelta(months=1 * n) 1169 if unit == "week": 1170 return relativedelta(weeks=1 * n) 1171 if unit == "day": 1172 return relativedelta(days=1 * n) 1173 if unit == "hour": 1174 return relativedelta(hours=1 * n) 1175 if unit == "minute": 1176 return relativedelta(minutes=1 * n) 1177 if unit == "second": 1178 return relativedelta(seconds=1 * n) 1179 1180 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1183def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1184 if unit == "year": 1185 return d.replace(month=1, day=1) 1186 if unit == "quarter": 1187 if d.month <= 3: 1188 return d.replace(month=1, day=1) 1189 elif d.month <= 6: 1190 return d.replace(month=4, day=1) 1191 elif d.month <= 9: 1192 return d.replace(month=7, day=1) 1193 else: 1194 return d.replace(month=10, day=1) 1195 if unit == "month": 1196 return d.replace(month=d.month, day=1) 1197 if unit == "week": 1198 # Assuming week starts on Monday (0) and ends on Sunday (6) 1199 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1200 if unit == "day": 1201 return d 1202 1203 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1245def gen(expression: t.Any) -> str: 1246 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1247 1248 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1249 generator is expensive so we have a bare minimum sql generator here. 1250 """ 1251 return Gen().gen(expression)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
1254class Gen: 1255 def __init__(self): 1256 self.stack = [] 1257 self.sqls = [] 1258 1259 def gen(self, expression: exp.Expression) -> str: 1260 self.stack = [expression] 1261 self.sqls.clear() 1262 1263 while self.stack: 1264 node = self.stack.pop() 1265 1266 if isinstance(node, exp.Expression): 1267 exp_handler_name = f"{node.key}_sql" 1268 1269 if hasattr(self, exp_handler_name): 1270 getattr(self, exp_handler_name)(node) 1271 elif isinstance(node, exp.Func): 1272 self._function(node) 1273 else: 1274 key = node.key.upper() 1275 self.stack.append(f"{key} " if self._args(node) else key) 1276 elif type(node) is list: 1277 for n in reversed(node): 1278 if n is not None: 1279 self.stack.extend((n, ",")) 1280 if node: 1281 self.stack.pop() 1282 else: 1283 if node is not None: 1284 self.sqls.append(str(node)) 1285 1286 return "".join(self.sqls) 1287 1288 def add_sql(self, e: exp.Add) -> None: 1289 self._binary(e, " + ") 1290 1291 def alias_sql(self, e: exp.Alias) -> None: 1292 self.stack.extend( 1293 ( 1294 e.args.get("alias"), 1295 " AS ", 1296 e.args.get("this"), 1297 ) 1298 ) 1299 1300 def and_sql(self, e: exp.And) -> None: 1301 self._binary(e, " AND ") 1302 1303 def anonymous_sql(self, e: exp.Anonymous) -> None: 1304 this = e.this 1305 if isinstance(this, str): 1306 name = this.upper() 1307 elif isinstance(this, exp.Identifier): 1308 name = this.this 1309 name = f'"{name}"' if this.quoted else name.upper() 1310 else: 1311 raise ValueError( 1312 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1313 ) 1314 1315 self.stack.extend( 1316 ( 1317 ")", 1318 e.expressions, 1319 "(", 1320 name, 1321 ) 1322 ) 1323 1324 def between_sql(self, e: exp.Between) -> None: 1325 self.stack.extend( 1326 ( 1327 e.args.get("high"), 1328 " AND ", 1329 e.args.get("low"), 1330 " BETWEEN ", 1331 e.this, 1332 ) 1333 ) 1334 1335 def boolean_sql(self, e: exp.Boolean) -> None: 1336 self.stack.append("TRUE" if e.this else "FALSE") 1337 1338 def bracket_sql(self, e: exp.Bracket) -> None: 1339 self.stack.extend( 1340 ( 1341 "]", 1342 e.expressions, 1343 "[", 1344 e.this, 1345 ) 1346 ) 1347 1348 def column_sql(self, e: exp.Column) -> None: 1349 for p in reversed(e.parts): 1350 self.stack.extend((p, ".")) 1351 self.stack.pop() 1352 1353 def datatype_sql(self, e: exp.DataType) -> None: 1354 self._args(e, 1) 1355 self.stack.append(f"{e.this.name} ") 1356 1357 def div_sql(self, e: exp.Div) -> None: 1358 self._binary(e, " / ") 1359 1360 def dot_sql(self, e: exp.Dot) -> None: 1361 self._binary(e, ".") 1362 1363 def eq_sql(self, e: exp.EQ) -> None: 1364 self._binary(e, " = ") 1365 1366 def from_sql(self, e: exp.From) -> None: 1367 self.stack.extend((e.this, "FROM ")) 1368 1369 def gt_sql(self, e: exp.GT) -> None: 1370 self._binary(e, " > ") 1371 1372 def gte_sql(self, e: exp.GTE) -> None: 1373 self._binary(e, " >= ") 1374 1375 def identifier_sql(self, e: exp.Identifier) -> None: 1376 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1377 1378 def ilike_sql(self, e: exp.ILike) -> None: 1379 self._binary(e, " ILIKE ") 1380 1381 def in_sql(self, e: exp.In) -> None: 1382 self.stack.append(")") 1383 self._args(e, 1) 1384 self.stack.extend( 1385 ( 1386 "(", 1387 " IN ", 1388 e.this, 1389 ) 1390 ) 1391 1392 def intdiv_sql(self, e: exp.IntDiv) -> None: 1393 self._binary(e, " DIV ") 1394 1395 def is_sql(self, e: exp.Is) -> None: 1396 self._binary(e, " IS ") 1397 1398 def like_sql(self, e: exp.Like) -> None: 1399 self._binary(e, " Like ") 1400 1401 def literal_sql(self, e: exp.Literal) -> None: 1402 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1403 1404 def lt_sql(self, e: exp.LT) -> None: 1405 self._binary(e, " < ") 1406 1407 def lte_sql(self, e: exp.LTE) -> None: 1408 self._binary(e, " <= ") 1409 1410 def mod_sql(self, e: exp.Mod) -> None: 1411 self._binary(e, " % ") 1412 1413 def mul_sql(self, e: exp.Mul) -> None: 1414 self._binary(e, " * ") 1415 1416 def neg_sql(self, e: exp.Neg) -> None: 1417 self._unary(e, "-") 1418 1419 def neq_sql(self, e: exp.NEQ) -> None: 1420 self._binary(e, " <> ") 1421 1422 def not_sql(self, e: exp.Not) -> None: 1423 self._unary(e, "NOT ") 1424 1425 def null_sql(self, e: exp.Null) -> None: 1426 self.stack.append("NULL") 1427 1428 def or_sql(self, e: exp.Or) -> None: 1429 self._binary(e, " OR ") 1430 1431 def paren_sql(self, e: exp.Paren) -> None: 1432 self.stack.extend( 1433 ( 1434 ")", 1435 e.this, 1436 "(", 1437 ) 1438 ) 1439 1440 def sub_sql(self, e: exp.Sub) -> None: 1441 self._binary(e, " - ") 1442 1443 def subquery_sql(self, e: exp.Subquery) -> None: 1444 self._args(e, 2) 1445 alias = e.args.get("alias") 1446 if alias: 1447 self.stack.append(alias) 1448 self.stack.extend((")", e.this, "(")) 1449 1450 def table_sql(self, e: exp.Table) -> None: 1451 self._args(e, 4) 1452 alias = e.args.get("alias") 1453 if alias: 1454 self.stack.append(alias) 1455 for p in reversed(e.parts): 1456 self.stack.extend((p, ".")) 1457 self.stack.pop() 1458 1459 def tablealias_sql(self, e: exp.TableAlias) -> None: 1460 columns = e.columns 1461 1462 if columns: 1463 self.stack.extend((")", columns, "(")) 1464 1465 self.stack.extend((e.this, " AS ")) 1466 1467 def var_sql(self, e: exp.Var) -> None: 1468 self.stack.append(e.this) 1469 1470 def _binary(self, e: exp.Binary, op: str) -> None: 1471 self.stack.extend((e.expression, op, e.this)) 1472 1473 def _unary(self, e: exp.Unary, op: str) -> None: 1474 self.stack.extend((e.this, op)) 1475 1476 def _function(self, e: exp.Func) -> None: 1477 self.stack.extend( 1478 ( 1479 ")", 1480 list(e.args.values()), 1481 "(", 1482 e.sql_name(), 1483 ) 1484 ) 1485 1486 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1487 kvs = [] 1488 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1489 1490 for k in arg_types or arg_types: 1491 v = node.args.get(k) 1492 1493 if v is not None: 1494 kvs.append([f":{k}", v]) 1495 if kvs: 1496 self.stack.append(kvs) 1497 return True 1498 return False
1259 def gen(self, expression: exp.Expression) -> str: 1260 self.stack = [expression] 1261 self.sqls.clear() 1262 1263 while self.stack: 1264 node = self.stack.pop() 1265 1266 if isinstance(node, exp.Expression): 1267 exp_handler_name = f"{node.key}_sql" 1268 1269 if hasattr(self, exp_handler_name): 1270 getattr(self, exp_handler_name)(node) 1271 elif isinstance(node, exp.Func): 1272 self._function(node) 1273 else: 1274 key = node.key.upper() 1275 self.stack.append(f"{key} " if self._args(node) else key) 1276 elif type(node) is list: 1277 for n in reversed(node): 1278 if n is not None: 1279 self.stack.extend((n, ",")) 1280 if node: 1281 self.stack.pop() 1282 else: 1283 if node is not None: 1284 self.sqls.append(str(node)) 1285 1286 return "".join(self.sqls)
1303 def anonymous_sql(self, e: exp.Anonymous) -> None: 1304 this = e.this 1305 if isinstance(this, str): 1306 name = this.upper() 1307 elif isinstance(this, exp.Identifier): 1308 name = this.this 1309 name = f'"{name}"' if this.quoted else name.upper() 1310 else: 1311 raise ValueError( 1312 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1313 ) 1314 1315 self.stack.extend( 1316 ( 1317 ")", 1318 e.expressions, 1319 "(", 1320 name, 1321 ) 1322 )