Edit on GitHub

sqlglot.optimizer.simplify

   1import datetime
   2import functools
   3import itertools
   4import typing as t
   5from collections import deque
   6from decimal import Decimal
   7
   8import sqlglot
   9from sqlglot import exp
  10from sqlglot.generator import cached_generator
  11from sqlglot.helper import first, merge_ranges, while_changing
  12from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  13
  14# Final means that an expression should not be simplified
  15FINAL = "final"
  16
  17
  18class UnsupportedUnit(Exception):
  19    pass
  20
  21
  22def simplify(expression, constant_propagation=False):
  23    """
  24    Rewrite sqlglot AST to simplify expressions.
  25
  26    Example:
  27        >>> import sqlglot
  28        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  29        >>> simplify(expression).sql()
  30        'TRUE'
  31
  32    Args:
  33        expression (sqlglot.Expression): expression to simplify
  34        constant_propagation: whether or not the constant propagation rule should be used
  35
  36    Returns:
  37        sqlglot.Expression: simplified expression
  38    """
  39
  40    generate = cached_generator()
  41
  42    # group by expressions cannot be simplified, for example
  43    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  44    # the projection must exactly match the group by key
  45    for group in expression.find_all(exp.Group):
  46        select = group.parent
  47        groups = set(group.expressions)
  48        group.meta[FINAL] = True
  49
  50        for e in select.selects:
  51            for node, *_ in e.walk():
  52                if node in groups:
  53                    e.meta[FINAL] = True
  54                    break
  55
  56        having = select.args.get("having")
  57        if having:
  58            for node, *_ in having.walk():
  59                if node in groups:
  60                    having.meta[FINAL] = True
  61                    break
  62
  63    def _simplify(expression, root=True):
  64        if expression.meta.get(FINAL):
  65            return expression
  66
  67        # Pre-order transformations
  68        node = expression
  69        node = rewrite_between(node)
  70        node = uniq_sort(node, generate, root)
  71        node = absorb_and_eliminate(node, root)
  72        node = simplify_concat(node)
  73        node = simplify_conditionals(node)
  74
  75        if constant_propagation:
  76            node = propagate_constants(node, root)
  77
  78        exp.replace_children(node, lambda e: _simplify(e, False))
  79
  80        # Post-order transformations
  81        node = simplify_not(node)
  82        node = flatten(node)
  83        node = simplify_connectors(node, root)
  84        node = remove_complements(node, root)
  85        node = simplify_coalesce(node)
  86        node.parent = expression.parent
  87        node = simplify_literals(node, root)
  88        node = simplify_equality(node)
  89        node = simplify_parens(node)
  90        node = simplify_datetrunc_predicate(node)
  91
  92        if root:
  93            expression.replace(node)
  94
  95        return node
  96
  97    expression = while_changing(expression, _simplify)
  98    remove_where_true(expression)
  99    return expression
 100
 101
 102def catch(*exceptions):
 103    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 104
 105    def decorator(func):
 106        def wrapped(expression, *args, **kwargs):
 107            try:
 108                return func(expression, *args, **kwargs)
 109            except exceptions:
 110                return expression
 111
 112        return wrapped
 113
 114    return decorator
 115
 116
 117def rewrite_between(expression: exp.Expression) -> exp.Expression:
 118    """Rewrite x between y and z to x >= y AND x <= z.
 119
 120    This is done because comparison simplification is only done on lt/lte/gt/gte.
 121    """
 122    if isinstance(expression, exp.Between):
 123        return exp.and_(
 124            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 125            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 126            copy=False,
 127        )
 128    return expression
 129
 130
 131def simplify_not(expression):
 132    """
 133    Demorgan's Law
 134    NOT (x OR y) -> NOT x AND NOT y
 135    NOT (x AND y) -> NOT x OR NOT y
 136    """
 137    if isinstance(expression, exp.Not):
 138        if is_null(expression.this):
 139            return exp.null()
 140        if isinstance(expression.this, exp.Paren):
 141            condition = expression.this.unnest()
 142            if isinstance(condition, exp.And):
 143                return exp.or_(
 144                    exp.not_(condition.left, copy=False),
 145                    exp.not_(condition.right, copy=False),
 146                    copy=False,
 147                )
 148            if isinstance(condition, exp.Or):
 149                return exp.and_(
 150                    exp.not_(condition.left, copy=False),
 151                    exp.not_(condition.right, copy=False),
 152                    copy=False,
 153                )
 154            if is_null(condition):
 155                return exp.null()
 156        if always_true(expression.this):
 157            return exp.false()
 158        if is_false(expression.this):
 159            return exp.true()
 160        if isinstance(expression.this, exp.Not):
 161            # double negation
 162            # NOT NOT x -> x
 163            return expression.this.this
 164    return expression
 165
 166
 167def flatten(expression):
 168    """
 169    A AND (B AND C) -> A AND B AND C
 170    A OR (B OR C) -> A OR B OR C
 171    """
 172    if isinstance(expression, exp.Connector):
 173        for node in expression.args.values():
 174            child = node.unnest()
 175            if isinstance(child, expression.__class__):
 176                node.replace(child)
 177    return expression
 178
 179
 180def simplify_connectors(expression, root=True):
 181    def _simplify_connectors(expression, left, right):
 182        if left == right:
 183            return left
 184        if isinstance(expression, exp.And):
 185            if is_false(left) or is_false(right):
 186                return exp.false()
 187            if is_null(left) or is_null(right):
 188                return exp.null()
 189            if always_true(left) and always_true(right):
 190                return exp.true()
 191            if always_true(left):
 192                return right
 193            if always_true(right):
 194                return left
 195            return _simplify_comparison(expression, left, right)
 196        elif isinstance(expression, exp.Or):
 197            if always_true(left) or always_true(right):
 198                return exp.true()
 199            if is_false(left) and is_false(right):
 200                return exp.false()
 201            if (
 202                (is_null(left) and is_null(right))
 203                or (is_null(left) and is_false(right))
 204                or (is_false(left) and is_null(right))
 205            ):
 206                return exp.null()
 207            if is_false(left):
 208                return right
 209            if is_false(right):
 210                return left
 211            return _simplify_comparison(expression, left, right, or_=True)
 212
 213    if isinstance(expression, exp.Connector):
 214        return _flat_simplify(expression, _simplify_connectors, root)
 215    return expression
 216
 217
 218LT_LTE = (exp.LT, exp.LTE)
 219GT_GTE = (exp.GT, exp.GTE)
 220
 221COMPARISONS = (
 222    *LT_LTE,
 223    *GT_GTE,
 224    exp.EQ,
 225    exp.NEQ,
 226    exp.Is,
 227)
 228
 229INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 230    exp.LT: exp.GT,
 231    exp.GT: exp.LT,
 232    exp.LTE: exp.GTE,
 233    exp.GTE: exp.LTE,
 234}
 235
 236
 237def _simplify_comparison(expression, left, right, or_=False):
 238    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 239        ll, lr = left.args.values()
 240        rl, rr = right.args.values()
 241
 242        largs = {ll, lr}
 243        rargs = {rl, rr}
 244
 245        matching = largs & rargs
 246        columns = {m for m in matching if isinstance(m, exp.Column)}
 247
 248        if matching and columns:
 249            try:
 250                l = first(largs - columns)
 251                r = first(rargs - columns)
 252            except StopIteration:
 253                return expression
 254
 255            # make sure the comparison is always of the form x > 1 instead of 1 < x
 256            if left.__class__ in INVERSE_COMPARISONS and l == ll:
 257                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
 258            if right.__class__ in INVERSE_COMPARISONS and r == rl:
 259                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
 260
 261            if l.is_number and r.is_number:
 262                l = float(l.name)
 263                r = float(r.name)
 264            elif l.is_string and r.is_string:
 265                l = l.name
 266                r = r.name
 267            else:
 268                return None
 269
 270            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 271                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 272                    return left if (av > bv if or_ else av <= bv) else right
 273                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 274                    return left if (av < bv if or_ else av >= bv) else right
 275
 276                # we can't ever shortcut to true because the column could be null
 277                if not or_:
 278                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 279                        if av <= bv:
 280                            return exp.false()
 281                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 282                        if av >= bv:
 283                            return exp.false()
 284                    elif isinstance(a, exp.EQ):
 285                        if isinstance(b, exp.LT):
 286                            return exp.false() if av >= bv else a
 287                        if isinstance(b, exp.LTE):
 288                            return exp.false() if av > bv else a
 289                        if isinstance(b, exp.GT):
 290                            return exp.false() if av <= bv else a
 291                        if isinstance(b, exp.GTE):
 292                            return exp.false() if av < bv else a
 293                        if isinstance(b, exp.NEQ):
 294                            return exp.false() if av == bv else a
 295    return None
 296
 297
 298def remove_complements(expression, root=True):
 299    """
 300    Removing complements.
 301
 302    A AND NOT A -> FALSE
 303    A OR NOT A -> TRUE
 304    """
 305    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 306        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 307
 308        for a, b in itertools.permutations(expression.flatten(), 2):
 309            if is_complement(a, b):
 310                return complement
 311    return expression
 312
 313
 314def uniq_sort(expression, generate, root=True):
 315    """
 316    Uniq and sort a connector.
 317
 318    C AND A AND B AND B -> A AND B AND C
 319    """
 320    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 321        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 322        flattened = tuple(expression.flatten())
 323        deduped = {generate(e): e for e in flattened}
 324        arr = tuple(deduped.items())
 325
 326        # check if the operands are already sorted, if not sort them
 327        # A AND C AND B -> A AND B AND C
 328        for i, (sql, e) in enumerate(arr[1:]):
 329            if sql < arr[i][0]:
 330                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 331                break
 332        else:
 333            # we didn't have to sort but maybe we need to dedup
 334            if len(deduped) < len(flattened):
 335                expression = result_func(*deduped.values(), copy=False)
 336
 337    return expression
 338
 339
 340def absorb_and_eliminate(expression, root=True):
 341    """
 342    absorption:
 343        A AND (A OR B) -> A
 344        A OR (A AND B) -> A
 345        A AND (NOT A OR B) -> A AND B
 346        A OR (NOT A AND B) -> A OR B
 347    elimination:
 348        (A AND B) OR (A AND NOT B) -> A
 349        (A OR B) AND (A OR NOT B) -> A
 350    """
 351    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 352        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 353
 354        for a, b in itertools.permutations(expression.flatten(), 2):
 355            if isinstance(a, kind):
 356                aa, ab = a.unnest_operands()
 357
 358                # absorb
 359                if is_complement(b, aa):
 360                    aa.replace(exp.true() if kind == exp.And else exp.false())
 361                elif is_complement(b, ab):
 362                    ab.replace(exp.true() if kind == exp.And else exp.false())
 363                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 364                    a.replace(exp.false() if kind == exp.And else exp.true())
 365                elif isinstance(b, kind):
 366                    # eliminate
 367                    rhs = b.unnest_operands()
 368                    ba, bb = rhs
 369
 370                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 371                        a.replace(aa)
 372                        b.replace(aa)
 373                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 374                        a.replace(ab)
 375                        b.replace(ab)
 376
 377    return expression
 378
 379
 380def propagate_constants(expression, root=True):
 381    """
 382    Propagate constants for conjunctions in DNF:
 383
 384    SELECT * FROM t WHERE a = b AND b = 5 becomes
 385    SELECT * FROM t WHERE a = 5 AND b = 5
 386
 387    Reference: https://www.sqlite.org/optoverview.html
 388    """
 389
 390    if (
 391        isinstance(expression, exp.And)
 392        and (root or not expression.same_parent)
 393        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 394    ):
 395        constant_mapping = {}
 396        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
 397            if isinstance(expr, exp.EQ):
 398                l, r = expr.left, expr.right
 399
 400                # TODO: create a helper that can be used to detect nested literal expressions such
 401                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 402                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 403                    pass
 404                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
 405                    l, r = r, l
 406                else:
 407                    continue
 408
 409                constant_mapping[l] = (id(l), r)
 410
 411        if constant_mapping:
 412            for column in find_all_in_scope(expression, exp.Column):
 413                parent = column.parent
 414                column_id, constant = constant_mapping.get(column) or (None, None)
 415                if (
 416                    column_id is not None
 417                    and id(column) != column_id
 418                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 419                ):
 420                    column.replace(constant.copy())
 421
 422    return expression
 423
 424
 425INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 426    exp.DateAdd: exp.Sub,
 427    exp.DateSub: exp.Add,
 428    exp.DatetimeAdd: exp.Sub,
 429    exp.DatetimeSub: exp.Add,
 430}
 431
 432INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 433    **INVERSE_DATE_OPS,
 434    exp.Add: exp.Sub,
 435    exp.Sub: exp.Add,
 436}
 437
 438
 439def _is_number(expression: exp.Expression) -> bool:
 440    return expression.is_number
 441
 442
 443def _is_interval(expression: exp.Expression) -> bool:
 444    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 445
 446
 447@catch(ModuleNotFoundError, UnsupportedUnit)
 448def simplify_equality(expression: exp.Expression) -> exp.Expression:
 449    """
 450    Use the subtraction and addition properties of equality to simplify expressions:
 451
 452        x + 1 = 3 becomes x = 2
 453
 454    There are two binary operations in the above expression: + and =
 455    Here's how we reference all the operands in the code below:
 456
 457          l     r
 458        x + 1 = 3
 459        a   b
 460    """
 461    if isinstance(expression, COMPARISONS):
 462        l, r = expression.left, expression.right
 463
 464        if l.__class__ in INVERSE_OPS:
 465            pass
 466        elif r.__class__ in INVERSE_OPS:
 467            l, r = r, l
 468        else:
 469            return expression
 470
 471        if r.is_number:
 472            a_predicate = _is_number
 473            b_predicate = _is_number
 474        elif _is_date_literal(r):
 475            a_predicate = _is_date_literal
 476            b_predicate = _is_interval
 477        else:
 478            return expression
 479
 480        if l.__class__ in INVERSE_DATE_OPS:
 481            l = t.cast(exp.IntervalOp, l)
 482            a = l.this
 483            b = l.interval()
 484        else:
 485            l = t.cast(exp.Binary, l)
 486            a, b = l.left, l.right
 487
 488        if not a_predicate(a) and b_predicate(b):
 489            pass
 490        elif not a_predicate(b) and b_predicate(a):
 491            a, b = b, a
 492        else:
 493            return expression
 494
 495        return expression.__class__(
 496            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 497        )
 498    return expression
 499
 500
 501def simplify_literals(expression, root=True):
 502    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 503        return _flat_simplify(expression, _simplify_binary, root)
 504
 505    if isinstance(expression, exp.Neg):
 506        this = expression.this
 507        if this.is_number:
 508            value = this.name
 509            if value[0] == "-":
 510                return exp.Literal.number(value[1:])
 511            return exp.Literal.number(f"-{value}")
 512
 513    return expression
 514
 515
 516def _simplify_binary(expression, a, b):
 517    if isinstance(expression, exp.Is):
 518        if isinstance(b, exp.Not):
 519            c = b.this
 520            not_ = True
 521        else:
 522            c = b
 523            not_ = False
 524
 525        if is_null(c):
 526            if isinstance(a, exp.Literal):
 527                return exp.true() if not_ else exp.false()
 528            if is_null(a):
 529                return exp.false() if not_ else exp.true()
 530    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
 531        return None
 532    elif is_null(a) or is_null(b):
 533        return exp.null()
 534
 535    if a.is_number and b.is_number:
 536        a = int(a.name) if a.is_int else Decimal(a.name)
 537        b = int(b.name) if b.is_int else Decimal(b.name)
 538
 539        if isinstance(expression, exp.Add):
 540            return exp.Literal.number(a + b)
 541        if isinstance(expression, exp.Sub):
 542            return exp.Literal.number(a - b)
 543        if isinstance(expression, exp.Mul):
 544            return exp.Literal.number(a * b)
 545        if isinstance(expression, exp.Div):
 546            # engines have differing int div behavior so intdiv is not safe
 547            if isinstance(a, int) and isinstance(b, int):
 548                return None
 549            return exp.Literal.number(a / b)
 550
 551        boolean = eval_boolean(expression, a, b)
 552
 553        if boolean:
 554            return boolean
 555    elif a.is_string and b.is_string:
 556        boolean = eval_boolean(expression, a.this, b.this)
 557
 558        if boolean:
 559            return boolean
 560    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 561        a, b = extract_date(a), extract_interval(b)
 562        if a and b:
 563            if isinstance(expression, exp.Add):
 564                return date_literal(a + b)
 565            if isinstance(expression, exp.Sub):
 566                return date_literal(a - b)
 567    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 568        a, b = extract_interval(a), extract_date(b)
 569        # you cannot subtract a date from an interval
 570        if a and b and isinstance(expression, exp.Add):
 571            return date_literal(a + b)
 572
 573    return None
 574
 575
 576def simplify_parens(expression):
 577    if not isinstance(expression, exp.Paren):
 578        return expression
 579
 580    this = expression.this
 581    parent = expression.parent
 582
 583    if not isinstance(this, exp.Select) and (
 584        not isinstance(parent, (exp.Condition, exp.Binary))
 585        or isinstance(parent, exp.Paren)
 586        or not isinstance(this, exp.Binary)
 587        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
 588        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 589        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 590        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 591    ):
 592        return this
 593    return expression
 594
 595
 596CONSTANTS = (
 597    exp.Literal,
 598    exp.Boolean,
 599    exp.Null,
 600)
 601
 602
 603def simplify_coalesce(expression):
 604    # COALESCE(x) -> x
 605    if (
 606        isinstance(expression, exp.Coalesce)
 607        and not expression.expressions
 608        # COALESCE is also used as a Spark partitioning hint
 609        and not isinstance(expression.parent, exp.Hint)
 610    ):
 611        return expression.this
 612
 613    if not isinstance(expression, COMPARISONS):
 614        return expression
 615
 616    if isinstance(expression.left, exp.Coalesce):
 617        coalesce = expression.left
 618        other = expression.right
 619    elif isinstance(expression.right, exp.Coalesce):
 620        coalesce = expression.right
 621        other = expression.left
 622    else:
 623        return expression
 624
 625    # This transformation is valid for non-constants,
 626    # but it really only does anything if they are both constants.
 627    if not isinstance(other, CONSTANTS):
 628        return expression
 629
 630    # Find the first constant arg
 631    for arg_index, arg in enumerate(coalesce.expressions):
 632        if isinstance(arg, CONSTANTS):
 633            break
 634    else:
 635        return expression
 636
 637    coalesce.set("expressions", coalesce.expressions[:arg_index])
 638
 639    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 640    # since we already remove COALESCE at the top of this function.
 641    coalesce = coalesce if coalesce.expressions else coalesce.this
 642
 643    # This expression is more complex than when we started, but it will get simplified further
 644    return exp.paren(
 645        exp.or_(
 646            exp.and_(
 647                coalesce.is_(exp.null()).not_(copy=False),
 648                expression.copy(),
 649                copy=False,
 650            ),
 651            exp.and_(
 652                coalesce.is_(exp.null()),
 653                type(expression)(this=arg.copy(), expression=other.copy()),
 654                copy=False,
 655            ),
 656            copy=False,
 657        )
 658    )
 659
 660
 661CONCATS = (exp.Concat, exp.DPipe)
 662SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
 663
 664
 665def simplify_concat(expression):
 666    """Reduces all groups that contain string literals by concatenating them."""
 667    if not isinstance(expression, CONCATS) or (
 668        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 669        isinstance(expression, exp.ConcatWs)
 670        and not expression.expressions[0].is_string
 671    ):
 672        return expression
 673
 674    if isinstance(expression, exp.ConcatWs):
 675        sep_expr, *expressions = expression.expressions
 676        sep = sep_expr.name
 677        concat_type = exp.ConcatWs
 678    else:
 679        expressions = expression.expressions
 680        sep = ""
 681        concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
 682
 683    new_args = []
 684    for is_string_group, group in itertools.groupby(
 685        expressions or expression.flatten(), lambda e: e.is_string
 686    ):
 687        if is_string_group:
 688            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 689        else:
 690            new_args.extend(group)
 691
 692    if len(new_args) == 1 and new_args[0].is_string:
 693        return new_args[0]
 694
 695    if concat_type is exp.ConcatWs:
 696        new_args = [sep_expr] + new_args
 697
 698    return concat_type(expressions=new_args)
 699
 700
 701def simplify_conditionals(expression):
 702    """Simplifies expressions like IF, CASE if their condition is statically known."""
 703    if isinstance(expression, exp.Case):
 704        this = expression.this
 705        for case in expression.args["ifs"]:
 706            cond = case.this
 707            if this:
 708                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 709                cond = cond.replace(this.pop().eq(cond))
 710
 711            if always_true(cond):
 712                return case.args["true"]
 713
 714            if always_false(cond):
 715                case.pop()
 716                if not expression.args["ifs"]:
 717                    return expression.args.get("default") or exp.null()
 718    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 719        if always_true(expression.this):
 720            return expression.args["true"]
 721        if always_false(expression.this):
 722            return expression.args.get("false") or exp.null()
 723
 724    return expression
 725
 726
 727DateRange = t.Tuple[datetime.date, datetime.date]
 728
 729
 730def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
 731    """
 732    Get the date range for a DATE_TRUNC equality comparison:
 733
 734    Example:
 735        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 736    Returns:
 737        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 738    """
 739    floor = date_floor(date, unit)
 740
 741    if date != floor:
 742        # This will always be False, except for NULL values.
 743        return None
 744
 745    return floor, floor + interval(unit)
 746
 747
 748def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 749    """Get the logical expression for a date range"""
 750    return exp.and_(
 751        left >= date_literal(drange[0]),
 752        left < date_literal(drange[1]),
 753        copy=False,
 754    )
 755
 756
 757def _datetrunc_eq(
 758    left: exp.Expression, date: datetime.date, unit: str
 759) -> t.Optional[exp.Expression]:
 760    drange = _datetrunc_range(date, unit)
 761    if not drange:
 762        return None
 763
 764    return _datetrunc_eq_expression(left, drange)
 765
 766
 767def _datetrunc_neq(
 768    left: exp.Expression, date: datetime.date, unit: str
 769) -> t.Optional[exp.Expression]:
 770    drange = _datetrunc_range(date, unit)
 771    if not drange:
 772        return None
 773
 774    return exp.and_(
 775        left < date_literal(drange[0]),
 776        left >= date_literal(drange[1]),
 777        copy=False,
 778    )
 779
 780
 781DateTruncBinaryTransform = t.Callable[
 782    [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
 783]
 784DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 785    exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
 786    exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
 787    exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
 788    exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
 789    exp.EQ: _datetrunc_eq,
 790    exp.NEQ: _datetrunc_neq,
 791}
 792DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 793
 794
 795def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 796    return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
 797
 798
 799@catch(ModuleNotFoundError, UnsupportedUnit)
 800def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
 801    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 802    comparison = expression.__class__
 803
 804    if comparison not in DATETRUNC_COMPARISONS:
 805        return expression
 806
 807    if isinstance(expression, exp.Binary):
 808        l, r = expression.left, expression.right
 809
 810        if _is_datetrunc_predicate(l, r):
 811            pass
 812        elif _is_datetrunc_predicate(r, l):
 813            comparison = INVERSE_COMPARISONS.get(comparison, comparison)
 814            l, r = r, l
 815        else:
 816            return expression
 817
 818        l = t.cast(exp.DateTrunc, l)
 819        unit = l.unit.name.lower()
 820        date = extract_date(r)
 821
 822        if not date:
 823            return expression
 824
 825        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
 826    elif isinstance(expression, exp.In):
 827        l = expression.this
 828        rs = expression.expressions
 829
 830        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 831            l = t.cast(exp.DateTrunc, l)
 832            unit = l.unit.name.lower()
 833
 834            ranges = []
 835            for r in rs:
 836                date = extract_date(r)
 837                if not date:
 838                    return expression
 839                drange = _datetrunc_range(date, unit)
 840                if drange:
 841                    ranges.append(drange)
 842
 843            if not ranges:
 844                return expression
 845
 846            ranges = merge_ranges(ranges)
 847
 848            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 849
 850    return expression
 851
 852
 853# CROSS joins result in an empty table if the right table is empty.
 854# So we can only simplify certain types of joins to CROSS.
 855# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 856JOINS = {
 857    ("", ""),
 858    ("", "INNER"),
 859    ("RIGHT", ""),
 860    ("RIGHT", "OUTER"),
 861}
 862
 863
 864def remove_where_true(expression):
 865    for where in expression.find_all(exp.Where):
 866        if always_true(where.this):
 867            where.parent.set("where", None)
 868    for join in expression.find_all(exp.Join):
 869        if (
 870            always_true(join.args.get("on"))
 871            and not join.args.get("using")
 872            and not join.args.get("method")
 873            and (join.side, join.kind) in JOINS
 874        ):
 875            join.set("on", None)
 876            join.set("side", None)
 877            join.set("kind", "CROSS")
 878
 879
 880def always_true(expression):
 881    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
 882        expression, exp.Literal
 883    )
 884
 885
 886def always_false(expression):
 887    return is_false(expression) or is_null(expression)
 888
 889
 890def is_complement(a, b):
 891    return isinstance(b, exp.Not) and b.this == a
 892
 893
 894def is_false(a: exp.Expression) -> bool:
 895    return type(a) is exp.Boolean and not a.this
 896
 897
 898def is_null(a: exp.Expression) -> bool:
 899    return type(a) is exp.Null
 900
 901
 902def eval_boolean(expression, a, b):
 903    if isinstance(expression, (exp.EQ, exp.Is)):
 904        return boolean_literal(a == b)
 905    if isinstance(expression, exp.NEQ):
 906        return boolean_literal(a != b)
 907    if isinstance(expression, exp.GT):
 908        return boolean_literal(a > b)
 909    if isinstance(expression, exp.GTE):
 910        return boolean_literal(a >= b)
 911    if isinstance(expression, exp.LT):
 912        return boolean_literal(a < b)
 913    if isinstance(expression, exp.LTE):
 914        return boolean_literal(a <= b)
 915    return None
 916
 917
 918def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
 919    if isinstance(value, datetime.datetime):
 920        return value.date()
 921    if isinstance(value, datetime.date):
 922        return value
 923    try:
 924        return datetime.datetime.fromisoformat(value).date()
 925    except ValueError:
 926        return None
 927
 928
 929def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
 930    if isinstance(value, datetime.datetime):
 931        return value
 932    if isinstance(value, datetime.date):
 933        return datetime.datetime(year=value.year, month=value.month, day=value.day)
 934    try:
 935        return datetime.datetime.fromisoformat(value)
 936    except ValueError:
 937        return None
 938
 939
 940def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 941    if not value:
 942        return None
 943    if to.is_type(exp.DataType.Type.DATE):
 944        return cast_as_date(value)
 945    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
 946        return cast_as_datetime(value)
 947    return None
 948
 949
 950def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
 951    if isinstance(cast, exp.Cast):
 952        to = cast.to
 953    elif isinstance(cast, exp.TsOrDsToDate):
 954        to = exp.DataType.build(exp.DataType.Type.DATE)
 955    else:
 956        return None
 957
 958    if isinstance(cast.this, exp.Literal):
 959        value: t.Any = cast.this.name
 960    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
 961        value = extract_date(cast.this)
 962    else:
 963        return None
 964    return cast_value(value, to)
 965
 966
 967def _is_date_literal(expression: exp.Expression) -> bool:
 968    return extract_date(expression) is not None
 969
 970
 971def extract_interval(expression):
 972    n = int(expression.name)
 973    unit = expression.text("unit").lower()
 974
 975    try:
 976        return interval(unit, n)
 977    except (UnsupportedUnit, ModuleNotFoundError):
 978        return None
 979
 980
 981def date_literal(date):
 982    return exp.cast(
 983        exp.Literal.string(date),
 984        exp.DataType.Type.DATETIME
 985        if isinstance(date, datetime.datetime)
 986        else exp.DataType.Type.DATE,
 987    )
 988
 989
 990def interval(unit: str, n: int = 1):
 991    from dateutil.relativedelta import relativedelta
 992
 993    if unit == "year":
 994        return relativedelta(years=1 * n)
 995    if unit == "quarter":
 996        return relativedelta(months=3 * n)
 997    if unit == "month":
 998        return relativedelta(months=1 * n)
 999    if unit == "week":
1000        return relativedelta(weeks=1 * n)
1001    if unit == "day":
1002        return relativedelta(days=1 * n)
1003    if unit == "hour":
1004        return relativedelta(hours=1 * n)
1005    if unit == "minute":
1006        return relativedelta(minutes=1 * n)
1007    if unit == "second":
1008        return relativedelta(seconds=1 * n)
1009
1010    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1011
1012
1013def date_floor(d: datetime.date, unit: str) -> datetime.date:
1014    if unit == "year":
1015        return d.replace(month=1, day=1)
1016    if unit == "quarter":
1017        if d.month <= 3:
1018            return d.replace(month=1, day=1)
1019        elif d.month <= 6:
1020            return d.replace(month=4, day=1)
1021        elif d.month <= 9:
1022            return d.replace(month=7, day=1)
1023        else:
1024            return d.replace(month=10, day=1)
1025    if unit == "month":
1026        return d.replace(month=d.month, day=1)
1027    if unit == "week":
1028        # Assuming week starts on Monday (0) and ends on Sunday (6)
1029        return d - datetime.timedelta(days=d.weekday())
1030    if unit == "day":
1031        return d
1032
1033    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1034
1035
1036def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1037    floor = date_floor(d, unit)
1038
1039    if floor == d:
1040        return d
1041
1042    return floor + interval(unit)
1043
1044
1045def boolean_literal(condition):
1046    return exp.true() if condition else exp.false()
1047
1048
1049def _flat_simplify(expression, simplifier, root=True):
1050    if root or not expression.same_parent:
1051        operands = []
1052        queue = deque(expression.flatten(unnest=False))
1053        size = len(queue)
1054
1055        while queue:
1056            a = queue.popleft()
1057
1058            for b in queue:
1059                result = simplifier(expression, a, b)
1060
1061                if result and result is not expression:
1062                    queue.remove(b)
1063                    queue.appendleft(result)
1064                    break
1065            else:
1066                operands.append(a)
1067
1068        if len(operands) < size:
1069            return functools.reduce(
1070                lambda a, b: expression.__class__(this=a, expression=b), operands
1071            )
1072    return expression
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
19class UnsupportedUnit(Exception):
20    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify(expression, constant_propagation=False):
 23def simplify(expression, constant_propagation=False):
 24    """
 25    Rewrite sqlglot AST to simplify expressions.
 26
 27    Example:
 28        >>> import sqlglot
 29        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 30        >>> simplify(expression).sql()
 31        'TRUE'
 32
 33    Args:
 34        expression (sqlglot.Expression): expression to simplify
 35        constant_propagation: whether or not the constant propagation rule should be used
 36
 37    Returns:
 38        sqlglot.Expression: simplified expression
 39    """
 40
 41    generate = cached_generator()
 42
 43    # group by expressions cannot be simplified, for example
 44    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 45    # the projection must exactly match the group by key
 46    for group in expression.find_all(exp.Group):
 47        select = group.parent
 48        groups = set(group.expressions)
 49        group.meta[FINAL] = True
 50
 51        for e in select.selects:
 52            for node, *_ in e.walk():
 53                if node in groups:
 54                    e.meta[FINAL] = True
 55                    break
 56
 57        having = select.args.get("having")
 58        if having:
 59            for node, *_ in having.walk():
 60                if node in groups:
 61                    having.meta[FINAL] = True
 62                    break
 63
 64    def _simplify(expression, root=True):
 65        if expression.meta.get(FINAL):
 66            return expression
 67
 68        # Pre-order transformations
 69        node = expression
 70        node = rewrite_between(node)
 71        node = uniq_sort(node, generate, root)
 72        node = absorb_and_eliminate(node, root)
 73        node = simplify_concat(node)
 74        node = simplify_conditionals(node)
 75
 76        if constant_propagation:
 77            node = propagate_constants(node, root)
 78
 79        exp.replace_children(node, lambda e: _simplify(e, False))
 80
 81        # Post-order transformations
 82        node = simplify_not(node)
 83        node = flatten(node)
 84        node = simplify_connectors(node, root)
 85        node = remove_complements(node, root)
 86        node = simplify_coalesce(node)
 87        node.parent = expression.parent
 88        node = simplify_literals(node, root)
 89        node = simplify_equality(node)
 90        node = simplify_parens(node)
 91        node = simplify_datetrunc_predicate(node)
 92
 93        if root:
 94            expression.replace(node)
 95
 96        return node
 97
 98    expression = while_changing(expression, _simplify)
 99    remove_where_true(expression)
100    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether or not the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
103def catch(*exceptions):
104    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
105
106    def decorator(func):
107        def wrapped(expression, *args, **kwargs):
108            try:
109                return func(expression, *args, **kwargs)
110            except exceptions:
111                return expression
112
113        return wrapped
114
115    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
118def rewrite_between(expression: exp.Expression) -> exp.Expression:
119    """Rewrite x between y and z to x >= y AND x <= z.
120
121    This is done because comparison simplification is only done on lt/lte/gt/gte.
122    """
123    if isinstance(expression, exp.Between):
124        return exp.and_(
125            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
126            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
127            copy=False,
128        )
129    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
132def simplify_not(expression):
133    """
134    Demorgan's Law
135    NOT (x OR y) -> NOT x AND NOT y
136    NOT (x AND y) -> NOT x OR NOT y
137    """
138    if isinstance(expression, exp.Not):
139        if is_null(expression.this):
140            return exp.null()
141        if isinstance(expression.this, exp.Paren):
142            condition = expression.this.unnest()
143            if isinstance(condition, exp.And):
144                return exp.or_(
145                    exp.not_(condition.left, copy=False),
146                    exp.not_(condition.right, copy=False),
147                    copy=False,
148                )
149            if isinstance(condition, exp.Or):
150                return exp.and_(
151                    exp.not_(condition.left, copy=False),
152                    exp.not_(condition.right, copy=False),
153                    copy=False,
154                )
155            if is_null(condition):
156                return exp.null()
157        if always_true(expression.this):
158            return exp.false()
159        if is_false(expression.this):
160            return exp.true()
161        if isinstance(expression.this, exp.Not):
162            # double negation
163            # NOT NOT x -> x
164            return expression.this.this
165    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
168def flatten(expression):
169    """
170    A AND (B AND C) -> A AND B AND C
171    A OR (B OR C) -> A OR B OR C
172    """
173    if isinstance(expression, exp.Connector):
174        for node in expression.args.values():
175            child = node.unnest()
176            if isinstance(child, expression.__class__):
177                node.replace(child)
178    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
181def simplify_connectors(expression, root=True):
182    def _simplify_connectors(expression, left, right):
183        if left == right:
184            return left
185        if isinstance(expression, exp.And):
186            if is_false(left) or is_false(right):
187                return exp.false()
188            if is_null(left) or is_null(right):
189                return exp.null()
190            if always_true(left) and always_true(right):
191                return exp.true()
192            if always_true(left):
193                return right
194            if always_true(right):
195                return left
196            return _simplify_comparison(expression, left, right)
197        elif isinstance(expression, exp.Or):
198            if always_true(left) or always_true(right):
199                return exp.true()
200            if is_false(left) and is_false(right):
201                return exp.false()
202            if (
203                (is_null(left) and is_null(right))
204                or (is_null(left) and is_false(right))
205                or (is_false(left) and is_null(right))
206            ):
207                return exp.null()
208            if is_false(left):
209                return right
210            if is_false(right):
211                return left
212            return _simplify_comparison(expression, left, right, or_=True)
213
214    if isinstance(expression, exp.Connector):
215        return _flat_simplify(expression, _simplify_connectors, root)
216    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
def remove_complements(expression, root=True):
299def remove_complements(expression, root=True):
300    """
301    Removing complements.
302
303    A AND NOT A -> FALSE
304    A OR NOT A -> TRUE
305    """
306    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
307        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
308
309        for a, b in itertools.permutations(expression.flatten(), 2):
310            if is_complement(a, b):
311                return complement
312    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
315def uniq_sort(expression, generate, root=True):
316    """
317    Uniq and sort a connector.
318
319    C AND A AND B AND B -> A AND B AND C
320    """
321    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
322        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
323        flattened = tuple(expression.flatten())
324        deduped = {generate(e): e for e in flattened}
325        arr = tuple(deduped.items())
326
327        # check if the operands are already sorted, if not sort them
328        # A AND C AND B -> A AND B AND C
329        for i, (sql, e) in enumerate(arr[1:]):
330            if sql < arr[i][0]:
331                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
332                break
333        else:
334            # we didn't have to sort but maybe we need to dedup
335            if len(deduped) < len(flattened):
336                expression = result_func(*deduped.values(), copy=False)
337
338    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
341def absorb_and_eliminate(expression, root=True):
342    """
343    absorption:
344        A AND (A OR B) -> A
345        A OR (A AND B) -> A
346        A AND (NOT A OR B) -> A AND B
347        A OR (NOT A AND B) -> A OR B
348    elimination:
349        (A AND B) OR (A AND NOT B) -> A
350        (A OR B) AND (A OR NOT B) -> A
351    """
352    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
353        kind = exp.Or if isinstance(expression, exp.And) else exp.And
354
355        for a, b in itertools.permutations(expression.flatten(), 2):
356            if isinstance(a, kind):
357                aa, ab = a.unnest_operands()
358
359                # absorb
360                if is_complement(b, aa):
361                    aa.replace(exp.true() if kind == exp.And else exp.false())
362                elif is_complement(b, ab):
363                    ab.replace(exp.true() if kind == exp.And else exp.false())
364                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
365                    a.replace(exp.false() if kind == exp.And else exp.true())
366                elif isinstance(b, kind):
367                    # eliminate
368                    rhs = b.unnest_operands()
369                    ba, bb = rhs
370
371                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
372                        a.replace(aa)
373                        b.replace(aa)
374                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
375                        a.replace(ab)
376                        b.replace(ab)
377
378    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
381def propagate_constants(expression, root=True):
382    """
383    Propagate constants for conjunctions in DNF:
384
385    SELECT * FROM t WHERE a = b AND b = 5 becomes
386    SELECT * FROM t WHERE a = 5 AND b = 5
387
388    Reference: https://www.sqlite.org/optoverview.html
389    """
390
391    if (
392        isinstance(expression, exp.And)
393        and (root or not expression.same_parent)
394        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
395    ):
396        constant_mapping = {}
397        for expr, *_ in walk_in_scope(expression, prune=lambda node, *_: isinstance(node, exp.If)):
398            if isinstance(expr, exp.EQ):
399                l, r = expr.left, expr.right
400
401                # TODO: create a helper that can be used to detect nested literal expressions such
402                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
403                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
404                    pass
405                elif isinstance(r, exp.Column) and isinstance(l, exp.Literal):
406                    l, r = r, l
407                else:
408                    continue
409
410                constant_mapping[l] = (id(l), r)
411
412        if constant_mapping:
413            for column in find_all_in_scope(expression, exp.Column):
414                parent = column.parent
415                column_id, constant = constant_mapping.get(column) or (None, None)
416                if (
417                    column_id is not None
418                    and id(column) != column_id
419                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
420                ):
421                    column.replace(constant.copy())
422
423    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
107        def wrapped(expression, *args, **kwargs):
108            try:
109                return func(expression, *args, **kwargs)
110            except exceptions:
111                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
502def simplify_literals(expression, root=True):
503    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
504        return _flat_simplify(expression, _simplify_binary, root)
505
506    if isinstance(expression, exp.Neg):
507        this = expression.this
508        if this.is_number:
509            value = this.name
510            if value[0] == "-":
511                return exp.Literal.number(value[1:])
512            return exp.Literal.number(f"-{value}")
513
514    return expression
def simplify_parens(expression):
577def simplify_parens(expression):
578    if not isinstance(expression, exp.Paren):
579        return expression
580
581    this = expression.this
582    parent = expression.parent
583
584    if not isinstance(this, exp.Select) and (
585        not isinstance(parent, (exp.Condition, exp.Binary))
586        or isinstance(parent, exp.Paren)
587        or not isinstance(this, exp.Binary)
588        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
589        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
590        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
591        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
592    ):
593        return this
594    return expression
def simplify_coalesce(expression):
604def simplify_coalesce(expression):
605    # COALESCE(x) -> x
606    if (
607        isinstance(expression, exp.Coalesce)
608        and not expression.expressions
609        # COALESCE is also used as a Spark partitioning hint
610        and not isinstance(expression.parent, exp.Hint)
611    ):
612        return expression.this
613
614    if not isinstance(expression, COMPARISONS):
615        return expression
616
617    if isinstance(expression.left, exp.Coalesce):
618        coalesce = expression.left
619        other = expression.right
620    elif isinstance(expression.right, exp.Coalesce):
621        coalesce = expression.right
622        other = expression.left
623    else:
624        return expression
625
626    # This transformation is valid for non-constants,
627    # but it really only does anything if they are both constants.
628    if not isinstance(other, CONSTANTS):
629        return expression
630
631    # Find the first constant arg
632    for arg_index, arg in enumerate(coalesce.expressions):
633        if isinstance(arg, CONSTANTS):
634            break
635    else:
636        return expression
637
638    coalesce.set("expressions", coalesce.expressions[:arg_index])
639
640    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
641    # since we already remove COALESCE at the top of this function.
642    coalesce = coalesce if coalesce.expressions else coalesce.this
643
644    # This expression is more complex than when we started, but it will get simplified further
645    return exp.paren(
646        exp.or_(
647            exp.and_(
648                coalesce.is_(exp.null()).not_(copy=False),
649                expression.copy(),
650                copy=False,
651            ),
652            exp.and_(
653                coalesce.is_(exp.null()),
654                type(expression)(this=arg.copy(), expression=other.copy()),
655                copy=False,
656            ),
657            copy=False,
658        )
659    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
SAFE_CONCATS = (<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.SafeDPipe'>)
def simplify_concat(expression):
666def simplify_concat(expression):
667    """Reduces all groups that contain string literals by concatenating them."""
668    if not isinstance(expression, CONCATS) or (
669        # We can't reduce a CONCAT_WS call if we don't statically know the separator
670        isinstance(expression, exp.ConcatWs)
671        and not expression.expressions[0].is_string
672    ):
673        return expression
674
675    if isinstance(expression, exp.ConcatWs):
676        sep_expr, *expressions = expression.expressions
677        sep = sep_expr.name
678        concat_type = exp.ConcatWs
679    else:
680        expressions = expression.expressions
681        sep = ""
682        concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
683
684    new_args = []
685    for is_string_group, group in itertools.groupby(
686        expressions or expression.flatten(), lambda e: e.is_string
687    ):
688        if is_string_group:
689            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
690        else:
691            new_args.extend(group)
692
693    if len(new_args) == 1 and new_args[0].is_string:
694        return new_args[0]
695
696    if concat_type is exp.ConcatWs:
697        new_args = [sep_expr] + new_args
698
699    return concat_type(expressions=new_args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
702def simplify_conditionals(expression):
703    """Simplifies expressions like IF, CASE if their condition is statically known."""
704    if isinstance(expression, exp.Case):
705        this = expression.this
706        for case in expression.args["ifs"]:
707            cond = case.this
708            if this:
709                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
710                cond = cond.replace(this.pop().eq(cond))
711
712            if always_true(cond):
713                return case.args["true"]
714
715            if always_false(cond):
716                case.pop()
717                if not expression.args["ifs"]:
718                    return expression.args.get("default") or exp.null()
719    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
720        if always_true(expression.this):
721            return expression.args["true"]
722        if always_false(expression.this):
723            return expression.args.get("false") or exp.null()
724
725    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

DateRange = typing.Tuple[datetime.date, datetime.date]
DateTruncBinaryTransform = typing.Callable[[sqlglot.expressions.Expression, datetime.date, str], typing.Optional[sqlglot.expressions.Expression]]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.EQ'>}
def simplify_datetrunc_predicate(expression, *args, **kwargs):
107        def wrapped(expression, *args, **kwargs):
108            try:
109                return func(expression, *args, **kwargs)
110            except exceptions:
111                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

JOINS = {('RIGHT', ''), ('RIGHT', 'OUTER'), ('', 'INNER'), ('', '')}
def remove_where_true(expression):
865def remove_where_true(expression):
866    for where in expression.find_all(exp.Where):
867        if always_true(where.this):
868            where.parent.set("where", None)
869    for join in expression.find_all(exp.Join):
870        if (
871            always_true(join.args.get("on"))
872            and not join.args.get("using")
873            and not join.args.get("method")
874            and (join.side, join.kind) in JOINS
875        ):
876            join.set("on", None)
877            join.set("side", None)
878            join.set("kind", "CROSS")
def always_true(expression):
881def always_true(expression):
882    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
883        expression, exp.Literal
884    )
def always_false(expression):
887def always_false(expression):
888    return is_false(expression) or is_null(expression)
def is_complement(a, b):
891def is_complement(a, b):
892    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
895def is_false(a: exp.Expression) -> bool:
896    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
899def is_null(a: exp.Expression) -> bool:
900    return type(a) is exp.Null
def eval_boolean(expression, a, b):
903def eval_boolean(expression, a, b):
904    if isinstance(expression, (exp.EQ, exp.Is)):
905        return boolean_literal(a == b)
906    if isinstance(expression, exp.NEQ):
907        return boolean_literal(a != b)
908    if isinstance(expression, exp.GT):
909        return boolean_literal(a > b)
910    if isinstance(expression, exp.GTE):
911        return boolean_literal(a >= b)
912    if isinstance(expression, exp.LT):
913        return boolean_literal(a < b)
914    if isinstance(expression, exp.LTE):
915        return boolean_literal(a <= b)
916    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
919def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
920    if isinstance(value, datetime.datetime):
921        return value.date()
922    if isinstance(value, datetime.date):
923        return value
924    try:
925        return datetime.datetime.fromisoformat(value).date()
926    except ValueError:
927        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
930def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
931    if isinstance(value, datetime.datetime):
932        return value
933    if isinstance(value, datetime.date):
934        return datetime.datetime(year=value.year, month=value.month, day=value.day)
935    try:
936        return datetime.datetime.fromisoformat(value)
937    except ValueError:
938        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
941def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
942    if not value:
943        return None
944    if to.is_type(exp.DataType.Type.DATE):
945        return cast_as_date(value)
946    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
947        return cast_as_datetime(value)
948    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
951def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
952    if isinstance(cast, exp.Cast):
953        to = cast.to
954    elif isinstance(cast, exp.TsOrDsToDate):
955        to = exp.DataType.build(exp.DataType.Type.DATE)
956    else:
957        return None
958
959    if isinstance(cast.this, exp.Literal):
960        value: t.Any = cast.this.name
961    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
962        value = extract_date(cast.this)
963    else:
964        return None
965    return cast_value(value, to)
def extract_interval(expression):
972def extract_interval(expression):
973    n = int(expression.name)
974    unit = expression.text("unit").lower()
975
976    try:
977        return interval(unit, n)
978    except (UnsupportedUnit, ModuleNotFoundError):
979        return None
def date_literal(date):
982def date_literal(date):
983    return exp.cast(
984        exp.Literal.string(date),
985        exp.DataType.Type.DATETIME
986        if isinstance(date, datetime.datetime)
987        else exp.DataType.Type.DATE,
988    )
def interval(unit: str, n: int = 1):
 991def interval(unit: str, n: int = 1):
 992    from dateutil.relativedelta import relativedelta
 993
 994    if unit == "year":
 995        return relativedelta(years=1 * n)
 996    if unit == "quarter":
 997        return relativedelta(months=3 * n)
 998    if unit == "month":
 999        return relativedelta(months=1 * n)
1000    if unit == "week":
1001        return relativedelta(weeks=1 * n)
1002    if unit == "day":
1003        return relativedelta(days=1 * n)
1004    if unit == "hour":
1005        return relativedelta(hours=1 * n)
1006    if unit == "minute":
1007        return relativedelta(minutes=1 * n)
1008    if unit == "second":
1009        return relativedelta(seconds=1 * n)
1010
1011    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
1014def date_floor(d: datetime.date, unit: str) -> datetime.date:
1015    if unit == "year":
1016        return d.replace(month=1, day=1)
1017    if unit == "quarter":
1018        if d.month <= 3:
1019            return d.replace(month=1, day=1)
1020        elif d.month <= 6:
1021            return d.replace(month=4, day=1)
1022        elif d.month <= 9:
1023            return d.replace(month=7, day=1)
1024        else:
1025            return d.replace(month=10, day=1)
1026    if unit == "month":
1027        return d.replace(month=d.month, day=1)
1028    if unit == "week":
1029        # Assuming week starts on Monday (0) and ends on Sunday (6)
1030        return d - datetime.timedelta(days=d.weekday())
1031    if unit == "day":
1032        return d
1033
1034    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1037def date_ceil(d: datetime.date, unit: str) -> datetime.date:
1038    floor = date_floor(d, unit)
1039
1040    if floor == d:
1041        return d
1042
1043    return floor + interval(unit)
def boolean_literal(condition):
1046def boolean_literal(condition):
1047    return exp.true() if condition else exp.false()