Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4from collections import deque
  5from decimal import Decimal
  6
  7from sqlglot import exp
  8from sqlglot.generator import cached_generator
  9from sqlglot.helper import first, while_changing
 10
 11# Final means that an expression should not be simplified
 12FINAL = "final"
 13
 14
 15def simplify(expression):
 16    """
 17    Rewrite sqlglot AST to simplify expressions.
 18
 19    Example:
 20        >>> import sqlglot
 21        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 22        >>> simplify(expression).sql()
 23        'TRUE'
 24
 25    Args:
 26        expression (sqlglot.Expression): expression to simplify
 27    Returns:
 28        sqlglot.Expression: simplified expression
 29    """
 30
 31    generate = cached_generator()
 32
 33    # group by expressions cannot be simplified, for example
 34    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 35    # the projection must exactly match the group by key
 36    for group in expression.find_all(exp.Group):
 37        select = group.parent
 38        groups = set(group.expressions)
 39        group.meta[FINAL] = True
 40
 41        for e in select.selects:
 42            for node, *_ in e.walk():
 43                if node in groups:
 44                    e.meta[FINAL] = True
 45                    break
 46
 47        having = select.args.get("having")
 48        if having:
 49            for node, *_ in having.walk():
 50                if node in groups:
 51                    having.meta[FINAL] = True
 52                    break
 53
 54    def _simplify(expression, root=True):
 55        if expression.meta.get(FINAL):
 56            return expression
 57
 58        # Pre-order transformations
 59        node = expression
 60        node = rewrite_between(node)
 61        node = uniq_sort(node, generate, root)
 62        node = absorb_and_eliminate(node, root)
 63        node = simplify_concat(node)
 64
 65        exp.replace_children(node, lambda e: _simplify(e, False))
 66
 67        # Post-order transformations
 68        node = simplify_not(node)
 69        node = flatten(node)
 70        node = simplify_connectors(node, root)
 71        node = remove_compliments(node, root)
 72        node = simplify_coalesce(node)
 73        node.parent = expression.parent
 74        node = simplify_literals(node, root)
 75        node = simplify_parens(node)
 76
 77        if root:
 78            expression.replace(node)
 79
 80        return node
 81
 82    expression = while_changing(expression, _simplify)
 83    remove_where_true(expression)
 84    return expression
 85
 86
 87def rewrite_between(expression: exp.Expression) -> exp.Expression:
 88    """Rewrite x between y and z to x >= y AND x <= z.
 89
 90    This is done because comparison simplification is only done on lt/lte/gt/gte.
 91    """
 92    if isinstance(expression, exp.Between):
 93        return exp.and_(
 94            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 95            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 96            copy=False,
 97        )
 98    return expression
 99
100
101def simplify_not(expression):
102    """
103    Demorgan's Law
104    NOT (x OR y) -> NOT x AND NOT y
105    NOT (x AND y) -> NOT x OR NOT y
106    """
107    if isinstance(expression, exp.Not):
108        if is_null(expression.this):
109            return exp.null()
110        if isinstance(expression.this, exp.Paren):
111            condition = expression.this.unnest()
112            if isinstance(condition, exp.And):
113                return exp.or_(
114                    exp.not_(condition.left, copy=False),
115                    exp.not_(condition.right, copy=False),
116                    copy=False,
117                )
118            if isinstance(condition, exp.Or):
119                return exp.and_(
120                    exp.not_(condition.left, copy=False),
121                    exp.not_(condition.right, copy=False),
122                    copy=False,
123                )
124            if is_null(condition):
125                return exp.null()
126        if always_true(expression.this):
127            return exp.false()
128        if is_false(expression.this):
129            return exp.true()
130        if isinstance(expression.this, exp.Not):
131            # double negation
132            # NOT NOT x -> x
133            return expression.this.this
134    return expression
135
136
137def flatten(expression):
138    """
139    A AND (B AND C) -> A AND B AND C
140    A OR (B OR C) -> A OR B OR C
141    """
142    if isinstance(expression, exp.Connector):
143        for node in expression.args.values():
144            child = node.unnest()
145            if isinstance(child, expression.__class__):
146                node.replace(child)
147    return expression
148
149
150def simplify_connectors(expression, root=True):
151    def _simplify_connectors(expression, left, right):
152        if left == right:
153            return left
154        if isinstance(expression, exp.And):
155            if is_false(left) or is_false(right):
156                return exp.false()
157            if is_null(left) or is_null(right):
158                return exp.null()
159            if always_true(left) and always_true(right):
160                return exp.true()
161            if always_true(left):
162                return right
163            if always_true(right):
164                return left
165            return _simplify_comparison(expression, left, right)
166        elif isinstance(expression, exp.Or):
167            if always_true(left) or always_true(right):
168                return exp.true()
169            if is_false(left) and is_false(right):
170                return exp.false()
171            if (
172                (is_null(left) and is_null(right))
173                or (is_null(left) and is_false(right))
174                or (is_false(left) and is_null(right))
175            ):
176                return exp.null()
177            if is_false(left):
178                return right
179            if is_false(right):
180                return left
181            return _simplify_comparison(expression, left, right, or_=True)
182
183    if isinstance(expression, exp.Connector):
184        return _flat_simplify(expression, _simplify_connectors, root)
185    return expression
186
187
188LT_LTE = (exp.LT, exp.LTE)
189GT_GTE = (exp.GT, exp.GTE)
190
191COMPARISONS = (
192    *LT_LTE,
193    *GT_GTE,
194    exp.EQ,
195    exp.NEQ,
196    exp.Is,
197)
198
199INVERSE_COMPARISONS = {
200    exp.LT: exp.GT,
201    exp.GT: exp.LT,
202    exp.LTE: exp.GTE,
203    exp.GTE: exp.LTE,
204}
205
206
207def _simplify_comparison(expression, left, right, or_=False):
208    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
209        ll, lr = left.args.values()
210        rl, rr = right.args.values()
211
212        largs = {ll, lr}
213        rargs = {rl, rr}
214
215        matching = largs & rargs
216        columns = {m for m in matching if isinstance(m, exp.Column)}
217
218        if matching and columns:
219            try:
220                l = first(largs - columns)
221                r = first(rargs - columns)
222            except StopIteration:
223                return expression
224
225            # make sure the comparison is always of the form x > 1 instead of 1 < x
226            if left.__class__ in INVERSE_COMPARISONS and l == ll:
227                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
228            if right.__class__ in INVERSE_COMPARISONS and r == rl:
229                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
230
231            if l.is_number and r.is_number:
232                l = float(l.name)
233                r = float(r.name)
234            elif l.is_string and r.is_string:
235                l = l.name
236                r = r.name
237            else:
238                return None
239
240            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
241                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
242                    return left if (av > bv if or_ else av <= bv) else right
243                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
244                    return left if (av < bv if or_ else av >= bv) else right
245
246                # we can't ever shortcut to true because the column could be null
247                if not or_:
248                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
249                        if av <= bv:
250                            return exp.false()
251                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
252                        if av >= bv:
253                            return exp.false()
254                    elif isinstance(a, exp.EQ):
255                        if isinstance(b, exp.LT):
256                            return exp.false() if av >= bv else a
257                        if isinstance(b, exp.LTE):
258                            return exp.false() if av > bv else a
259                        if isinstance(b, exp.GT):
260                            return exp.false() if av <= bv else a
261                        if isinstance(b, exp.GTE):
262                            return exp.false() if av < bv else a
263                        if isinstance(b, exp.NEQ):
264                            return exp.false() if av == bv else a
265    return None
266
267
268def remove_compliments(expression, root=True):
269    """
270    Removing compliments.
271
272    A AND NOT A -> FALSE
273    A OR NOT A -> TRUE
274    """
275    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
276        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
277
278        for a, b in itertools.permutations(expression.flatten(), 2):
279            if is_complement(a, b):
280                return compliment
281    return expression
282
283
284def uniq_sort(expression, generate, root=True):
285    """
286    Uniq and sort a connector.
287
288    C AND A AND B AND B -> A AND B AND C
289    """
290    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
291        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
292        flattened = tuple(expression.flatten())
293        deduped = {generate(e): e for e in flattened}
294        arr = tuple(deduped.items())
295
296        # check if the operands are already sorted, if not sort them
297        # A AND C AND B -> A AND B AND C
298        for i, (sql, e) in enumerate(arr[1:]):
299            if sql < arr[i][0]:
300                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
301                break
302        else:
303            # we didn't have to sort but maybe we need to dedup
304            if len(deduped) < len(flattened):
305                expression = result_func(*deduped.values(), copy=False)
306
307    return expression
308
309
310def absorb_and_eliminate(expression, root=True):
311    """
312    absorption:
313        A AND (A OR B) -> A
314        A OR (A AND B) -> A
315        A AND (NOT A OR B) -> A AND B
316        A OR (NOT A AND B) -> A OR B
317    elimination:
318        (A AND B) OR (A AND NOT B) -> A
319        (A OR B) AND (A OR NOT B) -> A
320    """
321    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
322        kind = exp.Or if isinstance(expression, exp.And) else exp.And
323
324        for a, b in itertools.permutations(expression.flatten(), 2):
325            if isinstance(a, kind):
326                aa, ab = a.unnest_operands()
327
328                # absorb
329                if is_complement(b, aa):
330                    aa.replace(exp.true() if kind == exp.And else exp.false())
331                elif is_complement(b, ab):
332                    ab.replace(exp.true() if kind == exp.And else exp.false())
333                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
334                    a.replace(exp.false() if kind == exp.And else exp.true())
335                elif isinstance(b, kind):
336                    # eliminate
337                    rhs = b.unnest_operands()
338                    ba, bb = rhs
339
340                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
341                        a.replace(aa)
342                        b.replace(aa)
343                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
344                        a.replace(ab)
345                        b.replace(ab)
346
347    return expression
348
349
350def simplify_literals(expression, root=True):
351    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
352        return _flat_simplify(expression, _simplify_binary, root)
353
354    if isinstance(expression, exp.Neg):
355        this = expression.this
356        if this.is_number:
357            value = this.name
358            if value[0] == "-":
359                return exp.Literal.number(value[1:])
360            return exp.Literal.number(f"-{value}")
361
362    return expression
363
364
365def _simplify_binary(expression, a, b):
366    if isinstance(expression, exp.Is):
367        if isinstance(b, exp.Not):
368            c = b.this
369            not_ = True
370        else:
371            c = b
372            not_ = False
373
374        if is_null(c):
375            if isinstance(a, exp.Literal):
376                return exp.true() if not_ else exp.false()
377            if is_null(a):
378                return exp.false() if not_ else exp.true()
379    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
380        return None
381    elif is_null(a) or is_null(b):
382        return exp.null()
383
384    if a.is_number and b.is_number:
385        a = int(a.name) if a.is_int else Decimal(a.name)
386        b = int(b.name) if b.is_int else Decimal(b.name)
387
388        if isinstance(expression, exp.Add):
389            return exp.Literal.number(a + b)
390        if isinstance(expression, exp.Sub):
391            return exp.Literal.number(a - b)
392        if isinstance(expression, exp.Mul):
393            return exp.Literal.number(a * b)
394        if isinstance(expression, exp.Div):
395            # engines have differing int div behavior so intdiv is not safe
396            if isinstance(a, int) and isinstance(b, int):
397                return None
398            return exp.Literal.number(a / b)
399
400        boolean = eval_boolean(expression, a, b)
401
402        if boolean:
403            return boolean
404    elif a.is_string and b.is_string:
405        boolean = eval_boolean(expression, a.this, b.this)
406
407        if boolean:
408            return boolean
409    elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
410        a, b = extract_date(a), extract_interval(b)
411        if a and b:
412            if isinstance(expression, exp.Add):
413                return date_literal(a + b)
414            if isinstance(expression, exp.Sub):
415                return date_literal(a - b)
416    elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
417        a, b = extract_interval(a), extract_date(b)
418        # you cannot subtract a date from an interval
419        if a and b and isinstance(expression, exp.Add):
420            return date_literal(a + b)
421
422    return None
423
424
425def simplify_parens(expression):
426    if not isinstance(expression, exp.Paren):
427        return expression
428
429    this = expression.this
430    parent = expression.parent
431
432    if not isinstance(this, exp.Select) and (
433        not isinstance(parent, (exp.Condition, exp.Binary))
434        or isinstance(parent, exp.Paren)
435        or not isinstance(this, exp.Binary)
436        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
437        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
438        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
439        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
440    ):
441        return this
442    return expression
443
444
445CONSTANTS = (
446    exp.Literal,
447    exp.Boolean,
448    exp.Null,
449)
450
451
452def simplify_coalesce(expression):
453    # COALESCE(x) -> x
454    if (
455        isinstance(expression, exp.Coalesce)
456        and not expression.expressions
457        # COALESCE is also used as a Spark partitioning hint
458        and not isinstance(expression.parent, exp.Hint)
459    ):
460        return expression.this
461
462    if not isinstance(expression, COMPARISONS):
463        return expression
464
465    if isinstance(expression.left, exp.Coalesce):
466        coalesce = expression.left
467        other = expression.right
468    elif isinstance(expression.right, exp.Coalesce):
469        coalesce = expression.right
470        other = expression.left
471    else:
472        return expression
473
474    # This transformation is valid for non-constants,
475    # but it really only does anything if they are both constants.
476    if not isinstance(other, CONSTANTS):
477        return expression
478
479    # Find the first constant arg
480    for arg_index, arg in enumerate(coalesce.expressions):
481        if isinstance(arg, CONSTANTS):
482            break
483    else:
484        return expression
485
486    coalesce.set("expressions", coalesce.expressions[:arg_index])
487
488    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
489    # since we already remove COALESCE at the top of this function.
490    coalesce = coalesce if coalesce.expressions else coalesce.this
491
492    # This expression is more complex than when we started, but it will get simplified further
493    return exp.paren(
494        exp.or_(
495            exp.and_(
496                coalesce.is_(exp.null()).not_(copy=False),
497                expression.copy(),
498                copy=False,
499            ),
500            exp.and_(
501                coalesce.is_(exp.null()),
502                type(expression)(this=arg.copy(), expression=other.copy()),
503                copy=False,
504            ),
505            copy=False,
506        )
507    )
508
509
510CONCATS = (exp.Concat, exp.DPipe)
511SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
512
513
514def simplify_concat(expression):
515    """Reduces all groups that contain string literals by concatenating them."""
516    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
517        return expression
518
519    new_args = []
520    for is_string_group, group in itertools.groupby(
521        expression.expressions or expression.flatten(), lambda e: e.is_string
522    ):
523        if is_string_group:
524            new_args.append(exp.Literal.string("".join(string.name for string in group)))
525        else:
526            new_args.extend(group)
527
528    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
529    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
530    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
531
532
533# CROSS joins result in an empty table if the right table is empty.
534# So we can only simplify certain types of joins to CROSS.
535# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
536JOINS = {
537    ("", ""),
538    ("", "INNER"),
539    ("RIGHT", ""),
540    ("RIGHT", "OUTER"),
541}
542
543
544def remove_where_true(expression):
545    for where in expression.find_all(exp.Where):
546        if always_true(where.this):
547            where.parent.set("where", None)
548    for join in expression.find_all(exp.Join):
549        if (
550            always_true(join.args.get("on"))
551            and not join.args.get("using")
552            and not join.args.get("method")
553            and (join.side, join.kind) in JOINS
554        ):
555            join.set("on", None)
556            join.set("side", None)
557            join.set("kind", "CROSS")
558
559
560def always_true(expression):
561    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
562        expression, exp.Literal
563    )
564
565
566def is_complement(a, b):
567    return isinstance(b, exp.Not) and b.this == a
568
569
570def is_false(a: exp.Expression) -> bool:
571    return type(a) is exp.Boolean and not a.this
572
573
574def is_null(a: exp.Expression) -> bool:
575    return type(a) is exp.Null
576
577
578def eval_boolean(expression, a, b):
579    if isinstance(expression, (exp.EQ, exp.Is)):
580        return boolean_literal(a == b)
581    if isinstance(expression, exp.NEQ):
582        return boolean_literal(a != b)
583    if isinstance(expression, exp.GT):
584        return boolean_literal(a > b)
585    if isinstance(expression, exp.GTE):
586        return boolean_literal(a >= b)
587    if isinstance(expression, exp.LT):
588        return boolean_literal(a < b)
589    if isinstance(expression, exp.LTE):
590        return boolean_literal(a <= b)
591    return None
592
593
594def extract_date(cast):
595    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
596    # so in that case we can't extract the date.
597    try:
598        if cast.args["to"].this == exp.DataType.Type.DATE:
599            return datetime.date.fromisoformat(cast.name)
600        if cast.args["to"].this == exp.DataType.Type.DATETIME:
601            return datetime.datetime.fromisoformat(cast.name)
602    except ValueError:
603        return None
604
605
606def extract_interval(interval):
607    try:
608        from dateutil.relativedelta import relativedelta  # type: ignore
609    except ModuleNotFoundError:
610        return None
611
612    n = int(interval.name)
613    unit = interval.text("unit").lower()
614
615    if unit == "year":
616        return relativedelta(years=n)
617    if unit == "month":
618        return relativedelta(months=n)
619    if unit == "week":
620        return relativedelta(weeks=n)
621    if unit == "day":
622        return relativedelta(days=n)
623    return None
624
625
626def date_literal(date):
627    return exp.cast(
628        exp.Literal.string(date),
629        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
630    )
631
632
633def boolean_literal(condition):
634    return exp.true() if condition else exp.false()
635
636
637def _flat_simplify(expression, simplifier, root=True):
638    if root or not expression.same_parent:
639        operands = []
640        queue = deque(expression.flatten(unnest=False))
641        size = len(queue)
642
643        while queue:
644            a = queue.popleft()
645
646            for b in queue:
647                result = simplifier(expression, a, b)
648
649                if result and result is not expression:
650                    queue.remove(b)
651                    queue.appendleft(result)
652                    break
653            else:
654                operands.append(a)
655
656        if len(operands) < size:
657            return functools.reduce(
658                lambda a, b: expression.__class__(this=a, expression=b), operands
659            )
660    return expression
FINAL = 'final'
def simplify(expression):
16def simplify(expression):
17    """
18    Rewrite sqlglot AST to simplify expressions.
19
20    Example:
21        >>> import sqlglot
22        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
23        >>> simplify(expression).sql()
24        'TRUE'
25
26    Args:
27        expression (sqlglot.Expression): expression to simplify
28    Returns:
29        sqlglot.Expression: simplified expression
30    """
31
32    generate = cached_generator()
33
34    # group by expressions cannot be simplified, for example
35    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
36    # the projection must exactly match the group by key
37    for group in expression.find_all(exp.Group):
38        select = group.parent
39        groups = set(group.expressions)
40        group.meta[FINAL] = True
41
42        for e in select.selects:
43            for node, *_ in e.walk():
44                if node in groups:
45                    e.meta[FINAL] = True
46                    break
47
48        having = select.args.get("having")
49        if having:
50            for node, *_ in having.walk():
51                if node in groups:
52                    having.meta[FINAL] = True
53                    break
54
55    def _simplify(expression, root=True):
56        if expression.meta.get(FINAL):
57            return expression
58
59        # Pre-order transformations
60        node = expression
61        node = rewrite_between(node)
62        node = uniq_sort(node, generate, root)
63        node = absorb_and_eliminate(node, root)
64        node = simplify_concat(node)
65
66        exp.replace_children(node, lambda e: _simplify(e, False))
67
68        # Post-order transformations
69        node = simplify_not(node)
70        node = flatten(node)
71        node = simplify_connectors(node, root)
72        node = remove_compliments(node, root)
73        node = simplify_coalesce(node)
74        node.parent = expression.parent
75        node = simplify_literals(node, root)
76        node = simplify_parens(node)
77
78        if root:
79            expression.replace(node)
80
81        return node
82
83    expression = while_changing(expression, _simplify)
84    remove_where_true(expression)
85    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
88def rewrite_between(expression: exp.Expression) -> exp.Expression:
89    """Rewrite x between y and z to x >= y AND x <= z.
90
91    This is done because comparison simplification is only done on lt/lte/gt/gte.
92    """
93    if isinstance(expression, exp.Between):
94        return exp.and_(
95            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
96            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
97            copy=False,
98        )
99    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
102def simplify_not(expression):
103    """
104    Demorgan's Law
105    NOT (x OR y) -> NOT x AND NOT y
106    NOT (x AND y) -> NOT x OR NOT y
107    """
108    if isinstance(expression, exp.Not):
109        if is_null(expression.this):
110            return exp.null()
111        if isinstance(expression.this, exp.Paren):
112            condition = expression.this.unnest()
113            if isinstance(condition, exp.And):
114                return exp.or_(
115                    exp.not_(condition.left, copy=False),
116                    exp.not_(condition.right, copy=False),
117                    copy=False,
118                )
119            if isinstance(condition, exp.Or):
120                return exp.and_(
121                    exp.not_(condition.left, copy=False),
122                    exp.not_(condition.right, copy=False),
123                    copy=False,
124                )
125            if is_null(condition):
126                return exp.null()
127        if always_true(expression.this):
128            return exp.false()
129        if is_false(expression.this):
130            return exp.true()
131        if isinstance(expression.this, exp.Not):
132            # double negation
133            # NOT NOT x -> x
134            return expression.this.this
135    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
138def flatten(expression):
139    """
140    A AND (B AND C) -> A AND B AND C
141    A OR (B OR C) -> A OR B OR C
142    """
143    if isinstance(expression, exp.Connector):
144        for node in expression.args.values():
145            child = node.unnest()
146            if isinstance(child, expression.__class__):
147                node.replace(child)
148    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
151def simplify_connectors(expression, root=True):
152    def _simplify_connectors(expression, left, right):
153        if left == right:
154            return left
155        if isinstance(expression, exp.And):
156            if is_false(left) or is_false(right):
157                return exp.false()
158            if is_null(left) or is_null(right):
159                return exp.null()
160            if always_true(left) and always_true(right):
161                return exp.true()
162            if always_true(left):
163                return right
164            if always_true(right):
165                return left
166            return _simplify_comparison(expression, left, right)
167        elif isinstance(expression, exp.Or):
168            if always_true(left) or always_true(right):
169                return exp.true()
170            if is_false(left) and is_false(right):
171                return exp.false()
172            if (
173                (is_null(left) and is_null(right))
174                or (is_null(left) and is_false(right))
175                or (is_false(left) and is_null(right))
176            ):
177                return exp.null()
178            if is_false(left):
179                return right
180            if is_false(right):
181                return left
182            return _simplify_comparison(expression, left, right, or_=True)
183
184    if isinstance(expression, exp.Connector):
185        return _flat_simplify(expression, _simplify_connectors, root)
186    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
INVERSE_COMPARISONS = {<class 'sqlglot.expressions.LT'>: <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GT'>: <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>: <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.GTE'>: <class 'sqlglot.expressions.LTE'>}
def remove_compliments(expression, root=True):
269def remove_compliments(expression, root=True):
270    """
271    Removing compliments.
272
273    A AND NOT A -> FALSE
274    A OR NOT A -> TRUE
275    """
276    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
277        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
278
279        for a, b in itertools.permutations(expression.flatten(), 2):
280            if is_complement(a, b):
281                return compliment
282    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
285def uniq_sort(expression, generate, root=True):
286    """
287    Uniq and sort a connector.
288
289    C AND A AND B AND B -> A AND B AND C
290    """
291    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
292        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
293        flattened = tuple(expression.flatten())
294        deduped = {generate(e): e for e in flattened}
295        arr = tuple(deduped.items())
296
297        # check if the operands are already sorted, if not sort them
298        # A AND C AND B -> A AND B AND C
299        for i, (sql, e) in enumerate(arr[1:]):
300            if sql < arr[i][0]:
301                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
302                break
303        else:
304            # we didn't have to sort but maybe we need to dedup
305            if len(deduped) < len(flattened):
306                expression = result_func(*deduped.values(), copy=False)
307
308    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
311def absorb_and_eliminate(expression, root=True):
312    """
313    absorption:
314        A AND (A OR B) -> A
315        A OR (A AND B) -> A
316        A AND (NOT A OR B) -> A AND B
317        A OR (NOT A AND B) -> A OR B
318    elimination:
319        (A AND B) OR (A AND NOT B) -> A
320        (A OR B) AND (A OR NOT B) -> A
321    """
322    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
323        kind = exp.Or if isinstance(expression, exp.And) else exp.And
324
325        for a, b in itertools.permutations(expression.flatten(), 2):
326            if isinstance(a, kind):
327                aa, ab = a.unnest_operands()
328
329                # absorb
330                if is_complement(b, aa):
331                    aa.replace(exp.true() if kind == exp.And else exp.false())
332                elif is_complement(b, ab):
333                    ab.replace(exp.true() if kind == exp.And else exp.false())
334                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
335                    a.replace(exp.false() if kind == exp.And else exp.true())
336                elif isinstance(b, kind):
337                    # eliminate
338                    rhs = b.unnest_operands()
339                    ba, bb = rhs
340
341                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
342                        a.replace(aa)
343                        b.replace(aa)
344                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
345                        a.replace(ab)
346                        b.replace(ab)
347
348    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_literals(expression, root=True):
351def simplify_literals(expression, root=True):
352    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
353        return _flat_simplify(expression, _simplify_binary, root)
354
355    if isinstance(expression, exp.Neg):
356        this = expression.this
357        if this.is_number:
358            value = this.name
359            if value[0] == "-":
360                return exp.Literal.number(value[1:])
361            return exp.Literal.number(f"-{value}")
362
363    return expression
def simplify_parens(expression):
426def simplify_parens(expression):
427    if not isinstance(expression, exp.Paren):
428        return expression
429
430    this = expression.this
431    parent = expression.parent
432
433    if not isinstance(this, exp.Select) and (
434        not isinstance(parent, (exp.Condition, exp.Binary))
435        or isinstance(parent, exp.Paren)
436        or not isinstance(this, exp.Binary)
437        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
438        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
439        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
440        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
441    ):
442        return this
443    return expression
def simplify_coalesce(expression):
453def simplify_coalesce(expression):
454    # COALESCE(x) -> x
455    if (
456        isinstance(expression, exp.Coalesce)
457        and not expression.expressions
458        # COALESCE is also used as a Spark partitioning hint
459        and not isinstance(expression.parent, exp.Hint)
460    ):
461        return expression.this
462
463    if not isinstance(expression, COMPARISONS):
464        return expression
465
466    if isinstance(expression.left, exp.Coalesce):
467        coalesce = expression.left
468        other = expression.right
469    elif isinstance(expression.right, exp.Coalesce):
470        coalesce = expression.right
471        other = expression.left
472    else:
473        return expression
474
475    # This transformation is valid for non-constants,
476    # but it really only does anything if they are both constants.
477    if not isinstance(other, CONSTANTS):
478        return expression
479
480    # Find the first constant arg
481    for arg_index, arg in enumerate(coalesce.expressions):
482        if isinstance(arg, CONSTANTS):
483            break
484    else:
485        return expression
486
487    coalesce.set("expressions", coalesce.expressions[:arg_index])
488
489    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
490    # since we already remove COALESCE at the top of this function.
491    coalesce = coalesce if coalesce.expressions else coalesce.this
492
493    # This expression is more complex than when we started, but it will get simplified further
494    return exp.paren(
495        exp.or_(
496            exp.and_(
497                coalesce.is_(exp.null()).not_(copy=False),
498                expression.copy(),
499                copy=False,
500            ),
501            exp.and_(
502                coalesce.is_(exp.null()),
503                type(expression)(this=arg.copy(), expression=other.copy()),
504                copy=False,
505            ),
506            copy=False,
507        )
508    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
SAFE_CONCATS = (<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.SafeDPipe'>)
def simplify_concat(expression):
515def simplify_concat(expression):
516    """Reduces all groups that contain string literals by concatenating them."""
517    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
518        return expression
519
520    new_args = []
521    for is_string_group, group in itertools.groupby(
522        expression.expressions or expression.flatten(), lambda e: e.is_string
523    ):
524        if is_string_group:
525            new_args.append(exp.Literal.string("".join(string.name for string in group)))
526        else:
527            new_args.extend(group)
528
529    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
530    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
531    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)

Reduces all groups that contain string literals by concatenating them.

JOINS = {('RIGHT', 'OUTER'), ('', 'INNER'), ('RIGHT', ''), ('', '')}
def remove_where_true(expression):
545def remove_where_true(expression):
546    for where in expression.find_all(exp.Where):
547        if always_true(where.this):
548            where.parent.set("where", None)
549    for join in expression.find_all(exp.Join):
550        if (
551            always_true(join.args.get("on"))
552            and not join.args.get("using")
553            and not join.args.get("method")
554            and (join.side, join.kind) in JOINS
555        ):
556            join.set("on", None)
557            join.set("side", None)
558            join.set("kind", "CROSS")
def always_true(expression):
561def always_true(expression):
562    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
563        expression, exp.Literal
564    )
def is_complement(a, b):
567def is_complement(a, b):
568    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
571def is_false(a: exp.Expression) -> bool:
572    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
575def is_null(a: exp.Expression) -> bool:
576    return type(a) is exp.Null
def eval_boolean(expression, a, b):
579def eval_boolean(expression, a, b):
580    if isinstance(expression, (exp.EQ, exp.Is)):
581        return boolean_literal(a == b)
582    if isinstance(expression, exp.NEQ):
583        return boolean_literal(a != b)
584    if isinstance(expression, exp.GT):
585        return boolean_literal(a > b)
586    if isinstance(expression, exp.GTE):
587        return boolean_literal(a >= b)
588    if isinstance(expression, exp.LT):
589        return boolean_literal(a < b)
590    if isinstance(expression, exp.LTE):
591        return boolean_literal(a <= b)
592    return None
def extract_date(cast):
595def extract_date(cast):
596    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
597    # so in that case we can't extract the date.
598    try:
599        if cast.args["to"].this == exp.DataType.Type.DATE:
600            return datetime.date.fromisoformat(cast.name)
601        if cast.args["to"].this == exp.DataType.Type.DATETIME:
602            return datetime.datetime.fromisoformat(cast.name)
603    except ValueError:
604        return None
def extract_interval(interval):
607def extract_interval(interval):
608    try:
609        from dateutil.relativedelta import relativedelta  # type: ignore
610    except ModuleNotFoundError:
611        return None
612
613    n = int(interval.name)
614    unit = interval.text("unit").lower()
615
616    if unit == "year":
617        return relativedelta(years=n)
618    if unit == "month":
619        return relativedelta(months=n)
620    if unit == "week":
621        return relativedelta(weeks=n)
622    if unit == "day":
623        return relativedelta(days=n)
624    return None
def date_literal(date):
627def date_literal(date):
628    return exp.cast(
629        exp.Literal.string(date),
630        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
631    )
def boolean_literal(condition):
634def boolean_literal(condition):
635    return exp.true() if condition else exp.false()