Edit on GitHub

sqlglot.optimizer.simplify

   1import datetime
   2import functools
   3import itertools
   4import typing as t
   5from collections import deque
   6from decimal import Decimal
   7
   8import sqlglot
   9from sqlglot import exp
  10from sqlglot.helper import first, is_iterable, merge_ranges, while_changing
  11from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  12
  13# Final means that an expression should not be simplified
  14FINAL = "final"
  15
  16
  17class UnsupportedUnit(Exception):
  18    pass
  19
  20
  21def simplify(expression, constant_propagation=False):
  22    """
  23    Rewrite sqlglot AST to simplify expressions.
  24
  25    Example:
  26        >>> import sqlglot
  27        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  28        >>> simplify(expression).sql()
  29        'TRUE'
  30
  31    Args:
  32        expression (sqlglot.Expression): expression to simplify
  33        constant_propagation: whether or not the constant propagation rule should be used
  34
  35    Returns:
  36        sqlglot.Expression: simplified expression
  37    """
  38
  39    # group by expressions cannot be simplified, for example
  40    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  41    # the projection must exactly match the group by key
  42    for group in expression.find_all(exp.Group):
  43        select = group.parent
  44        groups = set(group.expressions)
  45        group.meta[FINAL] = True
  46
  47        for e in select.selects:
  48            for node, *_ in e.walk():
  49                if node in groups:
  50                    e.meta[FINAL] = True
  51                    break
  52
  53        having = select.args.get("having")
  54        if having:
  55            for node, *_ in having.walk():
  56                if node in groups:
  57                    having.meta[FINAL] = True
  58                    break
  59
  60    def _simplify(expression, root=True):
  61        if expression.meta.get(FINAL):
  62            return expression
  63
  64        # Pre-order transformations
  65        node = expression
  66        node = rewrite_between(node)
  67        node = uniq_sort(node, root)
  68        node = absorb_and_eliminate(node, root)
  69        node = simplify_concat(node)
  70        node = simplify_conditionals(node)
  71
  72        if constant_propagation:
  73            node = propagate_constants(node, root)
  74
  75        exp.replace_children(node, lambda e: _simplify(e, False))
  76
  77        # Post-order transformations
  78        node = simplify_not(node)
  79        node = flatten(node)
  80        node = simplify_connectors(node, root)
  81        node = remove_complements(node, root)
  82        node = simplify_coalesce(node)
  83        node.parent = expression.parent
  84        node = simplify_literals(node, root)
  85        node = simplify_equality(node)
  86        node = simplify_parens(node)
  87        node = simplify_datetrunc_predicate(node)
  88
  89        if root:
  90            expression.replace(node)
  91
  92        return node
  93
  94    expression = while_changing(expression, _simplify)
  95    remove_where_true(expression)
  96    return expression
  97
  98
  99def catch(*exceptions):
 100    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 101
 102    def decorator(func):
 103        def wrapped(expression, *args, **kwargs):
 104            try:
 105                return func(expression, *args, **kwargs)
 106            except exceptions:
 107                return expression
 108
 109        return wrapped
 110
 111    return decorator
 112
 113
 114def rewrite_between(expression: exp.Expression) -> exp.Expression:
 115    """Rewrite x between y and z to x >= y AND x <= z.
 116
 117    This is done because comparison simplification is only done on lt/lte/gt/gte.
 118    """
 119    if isinstance(expression, exp.Between):
 120        return exp.and_(
 121            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 122            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 123            copy=False,
 124        )
 125    return expression
 126
 127
 128def simplify_not(expression):
 129    """
 130    Demorgan's Law
 131    NOT (x OR y) -> NOT x AND NOT y
 132    NOT (x AND y) -> NOT x OR NOT y
 133    """
 134    if isinstance(expression, exp.Not):
 135        if is_null(expression.this):
 136            return exp.null()
 137        if isinstance(expression.this, exp.Paren):
 138            condition = expression.this.unnest()
 139            if isinstance(condition, exp.And):
 140                return exp.or_(
 141                    exp.not_(condition.left, copy=False),
 142                    exp.not_(condition.right, copy=False),
 143                    copy=False,
 144                )
 145            if isinstance(condition, exp.Or):
 146                return exp.and_(
 147                    exp.not_(condition.left, copy=False),
 148                    exp.not_(condition.right, copy=False),
 149                    copy=False,
 150                )
 151            if is_null(condition):
 152                return exp.null()
 153        if always_true(expression.this):
 154            return exp.false()
 155        if is_false(expression.this):
 156            return exp.true()
 157        if isinstance(expression.this, exp.Not):
 158            # double negation
 159            # NOT NOT x -> x
 160            return expression.this.this
 161    return expression
 162
 163
 164def flatten(expression):
 165    """
 166    A AND (B AND C) -> A AND B AND C
 167    A OR (B OR C) -> A OR B OR C
 168    """
 169    if isinstance(expression, exp.Connector):
 170        for node in expression.args.values():
 171            child = node.unnest()
 172            if isinstance(child, expression.__class__):
 173                node.replace(child)
 174    return expression
 175
 176
 177def simplify_connectors(expression, root=True):
 178    def _simplify_connectors(expression, left, right):
 179        if left == right:
 180            return left
 181        if isinstance(expression, exp.And):
 182            if is_false(left) or is_false(right):
 183                return exp.false()
 184            if is_null(left) or is_null(right):
 185                return exp.null()
 186            if always_true(left) and always_true(right):
 187                return exp.true()
 188            if always_true(left):
 189                return right
 190            if always_true(right):
 191                return left
 192            return _simplify_comparison(expression, left, right)
 193        elif isinstance(expression, exp.Or):
 194            if always_true(left) or always_true(right):
 195                return exp.true()
 196            if is_false(left) and is_false(right):
 197                return exp.false()
 198            if (
 199                (is_null(left) and is_null(right))
 200                or (is_null(left) and is_false(right))
 201                or (is_false(left) and is_null(right))
 202            ):
 203                return exp.null()
 204            if is_false(left):
 205                return right
 206            if is_false(right):
 207                return left
 208            return _simplify_comparison(expression, left, right, or_=True)
 209
 210    if isinstance(expression, exp.Connector):
 211        return _flat_simplify(expression, _simplify_connectors, root)
 212    return expression
 213
 214
 215LT_LTE = (exp.LT, exp.LTE)
 216GT_GTE = (exp.GT, exp.GTE)
 217
 218COMPARISONS = (
 219    *LT_LTE,
 220    *GT_GTE,
 221    exp.EQ,
 222    exp.NEQ,
 223    exp.Is,
 224)
 225
 226INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 227    exp.LT: exp.GT,
 228    exp.GT: exp.LT,
 229    exp.LTE: exp.GTE,
 230    exp.GTE: exp.LTE,
 231}
 232
 233
 234def _simplify_comparison(expression, left, right, or_=False):
 235    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 236        ll, lr = left.args.values()
 237        rl, rr = right.args.values()
 238
 239        largs = {ll, lr}
 240        rargs = {rl, rr}
 241
 242        matching = largs & rargs
 243        columns = {m for m in matching if isinstance(m, exp.Column)}
 244
 245        if matching and columns:
 246            try:
 247                l = first(largs - columns)
 248                r = first(rargs - columns)
 249            except StopIteration:
 250                return expression
 251
 252            # make sure the comparison is always of the form x > 1 instead of 1 < x
 253            if left.__class__ in INVERSE_COMPARISONS and l == ll:
 254                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
 255            if right.__class__ in INVERSE_COMPARISONS and r == rl:
 256                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
 257
 258            if l.is_number and r.is_number:
 259                l = float(l.name)
 260                r = float(r.name)
 261            elif l.is_string and r.is_string:
 262                l = l.name
 263                r = r.name
 264            else:
 265                return None
 266
 267            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 268                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 269                    return left if (av > bv if or_ else av <= bv) else right
 270                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 271                    return left if (av < bv if or_ else av >= bv) else right
 272
 273                # we can't ever shortcut to true because the column could be null
 274                if not or_:
 275                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 276                        if av <= bv:
 277                            return exp.false()
 278                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 279                        if av >= bv:
 280                            return exp.false()
 281                    elif isinstance(a, exp.EQ):
 282                        if isinstance(b, exp.LT):
 283                            return exp.false() if av >= bv else a
 284                        if isinstance(b, exp.LTE):
 285                            return exp.false() if av > bv else a
 286                        if isinstance(b, exp.GT):
 287                            return exp.false() if av <= bv else a
 288                        if isinstance(b, exp.GTE):
 289                            return exp.false() if av < bv else a
 290                        if isinstance(b, exp.NEQ):
 291                            return exp.false() if av == bv else a
 292    return None
 293
 294
 295def remove_complements(expression, root=True):
 296    """
 297    Removing complements.
 298
 299    A AND NOT A -> FALSE
 300    A OR NOT A -> TRUE
 301    """
 302    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 303        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 304
 305        for a, b in itertools.permutations(expression.flatten(), 2):
 306            if is_complement(a, b):
 307                return complement
 308    return expression
 309
 310
 311def uniq_sort(expression, root=True):
 312    """
 313    Uniq and sort a connector.
 314
 315    C AND A AND B AND B -> A AND B AND C
 316    """
 317    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 318        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 319        flattened = tuple(expression.flatten())
 320        deduped = {gen(e): e for e in flattened}
 321        arr = tuple(deduped.items())
 322
 323        # check if the operands are already sorted, if not sort them
 324        # A AND C AND B -> A AND B AND C
 325        for i, (sql, e) in enumerate(arr[1:]):
 326            if sql < arr[i][0]:
 327                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 328                break
 329        else:
 330            # we didn't have to sort but maybe we need to dedup
 331            if len(deduped) < len(flattened):
 332                expression = result_func(*deduped.values(), copy=False)
 333
 334    return expression
 335
 336
 337def absorb_and_eliminate(expression, root=True):
 338    """
 339    absorption:
 340        A AND (A OR B) -> A
 341        A OR (A AND B) -> A
 342        A AND (NOT A OR B) -> A AND B
 343        A OR (NOT A AND B) -> A OR B
 344    elimination:
 345        (A AND B) OR (A AND NOT B) -> A
 346        (A OR B) AND (A OR NOT B) -> A
 347    """
 348    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 349        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 350
 351        for a, b in itertools.permutations(expression.flatten(), 2):
 352            if isinstance(a, kind):
 353                aa, ab = a.unnest_operands()
 354
 355                # absorb
 356                if is_complement(b, aa):
 357                    aa.replace(exp.true() if kind == exp.And else exp.false())
 358                elif is_complement(b, ab):
 359                    ab.replace(exp.true() if kind == exp.And else exp.false())
 360                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 361                    a.replace(exp.false() if kind == exp.And else exp.true())
 362                elif isinstance(b, kind):
 363                    # eliminate
 364                    rhs = b.unnest_operands()
 365                    ba, bb = rhs
 366
 367                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 368                        a.replace(aa)
 369                        b.replace(aa)
 370                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 371                        a.replace(ab)
 372                        b.replace(ab)
 373
 374    return expression
 375
 376
 377def propagate_constants(expression, root=True):
 378    """
 379    Propagate constants for conjunctions in DNF:
 380
 381    SELECT * FROM t WHERE a = b AND b = 5 becomes
 382    SELECT * FROM t WHERE a = 5 AND b = 5
 383
 384    Reference: https://www.sqlite.org/optoverview.html
 385    """
 386
 387    if (
 388        isinstance(expression, exp.And)
 389        and (root or not expression.same_parent)
 390        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 391    ):
 392        constant_mapping = {}
 393        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
 394            if isinstance(expr, exp.EQ):
 395                l, r = expr.left, expr.right
 396
 397                # TODO: create a helper that can be used to detect nested literal expressions such
 398                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 399                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 400                    pass
 401                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
 402                    l, r = r, l
 403                else:
 404                    continue
 405
 406                constant_mapping[l] = (id(l), r)
 407
 408        if constant_mapping:
 409            for column in find_all_in_scope(expression, exp.Column):
 410                parent = column.parent
 411                column_id, constant = constant_mapping.get(column) or (None, None)
 412                if (
 413                    column_id is not None
 414                    and id(column) != column_id
 415                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 416                ):
 417                    column.replace(constant.copy())
 418
 419    return expression
 420
 421
 422INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 423    exp.DateAdd: exp.Sub,
 424    exp.DateSub: exp.Add,
 425    exp.DatetimeAdd: exp.Sub,
 426    exp.DatetimeSub: exp.Add,
 427}
 428
 429INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 430    **INVERSE_DATE_OPS,
 431    exp.Add: exp.Sub,
 432    exp.Sub: exp.Add,
 433}
 434
 435
 436def _is_number(expression: exp.Expression) -> bool:
 437    return expression.is_number
 438
 439
 440def _is_interval(expression: exp.Expression) -> bool:
 441    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 442
 443
 444@catch(ModuleNotFoundError, UnsupportedUnit)
 445def simplify_equality(expression: exp.Expression) -> exp.Expression:
 446    """
 447    Use the subtraction and addition properties of equality to simplify expressions:
 448
 449        x + 1 = 3 becomes x = 2
 450
 451    There are two binary operations in the above expression: + and =
 452    Here's how we reference all the operands in the code below:
 453
 454          l     r
 455        x + 1 = 3
 456        a   b
 457    """
 458    if isinstance(expression, COMPARISONS):
 459        l, r = expression.left, expression.right
 460
 461        if l.__class__ in INVERSE_OPS:
 462            pass
 463        elif r.__class__ in INVERSE_OPS:
 464            l, r = r, l
 465        else:
 466            return expression
 467
 468        if r.is_number:
 469            a_predicate = _is_number
 470            b_predicate = _is_number
 471        elif _is_date_literal(r):
 472            a_predicate = _is_date_literal
 473            b_predicate = _is_interval
 474        else:
 475            return expression
 476
 477        if l.__class__ in INVERSE_DATE_OPS:
 478            l = t.cast(exp.IntervalOp, l)
 479            a = l.this
 480            b = l.interval()
 481        else:
 482            l = t.cast(exp.Binary, l)
 483            a, b = l.left, l.right
 484
 485        if not a_predicate(a) and b_predicate(b):
 486            pass
 487        elif not a_predicate(b) and b_predicate(a):
 488            a, b = b, a
 489        else:
 490            return expression
 491
 492        return expression.__class__(
 493            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 494        )
 495    return expression
 496
 497
 498def simplify_literals(expression, root=True):
 499    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 500        return _flat_simplify(expression, _simplify_binary, root)
 501
 502    if isinstance(expression, exp.Neg):
 503        this = expression.this
 504        if this.is_number:
 505            value = this.name
 506            if value[0] == "-":
 507                return exp.Literal.number(value[1:])
 508            return exp.Literal.number(f"-{value}")
 509
 510    return expression
 511
 512
 513def _simplify_binary(expression, a, b):
 514    if isinstance(expression, exp.Is):
 515        if isinstance(b, exp.Not):
 516            c = b.this
 517            not_ = True
 518        else:
 519            c = b
 520            not_ = False
 521
 522        if is_null(c):
 523            if isinstance(a, exp.Literal):
 524                return exp.true() if not_ else exp.false()
 525            if is_null(a):
 526                return exp.false() if not_ else exp.true()
 527    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
 528        return None
 529    elif is_null(a) or is_null(b):
 530        return exp.null()
 531
 532    if a.is_number and b.is_number:
 533        a = int(a.name) if a.is_int else Decimal(a.name)
 534        b = int(b.name) if b.is_int else Decimal(b.name)
 535
 536        if isinstance(expression, exp.Add):
 537            return exp.Literal.number(a + b)
 538        if isinstance(expression, exp.Sub):
 539            return exp.Literal.number(a - b)
 540        if isinstance(expression, exp.Mul):
 541            return exp.Literal.number(a * b)
 542        if isinstance(expression, exp.Div):
 543            # engines have differing int div behavior so intdiv is not safe
 544            if isinstance(a, int) and isinstance(b, int):
 545                return None
 546            return exp.Literal.number(a / b)
 547
 548        boolean = eval_boolean(expression, a, b)
 549
 550        if boolean:
 551            return boolean
 552    elif a.is_string and b.is_string:
 553        boolean = eval_boolean(expression, a.this, b.this)
 554
 555        if boolean:
 556            return boolean
 557    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 558        a, b = extract_date(a), extract_interval(b)
 559        if a and b:
 560            if isinstance(expression, exp.Add):
 561                return date_literal(a + b)
 562            if isinstance(expression, exp.Sub):
 563                return date_literal(a - b)
 564    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 565        a, b = extract_interval(a), extract_date(b)
 566        # you cannot subtract a date from an interval
 567        if a and b and isinstance(expression, exp.Add):
 568            return date_literal(a + b)
 569
 570    return None
 571
 572
 573def simplify_parens(expression):
 574    if not isinstance(expression, exp.Paren):
 575        return expression
 576
 577    this = expression.this
 578    parent = expression.parent
 579
 580    if not isinstance(this, exp.Select) and (
 581        not isinstance(parent, (exp.Condition, exp.Binary))
 582        or isinstance(parent, exp.Paren)
 583        or not isinstance(this, exp.Binary)
 584        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
 585        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 586        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 587        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 588    ):
 589        return this
 590    return expression
 591
 592
 593CONSTANTS = (
 594    exp.Literal,
 595    exp.Boolean,
 596    exp.Null,
 597)
 598
 599
 600def simplify_coalesce(expression):
 601    # COALESCE(x) -> x
 602    if (
 603        isinstance(expression, exp.Coalesce)
 604        and not expression.expressions
 605        # COALESCE is also used as a Spark partitioning hint
 606        and not isinstance(expression.parent, exp.Hint)
 607    ):
 608        return expression.this
 609
 610    if not isinstance(expression, COMPARISONS):
 611        return expression
 612
 613    if isinstance(expression.left, exp.Coalesce):
 614        coalesce = expression.left
 615        other = expression.right
 616    elif isinstance(expression.right, exp.Coalesce):
 617        coalesce = expression.right
 618        other = expression.left
 619    else:
 620        return expression
 621
 622    # This transformation is valid for non-constants,
 623    # but it really only does anything if they are both constants.
 624    if not isinstance(other, CONSTANTS):
 625        return expression
 626
 627    # Find the first constant arg
 628    for arg_index, arg in enumerate(coalesce.expressions):
 629        if isinstance(arg, CONSTANTS):
 630            break
 631    else:
 632        return expression
 633
 634    coalesce.set("expressions", coalesce.expressions[:arg_index])
 635
 636    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 637    # since we already remove COALESCE at the top of this function.
 638    coalesce = coalesce if coalesce.expressions else coalesce.this
 639
 640    # This expression is more complex than when we started, but it will get simplified further
 641    return exp.paren(
 642        exp.or_(
 643            exp.and_(
 644                coalesce.is_(exp.null()).not_(copy=False),
 645                expression.copy(),
 646                copy=False,
 647            ),
 648            exp.and_(
 649                coalesce.is_(exp.null()),
 650                type(expression)(this=arg.copy(), expression=other.copy()),
 651                copy=False,
 652            ),
 653            copy=False,
 654        )
 655    )
 656
 657
 658CONCATS = (exp.Concat, exp.DPipe)
 659SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
 660
 661
 662def simplify_concat(expression):
 663    """Reduces all groups that contain string literals by concatenating them."""
 664    if not isinstance(expression, CONCATS) or (
 665        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 666        isinstance(expression, exp.ConcatWs)
 667        and not expression.expressions[0].is_string
 668    ):
 669        return expression
 670
 671    if isinstance(expression, exp.ConcatWs):
 672        sep_expr, *expressions = expression.expressions
 673        sep = sep_expr.name
 674        concat_type = exp.ConcatWs
 675    else:
 676        expressions = expression.expressions
 677        sep = ""
 678        concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
 679
 680    new_args = []
 681    for is_string_group, group in itertools.groupby(
 682        expressions or expression.flatten(), lambda e: e.is_string
 683    ):
 684        if is_string_group:
 685            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 686        else:
 687            new_args.extend(group)
 688
 689    if len(new_args) == 1 and new_args[0].is_string:
 690        return new_args[0]
 691
 692    if concat_type is exp.ConcatWs:
 693        new_args = [sep_expr] + new_args
 694
 695    return concat_type(expressions=new_args)
 696
 697
 698def simplify_conditionals(expression):
 699    """Simplifies expressions like IF, CASE if their condition is statically known."""
 700    if isinstance(expression, exp.Case):
 701        this = expression.this
 702        for case in expression.args["ifs"]:
 703            cond = case.this
 704            if this:
 705                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 706                cond = cond.replace(this.pop().eq(cond))
 707
 708            if always_true(cond):
 709                return case.args["true"]
 710
 711            if always_false(cond):
 712                case.pop()
 713                if not expression.args["ifs"]:
 714                    return expression.args.get("default") or exp.null()
 715    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 716        if always_true(expression.this):
 717            return expression.args["true"]
 718        if always_false(expression.this):
 719            return expression.args.get("false") or exp.null()
 720
 721    return expression
 722
 723
 724DateRange = t.Tuple[datetime.date, datetime.date]
 725
 726
 727def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
 728    """
 729    Get the date range for a DATE_TRUNC equality comparison:
 730
 731    Example:
 732        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 733    Returns:
 734        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 735    """
 736    floor = date_floor(date, unit)
 737
 738    if date != floor:
 739        # This will always be False, except for NULL values.
 740        return None
 741
 742    return floor, floor + interval(unit)
 743
 744
 745def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 746    """Get the logical expression for a date range"""
 747    return exp.and_(
 748        left >= date_literal(drange[0]),
 749        left < date_literal(drange[1]),
 750        copy=False,
 751    )
 752
 753
 754def _datetrunc_eq(
 755    left: exp.Expression, date: datetime.date, unit: str
 756) -> t.Optional[exp.Expression]:
 757    drange = _datetrunc_range(date, unit)
 758    if not drange:
 759        return None
 760
 761    return _datetrunc_eq_expression(left, drange)
 762
 763
 764def _datetrunc_neq(
 765    left: exp.Expression, date: datetime.date, unit: str
 766) -> t.Optional[exp.Expression]:
 767    drange = _datetrunc_range(date, unit)
 768    if not drange:
 769        return None
 770
 771    return exp.and_(
 772        left < date_literal(drange[0]),
 773        left >= date_literal(drange[1]),
 774        copy=False,
 775    )
 776
 777
 778DateTruncBinaryTransform = t.Callable[
 779    [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
 780]
 781DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 782    exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
 783    exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
 784    exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
 785    exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
 786    exp.EQ: _datetrunc_eq,
 787    exp.NEQ: _datetrunc_neq,
 788}
 789DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 790
 791
 792def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 793    return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
 794
 795
 796@catch(ModuleNotFoundError, UnsupportedUnit)
 797def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
 798    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 799    comparison = expression.__class__
 800
 801    if comparison not in DATETRUNC_COMPARISONS:
 802        return expression
 803
 804    if isinstance(expression, exp.Binary):
 805        l, r = expression.left, expression.right
 806
 807        if _is_datetrunc_predicate(l, r):
 808            pass
 809        elif _is_datetrunc_predicate(r, l):
 810            comparison = INVERSE_COMPARISONS.get(comparison, comparison)
 811            l, r = r, l
 812        else:
 813            return expression
 814
 815        l = t.cast(exp.DateTrunc, l)
 816        unit = l.unit.name.lower()
 817        date = extract_date(r)
 818
 819        if not date:
 820            return expression
 821
 822        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
 823    elif isinstance(expression, exp.In):
 824        l = expression.this
 825        rs = expression.expressions
 826
 827        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 828            l = t.cast(exp.DateTrunc, l)
 829            unit = l.unit.name.lower()
 830
 831            ranges = []
 832            for r in rs:
 833                date = extract_date(r)
 834                if not date:
 835                    return expression
 836                drange = _datetrunc_range(date, unit)
 837                if drange:
 838                    ranges.append(drange)
 839
 840            if not ranges:
 841                return expression
 842
 843            ranges = merge_ranges(ranges)
 844
 845            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 846
 847    return expression
 848
 849
 850# CROSS joins result in an empty table if the right table is empty.
 851# So we can only simplify certain types of joins to CROSS.
 852# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 853JOINS = {
 854    ("", ""),
 855    ("", "INNER"),
 856    ("RIGHT", ""),
 857    ("RIGHT", "OUTER"),
 858}
 859
 860
 861def remove_where_true(expression):
 862    for where in expression.find_all(exp.Where):
 863        if always_true(where.this):
 864            where.parent.set("where", None)
 865    for join in expression.find_all(exp.Join):
 866        if (
 867            always_true(join.args.get("on"))
 868            and not join.args.get("using")
 869            and not join.args.get("method")
 870            and (join.side, join.kind) in JOINS
 871        ):
 872            join.set("on", None)
 873            join.set("side", None)
 874            join.set("kind", "CROSS")
 875
 876
 877def always_true(expression):
 878    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
 879        expression, exp.Literal
 880    )
 881
 882
 883def always_false(expression):
 884    return is_false(expression) or is_null(expression)
 885
 886
 887def is_complement(a, b):
 888    return isinstance(b, exp.Not) and b.this == a
 889
 890
 891def is_false(a: exp.Expression) -> bool:
 892    return type(a) is exp.Boolean and not a.this
 893
 894
 895def is_null(a: exp.Expression) -> bool:
 896    return type(a) is exp.Null
 897
 898
 899def eval_boolean(expression, a, b):
 900    if isinstance(expression, (exp.EQ, exp.Is)):
 901        return boolean_literal(a == b)
 902    if isinstance(expression, exp.NEQ):
 903        return boolean_literal(a != b)
 904    if isinstance(expression, exp.GT):
 905        return boolean_literal(a > b)
 906    if isinstance(expression, exp.GTE):
 907        return boolean_literal(a >= b)
 908    if isinstance(expression, exp.LT):
 909        return boolean_literal(a < b)
 910    if isinstance(expression, exp.LTE):
 911        return boolean_literal(a <= b)
 912    return None
 913
 914
 915def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
 916    if isinstance(value, datetime.datetime):
 917        return value.date()
 918    if isinstance(value, datetime.date):
 919        return value
 920    try:
 921        return datetime.datetime.fromisoformat(value).date()
 922    except ValueError:
 923        return None
 924
 925
 926def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
 927    if isinstance(value, datetime.datetime):
 928        return value
 929    if isinstance(value, datetime.date):
 930        return datetime.datetime(year=value.year, month=value.month, day=value.day)
 931    try:
 932        return datetime.datetime.fromisoformat(value)
 933    except ValueError:
 934        return None
 935
 936
 937def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 938    if not value:
 939        return None
 940    if to.is_type(exp.DataType.Type.DATE):
 941        return cast_as_date(value)
 942    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
 943        return cast_as_datetime(value)
 944    return None
 945
 946
 947def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 948    if isinstance(cast, exp.Cast):
 949        to = cast.to
 950    elif isinstance(cast, exp.TsOrDsToDate):
 951        to = exp.DataType.build(exp.DataType.Type.DATE)
 952    else:
 953        return None
 954
 955    if isinstance(cast.this, exp.Literal):
 956        value: t.Any = cast.this.name
 957    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
 958        value = extract_date(cast.this)
 959    else:
 960        return None
 961    return cast_value(value, to)
 962
 963
 964def _is_date_literal(expression: exp.Expression) -> bool:
 965    return extract_date(expression) is not None
 966
 967
 968def extract_interval(expression):
 969    n = int(expression.name)
 970    unit = expression.text("unit").lower()
 971
 972    try:
 973        return interval(unit, n)
 974    except (UnsupportedUnit, ModuleNotFoundError):
 975        return None
 976
 977
 978def date_literal(date):
 979    return exp.cast(
 980        exp.Literal.string(date),
 981        exp.DataType.Type.DATETIME
 982        if isinstance(date, datetime.datetime)
 983        else exp.DataType.Type.DATE,
 984    )
 985
 986
 987def interval(unit: str, n: int = 1):
 988    from dateutil.relativedelta import relativedelta
 989
 990    if unit == "year":
 991        return relativedelta(years=1 * n)
 992    if unit == "quarter":
 993        return relativedelta(months=3 * n)
 994    if unit == "month":
 995        return relativedelta(months=1 * n)
 996    if unit == "week":
 997        return relativedelta(weeks=1 * n)
 998    if unit == "day":
 999        return relativedelta(days=1 * n)
1000    if unit == "hour":
1001        return relativedelta(hours=1 * n)
1002    if unit == "minute":
1003        return relativedelta(minutes=1 * n)
1004    if unit == "second":
1005        return relativedelta(seconds=1 * n)
1006
1007    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1008
1009
1010def date_floor(d: datetime.date, unit: str) -> datetime.date:
1011    if unit == "year":
1012        return d.replace(month=1, day=1)
1013    if unit == "quarter":
1014        if d.month <= 3:
1015            return d.replace(month=1, day=1)
1016        elif d.month <= 6:
1017            return d.replace(month=4, day=1)
1018        elif d.month <= 9:
1019            return d.replace(month=7, day=1)
1020        else:
1021            return d.replace(month=10, day=1)
1022    if unit == "month":
1023        return d.replace(month=d.month, day=1)
1024    if unit == "week":
1025        # Assuming week starts on Monday (0) and ends on Sunday (6)
1026        return d - datetime.timedelta(days=d.weekday())
1027    if unit == "day":
1028        return d
1029
1030    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1031
1032
1033def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1034    floor = date_floor(d, unit)
1035
1036    if floor == d:
1037        return d
1038
1039    return floor + interval(unit)
1040
1041
1042def boolean_literal(condition):
1043    return exp.true() if condition else exp.false()
1044
1045
1046def _flat_simplify(expression, simplifier, root=True):
1047    if root or not expression.same_parent:
1048        operands = []
1049        queue = deque(expression.flatten(unnest=False))
1050        size = len(queue)
1051
1052        while queue:
1053            a = queue.popleft()
1054
1055            for b in queue:
1056                result = simplifier(expression, a, b)
1057
1058                if result and result is not expression:
1059                    queue.remove(b)
1060                    queue.appendleft(result)
1061                    break
1062            else:
1063                operands.append(a)
1064
1065        if len(operands) < size:
1066            return functools.reduce(
1067                lambda a, b: expression.__class__(this=a, expression=b), operands
1068            )
1069    return expression
1070
1071
1072def gen(expression: t.Any) -> str:
1073    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1074
1075    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1076    generator is expensive so we have a bare minimum sql generator here.
1077    """
1078    if expression is None:
1079        return "_"
1080    if is_iterable(expression):
1081        return ",".join(gen(e) for e in expression)
1082    if not isinstance(expression, exp.Expression):
1083        return str(expression)
1084
1085    etype = type(expression)
1086    if etype in GEN_MAP:
1087        return GEN_MAP[etype](expression)
1088    return f"{expression.key} {gen(expression.args.values())}"
1089
1090
1091GEN_MAP = {
1092    exp.Add: lambda e: _binary(e, "+"),
1093    exp.And: lambda e: _binary(e, "AND"),
1094    exp.Anonymous: lambda e: f"{e.this} {','.join(gen(e) for e in e.expressions)}",
1095    exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
1096    exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
1097    exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
1098    exp.Column: lambda e: ".".join(gen(p) for p in e.parts),
1099    exp.DataType: lambda e: f"{e.this.name} {gen(tuple(e.args.values())[1:])}",
1100    exp.Div: lambda e: _binary(e, "/"),
1101    exp.Dot: lambda e: _binary(e, "."),
1102    exp.DPipe: lambda e: _binary(e, "||"),
1103    exp.SafeDPipe: lambda e: _binary(e, "||"),
1104    exp.EQ: lambda e: _binary(e, "="),
1105    exp.GT: lambda e: _binary(e, ">"),
1106    exp.GTE: lambda e: _binary(e, ">="),
1107    exp.Identifier: lambda e: f'"{e.name}"' if e.quoted else e.name,
1108    exp.ILike: lambda e: _binary(e, "ILIKE"),
1109    exp.In: lambda e: f"{gen(e.this)} IN ({gen(tuple(e.args.values())[1:])})",
1110    exp.Is: lambda e: _binary(e, "IS"),
1111    exp.Like: lambda e: _binary(e, "LIKE"),
1112    exp.Literal: lambda e: f"'{e.name}'" if e.is_string else e.name,
1113    exp.LT: lambda e: _binary(e, "<"),
1114    exp.LTE: lambda e: _binary(e, "<="),
1115    exp.Mod: lambda e: _binary(e, "%"),
1116    exp.Mul: lambda e: _binary(e, "*"),
1117    exp.Neg: lambda e: _unary(e, "-"),
1118    exp.NEQ: lambda e: _binary(e, "<>"),
1119    exp.Not: lambda e: _unary(e, "NOT"),
1120    exp.Null: lambda e: "NULL",
1121    exp.Or: lambda e: _binary(e, "OR"),
1122    exp.Paren: lambda e: f"({gen(e.this)})",
1123    exp.Sub: lambda e: _binary(e, "-"),
1124    exp.Subquery: lambda e: f"({gen(e.args.values())})",
1125    exp.Table: lambda e: gen(e.args.values()),
1126    exp.Var: lambda e: e.name,
1127}
1128
1129
1130def _binary(e: exp.Binary, op: str) -> str:
1131    return f"{gen(e.left)} {op} {gen(e.right)}"
1132
1133
1134def _unary(e: exp.Unary, op: str) -> str:
1135    return f"{op} {gen(e.this)}"
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
18class UnsupportedUnit(Exception):
19    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify(expression, constant_propagation=False):
22def simplify(expression, constant_propagation=False):
23    """
24    Rewrite sqlglot AST to simplify expressions.
25
26    Example:
27        >>> import sqlglot
28        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
29        >>> simplify(expression).sql()
30        'TRUE'
31
32    Args:
33        expression (sqlglot.Expression): expression to simplify
34        constant_propagation: whether or not the constant propagation rule should be used
35
36    Returns:
37        sqlglot.Expression: simplified expression
38    """
39
40    # group by expressions cannot be simplified, for example
41    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
42    # the projection must exactly match the group by key
43    for group in expression.find_all(exp.Group):
44        select = group.parent
45        groups = set(group.expressions)
46        group.meta[FINAL] = True
47
48        for e in select.selects:
49            for node, *_ in e.walk():
50                if node in groups:
51                    e.meta[FINAL] = True
52                    break
53
54        having = select.args.get("having")
55        if having:
56            for node, *_ in having.walk():
57                if node in groups:
58                    having.meta[FINAL] = True
59                    break
60
61    def _simplify(expression, root=True):
62        if expression.meta.get(FINAL):
63            return expression
64
65        # Pre-order transformations
66        node = expression
67        node = rewrite_between(node)
68        node = uniq_sort(node, root)
69        node = absorb_and_eliminate(node, root)
70        node = simplify_concat(node)
71        node = simplify_conditionals(node)
72
73        if constant_propagation:
74            node = propagate_constants(node, root)
75
76        exp.replace_children(node, lambda e: _simplify(e, False))
77
78        # Post-order transformations
79        node = simplify_not(node)
80        node = flatten(node)
81        node = simplify_connectors(node, root)
82        node = remove_complements(node, root)
83        node = simplify_coalesce(node)
84        node.parent = expression.parent
85        node = simplify_literals(node, root)
86        node = simplify_equality(node)
87        node = simplify_parens(node)
88        node = simplify_datetrunc_predicate(node)
89
90        if root:
91            expression.replace(node)
92
93        return node
94
95    expression = while_changing(expression, _simplify)
96    remove_where_true(expression)
97    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether or not the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
100def catch(*exceptions):
101    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
102
103    def decorator(func):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression
109
110        return wrapped
111
112    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
115def rewrite_between(expression: exp.Expression) -> exp.Expression:
116    """Rewrite x between y and z to x >= y AND x <= z.
117
118    This is done because comparison simplification is only done on lt/lte/gt/gte.
119    """
120    if isinstance(expression, exp.Between):
121        return exp.and_(
122            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
123            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
124            copy=False,
125        )
126    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
129def simplify_not(expression):
130    """
131    Demorgan's Law
132    NOT (x OR y) -> NOT x AND NOT y
133    NOT (x AND y) -> NOT x OR NOT y
134    """
135    if isinstance(expression, exp.Not):
136        if is_null(expression.this):
137            return exp.null()
138        if isinstance(expression.this, exp.Paren):
139            condition = expression.this.unnest()
140            if isinstance(condition, exp.And):
141                return exp.or_(
142                    exp.not_(condition.left, copy=False),
143                    exp.not_(condition.right, copy=False),
144                    copy=False,
145                )
146            if isinstance(condition, exp.Or):
147                return exp.and_(
148                    exp.not_(condition.left, copy=False),
149                    exp.not_(condition.right, copy=False),
150                    copy=False,
151                )
152            if is_null(condition):
153                return exp.null()
154        if always_true(expression.this):
155            return exp.false()
156        if is_false(expression.this):
157            return exp.true()
158        if isinstance(expression.this, exp.Not):
159            # double negation
160            # NOT NOT x -> x
161            return expression.this.this
162    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
165def flatten(expression):
166    """
167    A AND (B AND C) -> A AND B AND C
168    A OR (B OR C) -> A OR B OR C
169    """
170    if isinstance(expression, exp.Connector):
171        for node in expression.args.values():
172            child = node.unnest()
173            if isinstance(child, expression.__class__):
174                node.replace(child)
175    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
178def simplify_connectors(expression, root=True):
179    def _simplify_connectors(expression, left, right):
180        if left == right:
181            return left
182        if isinstance(expression, exp.And):
183            if is_false(left) or is_false(right):
184                return exp.false()
185            if is_null(left) or is_null(right):
186                return exp.null()
187            if always_true(left) and always_true(right):
188                return exp.true()
189            if always_true(left):
190                return right
191            if always_true(right):
192                return left
193            return _simplify_comparison(expression, left, right)
194        elif isinstance(expression, exp.Or):
195            if always_true(left) or always_true(right):
196                return exp.true()
197            if is_false(left) and is_false(right):
198                return exp.false()
199            if (
200                (is_null(left) and is_null(right))
201                or (is_null(left) and is_false(right))
202                or (is_false(left) and is_null(right))
203            ):
204                return exp.null()
205            if is_false(left):
206                return right
207            if is_false(right):
208                return left
209            return _simplify_comparison(expression, left, right, or_=True)
210
211    if isinstance(expression, exp.Connector):
212        return _flat_simplify(expression, _simplify_connectors, root)
213    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
def remove_complements(expression, root=True):
296def remove_complements(expression, root=True):
297    """
298    Removing complements.
299
300    A AND NOT A -> FALSE
301    A OR NOT A -> TRUE
302    """
303    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
304        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
305
306        for a, b in itertools.permutations(expression.flatten(), 2):
307            if is_complement(a, b):
308                return complement
309    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
312def uniq_sort(expression, root=True):
313    """
314    Uniq and sort a connector.
315
316    C AND A AND B AND B -> A AND B AND C
317    """
318    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
319        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
320        flattened = tuple(expression.flatten())
321        deduped = {gen(e): e for e in flattened}
322        arr = tuple(deduped.items())
323
324        # check if the operands are already sorted, if not sort them
325        # A AND C AND B -> A AND B AND C
326        for i, (sql, e) in enumerate(arr[1:]):
327            if sql < arr[i][0]:
328                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
329                break
330        else:
331            # we didn't have to sort but maybe we need to dedup
332            if len(deduped) < len(flattened):
333                expression = result_func(*deduped.values(), copy=False)
334
335    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
338def absorb_and_eliminate(expression, root=True):
339    """
340    absorption:
341        A AND (A OR B) -> A
342        A OR (A AND B) -> A
343        A AND (NOT A OR B) -> A AND B
344        A OR (NOT A AND B) -> A OR B
345    elimination:
346        (A AND B) OR (A AND NOT B) -> A
347        (A OR B) AND (A OR NOT B) -> A
348    """
349    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
350        kind = exp.Or if isinstance(expression, exp.And) else exp.And
351
352        for a, b in itertools.permutations(expression.flatten(), 2):
353            if isinstance(a, kind):
354                aa, ab = a.unnest_operands()
355
356                # absorb
357                if is_complement(b, aa):
358                    aa.replace(exp.true() if kind == exp.And else exp.false())
359                elif is_complement(b, ab):
360                    ab.replace(exp.true() if kind == exp.And else exp.false())
361                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
362                    a.replace(exp.false() if kind == exp.And else exp.true())
363                elif isinstance(b, kind):
364                    # eliminate
365                    rhs = b.unnest_operands()
366                    ba, bb = rhs
367
368                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
369                        a.replace(aa)
370                        b.replace(aa)
371                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
372                        a.replace(ab)
373                        b.replace(ab)
374
375    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
378def propagate_constants(expression, root=True):
379    """
380    Propagate constants for conjunctions in DNF:
381
382    SELECT * FROM t WHERE a = b AND b = 5 becomes
383    SELECT * FROM t WHERE a = 5 AND b = 5
384
385    Reference: https://www.sqlite.org/optoverview.html
386    """
387
388    if (
389        isinstance(expression, exp.And)
390        and (root or not expression.same_parent)
391        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
392    ):
393        constant_mapping = {}
394        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
395            if isinstance(expr, exp.EQ):
396                l, r = expr.left, expr.right
397
398                # TODO: create a helper that can be used to detect nested literal expressions such
399                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
400                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
401                    pass
402                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
403                    l, r = r, l
404                else:
405                    continue
406
407                constant_mapping[l] = (id(l), r)
408
409        if constant_mapping:
410            for column in find_all_in_scope(expression, exp.Column):
411                parent = column.parent
412                column_id, constant = constant_mapping.get(column) or (None, None)
413                if (
414                    column_id is not None
415                    and id(column) != column_id
416                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
417                ):
418                    column.replace(constant.copy())
419
420    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
499def simplify_literals(expression, root=True):
500    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
501        return _flat_simplify(expression, _simplify_binary, root)
502
503    if isinstance(expression, exp.Neg):
504        this = expression.this
505        if this.is_number:
506            value = this.name
507            if value[0] == "-":
508                return exp.Literal.number(value[1:])
509            return exp.Literal.number(f"-{value}")
510
511    return expression
def simplify_parens(expression):
574def simplify_parens(expression):
575    if not isinstance(expression, exp.Paren):
576        return expression
577
578    this = expression.this
579    parent = expression.parent
580
581    if not isinstance(this, exp.Select) and (
582        not isinstance(parent, (exp.Condition, exp.Binary))
583        or isinstance(parent, exp.Paren)
584        or not isinstance(this, exp.Binary)
585        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
586        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
587        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
588        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
589    ):
590        return this
591    return expression
def simplify_coalesce(expression):
601def simplify_coalesce(expression):
602    # COALESCE(x) -> x
603    if (
604        isinstance(expression, exp.Coalesce)
605        and not expression.expressions
606        # COALESCE is also used as a Spark partitioning hint
607        and not isinstance(expression.parent, exp.Hint)
608    ):
609        return expression.this
610
611    if not isinstance(expression, COMPARISONS):
612        return expression
613
614    if isinstance(expression.left, exp.Coalesce):
615        coalesce = expression.left
616        other = expression.right
617    elif isinstance(expression.right, exp.Coalesce):
618        coalesce = expression.right
619        other = expression.left
620    else:
621        return expression
622
623    # This transformation is valid for non-constants,
624    # but it really only does anything if they are both constants.
625    if not isinstance(other, CONSTANTS):
626        return expression
627
628    # Find the first constant arg
629    for arg_index, arg in enumerate(coalesce.expressions):
630        if isinstance(arg, CONSTANTS):
631            break
632    else:
633        return expression
634
635    coalesce.set("expressions", coalesce.expressions[:arg_index])
636
637    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
638    # since we already remove COALESCE at the top of this function.
639    coalesce = coalesce if coalesce.expressions else coalesce.this
640
641    # This expression is more complex than when we started, but it will get simplified further
642    return exp.paren(
643        exp.or_(
644            exp.and_(
645                coalesce.is_(exp.null()).not_(copy=False),
646                expression.copy(),
647                copy=False,
648            ),
649            exp.and_(
650                coalesce.is_(exp.null()),
651                type(expression)(this=arg.copy(), expression=other.copy()),
652                copy=False,
653            ),
654            copy=False,
655        )
656    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
SAFE_CONCATS = (<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.SafeDPipe'>)
def simplify_concat(expression):
663def simplify_concat(expression):
664    """Reduces all groups that contain string literals by concatenating them."""
665    if not isinstance(expression, CONCATS) or (
666        # We can't reduce a CONCAT_WS call if we don't statically know the separator
667        isinstance(expression, exp.ConcatWs)
668        and not expression.expressions[0].is_string
669    ):
670        return expression
671
672    if isinstance(expression, exp.ConcatWs):
673        sep_expr, *expressions = expression.expressions
674        sep = sep_expr.name
675        concat_type = exp.ConcatWs
676    else:
677        expressions = expression.expressions
678        sep = ""
679        concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
680
681    new_args = []
682    for is_string_group, group in itertools.groupby(
683        expressions or expression.flatten(), lambda e: e.is_string
684    ):
685        if is_string_group:
686            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
687        else:
688            new_args.extend(group)
689
690    if len(new_args) == 1 and new_args[0].is_string:
691        return new_args[0]
692
693    if concat_type is exp.ConcatWs:
694        new_args = [sep_expr] + new_args
695
696    return concat_type(expressions=new_args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
699def simplify_conditionals(expression):
700    """Simplifies expressions like IF, CASE if their condition is statically known."""
701    if isinstance(expression, exp.Case):
702        this = expression.this
703        for case in expression.args["ifs"]:
704            cond = case.this
705            if this:
706                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
707                cond = cond.replace(this.pop().eq(cond))
708
709            if always_true(cond):
710                return case.args["true"]
711
712            if always_false(cond):
713                case.pop()
714                if not expression.args["ifs"]:
715                    return expression.args.get("default") or exp.null()
716    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
717        if always_true(expression.this):
718            return expression.args["true"]
719        if always_false(expression.this):
720            return expression.args.get("false") or exp.null()
721
722    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

DateRange = typing.Tuple[datetime.date, datetime.date]
DateTruncBinaryTransform = typing.Callable[[sqlglot.expressions.Expression, datetime.date, str], typing.Optional[sqlglot.expressions.Expression]]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.LTE'>}
def simplify_datetrunc_predicate(expression, *args, **kwargs):
104        def wrapped(expression, *args, **kwargs):
105            try:
106                return func(expression, *args, **kwargs)
107            except exceptions:
108                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

JOINS = {('RIGHT', ''), ('RIGHT', 'OUTER'), ('', 'INNER'), ('', '')}
def remove_where_true(expression):
862def remove_where_true(expression):
863    for where in expression.find_all(exp.Where):
864        if always_true(where.this):
865            where.parent.set("where", None)
866    for join in expression.find_all(exp.Join):
867        if (
868            always_true(join.args.get("on"))
869            and not join.args.get("using")
870            and not join.args.get("method")
871            and (join.side, join.kind) in JOINS
872        ):
873            join.set("on", None)
874            join.set("side", None)
875            join.set("kind", "CROSS")
def always_true(expression):
878def always_true(expression):
879    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
880        expression, exp.Literal
881    )
def always_false(expression):
884def always_false(expression):
885    return is_false(expression) or is_null(expression)
def is_complement(a, b):
888def is_complement(a, b):
889    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
892def is_false(a: exp.Expression) -> bool:
893    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
896def is_null(a: exp.Expression) -> bool:
897    return type(a) is exp.Null
def eval_boolean(expression, a, b):
900def eval_boolean(expression, a, b):
901    if isinstance(expression, (exp.EQ, exp.Is)):
902        return boolean_literal(a == b)
903    if isinstance(expression, exp.NEQ):
904        return boolean_literal(a != b)
905    if isinstance(expression, exp.GT):
906        return boolean_literal(a > b)
907    if isinstance(expression, exp.GTE):
908        return boolean_literal(a >= b)
909    if isinstance(expression, exp.LT):
910        return boolean_literal(a < b)
911    if isinstance(expression, exp.LTE):
912        return boolean_literal(a <= b)
913    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
916def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
917    if isinstance(value, datetime.datetime):
918        return value.date()
919    if isinstance(value, datetime.date):
920        return value
921    try:
922        return datetime.datetime.fromisoformat(value).date()
923    except ValueError:
924        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
927def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
928    if isinstance(value, datetime.datetime):
929        return value
930    if isinstance(value, datetime.date):
931        return datetime.datetime(year=value.year, month=value.month, day=value.day)
932    try:
933        return datetime.datetime.fromisoformat(value)
934    except ValueError:
935        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
938def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
939    if not value:
940        return None
941    if to.is_type(exp.DataType.Type.DATE):
942        return cast_as_date(value)
943    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
944        return cast_as_datetime(value)
945    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
948def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
949    if isinstance(cast, exp.Cast):
950        to = cast.to
951    elif isinstance(cast, exp.TsOrDsToDate):
952        to = exp.DataType.build(exp.DataType.Type.DATE)
953    else:
954        return None
955
956    if isinstance(cast.this, exp.Literal):
957        value: t.Any = cast.this.name
958    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
959        value = extract_date(cast.this)
960    else:
961        return None
962    return cast_value(value, to)
def extract_interval(expression):
969def extract_interval(expression):
970    n = int(expression.name)
971    unit = expression.text("unit").lower()
972
973    try:
974        return interval(unit, n)
975    except (UnsupportedUnit, ModuleNotFoundError):
976        return None
def date_literal(date):
979def date_literal(date):
980    return exp.cast(
981        exp.Literal.string(date),
982        exp.DataType.Type.DATETIME
983        if isinstance(date, datetime.datetime)
984        else exp.DataType.Type.DATE,
985    )
def interval(unit: str, n: int = 1):
 988def interval(unit: str, n: int = 1):
 989    from dateutil.relativedelta import relativedelta
 990
 991    if unit == "year":
 992        return relativedelta(years=1 * n)
 993    if unit == "quarter":
 994        return relativedelta(months=3 * n)
 995    if unit == "month":
 996        return relativedelta(months=1 * n)
 997    if unit == "week":
 998        return relativedelta(weeks=1 * n)
 999    if unit == "day":
1000        return relativedelta(days=1 * n)
1001    if unit == "hour":
1002        return relativedelta(hours=1 * n)
1003    if unit == "minute":
1004        return relativedelta(minutes=1 * n)
1005    if unit == "second":
1006        return relativedelta(seconds=1 * n)
1007
1008    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
1011def date_floor(d: datetime.date, unit: str) -> datetime.date:
1012    if unit == "year":
1013        return d.replace(month=1, day=1)
1014    if unit == "quarter":
1015        if d.month <= 3:
1016            return d.replace(month=1, day=1)
1017        elif d.month <= 6:
1018            return d.replace(month=4, day=1)
1019        elif d.month <= 9:
1020            return d.replace(month=7, day=1)
1021        else:
1022            return d.replace(month=10, day=1)
1023    if unit == "month":
1024        return d.replace(month=d.month, day=1)
1025    if unit == "week":
1026        # Assuming week starts on Monday (0) and ends on Sunday (6)
1027        return d - datetime.timedelta(days=d.weekday())
1028    if unit == "day":
1029        return d
1030
1031    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1034def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1035    floor = date_floor(d, unit)
1036
1037    if floor == d:
1038        return d
1039
1040    return floor + interval(unit)
def boolean_literal(condition):
1043def boolean_literal(condition):
1044    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1073def gen(expression: t.Any) -> str:
1074    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1075
1076    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1077    generator is expensive so we have a bare minimum sql generator here.
1078    """
1079    if expression is None:
1080        return "_"
1081    if is_iterable(expression):
1082        return ",".join(gen(e) for e in expression)
1083    if not isinstance(expression, exp.Expression):
1084        return str(expression)
1085
1086    etype = type(expression)
1087    if etype in GEN_MAP:
1088        return GEN_MAP[etype](expression)
1089    return f"{expression.key} {gen(expression.args.values())}"

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

GEN_MAP = {<class 'sqlglot.expressions.Add'>: <function <lambda>>, <class 'sqlglot.expressions.And'>: <function <lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function <lambda>>, <class 'sqlglot.expressions.Between'>: <function <lambda>>, <class 'sqlglot.expressions.Boolean'>: <function <lambda>>, <class 'sqlglot.expressions.Bracket'>: <function <lambda>>, <class 'sqlglot.expressions.Column'>: <function <lambda>>, <class 'sqlglot.expressions.DataType'>: <function <lambda>>, <class 'sqlglot.expressions.Div'>: <function <lambda>>, <class 'sqlglot.expressions.Dot'>: <function <lambda>>, <class 'sqlglot.expressions.DPipe'>: <function <lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.Identifier'>: <function <lambda>>, <class 'sqlglot.expressions.ILike'>: <function <lambda>>, <class 'sqlglot.expressions.In'>: <function <lambda>>, <class 'sqlglot.expressions.Is'>: <function <lambda>>, <class 'sqlglot.expressions.Like'>: <function <lambda>>, <class 'sqlglot.expressions.Literal'>: <function <lambda>>, <class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.Mod'>: <function <lambda>>, <class 'sqlglot.expressions.Mul'>: <function <lambda>>, <class 'sqlglot.expressions.Neg'>: <function <lambda>>, <class 'sqlglot.expressions.NEQ'>: <function <lambda>>, <class 'sqlglot.expressions.Not'>: <function <lambda>>, <class 'sqlglot.expressions.Null'>: <function <lambda>>, <class 'sqlglot.expressions.Or'>: <function <lambda>>, <class 'sqlglot.expressions.Paren'>: <function <lambda>>, <class 'sqlglot.expressions.Sub'>: <function <lambda>>, <class 'sqlglot.expressions.Subquery'>: <function <lambda>>, <class 'sqlglot.expressions.Table'>: <function <lambda>>, <class 'sqlglot.expressions.Var'>: <function <lambda>>}