Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4from collections import deque
  5from decimal import Decimal
  6
  7from sqlglot import exp
  8from sqlglot.generator import cached_generator
  9from sqlglot.helper import first, while_changing
 10
 11
 12def simplify(expression):
 13    """
 14    Rewrite sqlglot AST to simplify expressions.
 15
 16    Example:
 17        >>> import sqlglot
 18        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 19        >>> simplify(expression).sql()
 20        'TRUE'
 21
 22    Args:
 23        expression (sqlglot.Expression): expression to simplify
 24    Returns:
 25        sqlglot.Expression: simplified expression
 26    """
 27
 28    generate = cached_generator()
 29
 30    def _simplify(expression, root=True):
 31        node = expression
 32        node = rewrite_between(node)
 33        node = uniq_sort(node, generate, root)
 34        node = absorb_and_eliminate(node, root)
 35        exp.replace_children(node, lambda e: _simplify(e, False))
 36        node = simplify_not(node)
 37        node = flatten(node)
 38        node = simplify_connectors(node, root)
 39        node = remove_compliments(node, root)
 40        node.parent = expression.parent
 41        node = simplify_literals(node, root)
 42        node = simplify_parens(node)
 43        if root:
 44            expression.replace(node)
 45        return node
 46
 47    expression = while_changing(expression, _simplify)
 48    remove_where_true(expression)
 49    return expression
 50
 51
 52def rewrite_between(expression: exp.Expression) -> exp.Expression:
 53    """Rewrite x between y and z to x >= y AND x <= z.
 54
 55    This is done because comparison simplification is only done on lt/lte/gt/gte.
 56    """
 57    if isinstance(expression, exp.Between):
 58        return exp.and_(
 59            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 60            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 61            copy=False,
 62        )
 63    return expression
 64
 65
 66def simplify_not(expression):
 67    """
 68    Demorgan's Law
 69    NOT (x OR y) -> NOT x AND NOT y
 70    NOT (x AND y) -> NOT x OR NOT y
 71    """
 72    if isinstance(expression, exp.Not):
 73        if is_null(expression.this):
 74            return exp.null()
 75        if isinstance(expression.this, exp.Paren):
 76            condition = expression.this.unnest()
 77            if isinstance(condition, exp.And):
 78                return exp.or_(
 79                    exp.not_(condition.left, copy=False),
 80                    exp.not_(condition.right, copy=False),
 81                    copy=False,
 82                )
 83            if isinstance(condition, exp.Or):
 84                return exp.and_(
 85                    exp.not_(condition.left, copy=False),
 86                    exp.not_(condition.right, copy=False),
 87                    copy=False,
 88                )
 89            if is_null(condition):
 90                return exp.null()
 91        if always_true(expression.this):
 92            return exp.false()
 93        if is_false(expression.this):
 94            return exp.true()
 95        if isinstance(expression.this, exp.Not):
 96            # double negation
 97            # NOT NOT x -> x
 98            return expression.this.this
 99    return expression
100
101
102def flatten(expression):
103    """
104    A AND (B AND C) -> A AND B AND C
105    A OR (B OR C) -> A OR B OR C
106    """
107    if isinstance(expression, exp.Connector):
108        for node in expression.args.values():
109            child = node.unnest()
110            if isinstance(child, expression.__class__):
111                node.replace(child)
112    return expression
113
114
115def simplify_connectors(expression, root=True):
116    def _simplify_connectors(expression, left, right):
117        if left == right:
118            return left
119        if isinstance(expression, exp.And):
120            if is_false(left) or is_false(right):
121                return exp.false()
122            if is_null(left) or is_null(right):
123                return exp.null()
124            if always_true(left) and always_true(right):
125                return exp.true()
126            if always_true(left):
127                return right
128            if always_true(right):
129                return left
130            return _simplify_comparison(expression, left, right)
131        elif isinstance(expression, exp.Or):
132            if always_true(left) or always_true(right):
133                return exp.true()
134            if is_false(left) and is_false(right):
135                return exp.false()
136            if (
137                (is_null(left) and is_null(right))
138                or (is_null(left) and is_false(right))
139                or (is_false(left) and is_null(right))
140            ):
141                return exp.null()
142            if is_false(left):
143                return right
144            if is_false(right):
145                return left
146            return _simplify_comparison(expression, left, right, or_=True)
147
148    if isinstance(expression, exp.Connector):
149        return _flat_simplify(expression, _simplify_connectors, root)
150    return expression
151
152
153LT_LTE = (exp.LT, exp.LTE)
154GT_GTE = (exp.GT, exp.GTE)
155
156COMPARISONS = (
157    *LT_LTE,
158    *GT_GTE,
159    exp.EQ,
160    exp.NEQ,
161)
162
163INVERSE_COMPARISONS = {
164    exp.LT: exp.GT,
165    exp.GT: exp.LT,
166    exp.LTE: exp.GTE,
167    exp.GTE: exp.LTE,
168}
169
170
171def _simplify_comparison(expression, left, right, or_=False):
172    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
173        ll, lr = left.args.values()
174        rl, rr = right.args.values()
175
176        largs = {ll, lr}
177        rargs = {rl, rr}
178
179        matching = largs & rargs
180        columns = {m for m in matching if isinstance(m, exp.Column)}
181
182        if matching and columns:
183            try:
184                l = first(largs - columns)
185                r = first(rargs - columns)
186            except StopIteration:
187                return expression
188
189            # make sure the comparison is always of the form x > 1 instead of 1 < x
190            if left.__class__ in INVERSE_COMPARISONS and l == ll:
191                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
192            if right.__class__ in INVERSE_COMPARISONS and r == rl:
193                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
194
195            if l.is_number and r.is_number:
196                l = float(l.name)
197                r = float(r.name)
198            elif l.is_string and r.is_string:
199                l = l.name
200                r = r.name
201            else:
202                return None
203
204            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
205                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
206                    return left if (av > bv if or_ else av <= bv) else right
207                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
208                    return left if (av < bv if or_ else av >= bv) else right
209
210                # we can't ever shortcut to true because the column could be null
211                if not or_:
212                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
213                        if av <= bv:
214                            return exp.false()
215                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
216                        if av >= bv:
217                            return exp.false()
218                    elif isinstance(a, exp.EQ):
219                        if isinstance(b, exp.LT):
220                            return exp.false() if av >= bv else a
221                        if isinstance(b, exp.LTE):
222                            return exp.false() if av > bv else a
223                        if isinstance(b, exp.GT):
224                            return exp.false() if av <= bv else a
225                        if isinstance(b, exp.GTE):
226                            return exp.false() if av < bv else a
227                        if isinstance(b, exp.NEQ):
228                            return exp.false() if av == bv else a
229    return None
230
231
232def remove_compliments(expression, root=True):
233    """
234    Removing compliments.
235
236    A AND NOT A -> FALSE
237    A OR NOT A -> TRUE
238    """
239    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
240        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
241
242        for a, b in itertools.permutations(expression.flatten(), 2):
243            if is_complement(a, b):
244                return compliment
245    return expression
246
247
248def uniq_sort(expression, generate, root=True):
249    """
250    Uniq and sort a connector.
251
252    C AND A AND B AND B -> A AND B AND C
253    """
254    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
255        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
256        flattened = tuple(expression.flatten())
257        deduped = {generate(e): e for e in flattened}
258        arr = tuple(deduped.items())
259
260        # check if the operands are already sorted, if not sort them
261        # A AND C AND B -> A AND B AND C
262        for i, (sql, e) in enumerate(arr[1:]):
263            if sql < arr[i][0]:
264                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
265                break
266        else:
267            # we didn't have to sort but maybe we need to dedup
268            if len(deduped) < len(flattened):
269                expression = result_func(*deduped.values(), copy=False)
270
271    return expression
272
273
274def absorb_and_eliminate(expression, root=True):
275    """
276    absorption:
277        A AND (A OR B) -> A
278        A OR (A AND B) -> A
279        A AND (NOT A OR B) -> A AND B
280        A OR (NOT A AND B) -> A OR B
281    elimination:
282        (A AND B) OR (A AND NOT B) -> A
283        (A OR B) AND (A OR NOT B) -> A
284    """
285    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
286        kind = exp.Or if isinstance(expression, exp.And) else exp.And
287
288        for a, b in itertools.permutations(expression.flatten(), 2):
289            if isinstance(a, kind):
290                aa, ab = a.unnest_operands()
291
292                # absorb
293                if is_complement(b, aa):
294                    aa.replace(exp.true() if kind == exp.And else exp.false())
295                elif is_complement(b, ab):
296                    ab.replace(exp.true() if kind == exp.And else exp.false())
297                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
298                    a.replace(exp.false() if kind == exp.And else exp.true())
299                elif isinstance(b, kind):
300                    # eliminate
301                    rhs = b.unnest_operands()
302                    ba, bb = rhs
303
304                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
305                        a.replace(aa)
306                        b.replace(aa)
307                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
308                        a.replace(ab)
309                        b.replace(ab)
310
311    return expression
312
313
314def simplify_literals(expression, root=True):
315    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
316        return _flat_simplify(expression, _simplify_binary, root)
317    elif isinstance(expression, exp.Neg):
318        this = expression.this
319        if this.is_number:
320            value = this.name
321            if value[0] == "-":
322                return exp.Literal.number(value[1:])
323            return exp.Literal.number(f"-{value}")
324
325    return expression
326
327
328def _simplify_binary(expression, a, b):
329    if isinstance(expression, exp.Is):
330        if isinstance(b, exp.Not):
331            c = b.this
332            not_ = True
333        else:
334            c = b
335            not_ = False
336
337        if is_null(c):
338            if isinstance(a, exp.Literal):
339                return exp.true() if not_ else exp.false()
340            if is_null(a):
341                return exp.false() if not_ else exp.true()
342    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
343        return None
344    elif is_null(a) or is_null(b):
345        return exp.null()
346
347    if a.is_number and b.is_number:
348        a = int(a.name) if a.is_int else Decimal(a.name)
349        b = int(b.name) if b.is_int else Decimal(b.name)
350
351        if isinstance(expression, exp.Add):
352            return exp.Literal.number(a + b)
353        if isinstance(expression, exp.Sub):
354            return exp.Literal.number(a - b)
355        if isinstance(expression, exp.Mul):
356            return exp.Literal.number(a * b)
357        if isinstance(expression, exp.Div):
358            # engines have differing int div behavior so intdiv is not safe
359            if isinstance(a, int) and isinstance(b, int):
360                return None
361            return exp.Literal.number(a / b)
362
363        boolean = eval_boolean(expression, a, b)
364
365        if boolean:
366            return boolean
367    elif a.is_string and b.is_string:
368        boolean = eval_boolean(expression, a.this, b.this)
369
370        if boolean:
371            return boolean
372    elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
373        a, b = extract_date(a), extract_interval(b)
374        if a and b:
375            if isinstance(expression, exp.Add):
376                return date_literal(a + b)
377            if isinstance(expression, exp.Sub):
378                return date_literal(a - b)
379    elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
380        a, b = extract_interval(a), extract_date(b)
381        # you cannot subtract a date from an interval
382        if a and b and isinstance(expression, exp.Add):
383            return date_literal(a + b)
384
385    return None
386
387
388def simplify_parens(expression):
389    if not isinstance(expression, exp.Paren):
390        return expression
391
392    this = expression.this
393    parent = expression.parent
394
395    if not isinstance(this, exp.Select) and (
396        not isinstance(parent, (exp.Condition, exp.Binary))
397        or isinstance(this, exp.Predicate)
398        or not isinstance(this, exp.Binary)
399        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
400        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
401    ):
402        return expression.this
403    return expression
404
405
406def remove_where_true(expression):
407    for where in expression.find_all(exp.Where):
408        if always_true(where.this):
409            where.parent.set("where", None)
410    for join in expression.find_all(exp.Join):
411        if always_true(join.args.get("on")):
412            join.set("kind", "CROSS")
413            join.set("on", None)
414
415
416def always_true(expression):
417    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
418        expression, exp.Literal
419    )
420
421
422def is_complement(a, b):
423    return isinstance(b, exp.Not) and b.this == a
424
425
426def is_false(a: exp.Expression) -> bool:
427    return type(a) is exp.Boolean and not a.this
428
429
430def is_null(a: exp.Expression) -> bool:
431    return type(a) is exp.Null
432
433
434def eval_boolean(expression, a, b):
435    if isinstance(expression, (exp.EQ, exp.Is)):
436        return boolean_literal(a == b)
437    if isinstance(expression, exp.NEQ):
438        return boolean_literal(a != b)
439    if isinstance(expression, exp.GT):
440        return boolean_literal(a > b)
441    if isinstance(expression, exp.GTE):
442        return boolean_literal(a >= b)
443    if isinstance(expression, exp.LT):
444        return boolean_literal(a < b)
445    if isinstance(expression, exp.LTE):
446        return boolean_literal(a <= b)
447    return None
448
449
450def extract_date(cast):
451    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
452    # so in that case we can't extract the date.
453    try:
454        if cast.args["to"].this == exp.DataType.Type.DATE:
455            return datetime.date.fromisoformat(cast.name)
456        if cast.args["to"].this == exp.DataType.Type.DATETIME:
457            return datetime.datetime.fromisoformat(cast.name)
458    except ValueError:
459        return None
460
461
462def extract_interval(interval):
463    try:
464        from dateutil.relativedelta import relativedelta  # type: ignore
465    except ModuleNotFoundError:
466        return None
467
468    n = int(interval.name)
469    unit = interval.text("unit").lower()
470
471    if unit == "year":
472        return relativedelta(years=n)
473    if unit == "month":
474        return relativedelta(months=n)
475    if unit == "week":
476        return relativedelta(weeks=n)
477    if unit == "day":
478        return relativedelta(days=n)
479    return None
480
481
482def date_literal(date):
483    return exp.cast(
484        exp.Literal.string(date),
485        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
486    )
487
488
489def boolean_literal(condition):
490    return exp.true() if condition else exp.false()
491
492
493def _flat_simplify(expression, simplifier, root=True):
494    if root or not expression.same_parent:
495        operands = []
496        queue = deque(expression.flatten(unnest=False))
497        size = len(queue)
498
499        while queue:
500            a = queue.popleft()
501
502            for b in queue:
503                result = simplifier(expression, a, b)
504
505                if result:
506                    queue.remove(b)
507                    queue.appendleft(result)
508                    break
509            else:
510                operands.append(a)
511
512        if len(operands) < size:
513            return functools.reduce(
514                lambda a, b: expression.__class__(this=a, expression=b), operands
515            )
516    return expression
def simplify(expression):
13def simplify(expression):
14    """
15    Rewrite sqlglot AST to simplify expressions.
16
17    Example:
18        >>> import sqlglot
19        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
20        >>> simplify(expression).sql()
21        'TRUE'
22
23    Args:
24        expression (sqlglot.Expression): expression to simplify
25    Returns:
26        sqlglot.Expression: simplified expression
27    """
28
29    generate = cached_generator()
30
31    def _simplify(expression, root=True):
32        node = expression
33        node = rewrite_between(node)
34        node = uniq_sort(node, generate, root)
35        node = absorb_and_eliminate(node, root)
36        exp.replace_children(node, lambda e: _simplify(e, False))
37        node = simplify_not(node)
38        node = flatten(node)
39        node = simplify_connectors(node, root)
40        node = remove_compliments(node, root)
41        node.parent = expression.parent
42        node = simplify_literals(node, root)
43        node = simplify_parens(node)
44        if root:
45            expression.replace(node)
46        return node
47
48    expression = while_changing(expression, _simplify)
49    remove_where_true(expression)
50    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
53def rewrite_between(expression: exp.Expression) -> exp.Expression:
54    """Rewrite x between y and z to x >= y AND x <= z.
55
56    This is done because comparison simplification is only done on lt/lte/gt/gte.
57    """
58    if isinstance(expression, exp.Between):
59        return exp.and_(
60            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
61            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
62            copy=False,
63        )
64    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
 67def simplify_not(expression):
 68    """
 69    Demorgan's Law
 70    NOT (x OR y) -> NOT x AND NOT y
 71    NOT (x AND y) -> NOT x OR NOT y
 72    """
 73    if isinstance(expression, exp.Not):
 74        if is_null(expression.this):
 75            return exp.null()
 76        if isinstance(expression.this, exp.Paren):
 77            condition = expression.this.unnest()
 78            if isinstance(condition, exp.And):
 79                return exp.or_(
 80                    exp.not_(condition.left, copy=False),
 81                    exp.not_(condition.right, copy=False),
 82                    copy=False,
 83                )
 84            if isinstance(condition, exp.Or):
 85                return exp.and_(
 86                    exp.not_(condition.left, copy=False),
 87                    exp.not_(condition.right, copy=False),
 88                    copy=False,
 89                )
 90            if is_null(condition):
 91                return exp.null()
 92        if always_true(expression.this):
 93            return exp.false()
 94        if is_false(expression.this):
 95            return exp.true()
 96        if isinstance(expression.this, exp.Not):
 97            # double negation
 98            # NOT NOT x -> x
 99            return expression.this.this
100    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
103def flatten(expression):
104    """
105    A AND (B AND C) -> A AND B AND C
106    A OR (B OR C) -> A OR B OR C
107    """
108    if isinstance(expression, exp.Connector):
109        for node in expression.args.values():
110            child = node.unnest()
111            if isinstance(child, expression.__class__):
112                node.replace(child)
113    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
116def simplify_connectors(expression, root=True):
117    def _simplify_connectors(expression, left, right):
118        if left == right:
119            return left
120        if isinstance(expression, exp.And):
121            if is_false(left) or is_false(right):
122                return exp.false()
123            if is_null(left) or is_null(right):
124                return exp.null()
125            if always_true(left) and always_true(right):
126                return exp.true()
127            if always_true(left):
128                return right
129            if always_true(right):
130                return left
131            return _simplify_comparison(expression, left, right)
132        elif isinstance(expression, exp.Or):
133            if always_true(left) or always_true(right):
134                return exp.true()
135            if is_false(left) and is_false(right):
136                return exp.false()
137            if (
138                (is_null(left) and is_null(right))
139                or (is_null(left) and is_false(right))
140                or (is_false(left) and is_null(right))
141            ):
142                return exp.null()
143            if is_false(left):
144                return right
145            if is_false(right):
146                return left
147            return _simplify_comparison(expression, left, right, or_=True)
148
149    if isinstance(expression, exp.Connector):
150        return _flat_simplify(expression, _simplify_connectors, root)
151    return expression
def remove_compliments(expression, root=True):
233def remove_compliments(expression, root=True):
234    """
235    Removing compliments.
236
237    A AND NOT A -> FALSE
238    A OR NOT A -> TRUE
239    """
240    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
241        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
242
243        for a, b in itertools.permutations(expression.flatten(), 2):
244            if is_complement(a, b):
245                return compliment
246    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
249def uniq_sort(expression, generate, root=True):
250    """
251    Uniq and sort a connector.
252
253    C AND A AND B AND B -> A AND B AND C
254    """
255    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
256        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
257        flattened = tuple(expression.flatten())
258        deduped = {generate(e): e for e in flattened}
259        arr = tuple(deduped.items())
260
261        # check if the operands are already sorted, if not sort them
262        # A AND C AND B -> A AND B AND C
263        for i, (sql, e) in enumerate(arr[1:]):
264            if sql < arr[i][0]:
265                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
266                break
267        else:
268            # we didn't have to sort but maybe we need to dedup
269            if len(deduped) < len(flattened):
270                expression = result_func(*deduped.values(), copy=False)
271
272    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
275def absorb_and_eliminate(expression, root=True):
276    """
277    absorption:
278        A AND (A OR B) -> A
279        A OR (A AND B) -> A
280        A AND (NOT A OR B) -> A AND B
281        A OR (NOT A AND B) -> A OR B
282    elimination:
283        (A AND B) OR (A AND NOT B) -> A
284        (A OR B) AND (A OR NOT B) -> A
285    """
286    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
287        kind = exp.Or if isinstance(expression, exp.And) else exp.And
288
289        for a, b in itertools.permutations(expression.flatten(), 2):
290            if isinstance(a, kind):
291                aa, ab = a.unnest_operands()
292
293                # absorb
294                if is_complement(b, aa):
295                    aa.replace(exp.true() if kind == exp.And else exp.false())
296                elif is_complement(b, ab):
297                    ab.replace(exp.true() if kind == exp.And else exp.false())
298                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
299                    a.replace(exp.false() if kind == exp.And else exp.true())
300                elif isinstance(b, kind):
301                    # eliminate
302                    rhs = b.unnest_operands()
303                    ba, bb = rhs
304
305                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
306                        a.replace(aa)
307                        b.replace(aa)
308                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
309                        a.replace(ab)
310                        b.replace(ab)
311
312    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_literals(expression, root=True):
315def simplify_literals(expression, root=True):
316    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
317        return _flat_simplify(expression, _simplify_binary, root)
318    elif isinstance(expression, exp.Neg):
319        this = expression.this
320        if this.is_number:
321            value = this.name
322            if value[0] == "-":
323                return exp.Literal.number(value[1:])
324            return exp.Literal.number(f"-{value}")
325
326    return expression
def simplify_parens(expression):
389def simplify_parens(expression):
390    if not isinstance(expression, exp.Paren):
391        return expression
392
393    this = expression.this
394    parent = expression.parent
395
396    if not isinstance(this, exp.Select) and (
397        not isinstance(parent, (exp.Condition, exp.Binary))
398        or isinstance(this, exp.Predicate)
399        or not isinstance(this, exp.Binary)
400        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
401        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
402    ):
403        return expression.this
404    return expression
def remove_where_true(expression):
407def remove_where_true(expression):
408    for where in expression.find_all(exp.Where):
409        if always_true(where.this):
410            where.parent.set("where", None)
411    for join in expression.find_all(exp.Join):
412        if always_true(join.args.get("on")):
413            join.set("kind", "CROSS")
414            join.set("on", None)
def always_true(expression):
417def always_true(expression):
418    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
419        expression, exp.Literal
420    )
def is_complement(a, b):
423def is_complement(a, b):
424    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
427def is_false(a: exp.Expression) -> bool:
428    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
431def is_null(a: exp.Expression) -> bool:
432    return type(a) is exp.Null
def eval_boolean(expression, a, b):
435def eval_boolean(expression, a, b):
436    if isinstance(expression, (exp.EQ, exp.Is)):
437        return boolean_literal(a == b)
438    if isinstance(expression, exp.NEQ):
439        return boolean_literal(a != b)
440    if isinstance(expression, exp.GT):
441        return boolean_literal(a > b)
442    if isinstance(expression, exp.GTE):
443        return boolean_literal(a >= b)
444    if isinstance(expression, exp.LT):
445        return boolean_literal(a < b)
446    if isinstance(expression, exp.LTE):
447        return boolean_literal(a <= b)
448    return None
def extract_date(cast):
451def extract_date(cast):
452    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
453    # so in that case we can't extract the date.
454    try:
455        if cast.args["to"].this == exp.DataType.Type.DATE:
456            return datetime.date.fromisoformat(cast.name)
457        if cast.args["to"].this == exp.DataType.Type.DATETIME:
458            return datetime.datetime.fromisoformat(cast.name)
459    except ValueError:
460        return None
def extract_interval(interval):
463def extract_interval(interval):
464    try:
465        from dateutil.relativedelta import relativedelta  # type: ignore
466    except ModuleNotFoundError:
467        return None
468
469    n = int(interval.name)
470    unit = interval.text("unit").lower()
471
472    if unit == "year":
473        return relativedelta(years=n)
474    if unit == "month":
475        return relativedelta(months=n)
476    if unit == "week":
477        return relativedelta(weeks=n)
478    if unit == "day":
479        return relativedelta(days=n)
480    return None
def date_literal(date):
483def date_literal(date):
484    return exp.cast(
485        exp.Literal.string(date),
486        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
487    )
def boolean_literal(condition):
490def boolean_literal(condition):
491    return exp.true() if condition else exp.false()