Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import functools
   5import itertools
   6import typing as t
   7from collections import deque
   8from decimal import Decimal
   9from functools import reduce
  10
  11import sqlglot
  12from sqlglot import Dialect, exp
  13from sqlglot.helper import first, merge_ranges, while_changing
  14from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  15
  16if t.TYPE_CHECKING:
  17    from sqlglot.dialects.dialect import DialectType
  18
  19    DateTruncBinaryTransform = t.Callable[
  20        [exp.Expression, datetime.date, str, Dialect, exp.DataType], t.Optional[exp.Expression]
  21    ]
  22
  23# Final means that an expression should not be simplified
  24FINAL = "final"
  25
  26# Value ranges for byte-sized signed/unsigned integers
  27TINYINT_MIN = -128
  28TINYINT_MAX = 127
  29UTINYINT_MIN = 0
  30UTINYINT_MAX = 255
  31
  32
  33class UnsupportedUnit(Exception):
  34    pass
  35
  36
  37def simplify(
  38    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
  39):
  40    """
  41    Rewrite sqlglot AST to simplify expressions.
  42
  43    Example:
  44        >>> import sqlglot
  45        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  46        >>> simplify(expression).sql()
  47        'TRUE'
  48
  49    Args:
  50        expression (sqlglot.Expression): expression to simplify
  51        constant_propagation: whether the constant propagation rule should be used
  52
  53    Returns:
  54        sqlglot.Expression: simplified expression
  55    """
  56
  57    dialect = Dialect.get_or_raise(dialect)
  58
  59    def _simplify(expression, root=True):
  60        if expression.meta.get(FINAL):
  61            return expression
  62
  63        # group by expressions cannot be simplified, for example
  64        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  65        # the projection must exactly match the group by key
  66        group = expression.args.get("group")
  67
  68        if group and hasattr(expression, "selects"):
  69            groups = set(group.expressions)
  70            group.meta[FINAL] = True
  71
  72            for e in expression.selects:
  73                for node in e.walk():
  74                    if node in groups:
  75                        e.meta[FINAL] = True
  76                        break
  77
  78            having = expression.args.get("having")
  79            if having:
  80                for node in having.walk():
  81                    if node in groups:
  82                        having.meta[FINAL] = True
  83                        break
  84
  85        # Pre-order transformations
  86        node = expression
  87        node = rewrite_between(node)
  88        node = uniq_sort(node, root)
  89        node = absorb_and_eliminate(node, root)
  90        node = simplify_concat(node)
  91        node = simplify_conditionals(node)
  92
  93        if constant_propagation:
  94            node = propagate_constants(node, root)
  95
  96        exp.replace_children(node, lambda e: _simplify(e, False))
  97
  98        # Post-order transformations
  99        node = simplify_not(node)
 100        node = flatten(node)
 101        node = simplify_connectors(node, root)
 102        node = remove_complements(node, root)
 103        node = simplify_coalesce(node)
 104        node.parent = expression.parent
 105        node = simplify_literals(node, root)
 106        node = simplify_equality(node)
 107        node = simplify_parens(node)
 108        node = simplify_datetrunc(node, dialect)
 109        node = sort_comparison(node)
 110        node = simplify_startswith(node)
 111
 112        if root:
 113            expression.replace(node)
 114        return node
 115
 116    expression = while_changing(expression, _simplify)
 117    remove_where_true(expression)
 118    return expression
 119
 120
 121def catch(*exceptions):
 122    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 123
 124    def decorator(func):
 125        def wrapped(expression, *args, **kwargs):
 126            try:
 127                return func(expression, *args, **kwargs)
 128            except exceptions:
 129                return expression
 130
 131        return wrapped
 132
 133    return decorator
 134
 135
 136def rewrite_between(expression: exp.Expression) -> exp.Expression:
 137    """Rewrite x between y and z to x >= y AND x <= z.
 138
 139    This is done because comparison simplification is only done on lt/lte/gt/gte.
 140    """
 141    if isinstance(expression, exp.Between):
 142        negate = isinstance(expression.parent, exp.Not)
 143
 144        expression = exp.and_(
 145            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 146            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 147            copy=False,
 148        )
 149
 150        if negate:
 151            expression = exp.paren(expression, copy=False)
 152
 153    return expression
 154
 155
 156COMPLEMENT_COMPARISONS = {
 157    exp.LT: exp.GTE,
 158    exp.GT: exp.LTE,
 159    exp.LTE: exp.GT,
 160    exp.GTE: exp.LT,
 161    exp.EQ: exp.NEQ,
 162    exp.NEQ: exp.EQ,
 163}
 164
 165
 166def simplify_not(expression):
 167    """
 168    Demorgan's Law
 169    NOT (x OR y) -> NOT x AND NOT y
 170    NOT (x AND y) -> NOT x OR NOT y
 171    """
 172    if isinstance(expression, exp.Not):
 173        this = expression.this
 174        if is_null(this):
 175            return exp.null()
 176        if this.__class__ in COMPLEMENT_COMPARISONS:
 177            return COMPLEMENT_COMPARISONS[this.__class__](
 178                this=this.this, expression=this.expression
 179            )
 180        if isinstance(this, exp.Paren):
 181            condition = this.unnest()
 182            if isinstance(condition, exp.And):
 183                return exp.paren(
 184                    exp.or_(
 185                        exp.not_(condition.left, copy=False),
 186                        exp.not_(condition.right, copy=False),
 187                        copy=False,
 188                    )
 189                )
 190            if isinstance(condition, exp.Or):
 191                return exp.paren(
 192                    exp.and_(
 193                        exp.not_(condition.left, copy=False),
 194                        exp.not_(condition.right, copy=False),
 195                        copy=False,
 196                    )
 197                )
 198            if is_null(condition):
 199                return exp.null()
 200        if always_true(this):
 201            return exp.false()
 202        if is_false(this):
 203            return exp.true()
 204        if isinstance(this, exp.Not):
 205            # double negation
 206            # NOT NOT x -> x
 207            return this.this
 208    return expression
 209
 210
 211def flatten(expression):
 212    """
 213    A AND (B AND C) -> A AND B AND C
 214    A OR (B OR C) -> A OR B OR C
 215    """
 216    if isinstance(expression, exp.Connector):
 217        for node in expression.args.values():
 218            child = node.unnest()
 219            if isinstance(child, expression.__class__):
 220                node.replace(child)
 221    return expression
 222
 223
 224def simplify_connectors(expression, root=True):
 225    def _simplify_connectors(expression, left, right):
 226        if left == right:
 227            if isinstance(expression, exp.Xor):
 228                return exp.false()
 229            return left
 230        if isinstance(expression, exp.And):
 231            if is_false(left) or is_false(right):
 232                return exp.false()
 233            if is_null(left) or is_null(right):
 234                return exp.null()
 235            if always_true(left) and always_true(right):
 236                return exp.true()
 237            if always_true(left):
 238                return right
 239            if always_true(right):
 240                return left
 241            return _simplify_comparison(expression, left, right)
 242        elif isinstance(expression, exp.Or):
 243            if always_true(left) or always_true(right):
 244                return exp.true()
 245            if is_false(left) and is_false(right):
 246                return exp.false()
 247            if (
 248                (is_null(left) and is_null(right))
 249                or (is_null(left) and is_false(right))
 250                or (is_false(left) and is_null(right))
 251            ):
 252                return exp.null()
 253            if is_false(left):
 254                return right
 255            if is_false(right):
 256                return left
 257            return _simplify_comparison(expression, left, right, or_=True)
 258
 259    if isinstance(expression, exp.Connector):
 260        return _flat_simplify(expression, _simplify_connectors, root)
 261    return expression
 262
 263
 264LT_LTE = (exp.LT, exp.LTE)
 265GT_GTE = (exp.GT, exp.GTE)
 266
 267COMPARISONS = (
 268    *LT_LTE,
 269    *GT_GTE,
 270    exp.EQ,
 271    exp.NEQ,
 272    exp.Is,
 273)
 274
 275INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 276    exp.LT: exp.GT,
 277    exp.GT: exp.LT,
 278    exp.LTE: exp.GTE,
 279    exp.GTE: exp.LTE,
 280}
 281
 282NONDETERMINISTIC = (exp.Rand, exp.Randn)
 283
 284
 285def _simplify_comparison(expression, left, right, or_=False):
 286    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 287        ll, lr = left.args.values()
 288        rl, rr = right.args.values()
 289
 290        largs = {ll, lr}
 291        rargs = {rl, rr}
 292
 293        matching = largs & rargs
 294        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 295
 296        if matching and columns:
 297            try:
 298                l = first(largs - columns)
 299                r = first(rargs - columns)
 300            except StopIteration:
 301                return expression
 302
 303            if l.is_number and r.is_number:
 304                l = float(l.name)
 305                r = float(r.name)
 306            elif l.is_string and r.is_string:
 307                l = l.name
 308                r = r.name
 309            else:
 310                l = extract_date(l)
 311                if not l:
 312                    return None
 313                r = extract_date(r)
 314                if not r:
 315                    return None
 316                # python won't compare date and datetime, but many engines will upcast
 317                l, r = cast_as_datetime(l), cast_as_datetime(r)
 318
 319            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 320                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 321                    return left if (av > bv if or_ else av <= bv) else right
 322                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 323                    return left if (av < bv if or_ else av >= bv) else right
 324
 325                # we can't ever shortcut to true because the column could be null
 326                if not or_:
 327                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 328                        if av <= bv:
 329                            return exp.false()
 330                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 331                        if av >= bv:
 332                            return exp.false()
 333                    elif isinstance(a, exp.EQ):
 334                        if isinstance(b, exp.LT):
 335                            return exp.false() if av >= bv else a
 336                        if isinstance(b, exp.LTE):
 337                            return exp.false() if av > bv else a
 338                        if isinstance(b, exp.GT):
 339                            return exp.false() if av <= bv else a
 340                        if isinstance(b, exp.GTE):
 341                            return exp.false() if av < bv else a
 342                        if isinstance(b, exp.NEQ):
 343                            return exp.false() if av == bv else a
 344    return None
 345
 346
 347def remove_complements(expression, root=True):
 348    """
 349    Removing complements.
 350
 351    A AND NOT A -> FALSE
 352    A OR NOT A -> TRUE
 353    """
 354    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 355        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 356
 357        for a, b in itertools.permutations(expression.flatten(), 2):
 358            if is_complement(a, b):
 359                return complement
 360    return expression
 361
 362
 363def uniq_sort(expression, root=True):
 364    """
 365    Uniq and sort a connector.
 366
 367    C AND A AND B AND B -> A AND B AND C
 368    """
 369    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 370        flattened = tuple(expression.flatten())
 371
 372        if isinstance(expression, exp.Xor):
 373            result_func = exp.xor
 374            # Do not deduplicate XOR as A XOR A != A if A == True
 375            deduped = None
 376            arr = tuple((gen(e), e) for e in flattened)
 377        else:
 378            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 379            deduped = {gen(e): e for e in flattened}
 380            arr = tuple(deduped.items())
 381
 382        # check if the operands are already sorted, if not sort them
 383        # A AND C AND B -> A AND B AND C
 384        for i, (sql, e) in enumerate(arr[1:]):
 385            if sql < arr[i][0]:
 386                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 387                break
 388        else:
 389            # we didn't have to sort but maybe we need to dedup
 390            if deduped and len(deduped) < len(flattened):
 391                expression = result_func(*deduped.values(), copy=False)
 392
 393    return expression
 394
 395
 396def absorb_and_eliminate(expression, root=True):
 397    """
 398    absorption:
 399        A AND (A OR B) -> A
 400        A OR (A AND B) -> A
 401        A AND (NOT A OR B) -> A AND B
 402        A OR (NOT A AND B) -> A OR B
 403    elimination:
 404        (A AND B) OR (A AND NOT B) -> A
 405        (A OR B) AND (A OR NOT B) -> A
 406    """
 407    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 408        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 409
 410        for a, b in itertools.permutations(expression.flatten(), 2):
 411            if isinstance(a, kind):
 412                aa, ab = a.unnest_operands()
 413
 414                # absorb
 415                if is_complement(b, aa):
 416                    aa.replace(exp.true() if kind == exp.And else exp.false())
 417                elif is_complement(b, ab):
 418                    ab.replace(exp.true() if kind == exp.And else exp.false())
 419                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 420                    a.replace(exp.false() if kind == exp.And else exp.true())
 421                elif isinstance(b, kind):
 422                    # eliminate
 423                    rhs = b.unnest_operands()
 424                    ba, bb = rhs
 425
 426                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 427                        a.replace(aa)
 428                        b.replace(aa)
 429                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 430                        a.replace(ab)
 431                        b.replace(ab)
 432
 433    return expression
 434
 435
 436def propagate_constants(expression, root=True):
 437    """
 438    Propagate constants for conjunctions in DNF:
 439
 440    SELECT * FROM t WHERE a = b AND b = 5 becomes
 441    SELECT * FROM t WHERE a = 5 AND b = 5
 442
 443    Reference: https://www.sqlite.org/optoverview.html
 444    """
 445
 446    if (
 447        isinstance(expression, exp.And)
 448        and (root or not expression.same_parent)
 449        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 450    ):
 451        constant_mapping = {}
 452        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 453            if isinstance(expr, exp.EQ):
 454                l, r = expr.left, expr.right
 455
 456                # TODO: create a helper that can be used to detect nested literal expressions such
 457                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 458                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 459                    constant_mapping[l] = (id(l), r)
 460
 461        if constant_mapping:
 462            for column in find_all_in_scope(expression, exp.Column):
 463                parent = column.parent
 464                column_id, constant = constant_mapping.get(column) or (None, None)
 465                if (
 466                    column_id is not None
 467                    and id(column) != column_id
 468                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 469                ):
 470                    column.replace(constant.copy())
 471
 472    return expression
 473
 474
 475INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 476    exp.DateAdd: exp.Sub,
 477    exp.DateSub: exp.Add,
 478    exp.DatetimeAdd: exp.Sub,
 479    exp.DatetimeSub: exp.Add,
 480}
 481
 482INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 483    **INVERSE_DATE_OPS,
 484    exp.Add: exp.Sub,
 485    exp.Sub: exp.Add,
 486}
 487
 488
 489def _is_number(expression: exp.Expression) -> bool:
 490    return expression.is_number
 491
 492
 493def _is_interval(expression: exp.Expression) -> bool:
 494    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 495
 496
 497@catch(ModuleNotFoundError, UnsupportedUnit)
 498def simplify_equality(expression: exp.Expression) -> exp.Expression:
 499    """
 500    Use the subtraction and addition properties of equality to simplify expressions:
 501
 502        x + 1 = 3 becomes x = 2
 503
 504    There are two binary operations in the above expression: + and =
 505    Here's how we reference all the operands in the code below:
 506
 507          l     r
 508        x + 1 = 3
 509        a   b
 510    """
 511    if isinstance(expression, COMPARISONS):
 512        l, r = expression.left, expression.right
 513
 514        if l.__class__ not in INVERSE_OPS:
 515            return expression
 516
 517        if r.is_number:
 518            a_predicate = _is_number
 519            b_predicate = _is_number
 520        elif _is_date_literal(r):
 521            a_predicate = _is_date_literal
 522            b_predicate = _is_interval
 523        else:
 524            return expression
 525
 526        if l.__class__ in INVERSE_DATE_OPS:
 527            l = t.cast(exp.IntervalOp, l)
 528            a = l.this
 529            b = l.interval()
 530        else:
 531            l = t.cast(exp.Binary, l)
 532            a, b = l.left, l.right
 533
 534        if not a_predicate(a) and b_predicate(b):
 535            pass
 536        elif not a_predicate(b) and b_predicate(a):
 537            a, b = b, a
 538        else:
 539            return expression
 540
 541        return expression.__class__(
 542            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 543        )
 544    return expression
 545
 546
 547def simplify_literals(expression, root=True):
 548    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 549        return _flat_simplify(expression, _simplify_binary, root)
 550
 551    if isinstance(expression, exp.Neg):
 552        this = expression.this
 553        if this.is_number:
 554            value = this.name
 555            if value[0] == "-":
 556                return exp.Literal.number(value[1:])
 557            return exp.Literal.number(f"-{value}")
 558
 559    if type(expression) in INVERSE_DATE_OPS:
 560        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 561
 562    return expression
 563
 564
 565NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 566
 567
 568def _simplify_integer_cast(expr: exp.Expression) -> exp.Expression:
 569    if isinstance(expr, exp.Cast) and isinstance(expr.this, exp.Cast):
 570        this = _simplify_integer_cast(expr.this)
 571    else:
 572        this = expr.this
 573
 574    if isinstance(expr, exp.Cast) and this.is_int:
 575        num = int(this.name)
 576
 577        # Remove the (up)cast from small (byte-sized) integers in predicates which is side-effect free. Downcasts on any
 578        # integer type might cause overflow, thus the cast cannot be eliminated and the behavior is
 579        # engine-dependent
 580        if (
 581            TINYINT_MIN <= num <= TINYINT_MAX and expr.to.this in exp.DataType.SIGNED_INTEGER_TYPES
 582        ) or (
 583            UTINYINT_MIN <= num <= UTINYINT_MAX
 584            and expr.to.this in exp.DataType.UNSIGNED_INTEGER_TYPES
 585        ):
 586            return this
 587
 588    return expr
 589
 590
 591def _simplify_binary(expression, a, b):
 592    if isinstance(expression, COMPARISONS):
 593        a = _simplify_integer_cast(a)
 594        b = _simplify_integer_cast(b)
 595
 596    if isinstance(expression, exp.Is):
 597        if isinstance(b, exp.Not):
 598            c = b.this
 599            not_ = True
 600        else:
 601            c = b
 602            not_ = False
 603
 604        if is_null(c):
 605            if isinstance(a, exp.Literal):
 606                return exp.true() if not_ else exp.false()
 607            if is_null(a):
 608                return exp.false() if not_ else exp.true()
 609    elif isinstance(expression, NULL_OK):
 610        return None
 611    elif is_null(a) or is_null(b):
 612        return exp.null()
 613
 614    if a.is_number and b.is_number:
 615        num_a = int(a.name) if a.is_int else Decimal(a.name)
 616        num_b = int(b.name) if b.is_int else Decimal(b.name)
 617
 618        if isinstance(expression, exp.Add):
 619            return exp.Literal.number(num_a + num_b)
 620        if isinstance(expression, exp.Mul):
 621            return exp.Literal.number(num_a * num_b)
 622
 623        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 624        if isinstance(expression, exp.Sub):
 625            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 626        if isinstance(expression, exp.Div):
 627            # engines have differing int div behavior so intdiv is not safe
 628            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 629                return None
 630            return exp.Literal.number(num_a / num_b)
 631
 632        boolean = eval_boolean(expression, num_a, num_b)
 633
 634        if boolean:
 635            return boolean
 636    elif a.is_string and b.is_string:
 637        boolean = eval_boolean(expression, a.this, b.this)
 638
 639        if boolean:
 640            return boolean
 641    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 642        date, b = extract_date(a), extract_interval(b)
 643        if date and b:
 644            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 645                return date_literal(date + b, extract_type(a))
 646            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 647                return date_literal(date - b, extract_type(a))
 648    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 649        a, date = extract_interval(a), extract_date(b)
 650        # you cannot subtract a date from an interval
 651        if a and b and isinstance(expression, exp.Add):
 652            return date_literal(a + date, extract_type(b))
 653    elif _is_date_literal(a) and _is_date_literal(b):
 654        if isinstance(expression, exp.Predicate):
 655            a, b = extract_date(a), extract_date(b)
 656            boolean = eval_boolean(expression, a, b)
 657            if boolean:
 658                return boolean
 659
 660    return None
 661
 662
 663def simplify_parens(expression):
 664    if not isinstance(expression, exp.Paren):
 665        return expression
 666
 667    this = expression.this
 668    parent = expression.parent
 669    parent_is_predicate = isinstance(parent, exp.Predicate)
 670
 671    if (
 672        not isinstance(this, exp.Select)
 673        and not isinstance(parent, exp.SubqueryPredicate)
 674        and (
 675            not isinstance(parent, (exp.Condition, exp.Binary))
 676            or isinstance(parent, exp.Paren)
 677            or (
 678                not isinstance(this, exp.Binary)
 679                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 680            )
 681            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 682            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 683            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 684            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 685        )
 686    ):
 687        return this
 688    return expression
 689
 690
 691def _is_nonnull_constant(expression: exp.Expression) -> bool:
 692    return isinstance(expression, exp.NONNULL_CONSTANTS) or _is_date_literal(expression)
 693
 694
 695def _is_constant(expression: exp.Expression) -> bool:
 696    return isinstance(expression, exp.CONSTANTS) or _is_date_literal(expression)
 697
 698
 699def simplify_coalesce(expression):
 700    # COALESCE(x) -> x
 701    if (
 702        isinstance(expression, exp.Coalesce)
 703        and (not expression.expressions or _is_nonnull_constant(expression.this))
 704        # COALESCE is also used as a Spark partitioning hint
 705        and not isinstance(expression.parent, exp.Hint)
 706    ):
 707        return expression.this
 708
 709    if not isinstance(expression, COMPARISONS):
 710        return expression
 711
 712    if isinstance(expression.left, exp.Coalesce):
 713        coalesce = expression.left
 714        other = expression.right
 715    elif isinstance(expression.right, exp.Coalesce):
 716        coalesce = expression.right
 717        other = expression.left
 718    else:
 719        return expression
 720
 721    # This transformation is valid for non-constants,
 722    # but it really only does anything if they are both constants.
 723    if not _is_constant(other):
 724        return expression
 725
 726    # Find the first constant arg
 727    for arg_index, arg in enumerate(coalesce.expressions):
 728        if _is_constant(arg):
 729            break
 730    else:
 731        return expression
 732
 733    coalesce.set("expressions", coalesce.expressions[:arg_index])
 734
 735    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 736    # since we already remove COALESCE at the top of this function.
 737    coalesce = coalesce if coalesce.expressions else coalesce.this
 738
 739    # This expression is more complex than when we started, but it will get simplified further
 740    return exp.paren(
 741        exp.or_(
 742            exp.and_(
 743                coalesce.is_(exp.null()).not_(copy=False),
 744                expression.copy(),
 745                copy=False,
 746            ),
 747            exp.and_(
 748                coalesce.is_(exp.null()),
 749                type(expression)(this=arg.copy(), expression=other.copy()),
 750                copy=False,
 751            ),
 752            copy=False,
 753        )
 754    )
 755
 756
 757CONCATS = (exp.Concat, exp.DPipe)
 758
 759
 760def simplify_concat(expression):
 761    """Reduces all groups that contain string literals by concatenating them."""
 762    if not isinstance(expression, CONCATS) or (
 763        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 764        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 765    ):
 766        return expression
 767
 768    if isinstance(expression, exp.ConcatWs):
 769        sep_expr, *expressions = expression.expressions
 770        sep = sep_expr.name
 771        concat_type = exp.ConcatWs
 772        args = {}
 773    else:
 774        expressions = expression.expressions
 775        sep = ""
 776        concat_type = exp.Concat
 777        args = {
 778            "safe": expression.args.get("safe"),
 779            "coalesce": expression.args.get("coalesce"),
 780        }
 781
 782    new_args = []
 783    for is_string_group, group in itertools.groupby(
 784        expressions or expression.flatten(), lambda e: e.is_string
 785    ):
 786        if is_string_group:
 787            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 788        else:
 789            new_args.extend(group)
 790
 791    if len(new_args) == 1 and new_args[0].is_string:
 792        return new_args[0]
 793
 794    if concat_type is exp.ConcatWs:
 795        new_args = [sep_expr] + new_args
 796    elif isinstance(expression, exp.DPipe):
 797        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
 798
 799    return concat_type(expressions=new_args, **args)
 800
 801
 802def simplify_conditionals(expression):
 803    """Simplifies expressions like IF, CASE if their condition is statically known."""
 804    if isinstance(expression, exp.Case):
 805        this = expression.this
 806        for case in expression.args["ifs"]:
 807            cond = case.this
 808            if this:
 809                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 810                cond = cond.replace(this.pop().eq(cond))
 811
 812            if always_true(cond):
 813                return case.args["true"]
 814
 815            if always_false(cond):
 816                case.pop()
 817                if not expression.args["ifs"]:
 818                    return expression.args.get("default") or exp.null()
 819    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 820        if always_true(expression.this):
 821            return expression.args["true"]
 822        if always_false(expression.this):
 823            return expression.args.get("false") or exp.null()
 824
 825    return expression
 826
 827
 828def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 829    """
 830    Reduces a prefix check to either TRUE or FALSE if both the string and the
 831    prefix are statically known.
 832
 833    Example:
 834        >>> from sqlglot import parse_one
 835        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 836        'TRUE'
 837    """
 838    if (
 839        isinstance(expression, exp.StartsWith)
 840        and expression.this.is_string
 841        and expression.expression.is_string
 842    ):
 843        return exp.convert(expression.name.startswith(expression.expression.name))
 844
 845    return expression
 846
 847
 848DateRange = t.Tuple[datetime.date, datetime.date]
 849
 850
 851def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 852    """
 853    Get the date range for a DATE_TRUNC equality comparison:
 854
 855    Example:
 856        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 857    Returns:
 858        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 859    """
 860    floor = date_floor(date, unit, dialect)
 861
 862    if date != floor:
 863        # This will always be False, except for NULL values.
 864        return None
 865
 866    return floor, floor + interval(unit)
 867
 868
 869def _datetrunc_eq_expression(
 870    left: exp.Expression, drange: DateRange, target_type: t.Optional[exp.DataType]
 871) -> exp.Expression:
 872    """Get the logical expression for a date range"""
 873    return exp.and_(
 874        left >= date_literal(drange[0], target_type),
 875        left < date_literal(drange[1], target_type),
 876        copy=False,
 877    )
 878
 879
 880def _datetrunc_eq(
 881    left: exp.Expression,
 882    date: datetime.date,
 883    unit: str,
 884    dialect: Dialect,
 885    target_type: t.Optional[exp.DataType],
 886) -> t.Optional[exp.Expression]:
 887    drange = _datetrunc_range(date, unit, dialect)
 888    if not drange:
 889        return None
 890
 891    return _datetrunc_eq_expression(left, drange, target_type)
 892
 893
 894def _datetrunc_neq(
 895    left: exp.Expression,
 896    date: datetime.date,
 897    unit: str,
 898    dialect: Dialect,
 899    target_type: t.Optional[exp.DataType],
 900) -> t.Optional[exp.Expression]:
 901    drange = _datetrunc_range(date, unit, dialect)
 902    if not drange:
 903        return None
 904
 905    return exp.and_(
 906        left < date_literal(drange[0], target_type),
 907        left >= date_literal(drange[1], target_type),
 908        copy=False,
 909    )
 910
 911
 912DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 913    exp.LT: lambda l, dt, u, d, t: l
 914    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u), t),
 915    exp.GT: lambda l, dt, u, d, t: l >= date_literal(date_floor(dt, u, d) + interval(u), t),
 916    exp.LTE: lambda l, dt, u, d, t: l < date_literal(date_floor(dt, u, d) + interval(u), t),
 917    exp.GTE: lambda l, dt, u, d, t: l >= date_literal(date_ceil(dt, u, d), t),
 918    exp.EQ: _datetrunc_eq,
 919    exp.NEQ: _datetrunc_neq,
 920}
 921DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 922DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 923
 924
 925def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 926    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
 927
 928
 929@catch(ModuleNotFoundError, UnsupportedUnit)
 930def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
 931    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 932    comparison = expression.__class__
 933
 934    if isinstance(expression, DATETRUNCS):
 935        this = expression.this
 936        trunc_type = extract_type(this)
 937        date = extract_date(this)
 938        if date and expression.unit:
 939            return date_literal(date_floor(date, expression.unit.name.lower(), dialect), trunc_type)
 940    elif comparison not in DATETRUNC_COMPARISONS:
 941        return expression
 942
 943    if isinstance(expression, exp.Binary):
 944        l, r = expression.left, expression.right
 945
 946        if not _is_datetrunc_predicate(l, r):
 947            return expression
 948
 949        l = t.cast(exp.DateTrunc, l)
 950        trunc_arg = l.this
 951        unit = l.unit.name.lower()
 952        date = extract_date(r)
 953
 954        if not date:
 955            return expression
 956
 957        return (
 958            DATETRUNC_BINARY_COMPARISONS[comparison](
 959                trunc_arg, date, unit, dialect, extract_type(trunc_arg, r)
 960            )
 961            or expression
 962        )
 963
 964    if isinstance(expression, exp.In):
 965        l = expression.this
 966        rs = expression.expressions
 967
 968        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 969            l = t.cast(exp.DateTrunc, l)
 970            unit = l.unit.name.lower()
 971
 972            ranges = []
 973            for r in rs:
 974                date = extract_date(r)
 975                if not date:
 976                    return expression
 977                drange = _datetrunc_range(date, unit, dialect)
 978                if drange:
 979                    ranges.append(drange)
 980
 981            if not ranges:
 982                return expression
 983
 984            ranges = merge_ranges(ranges)
 985            target_type = extract_type(l, *rs)
 986
 987            return exp.or_(
 988                *[_datetrunc_eq_expression(l, drange, target_type) for drange in ranges], copy=False
 989            )
 990
 991    return expression
 992
 993
 994def sort_comparison(expression: exp.Expression) -> exp.Expression:
 995    if expression.__class__ in COMPLEMENT_COMPARISONS:
 996        l, r = expression.this, expression.expression
 997        l_column = isinstance(l, exp.Column)
 998        r_column = isinstance(r, exp.Column)
 999        l_const = _is_constant(l)
1000        r_const = _is_constant(r)
1001
1002        if (l_column and not r_column) or (r_const and not l_const):
1003            return expression
1004        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1005            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1006                this=r, expression=l
1007            )
1008    return expression
1009
1010
1011# CROSS joins result in an empty table if the right table is empty.
1012# So we can only simplify certain types of joins to CROSS.
1013# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
1014JOINS = {
1015    ("", ""),
1016    ("", "INNER"),
1017    ("RIGHT", ""),
1018    ("RIGHT", "OUTER"),
1019}
1020
1021
1022def remove_where_true(expression):
1023    for where in expression.find_all(exp.Where):
1024        if always_true(where.this):
1025            where.pop()
1026    for join in expression.find_all(exp.Join):
1027        if (
1028            always_true(join.args.get("on"))
1029            and not join.args.get("using")
1030            and not join.args.get("method")
1031            and (join.side, join.kind) in JOINS
1032        ):
1033            join.args["on"].pop()
1034            join.set("side", None)
1035            join.set("kind", "CROSS")
1036
1037
1038def always_true(expression):
1039    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1040        expression, exp.Literal
1041    )
1042
1043
1044def always_false(expression):
1045    return is_false(expression) or is_null(expression)
1046
1047
1048def is_complement(a, b):
1049    return isinstance(b, exp.Not) and b.this == a
1050
1051
1052def is_false(a: exp.Expression) -> bool:
1053    return type(a) is exp.Boolean and not a.this
1054
1055
1056def is_null(a: exp.Expression) -> bool:
1057    return type(a) is exp.Null
1058
1059
1060def eval_boolean(expression, a, b):
1061    if isinstance(expression, (exp.EQ, exp.Is)):
1062        return boolean_literal(a == b)
1063    if isinstance(expression, exp.NEQ):
1064        return boolean_literal(a != b)
1065    if isinstance(expression, exp.GT):
1066        return boolean_literal(a > b)
1067    if isinstance(expression, exp.GTE):
1068        return boolean_literal(a >= b)
1069    if isinstance(expression, exp.LT):
1070        return boolean_literal(a < b)
1071    if isinstance(expression, exp.LTE):
1072        return boolean_literal(a <= b)
1073    return None
1074
1075
1076def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1077    if isinstance(value, datetime.datetime):
1078        return value.date()
1079    if isinstance(value, datetime.date):
1080        return value
1081    try:
1082        return datetime.datetime.fromisoformat(value).date()
1083    except ValueError:
1084        return None
1085
1086
1087def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1088    if isinstance(value, datetime.datetime):
1089        return value
1090    if isinstance(value, datetime.date):
1091        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1092    try:
1093        return datetime.datetime.fromisoformat(value)
1094    except ValueError:
1095        return None
1096
1097
1098def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1099    if not value:
1100        return None
1101    if to.is_type(exp.DataType.Type.DATE):
1102        return cast_as_date(value)
1103    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1104        return cast_as_datetime(value)
1105    return None
1106
1107
1108def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1109    if isinstance(cast, exp.Cast):
1110        to = cast.to
1111    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1112        to = exp.DataType.build(exp.DataType.Type.DATE)
1113    else:
1114        return None
1115
1116    if isinstance(cast.this, exp.Literal):
1117        value: t.Any = cast.this.name
1118    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1119        value = extract_date(cast.this)
1120    else:
1121        return None
1122    return cast_value(value, to)
1123
1124
1125def _is_date_literal(expression: exp.Expression) -> bool:
1126    return extract_date(expression) is not None
1127
1128
1129def extract_interval(expression):
1130    try:
1131        n = int(expression.name)
1132        unit = expression.text("unit").lower()
1133        return interval(unit, n)
1134    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1135        return None
1136
1137
1138def extract_type(*expressions):
1139    target_type = None
1140    for expression in expressions:
1141        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1142        if target_type:
1143            break
1144
1145    return target_type
1146
1147
1148def date_literal(date, target_type=None):
1149    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1150        target_type = (
1151            exp.DataType.Type.DATETIME
1152            if isinstance(date, datetime.datetime)
1153            else exp.DataType.Type.DATE
1154        )
1155
1156    return exp.cast(exp.Literal.string(date), target_type)
1157
1158
1159def interval(unit: str, n: int = 1):
1160    from dateutil.relativedelta import relativedelta
1161
1162    if unit == "year":
1163        return relativedelta(years=1 * n)
1164    if unit == "quarter":
1165        return relativedelta(months=3 * n)
1166    if unit == "month":
1167        return relativedelta(months=1 * n)
1168    if unit == "week":
1169        return relativedelta(weeks=1 * n)
1170    if unit == "day":
1171        return relativedelta(days=1 * n)
1172    if unit == "hour":
1173        return relativedelta(hours=1 * n)
1174    if unit == "minute":
1175        return relativedelta(minutes=1 * n)
1176    if unit == "second":
1177        return relativedelta(seconds=1 * n)
1178
1179    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1180
1181
1182def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1183    if unit == "year":
1184        return d.replace(month=1, day=1)
1185    if unit == "quarter":
1186        if d.month <= 3:
1187            return d.replace(month=1, day=1)
1188        elif d.month <= 6:
1189            return d.replace(month=4, day=1)
1190        elif d.month <= 9:
1191            return d.replace(month=7, day=1)
1192        else:
1193            return d.replace(month=10, day=1)
1194    if unit == "month":
1195        return d.replace(month=d.month, day=1)
1196    if unit == "week":
1197        # Assuming week starts on Monday (0) and ends on Sunday (6)
1198        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1199    if unit == "day":
1200        return d
1201
1202    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1203
1204
1205def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1206    floor = date_floor(d, unit, dialect)
1207
1208    if floor == d:
1209        return d
1210
1211    return floor + interval(unit)
1212
1213
1214def boolean_literal(condition):
1215    return exp.true() if condition else exp.false()
1216
1217
1218def _flat_simplify(expression, simplifier, root=True):
1219    if root or not expression.same_parent:
1220        operands = []
1221        queue = deque(expression.flatten(unnest=False))
1222        size = len(queue)
1223
1224        while queue:
1225            a = queue.popleft()
1226
1227            for b in queue:
1228                result = simplifier(expression, a, b)
1229
1230                if result and result is not expression:
1231                    queue.remove(b)
1232                    queue.appendleft(result)
1233                    break
1234            else:
1235                operands.append(a)
1236
1237        if len(operands) < size:
1238            return functools.reduce(
1239                lambda a, b: expression.__class__(this=a, expression=b), operands
1240            )
1241    return expression
1242
1243
1244def gen(expression: t.Any) -> str:
1245    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1246
1247    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1248    generator is expensive so we have a bare minimum sql generator here.
1249    """
1250    return Gen().gen(expression)
1251
1252
1253class Gen:
1254    def __init__(self):
1255        self.stack = []
1256        self.sqls = []
1257
1258    def gen(self, expression: exp.Expression) -> str:
1259        self.stack = [expression]
1260        self.sqls.clear()
1261
1262        while self.stack:
1263            node = self.stack.pop()
1264
1265            if isinstance(node, exp.Expression):
1266                exp_handler_name = f"{node.key}_sql"
1267
1268                if hasattr(self, exp_handler_name):
1269                    getattr(self, exp_handler_name)(node)
1270                elif isinstance(node, exp.Func):
1271                    self._function(node)
1272                else:
1273                    key = node.key.upper()
1274                    self.stack.append(f"{key} " if self._args(node) else key)
1275            elif type(node) is list:
1276                for n in reversed(node):
1277                    if n is not None:
1278                        self.stack.extend((n, ","))
1279                if node:
1280                    self.stack.pop()
1281            else:
1282                if node is not None:
1283                    self.sqls.append(str(node))
1284
1285        return "".join(self.sqls)
1286
1287    def add_sql(self, e: exp.Add) -> None:
1288        self._binary(e, " + ")
1289
1290    def alias_sql(self, e: exp.Alias) -> None:
1291        self.stack.extend(
1292            (
1293                e.args.get("alias"),
1294                " AS ",
1295                e.args.get("this"),
1296            )
1297        )
1298
1299    def and_sql(self, e: exp.And) -> None:
1300        self._binary(e, " AND ")
1301
1302    def anonymous_sql(self, e: exp.Anonymous) -> None:
1303        this = e.this
1304        if isinstance(this, str):
1305            name = this.upper()
1306        elif isinstance(this, exp.Identifier):
1307            name = this.this
1308            name = f'"{name}"' if this.quoted else name.upper()
1309        else:
1310            raise ValueError(
1311                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1312            )
1313
1314        self.stack.extend(
1315            (
1316                ")",
1317                e.expressions,
1318                "(",
1319                name,
1320            )
1321        )
1322
1323    def between_sql(self, e: exp.Between) -> None:
1324        self.stack.extend(
1325            (
1326                e.args.get("high"),
1327                " AND ",
1328                e.args.get("low"),
1329                " BETWEEN ",
1330                e.this,
1331            )
1332        )
1333
1334    def boolean_sql(self, e: exp.Boolean) -> None:
1335        self.stack.append("TRUE" if e.this else "FALSE")
1336
1337    def bracket_sql(self, e: exp.Bracket) -> None:
1338        self.stack.extend(
1339            (
1340                "]",
1341                e.expressions,
1342                "[",
1343                e.this,
1344            )
1345        )
1346
1347    def column_sql(self, e: exp.Column) -> None:
1348        for p in reversed(e.parts):
1349            self.stack.extend((p, "."))
1350        self.stack.pop()
1351
1352    def datatype_sql(self, e: exp.DataType) -> None:
1353        self._args(e, 1)
1354        self.stack.append(f"{e.this.name} ")
1355
1356    def div_sql(self, e: exp.Div) -> None:
1357        self._binary(e, " / ")
1358
1359    def dot_sql(self, e: exp.Dot) -> None:
1360        self._binary(e, ".")
1361
1362    def eq_sql(self, e: exp.EQ) -> None:
1363        self._binary(e, " = ")
1364
1365    def from_sql(self, e: exp.From) -> None:
1366        self.stack.extend((e.this, "FROM "))
1367
1368    def gt_sql(self, e: exp.GT) -> None:
1369        self._binary(e, " > ")
1370
1371    def gte_sql(self, e: exp.GTE) -> None:
1372        self._binary(e, " >= ")
1373
1374    def identifier_sql(self, e: exp.Identifier) -> None:
1375        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1376
1377    def ilike_sql(self, e: exp.ILike) -> None:
1378        self._binary(e, " ILIKE ")
1379
1380    def in_sql(self, e: exp.In) -> None:
1381        self.stack.append(")")
1382        self._args(e, 1)
1383        self.stack.extend(
1384            (
1385                "(",
1386                " IN ",
1387                e.this,
1388            )
1389        )
1390
1391    def intdiv_sql(self, e: exp.IntDiv) -> None:
1392        self._binary(e, " DIV ")
1393
1394    def is_sql(self, e: exp.Is) -> None:
1395        self._binary(e, " IS ")
1396
1397    def like_sql(self, e: exp.Like) -> None:
1398        self._binary(e, " Like ")
1399
1400    def literal_sql(self, e: exp.Literal) -> None:
1401        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1402
1403    def lt_sql(self, e: exp.LT) -> None:
1404        self._binary(e, " < ")
1405
1406    def lte_sql(self, e: exp.LTE) -> None:
1407        self._binary(e, " <= ")
1408
1409    def mod_sql(self, e: exp.Mod) -> None:
1410        self._binary(e, " % ")
1411
1412    def mul_sql(self, e: exp.Mul) -> None:
1413        self._binary(e, " * ")
1414
1415    def neg_sql(self, e: exp.Neg) -> None:
1416        self._unary(e, "-")
1417
1418    def neq_sql(self, e: exp.NEQ) -> None:
1419        self._binary(e, " <> ")
1420
1421    def not_sql(self, e: exp.Not) -> None:
1422        self._unary(e, "NOT ")
1423
1424    def null_sql(self, e: exp.Null) -> None:
1425        self.stack.append("NULL")
1426
1427    def or_sql(self, e: exp.Or) -> None:
1428        self._binary(e, " OR ")
1429
1430    def paren_sql(self, e: exp.Paren) -> None:
1431        self.stack.extend(
1432            (
1433                ")",
1434                e.this,
1435                "(",
1436            )
1437        )
1438
1439    def sub_sql(self, e: exp.Sub) -> None:
1440        self._binary(e, " - ")
1441
1442    def subquery_sql(self, e: exp.Subquery) -> None:
1443        self._args(e, 2)
1444        alias = e.args.get("alias")
1445        if alias:
1446            self.stack.append(alias)
1447        self.stack.extend((")", e.this, "("))
1448
1449    def table_sql(self, e: exp.Table) -> None:
1450        self._args(e, 4)
1451        alias = e.args.get("alias")
1452        if alias:
1453            self.stack.append(alias)
1454        for p in reversed(e.parts):
1455            self.stack.extend((p, "."))
1456        self.stack.pop()
1457
1458    def tablealias_sql(self, e: exp.TableAlias) -> None:
1459        columns = e.columns
1460
1461        if columns:
1462            self.stack.extend((")", columns, "("))
1463
1464        self.stack.extend((e.this, " AS "))
1465
1466    def var_sql(self, e: exp.Var) -> None:
1467        self.stack.append(e.this)
1468
1469    def _binary(self, e: exp.Binary, op: str) -> None:
1470        self.stack.extend((e.expression, op, e.this))
1471
1472    def _unary(self, e: exp.Unary, op: str) -> None:
1473        self.stack.extend((e.this, op))
1474
1475    def _function(self, e: exp.Func) -> None:
1476        self.stack.extend(
1477            (
1478                ")",
1479                list(e.args.values()),
1480                "(",
1481                e.sql_name(),
1482            )
1483        )
1484
1485    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1486        kvs = []
1487        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1488
1489        for k in arg_types or arg_types:
1490            v = node.args.get(k)
1491
1492            if v is not None:
1493                kvs.append([f":{k}", v])
1494        if kvs:
1495            self.stack.append(kvs)
1496            return True
1497        return False
FINAL = 'final'
TINYINT_MIN = -128
TINYINT_MAX = 127
UTINYINT_MIN = 0
UTINYINT_MAX = 255
class UnsupportedUnit(builtins.Exception):
34class UnsupportedUnit(Exception):
35    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None):
 38def simplify(
 39    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
 40):
 41    """
 42    Rewrite sqlglot AST to simplify expressions.
 43
 44    Example:
 45        >>> import sqlglot
 46        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 47        >>> simplify(expression).sql()
 48        'TRUE'
 49
 50    Args:
 51        expression (sqlglot.Expression): expression to simplify
 52        constant_propagation: whether the constant propagation rule should be used
 53
 54    Returns:
 55        sqlglot.Expression: simplified expression
 56    """
 57
 58    dialect = Dialect.get_or_raise(dialect)
 59
 60    def _simplify(expression, root=True):
 61        if expression.meta.get(FINAL):
 62            return expression
 63
 64        # group by expressions cannot be simplified, for example
 65        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 66        # the projection must exactly match the group by key
 67        group = expression.args.get("group")
 68
 69        if group and hasattr(expression, "selects"):
 70            groups = set(group.expressions)
 71            group.meta[FINAL] = True
 72
 73            for e in expression.selects:
 74                for node in e.walk():
 75                    if node in groups:
 76                        e.meta[FINAL] = True
 77                        break
 78
 79            having = expression.args.get("having")
 80            if having:
 81                for node in having.walk():
 82                    if node in groups:
 83                        having.meta[FINAL] = True
 84                        break
 85
 86        # Pre-order transformations
 87        node = expression
 88        node = rewrite_between(node)
 89        node = uniq_sort(node, root)
 90        node = absorb_and_eliminate(node, root)
 91        node = simplify_concat(node)
 92        node = simplify_conditionals(node)
 93
 94        if constant_propagation:
 95            node = propagate_constants(node, root)
 96
 97        exp.replace_children(node, lambda e: _simplify(e, False))
 98
 99        # Post-order transformations
100        node = simplify_not(node)
101        node = flatten(node)
102        node = simplify_connectors(node, root)
103        node = remove_complements(node, root)
104        node = simplify_coalesce(node)
105        node.parent = expression.parent
106        node = simplify_literals(node, root)
107        node = simplify_equality(node)
108        node = simplify_parens(node)
109        node = simplify_datetrunc(node, dialect)
110        node = sort_comparison(node)
111        node = simplify_startswith(node)
112
113        if root:
114            expression.replace(node)
115        return node
116
117    expression = while_changing(expression, _simplify)
118    remove_where_true(expression)
119    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
122def catch(*exceptions):
123    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
124
125    def decorator(func):
126        def wrapped(expression, *args, **kwargs):
127            try:
128                return func(expression, *args, **kwargs)
129            except exceptions:
130                return expression
131
132        return wrapped
133
134    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
137def rewrite_between(expression: exp.Expression) -> exp.Expression:
138    """Rewrite x between y and z to x >= y AND x <= z.
139
140    This is done because comparison simplification is only done on lt/lte/gt/gte.
141    """
142    if isinstance(expression, exp.Between):
143        negate = isinstance(expression.parent, exp.Not)
144
145        expression = exp.and_(
146            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
147            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
148            copy=False,
149        )
150
151        if negate:
152            expression = exp.paren(expression, copy=False)
153
154    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
167def simplify_not(expression):
168    """
169    Demorgan's Law
170    NOT (x OR y) -> NOT x AND NOT y
171    NOT (x AND y) -> NOT x OR NOT y
172    """
173    if isinstance(expression, exp.Not):
174        this = expression.this
175        if is_null(this):
176            return exp.null()
177        if this.__class__ in COMPLEMENT_COMPARISONS:
178            return COMPLEMENT_COMPARISONS[this.__class__](
179                this=this.this, expression=this.expression
180            )
181        if isinstance(this, exp.Paren):
182            condition = this.unnest()
183            if isinstance(condition, exp.And):
184                return exp.paren(
185                    exp.or_(
186                        exp.not_(condition.left, copy=False),
187                        exp.not_(condition.right, copy=False),
188                        copy=False,
189                    )
190                )
191            if isinstance(condition, exp.Or):
192                return exp.paren(
193                    exp.and_(
194                        exp.not_(condition.left, copy=False),
195                        exp.not_(condition.right, copy=False),
196                        copy=False,
197                    )
198                )
199            if is_null(condition):
200                return exp.null()
201        if always_true(this):
202            return exp.false()
203        if is_false(this):
204            return exp.true()
205        if isinstance(this, exp.Not):
206            # double negation
207            # NOT NOT x -> x
208            return this.this
209    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
212def flatten(expression):
213    """
214    A AND (B AND C) -> A AND B AND C
215    A OR (B OR C) -> A OR B OR C
216    """
217    if isinstance(expression, exp.Connector):
218        for node in expression.args.values():
219            child = node.unnest()
220            if isinstance(child, expression.__class__):
221                node.replace(child)
222    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
225def simplify_connectors(expression, root=True):
226    def _simplify_connectors(expression, left, right):
227        if left == right:
228            if isinstance(expression, exp.Xor):
229                return exp.false()
230            return left
231        if isinstance(expression, exp.And):
232            if is_false(left) or is_false(right):
233                return exp.false()
234            if is_null(left) or is_null(right):
235                return exp.null()
236            if always_true(left) and always_true(right):
237                return exp.true()
238            if always_true(left):
239                return right
240            if always_true(right):
241                return left
242            return _simplify_comparison(expression, left, right)
243        elif isinstance(expression, exp.Or):
244            if always_true(left) or always_true(right):
245                return exp.true()
246            if is_false(left) and is_false(right):
247                return exp.false()
248            if (
249                (is_null(left) and is_null(right))
250                or (is_null(left) and is_false(right))
251                or (is_false(left) and is_null(right))
252            ):
253                return exp.null()
254            if is_false(left):
255                return right
256            if is_false(right):
257                return left
258            return _simplify_comparison(expression, left, right, or_=True)
259
260    if isinstance(expression, exp.Connector):
261        return _flat_simplify(expression, _simplify_connectors, root)
262    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
def remove_complements(expression, root=True):
348def remove_complements(expression, root=True):
349    """
350    Removing complements.
351
352    A AND NOT A -> FALSE
353    A OR NOT A -> TRUE
354    """
355    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
356        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
357
358        for a, b in itertools.permutations(expression.flatten(), 2):
359            if is_complement(a, b):
360                return complement
361    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
364def uniq_sort(expression, root=True):
365    """
366    Uniq and sort a connector.
367
368    C AND A AND B AND B -> A AND B AND C
369    """
370    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
371        flattened = tuple(expression.flatten())
372
373        if isinstance(expression, exp.Xor):
374            result_func = exp.xor
375            # Do not deduplicate XOR as A XOR A != A if A == True
376            deduped = None
377            arr = tuple((gen(e), e) for e in flattened)
378        else:
379            result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
380            deduped = {gen(e): e for e in flattened}
381            arr = tuple(deduped.items())
382
383        # check if the operands are already sorted, if not sort them
384        # A AND C AND B -> A AND B AND C
385        for i, (sql, e) in enumerate(arr[1:]):
386            if sql < arr[i][0]:
387                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
388                break
389        else:
390            # we didn't have to sort but maybe we need to dedup
391            if deduped and len(deduped) < len(flattened):
392                expression = result_func(*deduped.values(), copy=False)
393
394    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
397def absorb_and_eliminate(expression, root=True):
398    """
399    absorption:
400        A AND (A OR B) -> A
401        A OR (A AND B) -> A
402        A AND (NOT A OR B) -> A AND B
403        A OR (NOT A AND B) -> A OR B
404    elimination:
405        (A AND B) OR (A AND NOT B) -> A
406        (A OR B) AND (A OR NOT B) -> A
407    """
408    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
409        kind = exp.Or if isinstance(expression, exp.And) else exp.And
410
411        for a, b in itertools.permutations(expression.flatten(), 2):
412            if isinstance(a, kind):
413                aa, ab = a.unnest_operands()
414
415                # absorb
416                if is_complement(b, aa):
417                    aa.replace(exp.true() if kind == exp.And else exp.false())
418                elif is_complement(b, ab):
419                    ab.replace(exp.true() if kind == exp.And else exp.false())
420                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
421                    a.replace(exp.false() if kind == exp.And else exp.true())
422                elif isinstance(b, kind):
423                    # eliminate
424                    rhs = b.unnest_operands()
425                    ba, bb = rhs
426
427                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
428                        a.replace(aa)
429                        b.replace(aa)
430                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
431                        a.replace(ab)
432                        b.replace(ab)
433
434    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
437def propagate_constants(expression, root=True):
438    """
439    Propagate constants for conjunctions in DNF:
440
441    SELECT * FROM t WHERE a = b AND b = 5 becomes
442    SELECT * FROM t WHERE a = 5 AND b = 5
443
444    Reference: https://www.sqlite.org/optoverview.html
445    """
446
447    if (
448        isinstance(expression, exp.And)
449        and (root or not expression.same_parent)
450        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
451    ):
452        constant_mapping = {}
453        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
454            if isinstance(expr, exp.EQ):
455                l, r = expr.left, expr.right
456
457                # TODO: create a helper that can be used to detect nested literal expressions such
458                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
459                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
460                    constant_mapping[l] = (id(l), r)
461
462        if constant_mapping:
463            for column in find_all_in_scope(expression, exp.Column):
464                parent = column.parent
465                column_id, constant = constant_mapping.get(column) or (None, None)
466                if (
467                    column_id is not None
468                    and id(column) != column_id
469                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
470                ):
471                    column.replace(constant.copy())
472
473    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
126        def wrapped(expression, *args, **kwargs):
127            try:
128                return func(expression, *args, **kwargs)
129            except exceptions:
130                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
548def simplify_literals(expression, root=True):
549    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
550        return _flat_simplify(expression, _simplify_binary, root)
551
552    if isinstance(expression, exp.Neg):
553        this = expression.this
554        if this.is_number:
555            value = this.name
556            if value[0] == "-":
557                return exp.Literal.number(value[1:])
558            return exp.Literal.number(f"-{value}")
559
560    if type(expression) in INVERSE_DATE_OPS:
561        return _simplify_binary(expression, expression.this, expression.interval()) or expression
562
563    return expression
def simplify_parens(expression):
664def simplify_parens(expression):
665    if not isinstance(expression, exp.Paren):
666        return expression
667
668    this = expression.this
669    parent = expression.parent
670    parent_is_predicate = isinstance(parent, exp.Predicate)
671
672    if (
673        not isinstance(this, exp.Select)
674        and not isinstance(parent, exp.SubqueryPredicate)
675        and (
676            not isinstance(parent, (exp.Condition, exp.Binary))
677            or isinstance(parent, exp.Paren)
678            or (
679                not isinstance(this, exp.Binary)
680                and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
681            )
682            or (isinstance(this, exp.Predicate) and not parent_is_predicate)
683            or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
684            or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
685            or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
686        )
687    ):
688        return this
689    return expression
def simplify_coalesce(expression):
700def simplify_coalesce(expression):
701    # COALESCE(x) -> x
702    if (
703        isinstance(expression, exp.Coalesce)
704        and (not expression.expressions or _is_nonnull_constant(expression.this))
705        # COALESCE is also used as a Spark partitioning hint
706        and not isinstance(expression.parent, exp.Hint)
707    ):
708        return expression.this
709
710    if not isinstance(expression, COMPARISONS):
711        return expression
712
713    if isinstance(expression.left, exp.Coalesce):
714        coalesce = expression.left
715        other = expression.right
716    elif isinstance(expression.right, exp.Coalesce):
717        coalesce = expression.right
718        other = expression.left
719    else:
720        return expression
721
722    # This transformation is valid for non-constants,
723    # but it really only does anything if they are both constants.
724    if not _is_constant(other):
725        return expression
726
727    # Find the first constant arg
728    for arg_index, arg in enumerate(coalesce.expressions):
729        if _is_constant(arg):
730            break
731    else:
732        return expression
733
734    coalesce.set("expressions", coalesce.expressions[:arg_index])
735
736    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
737    # since we already remove COALESCE at the top of this function.
738    coalesce = coalesce if coalesce.expressions else coalesce.this
739
740    # This expression is more complex than when we started, but it will get simplified further
741    return exp.paren(
742        exp.or_(
743            exp.and_(
744                coalesce.is_(exp.null()).not_(copy=False),
745                expression.copy(),
746                copy=False,
747            ),
748            exp.and_(
749                coalesce.is_(exp.null()),
750                type(expression)(this=arg.copy(), expression=other.copy()),
751                copy=False,
752            ),
753            copy=False,
754        )
755    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
761def simplify_concat(expression):
762    """Reduces all groups that contain string literals by concatenating them."""
763    if not isinstance(expression, CONCATS) or (
764        # We can't reduce a CONCAT_WS call if we don't statically know the separator
765        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
766    ):
767        return expression
768
769    if isinstance(expression, exp.ConcatWs):
770        sep_expr, *expressions = expression.expressions
771        sep = sep_expr.name
772        concat_type = exp.ConcatWs
773        args = {}
774    else:
775        expressions = expression.expressions
776        sep = ""
777        concat_type = exp.Concat
778        args = {
779            "safe": expression.args.get("safe"),
780            "coalesce": expression.args.get("coalesce"),
781        }
782
783    new_args = []
784    for is_string_group, group in itertools.groupby(
785        expressions or expression.flatten(), lambda e: e.is_string
786    ):
787        if is_string_group:
788            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
789        else:
790            new_args.extend(group)
791
792    if len(new_args) == 1 and new_args[0].is_string:
793        return new_args[0]
794
795    if concat_type is exp.ConcatWs:
796        new_args = [sep_expr] + new_args
797    elif isinstance(expression, exp.DPipe):
798        return reduce(lambda x, y: exp.DPipe(this=x, expression=y), new_args)
799
800    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
803def simplify_conditionals(expression):
804    """Simplifies expressions like IF, CASE if their condition is statically known."""
805    if isinstance(expression, exp.Case):
806        this = expression.this
807        for case in expression.args["ifs"]:
808            cond = case.this
809            if this:
810                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
811                cond = cond.replace(this.pop().eq(cond))
812
813            if always_true(cond):
814                return case.args["true"]
815
816            if always_false(cond):
817                case.pop()
818                if not expression.args["ifs"]:
819                    return expression.args.get("default") or exp.null()
820    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
821        if always_true(expression.this):
822            return expression.args["true"]
823        if always_false(expression.this):
824            return expression.args.get("false") or exp.null()
825
826    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
829def simplify_startswith(expression: exp.Expression) -> exp.Expression:
830    """
831    Reduces a prefix check to either TRUE or FALSE if both the string and the
832    prefix are statically known.
833
834    Example:
835        >>> from sqlglot import parse_one
836        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
837        'TRUE'
838    """
839    if (
840        isinstance(expression, exp.StartsWith)
841        and expression.this.is_string
842        and expression.expression.is_string
843    ):
844        return exp.convert(expression.name.startswith(expression.expression.name))
845
846    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.dialect.Dialect, sqlglot.expressions.DataType], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.In'>}
def simplify_datetrunc(expression, *args, **kwargs):
126        def wrapped(expression, *args, **kwargs):
127            try:
128                return func(expression, *args, **kwargs)
129            except exceptions:
130                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
 995def sort_comparison(expression: exp.Expression) -> exp.Expression:
 996    if expression.__class__ in COMPLEMENT_COMPARISONS:
 997        l, r = expression.this, expression.expression
 998        l_column = isinstance(l, exp.Column)
 999        r_column = isinstance(r, exp.Column)
1000        l_const = _is_constant(l)
1001        r_const = _is_constant(r)
1002
1003        if (l_column and not r_column) or (r_const and not l_const):
1004            return expression
1005        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
1006            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
1007                this=r, expression=l
1008            )
1009    return expression
JOINS = {('RIGHT', ''), ('RIGHT', 'OUTER'), ('', 'INNER'), ('', '')}
def remove_where_true(expression):
1023def remove_where_true(expression):
1024    for where in expression.find_all(exp.Where):
1025        if always_true(where.this):
1026            where.pop()
1027    for join in expression.find_all(exp.Join):
1028        if (
1029            always_true(join.args.get("on"))
1030            and not join.args.get("using")
1031            and not join.args.get("method")
1032            and (join.side, join.kind) in JOINS
1033        ):
1034            join.args["on"].pop()
1035            join.set("side", None)
1036            join.set("kind", "CROSS")
def always_true(expression):
1039def always_true(expression):
1040    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
1041        expression, exp.Literal
1042    )
def always_false(expression):
1045def always_false(expression):
1046    return is_false(expression) or is_null(expression)
def is_complement(a, b):
1049def is_complement(a, b):
1050    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
1053def is_false(a: exp.Expression) -> bool:
1054    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
1057def is_null(a: exp.Expression) -> bool:
1058    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1061def eval_boolean(expression, a, b):
1062    if isinstance(expression, (exp.EQ, exp.Is)):
1063        return boolean_literal(a == b)
1064    if isinstance(expression, exp.NEQ):
1065        return boolean_literal(a != b)
1066    if isinstance(expression, exp.GT):
1067        return boolean_literal(a > b)
1068    if isinstance(expression, exp.GTE):
1069        return boolean_literal(a >= b)
1070    if isinstance(expression, exp.LT):
1071        return boolean_literal(a < b)
1072    if isinstance(expression, exp.LTE):
1073        return boolean_literal(a <= b)
1074    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1077def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1078    if isinstance(value, datetime.datetime):
1079        return value.date()
1080    if isinstance(value, datetime.date):
1081        return value
1082    try:
1083        return datetime.datetime.fromisoformat(value).date()
1084    except ValueError:
1085        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1088def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1089    if isinstance(value, datetime.datetime):
1090        return value
1091    if isinstance(value, datetime.date):
1092        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1093    try:
1094        return datetime.datetime.fromisoformat(value)
1095    except ValueError:
1096        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1099def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1100    if not value:
1101        return None
1102    if to.is_type(exp.DataType.Type.DATE):
1103        return cast_as_date(value)
1104    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1105        return cast_as_datetime(value)
1106    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1109def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1110    if isinstance(cast, exp.Cast):
1111        to = cast.to
1112    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1113        to = exp.DataType.build(exp.DataType.Type.DATE)
1114    else:
1115        return None
1116
1117    if isinstance(cast.this, exp.Literal):
1118        value: t.Any = cast.this.name
1119    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1120        value = extract_date(cast.this)
1121    else:
1122        return None
1123    return cast_value(value, to)
def extract_interval(expression):
1130def extract_interval(expression):
1131    try:
1132        n = int(expression.name)
1133        unit = expression.text("unit").lower()
1134        return interval(unit, n)
1135    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1136        return None
def extract_type(*expressions):
1139def extract_type(*expressions):
1140    target_type = None
1141    for expression in expressions:
1142        target_type = expression.to if isinstance(expression, exp.Cast) else expression.type
1143        if target_type:
1144            break
1145
1146    return target_type
def date_literal(date, target_type=None):
1149def date_literal(date, target_type=None):
1150    if not target_type or not target_type.is_type(*exp.DataType.TEMPORAL_TYPES):
1151        target_type = (
1152            exp.DataType.Type.DATETIME
1153            if isinstance(date, datetime.datetime)
1154            else exp.DataType.Type.DATE
1155        )
1156
1157    return exp.cast(exp.Literal.string(date), target_type)
def interval(unit: str, n: int = 1):
1160def interval(unit: str, n: int = 1):
1161    from dateutil.relativedelta import relativedelta
1162
1163    if unit == "year":
1164        return relativedelta(years=1 * n)
1165    if unit == "quarter":
1166        return relativedelta(months=3 * n)
1167    if unit == "month":
1168        return relativedelta(months=1 * n)
1169    if unit == "week":
1170        return relativedelta(weeks=1 * n)
1171    if unit == "day":
1172        return relativedelta(days=1 * n)
1173    if unit == "hour":
1174        return relativedelta(hours=1 * n)
1175    if unit == "minute":
1176        return relativedelta(minutes=1 * n)
1177    if unit == "second":
1178        return relativedelta(seconds=1 * n)
1179
1180    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1183def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1184    if unit == "year":
1185        return d.replace(month=1, day=1)
1186    if unit == "quarter":
1187        if d.month <= 3:
1188            return d.replace(month=1, day=1)
1189        elif d.month <= 6:
1190            return d.replace(month=4, day=1)
1191        elif d.month <= 9:
1192            return d.replace(month=7, day=1)
1193        else:
1194            return d.replace(month=10, day=1)
1195    if unit == "month":
1196        return d.replace(month=d.month, day=1)
1197    if unit == "week":
1198        # Assuming week starts on Monday (0) and ends on Sunday (6)
1199        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1200    if unit == "day":
1201        return d
1202
1203    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1206def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1207    floor = date_floor(d, unit, dialect)
1208
1209    if floor == d:
1210        return d
1211
1212    return floor + interval(unit)
def boolean_literal(condition):
1215def boolean_literal(condition):
1216    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1245def gen(expression: t.Any) -> str:
1246    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1247
1248    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1249    generator is expensive so we have a bare minimum sql generator here.
1250    """
1251    return Gen().gen(expression)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

class Gen:
1254class Gen:
1255    def __init__(self):
1256        self.stack = []
1257        self.sqls = []
1258
1259    def gen(self, expression: exp.Expression) -> str:
1260        self.stack = [expression]
1261        self.sqls.clear()
1262
1263        while self.stack:
1264            node = self.stack.pop()
1265
1266            if isinstance(node, exp.Expression):
1267                exp_handler_name = f"{node.key}_sql"
1268
1269                if hasattr(self, exp_handler_name):
1270                    getattr(self, exp_handler_name)(node)
1271                elif isinstance(node, exp.Func):
1272                    self._function(node)
1273                else:
1274                    key = node.key.upper()
1275                    self.stack.append(f"{key} " if self._args(node) else key)
1276            elif type(node) is list:
1277                for n in reversed(node):
1278                    if n is not None:
1279                        self.stack.extend((n, ","))
1280                if node:
1281                    self.stack.pop()
1282            else:
1283                if node is not None:
1284                    self.sqls.append(str(node))
1285
1286        return "".join(self.sqls)
1287
1288    def add_sql(self, e: exp.Add) -> None:
1289        self._binary(e, " + ")
1290
1291    def alias_sql(self, e: exp.Alias) -> None:
1292        self.stack.extend(
1293            (
1294                e.args.get("alias"),
1295                " AS ",
1296                e.args.get("this"),
1297            )
1298        )
1299
1300    def and_sql(self, e: exp.And) -> None:
1301        self._binary(e, " AND ")
1302
1303    def anonymous_sql(self, e: exp.Anonymous) -> None:
1304        this = e.this
1305        if isinstance(this, str):
1306            name = this.upper()
1307        elif isinstance(this, exp.Identifier):
1308            name = this.this
1309            name = f'"{name}"' if this.quoted else name.upper()
1310        else:
1311            raise ValueError(
1312                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1313            )
1314
1315        self.stack.extend(
1316            (
1317                ")",
1318                e.expressions,
1319                "(",
1320                name,
1321            )
1322        )
1323
1324    def between_sql(self, e: exp.Between) -> None:
1325        self.stack.extend(
1326            (
1327                e.args.get("high"),
1328                " AND ",
1329                e.args.get("low"),
1330                " BETWEEN ",
1331                e.this,
1332            )
1333        )
1334
1335    def boolean_sql(self, e: exp.Boolean) -> None:
1336        self.stack.append("TRUE" if e.this else "FALSE")
1337
1338    def bracket_sql(self, e: exp.Bracket) -> None:
1339        self.stack.extend(
1340            (
1341                "]",
1342                e.expressions,
1343                "[",
1344                e.this,
1345            )
1346        )
1347
1348    def column_sql(self, e: exp.Column) -> None:
1349        for p in reversed(e.parts):
1350            self.stack.extend((p, "."))
1351        self.stack.pop()
1352
1353    def datatype_sql(self, e: exp.DataType) -> None:
1354        self._args(e, 1)
1355        self.stack.append(f"{e.this.name} ")
1356
1357    def div_sql(self, e: exp.Div) -> None:
1358        self._binary(e, " / ")
1359
1360    def dot_sql(self, e: exp.Dot) -> None:
1361        self._binary(e, ".")
1362
1363    def eq_sql(self, e: exp.EQ) -> None:
1364        self._binary(e, " = ")
1365
1366    def from_sql(self, e: exp.From) -> None:
1367        self.stack.extend((e.this, "FROM "))
1368
1369    def gt_sql(self, e: exp.GT) -> None:
1370        self._binary(e, " > ")
1371
1372    def gte_sql(self, e: exp.GTE) -> None:
1373        self._binary(e, " >= ")
1374
1375    def identifier_sql(self, e: exp.Identifier) -> None:
1376        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1377
1378    def ilike_sql(self, e: exp.ILike) -> None:
1379        self._binary(e, " ILIKE ")
1380
1381    def in_sql(self, e: exp.In) -> None:
1382        self.stack.append(")")
1383        self._args(e, 1)
1384        self.stack.extend(
1385            (
1386                "(",
1387                " IN ",
1388                e.this,
1389            )
1390        )
1391
1392    def intdiv_sql(self, e: exp.IntDiv) -> None:
1393        self._binary(e, " DIV ")
1394
1395    def is_sql(self, e: exp.Is) -> None:
1396        self._binary(e, " IS ")
1397
1398    def like_sql(self, e: exp.Like) -> None:
1399        self._binary(e, " Like ")
1400
1401    def literal_sql(self, e: exp.Literal) -> None:
1402        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1403
1404    def lt_sql(self, e: exp.LT) -> None:
1405        self._binary(e, " < ")
1406
1407    def lte_sql(self, e: exp.LTE) -> None:
1408        self._binary(e, " <= ")
1409
1410    def mod_sql(self, e: exp.Mod) -> None:
1411        self._binary(e, " % ")
1412
1413    def mul_sql(self, e: exp.Mul) -> None:
1414        self._binary(e, " * ")
1415
1416    def neg_sql(self, e: exp.Neg) -> None:
1417        self._unary(e, "-")
1418
1419    def neq_sql(self, e: exp.NEQ) -> None:
1420        self._binary(e, " <> ")
1421
1422    def not_sql(self, e: exp.Not) -> None:
1423        self._unary(e, "NOT ")
1424
1425    def null_sql(self, e: exp.Null) -> None:
1426        self.stack.append("NULL")
1427
1428    def or_sql(self, e: exp.Or) -> None:
1429        self._binary(e, " OR ")
1430
1431    def paren_sql(self, e: exp.Paren) -> None:
1432        self.stack.extend(
1433            (
1434                ")",
1435                e.this,
1436                "(",
1437            )
1438        )
1439
1440    def sub_sql(self, e: exp.Sub) -> None:
1441        self._binary(e, " - ")
1442
1443    def subquery_sql(self, e: exp.Subquery) -> None:
1444        self._args(e, 2)
1445        alias = e.args.get("alias")
1446        if alias:
1447            self.stack.append(alias)
1448        self.stack.extend((")", e.this, "("))
1449
1450    def table_sql(self, e: exp.Table) -> None:
1451        self._args(e, 4)
1452        alias = e.args.get("alias")
1453        if alias:
1454            self.stack.append(alias)
1455        for p in reversed(e.parts):
1456            self.stack.extend((p, "."))
1457        self.stack.pop()
1458
1459    def tablealias_sql(self, e: exp.TableAlias) -> None:
1460        columns = e.columns
1461
1462        if columns:
1463            self.stack.extend((")", columns, "("))
1464
1465        self.stack.extend((e.this, " AS "))
1466
1467    def var_sql(self, e: exp.Var) -> None:
1468        self.stack.append(e.this)
1469
1470    def _binary(self, e: exp.Binary, op: str) -> None:
1471        self.stack.extend((e.expression, op, e.this))
1472
1473    def _unary(self, e: exp.Unary, op: str) -> None:
1474        self.stack.extend((e.this, op))
1475
1476    def _function(self, e: exp.Func) -> None:
1477        self.stack.extend(
1478            (
1479                ")",
1480                list(e.args.values()),
1481                "(",
1482                e.sql_name(),
1483            )
1484        )
1485
1486    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1487        kvs = []
1488        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1489
1490        for k in arg_types or arg_types:
1491            v = node.args.get(k)
1492
1493            if v is not None:
1494                kvs.append([f":{k}", v])
1495        if kvs:
1496            self.stack.append(kvs)
1497            return True
1498        return False
stack
sqls
def gen(self, expression: sqlglot.expressions.Expression) -> str:
1259    def gen(self, expression: exp.Expression) -> str:
1260        self.stack = [expression]
1261        self.sqls.clear()
1262
1263        while self.stack:
1264            node = self.stack.pop()
1265
1266            if isinstance(node, exp.Expression):
1267                exp_handler_name = f"{node.key}_sql"
1268
1269                if hasattr(self, exp_handler_name):
1270                    getattr(self, exp_handler_name)(node)
1271                elif isinstance(node, exp.Func):
1272                    self._function(node)
1273                else:
1274                    key = node.key.upper()
1275                    self.stack.append(f"{key} " if self._args(node) else key)
1276            elif type(node) is list:
1277                for n in reversed(node):
1278                    if n is not None:
1279                        self.stack.extend((n, ","))
1280                if node:
1281                    self.stack.pop()
1282            else:
1283                if node is not None:
1284                    self.sqls.append(str(node))
1285
1286        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1288    def add_sql(self, e: exp.Add) -> None:
1289        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1291    def alias_sql(self, e: exp.Alias) -> None:
1292        self.stack.extend(
1293            (
1294                e.args.get("alias"),
1295                " AS ",
1296                e.args.get("this"),
1297            )
1298        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1300    def and_sql(self, e: exp.And) -> None:
1301        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1303    def anonymous_sql(self, e: exp.Anonymous) -> None:
1304        this = e.this
1305        if isinstance(this, str):
1306            name = this.upper()
1307        elif isinstance(this, exp.Identifier):
1308            name = this.this
1309            name = f'"{name}"' if this.quoted else name.upper()
1310        else:
1311            raise ValueError(
1312                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1313            )
1314
1315        self.stack.extend(
1316            (
1317                ")",
1318                e.expressions,
1319                "(",
1320                name,
1321            )
1322        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1324    def between_sql(self, e: exp.Between) -> None:
1325        self.stack.extend(
1326            (
1327                e.args.get("high"),
1328                " AND ",
1329                e.args.get("low"),
1330                " BETWEEN ",
1331                e.this,
1332            )
1333        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1335    def boolean_sql(self, e: exp.Boolean) -> None:
1336        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1338    def bracket_sql(self, e: exp.Bracket) -> None:
1339        self.stack.extend(
1340            (
1341                "]",
1342                e.expressions,
1343                "[",
1344                e.this,
1345            )
1346        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1348    def column_sql(self, e: exp.Column) -> None:
1349        for p in reversed(e.parts):
1350            self.stack.extend((p, "."))
1351        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1353    def datatype_sql(self, e: exp.DataType) -> None:
1354        self._args(e, 1)
1355        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1357    def div_sql(self, e: exp.Div) -> None:
1358        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1360    def dot_sql(self, e: exp.Dot) -> None:
1361        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1363    def eq_sql(self, e: exp.EQ) -> None:
1364        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1366    def from_sql(self, e: exp.From) -> None:
1367        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1369    def gt_sql(self, e: exp.GT) -> None:
1370        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1372    def gte_sql(self, e: exp.GTE) -> None:
1373        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1375    def identifier_sql(self, e: exp.Identifier) -> None:
1376        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1378    def ilike_sql(self, e: exp.ILike) -> None:
1379        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1381    def in_sql(self, e: exp.In) -> None:
1382        self.stack.append(")")
1383        self._args(e, 1)
1384        self.stack.extend(
1385            (
1386                "(",
1387                " IN ",
1388                e.this,
1389            )
1390        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1392    def intdiv_sql(self, e: exp.IntDiv) -> None:
1393        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1395    def is_sql(self, e: exp.Is) -> None:
1396        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1398    def like_sql(self, e: exp.Like) -> None:
1399        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1401    def literal_sql(self, e: exp.Literal) -> None:
1402        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1404    def lt_sql(self, e: exp.LT) -> None:
1405        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1407    def lte_sql(self, e: exp.LTE) -> None:
1408        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1410    def mod_sql(self, e: exp.Mod) -> None:
1411        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1413    def mul_sql(self, e: exp.Mul) -> None:
1414        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1416    def neg_sql(self, e: exp.Neg) -> None:
1417        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1419    def neq_sql(self, e: exp.NEQ) -> None:
1420        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1422    def not_sql(self, e: exp.Not) -> None:
1423        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1425    def null_sql(self, e: exp.Null) -> None:
1426        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1428    def or_sql(self, e: exp.Or) -> None:
1429        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1431    def paren_sql(self, e: exp.Paren) -> None:
1432        self.stack.extend(
1433            (
1434                ")",
1435                e.this,
1436                "(",
1437            )
1438        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1440    def sub_sql(self, e: exp.Sub) -> None:
1441        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1443    def subquery_sql(self, e: exp.Subquery) -> None:
1444        self._args(e, 2)
1445        alias = e.args.get("alias")
1446        if alias:
1447            self.stack.append(alias)
1448        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1450    def table_sql(self, e: exp.Table) -> None:
1451        self._args(e, 4)
1452        alias = e.args.get("alias")
1453        if alias:
1454            self.stack.append(alias)
1455        for p in reversed(e.parts):
1456            self.stack.extend((p, "."))
1457        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1459    def tablealias_sql(self, e: exp.TableAlias) -> None:
1460        columns = e.columns
1461
1462        if columns:
1463            self.stack.extend((")", columns, "("))
1464
1465        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1467    def var_sql(self, e: exp.Var) -> None:
1468        self.stack.append(e.this)