Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4from collections import deque
  5from decimal import Decimal
  6
  7from sqlglot import exp
  8from sqlglot.generator import cached_generator
  9from sqlglot.helper import first, while_changing
 10
 11# Final means that an expression should not be simplified
 12FINAL = "final"
 13
 14
 15def simplify(expression):
 16    """
 17    Rewrite sqlglot AST to simplify expressions.
 18
 19    Example:
 20        >>> import sqlglot
 21        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 22        >>> simplify(expression).sql()
 23        'TRUE'
 24
 25    Args:
 26        expression (sqlglot.Expression): expression to simplify
 27    Returns:
 28        sqlglot.Expression: simplified expression
 29    """
 30
 31    generate = cached_generator()
 32
 33    # group by expressions cannot be simplified, for example
 34    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 35    # the projection must exactly match the group by key
 36    for group in expression.find_all(exp.Group):
 37        select = group.parent
 38        groups = set(group.expressions)
 39        group.meta[FINAL] = True
 40
 41        for e in select.selects:
 42            for node, *_ in e.walk():
 43                if node in groups:
 44                    e.meta[FINAL] = True
 45                    break
 46
 47        having = select.args.get("having")
 48        if having:
 49            for node, *_ in having.walk():
 50                if node in groups:
 51                    having.meta[FINAL] = True
 52                    break
 53
 54    def _simplify(expression, root=True):
 55        if expression.meta.get(FINAL):
 56            return expression
 57        node = expression
 58        node = rewrite_between(node)
 59        node = uniq_sort(node, generate, root)
 60        node = absorb_and_eliminate(node, root)
 61        exp.replace_children(node, lambda e: _simplify(e, False))
 62        node = simplify_not(node)
 63        node = flatten(node)
 64        node = simplify_connectors(node, root)
 65        node = remove_compliments(node, root)
 66        node.parent = expression.parent
 67        node = simplify_literals(node, root)
 68        node = simplify_parens(node)
 69        if root:
 70            expression.replace(node)
 71        return node
 72
 73    expression = while_changing(expression, _simplify)
 74    remove_where_true(expression)
 75    return expression
 76
 77
 78def rewrite_between(expression: exp.Expression) -> exp.Expression:
 79    """Rewrite x between y and z to x >= y AND x <= z.
 80
 81    This is done because comparison simplification is only done on lt/lte/gt/gte.
 82    """
 83    if isinstance(expression, exp.Between):
 84        return exp.and_(
 85            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 86            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 87            copy=False,
 88        )
 89    return expression
 90
 91
 92def simplify_not(expression):
 93    """
 94    Demorgan's Law
 95    NOT (x OR y) -> NOT x AND NOT y
 96    NOT (x AND y) -> NOT x OR NOT y
 97    """
 98    if isinstance(expression, exp.Not):
 99        if is_null(expression.this):
100            return exp.null()
101        if isinstance(expression.this, exp.Paren):
102            condition = expression.this.unnest()
103            if isinstance(condition, exp.And):
104                return exp.or_(
105                    exp.not_(condition.left, copy=False),
106                    exp.not_(condition.right, copy=False),
107                    copy=False,
108                )
109            if isinstance(condition, exp.Or):
110                return exp.and_(
111                    exp.not_(condition.left, copy=False),
112                    exp.not_(condition.right, copy=False),
113                    copy=False,
114                )
115            if is_null(condition):
116                return exp.null()
117        if always_true(expression.this):
118            return exp.false()
119        if is_false(expression.this):
120            return exp.true()
121        if isinstance(expression.this, exp.Not):
122            # double negation
123            # NOT NOT x -> x
124            return expression.this.this
125    return expression
126
127
128def flatten(expression):
129    """
130    A AND (B AND C) -> A AND B AND C
131    A OR (B OR C) -> A OR B OR C
132    """
133    if isinstance(expression, exp.Connector):
134        for node in expression.args.values():
135            child = node.unnest()
136            if isinstance(child, expression.__class__):
137                node.replace(child)
138    return expression
139
140
141def simplify_connectors(expression, root=True):
142    def _simplify_connectors(expression, left, right):
143        if left == right:
144            return left
145        if isinstance(expression, exp.And):
146            if is_false(left) or is_false(right):
147                return exp.false()
148            if is_null(left) or is_null(right):
149                return exp.null()
150            if always_true(left) and always_true(right):
151                return exp.true()
152            if always_true(left):
153                return right
154            if always_true(right):
155                return left
156            return _simplify_comparison(expression, left, right)
157        elif isinstance(expression, exp.Or):
158            if always_true(left) or always_true(right):
159                return exp.true()
160            if is_false(left) and is_false(right):
161                return exp.false()
162            if (
163                (is_null(left) and is_null(right))
164                or (is_null(left) and is_false(right))
165                or (is_false(left) and is_null(right))
166            ):
167                return exp.null()
168            if is_false(left):
169                return right
170            if is_false(right):
171                return left
172            return _simplify_comparison(expression, left, right, or_=True)
173
174    if isinstance(expression, exp.Connector):
175        return _flat_simplify(expression, _simplify_connectors, root)
176    return expression
177
178
179LT_LTE = (exp.LT, exp.LTE)
180GT_GTE = (exp.GT, exp.GTE)
181
182COMPARISONS = (
183    *LT_LTE,
184    *GT_GTE,
185    exp.EQ,
186    exp.NEQ,
187)
188
189INVERSE_COMPARISONS = {
190    exp.LT: exp.GT,
191    exp.GT: exp.LT,
192    exp.LTE: exp.GTE,
193    exp.GTE: exp.LTE,
194}
195
196
197def _simplify_comparison(expression, left, right, or_=False):
198    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
199        ll, lr = left.args.values()
200        rl, rr = right.args.values()
201
202        largs = {ll, lr}
203        rargs = {rl, rr}
204
205        matching = largs & rargs
206        columns = {m for m in matching if isinstance(m, exp.Column)}
207
208        if matching and columns:
209            try:
210                l = first(largs - columns)
211                r = first(rargs - columns)
212            except StopIteration:
213                return expression
214
215            # make sure the comparison is always of the form x > 1 instead of 1 < x
216            if left.__class__ in INVERSE_COMPARISONS and l == ll:
217                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
218            if right.__class__ in INVERSE_COMPARISONS and r == rl:
219                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
220
221            if l.is_number and r.is_number:
222                l = float(l.name)
223                r = float(r.name)
224            elif l.is_string and r.is_string:
225                l = l.name
226                r = r.name
227            else:
228                return None
229
230            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
231                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
232                    return left if (av > bv if or_ else av <= bv) else right
233                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
234                    return left if (av < bv if or_ else av >= bv) else right
235
236                # we can't ever shortcut to true because the column could be null
237                if not or_:
238                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
239                        if av <= bv:
240                            return exp.false()
241                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
242                        if av >= bv:
243                            return exp.false()
244                    elif isinstance(a, exp.EQ):
245                        if isinstance(b, exp.LT):
246                            return exp.false() if av >= bv else a
247                        if isinstance(b, exp.LTE):
248                            return exp.false() if av > bv else a
249                        if isinstance(b, exp.GT):
250                            return exp.false() if av <= bv else a
251                        if isinstance(b, exp.GTE):
252                            return exp.false() if av < bv else a
253                        if isinstance(b, exp.NEQ):
254                            return exp.false() if av == bv else a
255    return None
256
257
258def remove_compliments(expression, root=True):
259    """
260    Removing compliments.
261
262    A AND NOT A -> FALSE
263    A OR NOT A -> TRUE
264    """
265    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
266        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
267
268        for a, b in itertools.permutations(expression.flatten(), 2):
269            if is_complement(a, b):
270                return compliment
271    return expression
272
273
274def uniq_sort(expression, generate, root=True):
275    """
276    Uniq and sort a connector.
277
278    C AND A AND B AND B -> A AND B AND C
279    """
280    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
281        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
282        flattened = tuple(expression.flatten())
283        deduped = {generate(e): e for e in flattened}
284        arr = tuple(deduped.items())
285
286        # check if the operands are already sorted, if not sort them
287        # A AND C AND B -> A AND B AND C
288        for i, (sql, e) in enumerate(arr[1:]):
289            if sql < arr[i][0]:
290                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
291                break
292        else:
293            # we didn't have to sort but maybe we need to dedup
294            if len(deduped) < len(flattened):
295                expression = result_func(*deduped.values(), copy=False)
296
297    return expression
298
299
300def absorb_and_eliminate(expression, root=True):
301    """
302    absorption:
303        A AND (A OR B) -> A
304        A OR (A AND B) -> A
305        A AND (NOT A OR B) -> A AND B
306        A OR (NOT A AND B) -> A OR B
307    elimination:
308        (A AND B) OR (A AND NOT B) -> A
309        (A OR B) AND (A OR NOT B) -> A
310    """
311    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
312        kind = exp.Or if isinstance(expression, exp.And) else exp.And
313
314        for a, b in itertools.permutations(expression.flatten(), 2):
315            if isinstance(a, kind):
316                aa, ab = a.unnest_operands()
317
318                # absorb
319                if is_complement(b, aa):
320                    aa.replace(exp.true() if kind == exp.And else exp.false())
321                elif is_complement(b, ab):
322                    ab.replace(exp.true() if kind == exp.And else exp.false())
323                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
324                    a.replace(exp.false() if kind == exp.And else exp.true())
325                elif isinstance(b, kind):
326                    # eliminate
327                    rhs = b.unnest_operands()
328                    ba, bb = rhs
329
330                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
331                        a.replace(aa)
332                        b.replace(aa)
333                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
334                        a.replace(ab)
335                        b.replace(ab)
336
337    return expression
338
339
340def simplify_literals(expression, root=True):
341    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
342        return _flat_simplify(expression, _simplify_binary, root)
343    elif isinstance(expression, exp.Neg):
344        this = expression.this
345        if this.is_number:
346            value = this.name
347            if value[0] == "-":
348                return exp.Literal.number(value[1:])
349            return exp.Literal.number(f"-{value}")
350
351    return expression
352
353
354def _simplify_binary(expression, a, b):
355    if isinstance(expression, exp.Is):
356        if isinstance(b, exp.Not):
357            c = b.this
358            not_ = True
359        else:
360            c = b
361            not_ = False
362
363        if is_null(c):
364            if isinstance(a, exp.Literal):
365                return exp.true() if not_ else exp.false()
366            if is_null(a):
367                return exp.false() if not_ else exp.true()
368    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
369        return None
370    elif is_null(a) or is_null(b):
371        return exp.null()
372
373    if a.is_number and b.is_number:
374        a = int(a.name) if a.is_int else Decimal(a.name)
375        b = int(b.name) if b.is_int else Decimal(b.name)
376
377        if isinstance(expression, exp.Add):
378            return exp.Literal.number(a + b)
379        if isinstance(expression, exp.Sub):
380            return exp.Literal.number(a - b)
381        if isinstance(expression, exp.Mul):
382            return exp.Literal.number(a * b)
383        if isinstance(expression, exp.Div):
384            # engines have differing int div behavior so intdiv is not safe
385            if isinstance(a, int) and isinstance(b, int):
386                return None
387            return exp.Literal.number(a / b)
388
389        boolean = eval_boolean(expression, a, b)
390
391        if boolean:
392            return boolean
393    elif a.is_string and b.is_string:
394        boolean = eval_boolean(expression, a.this, b.this)
395
396        if boolean:
397            return boolean
398    elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
399        a, b = extract_date(a), extract_interval(b)
400        if a and b:
401            if isinstance(expression, exp.Add):
402                return date_literal(a + b)
403            if isinstance(expression, exp.Sub):
404                return date_literal(a - b)
405    elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
406        a, b = extract_interval(a), extract_date(b)
407        # you cannot subtract a date from an interval
408        if a and b and isinstance(expression, exp.Add):
409            return date_literal(a + b)
410
411    return None
412
413
414def simplify_parens(expression):
415    if not isinstance(expression, exp.Paren):
416        return expression
417
418    this = expression.this
419    parent = expression.parent
420
421    if not isinstance(this, exp.Select) and (
422        not isinstance(parent, (exp.Condition, exp.Binary))
423        or isinstance(this, exp.Predicate)
424        or not isinstance(this, exp.Binary)
425        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
426        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
427        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
428    ):
429        return expression.this
430    return expression
431
432
433def remove_where_true(expression):
434    for where in expression.find_all(exp.Where):
435        if always_true(where.this):
436            where.parent.set("where", None)
437    for join in expression.find_all(exp.Join):
438        if (
439            always_true(join.args.get("on"))
440            and not join.args.get("using")
441            and not join.args.get("method")
442        ):
443            join.set("on", None)
444            join.set("side", None)
445            join.set("kind", "CROSS")
446
447
448def always_true(expression):
449    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
450        expression, exp.Literal
451    )
452
453
454def is_complement(a, b):
455    return isinstance(b, exp.Not) and b.this == a
456
457
458def is_false(a: exp.Expression) -> bool:
459    return type(a) is exp.Boolean and not a.this
460
461
462def is_null(a: exp.Expression) -> bool:
463    return type(a) is exp.Null
464
465
466def eval_boolean(expression, a, b):
467    if isinstance(expression, (exp.EQ, exp.Is)):
468        return boolean_literal(a == b)
469    if isinstance(expression, exp.NEQ):
470        return boolean_literal(a != b)
471    if isinstance(expression, exp.GT):
472        return boolean_literal(a > b)
473    if isinstance(expression, exp.GTE):
474        return boolean_literal(a >= b)
475    if isinstance(expression, exp.LT):
476        return boolean_literal(a < b)
477    if isinstance(expression, exp.LTE):
478        return boolean_literal(a <= b)
479    return None
480
481
482def extract_date(cast):
483    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
484    # so in that case we can't extract the date.
485    try:
486        if cast.args["to"].this == exp.DataType.Type.DATE:
487            return datetime.date.fromisoformat(cast.name)
488        if cast.args["to"].this == exp.DataType.Type.DATETIME:
489            return datetime.datetime.fromisoformat(cast.name)
490    except ValueError:
491        return None
492
493
494def extract_interval(interval):
495    try:
496        from dateutil.relativedelta import relativedelta  # type: ignore
497    except ModuleNotFoundError:
498        return None
499
500    n = int(interval.name)
501    unit = interval.text("unit").lower()
502
503    if unit == "year":
504        return relativedelta(years=n)
505    if unit == "month":
506        return relativedelta(months=n)
507    if unit == "week":
508        return relativedelta(weeks=n)
509    if unit == "day":
510        return relativedelta(days=n)
511    return None
512
513
514def date_literal(date):
515    return exp.cast(
516        exp.Literal.string(date),
517        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
518    )
519
520
521def boolean_literal(condition):
522    return exp.true() if condition else exp.false()
523
524
525def _flat_simplify(expression, simplifier, root=True):
526    if root or not expression.same_parent:
527        operands = []
528        queue = deque(expression.flatten(unnest=False))
529        size = len(queue)
530
531        while queue:
532            a = queue.popleft()
533
534            for b in queue:
535                result = simplifier(expression, a, b)
536
537                if result:
538                    queue.remove(b)
539                    queue.appendleft(result)
540                    break
541            else:
542                operands.append(a)
543
544        if len(operands) < size:
545            return functools.reduce(
546                lambda a, b: expression.__class__(this=a, expression=b), operands
547            )
548    return expression
FINAL = 'final'
def simplify(expression):
16def simplify(expression):
17    """
18    Rewrite sqlglot AST to simplify expressions.
19
20    Example:
21        >>> import sqlglot
22        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
23        >>> simplify(expression).sql()
24        'TRUE'
25
26    Args:
27        expression (sqlglot.Expression): expression to simplify
28    Returns:
29        sqlglot.Expression: simplified expression
30    """
31
32    generate = cached_generator()
33
34    # group by expressions cannot be simplified, for example
35    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
36    # the projection must exactly match the group by key
37    for group in expression.find_all(exp.Group):
38        select = group.parent
39        groups = set(group.expressions)
40        group.meta[FINAL] = True
41
42        for e in select.selects:
43            for node, *_ in e.walk():
44                if node in groups:
45                    e.meta[FINAL] = True
46                    break
47
48        having = select.args.get("having")
49        if having:
50            for node, *_ in having.walk():
51                if node in groups:
52                    having.meta[FINAL] = True
53                    break
54
55    def _simplify(expression, root=True):
56        if expression.meta.get(FINAL):
57            return expression
58        node = expression
59        node = rewrite_between(node)
60        node = uniq_sort(node, generate, root)
61        node = absorb_and_eliminate(node, root)
62        exp.replace_children(node, lambda e: _simplify(e, False))
63        node = simplify_not(node)
64        node = flatten(node)
65        node = simplify_connectors(node, root)
66        node = remove_compliments(node, root)
67        node.parent = expression.parent
68        node = simplify_literals(node, root)
69        node = simplify_parens(node)
70        if root:
71            expression.replace(node)
72        return node
73
74    expression = while_changing(expression, _simplify)
75    remove_where_true(expression)
76    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
79def rewrite_between(expression: exp.Expression) -> exp.Expression:
80    """Rewrite x between y and z to x >= y AND x <= z.
81
82    This is done because comparison simplification is only done on lt/lte/gt/gte.
83    """
84    if isinstance(expression, exp.Between):
85        return exp.and_(
86            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
87            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
88            copy=False,
89        )
90    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
 93def simplify_not(expression):
 94    """
 95    Demorgan's Law
 96    NOT (x OR y) -> NOT x AND NOT y
 97    NOT (x AND y) -> NOT x OR NOT y
 98    """
 99    if isinstance(expression, exp.Not):
100        if is_null(expression.this):
101            return exp.null()
102        if isinstance(expression.this, exp.Paren):
103            condition = expression.this.unnest()
104            if isinstance(condition, exp.And):
105                return exp.or_(
106                    exp.not_(condition.left, copy=False),
107                    exp.not_(condition.right, copy=False),
108                    copy=False,
109                )
110            if isinstance(condition, exp.Or):
111                return exp.and_(
112                    exp.not_(condition.left, copy=False),
113                    exp.not_(condition.right, copy=False),
114                    copy=False,
115                )
116            if is_null(condition):
117                return exp.null()
118        if always_true(expression.this):
119            return exp.false()
120        if is_false(expression.this):
121            return exp.true()
122        if isinstance(expression.this, exp.Not):
123            # double negation
124            # NOT NOT x -> x
125            return expression.this.this
126    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
129def flatten(expression):
130    """
131    A AND (B AND C) -> A AND B AND C
132    A OR (B OR C) -> A OR B OR C
133    """
134    if isinstance(expression, exp.Connector):
135        for node in expression.args.values():
136            child = node.unnest()
137            if isinstance(child, expression.__class__):
138                node.replace(child)
139    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
142def simplify_connectors(expression, root=True):
143    def _simplify_connectors(expression, left, right):
144        if left == right:
145            return left
146        if isinstance(expression, exp.And):
147            if is_false(left) or is_false(right):
148                return exp.false()
149            if is_null(left) or is_null(right):
150                return exp.null()
151            if always_true(left) and always_true(right):
152                return exp.true()
153            if always_true(left):
154                return right
155            if always_true(right):
156                return left
157            return _simplify_comparison(expression, left, right)
158        elif isinstance(expression, exp.Or):
159            if always_true(left) or always_true(right):
160                return exp.true()
161            if is_false(left) and is_false(right):
162                return exp.false()
163            if (
164                (is_null(left) and is_null(right))
165                or (is_null(left) and is_false(right))
166                or (is_false(left) and is_null(right))
167            ):
168                return exp.null()
169            if is_false(left):
170                return right
171            if is_false(right):
172                return left
173            return _simplify_comparison(expression, left, right, or_=True)
174
175    if isinstance(expression, exp.Connector):
176        return _flat_simplify(expression, _simplify_connectors, root)
177    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
INVERSE_COMPARISONS = {<class 'sqlglot.expressions.LT'>: <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GT'>: <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>: <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.GTE'>: <class 'sqlglot.expressions.LTE'>}
def remove_compliments(expression, root=True):
259def remove_compliments(expression, root=True):
260    """
261    Removing compliments.
262
263    A AND NOT A -> FALSE
264    A OR NOT A -> TRUE
265    """
266    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
267        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
268
269        for a, b in itertools.permutations(expression.flatten(), 2):
270            if is_complement(a, b):
271                return compliment
272    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
275def uniq_sort(expression, generate, root=True):
276    """
277    Uniq and sort a connector.
278
279    C AND A AND B AND B -> A AND B AND C
280    """
281    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
282        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
283        flattened = tuple(expression.flatten())
284        deduped = {generate(e): e for e in flattened}
285        arr = tuple(deduped.items())
286
287        # check if the operands are already sorted, if not sort them
288        # A AND C AND B -> A AND B AND C
289        for i, (sql, e) in enumerate(arr[1:]):
290            if sql < arr[i][0]:
291                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
292                break
293        else:
294            # we didn't have to sort but maybe we need to dedup
295            if len(deduped) < len(flattened):
296                expression = result_func(*deduped.values(), copy=False)
297
298    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
301def absorb_and_eliminate(expression, root=True):
302    """
303    absorption:
304        A AND (A OR B) -> A
305        A OR (A AND B) -> A
306        A AND (NOT A OR B) -> A AND B
307        A OR (NOT A AND B) -> A OR B
308    elimination:
309        (A AND B) OR (A AND NOT B) -> A
310        (A OR B) AND (A OR NOT B) -> A
311    """
312    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
313        kind = exp.Or if isinstance(expression, exp.And) else exp.And
314
315        for a, b in itertools.permutations(expression.flatten(), 2):
316            if isinstance(a, kind):
317                aa, ab = a.unnest_operands()
318
319                # absorb
320                if is_complement(b, aa):
321                    aa.replace(exp.true() if kind == exp.And else exp.false())
322                elif is_complement(b, ab):
323                    ab.replace(exp.true() if kind == exp.And else exp.false())
324                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
325                    a.replace(exp.false() if kind == exp.And else exp.true())
326                elif isinstance(b, kind):
327                    # eliminate
328                    rhs = b.unnest_operands()
329                    ba, bb = rhs
330
331                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
332                        a.replace(aa)
333                        b.replace(aa)
334                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
335                        a.replace(ab)
336                        b.replace(ab)
337
338    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_literals(expression, root=True):
341def simplify_literals(expression, root=True):
342    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
343        return _flat_simplify(expression, _simplify_binary, root)
344    elif isinstance(expression, exp.Neg):
345        this = expression.this
346        if this.is_number:
347            value = this.name
348            if value[0] == "-":
349                return exp.Literal.number(value[1:])
350            return exp.Literal.number(f"-{value}")
351
352    return expression
def simplify_parens(expression):
415def simplify_parens(expression):
416    if not isinstance(expression, exp.Paren):
417        return expression
418
419    this = expression.this
420    parent = expression.parent
421
422    if not isinstance(this, exp.Select) and (
423        not isinstance(parent, (exp.Condition, exp.Binary))
424        or isinstance(this, exp.Predicate)
425        or not isinstance(this, exp.Binary)
426        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
427        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
428        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
429    ):
430        return expression.this
431    return expression
def remove_where_true(expression):
434def remove_where_true(expression):
435    for where in expression.find_all(exp.Where):
436        if always_true(where.this):
437            where.parent.set("where", None)
438    for join in expression.find_all(exp.Join):
439        if (
440            always_true(join.args.get("on"))
441            and not join.args.get("using")
442            and not join.args.get("method")
443        ):
444            join.set("on", None)
445            join.set("side", None)
446            join.set("kind", "CROSS")
def always_true(expression):
449def always_true(expression):
450    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
451        expression, exp.Literal
452    )
def is_complement(a, b):
455def is_complement(a, b):
456    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
459def is_false(a: exp.Expression) -> bool:
460    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
463def is_null(a: exp.Expression) -> bool:
464    return type(a) is exp.Null
def eval_boolean(expression, a, b):
467def eval_boolean(expression, a, b):
468    if isinstance(expression, (exp.EQ, exp.Is)):
469        return boolean_literal(a == b)
470    if isinstance(expression, exp.NEQ):
471        return boolean_literal(a != b)
472    if isinstance(expression, exp.GT):
473        return boolean_literal(a > b)
474    if isinstance(expression, exp.GTE):
475        return boolean_literal(a >= b)
476    if isinstance(expression, exp.LT):
477        return boolean_literal(a < b)
478    if isinstance(expression, exp.LTE):
479        return boolean_literal(a <= b)
480    return None
def extract_date(cast):
483def extract_date(cast):
484    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
485    # so in that case we can't extract the date.
486    try:
487        if cast.args["to"].this == exp.DataType.Type.DATE:
488            return datetime.date.fromisoformat(cast.name)
489        if cast.args["to"].this == exp.DataType.Type.DATETIME:
490            return datetime.datetime.fromisoformat(cast.name)
491    except ValueError:
492        return None
def extract_interval(interval):
495def extract_interval(interval):
496    try:
497        from dateutil.relativedelta import relativedelta  # type: ignore
498    except ModuleNotFoundError:
499        return None
500
501    n = int(interval.name)
502    unit = interval.text("unit").lower()
503
504    if unit == "year":
505        return relativedelta(years=n)
506    if unit == "month":
507        return relativedelta(months=n)
508    if unit == "week":
509        return relativedelta(weeks=n)
510    if unit == "day":
511        return relativedelta(days=n)
512    return None
def date_literal(date):
515def date_literal(date):
516    return exp.cast(
517        exp.Literal.string(date),
518        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
519    )
def boolean_literal(condition):
522def boolean_literal(condition):
523    return exp.true() if condition else exp.false()