Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4from collections import deque
  5from decimal import Decimal
  6
  7from sqlglot import exp
  8from sqlglot.generator import Generator
  9from sqlglot.helper import first, while_changing
 10
 11GENERATOR = Generator(normalize=True, identify="safe")
 12
 13
 14def simplify(expression):
 15    """
 16    Rewrite sqlglot AST to simplify expressions.
 17
 18    Example:
 19        >>> import sqlglot
 20        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 21        >>> simplify(expression).sql()
 22        'TRUE'
 23
 24    Args:
 25        expression (sqlglot.Expression): expression to simplify
 26    Returns:
 27        sqlglot.Expression: simplified expression
 28    """
 29
 30    cache = {}
 31
 32    def _simplify(expression, root=True):
 33        node = expression
 34        node = rewrite_between(node)
 35        node = uniq_sort(node, cache, root)
 36        node = absorb_and_eliminate(node, root)
 37        exp.replace_children(node, lambda e: _simplify(e, False))
 38        node = simplify_not(node)
 39        node = flatten(node)
 40        node = simplify_connectors(node, root)
 41        node = remove_compliments(node, root)
 42        node.parent = expression.parent
 43        node = simplify_literals(node, root)
 44        node = simplify_parens(node)
 45        if root:
 46            expression.replace(node)
 47        return node
 48
 49    expression = while_changing(expression, _simplify)
 50    remove_where_true(expression)
 51    return expression
 52
 53
 54def rewrite_between(expression: exp.Expression) -> exp.Expression:
 55    """Rewrite x between y and z to x >= y AND x <= z.
 56
 57    This is done because comparison simplification is only done on lt/lte/gt/gte.
 58    """
 59    if isinstance(expression, exp.Between):
 60        return exp.and_(
 61            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 62            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 63            copy=False,
 64        )
 65    return expression
 66
 67
 68def simplify_not(expression):
 69    """
 70    Demorgan's Law
 71    NOT (x OR y) -> NOT x AND NOT y
 72    NOT (x AND y) -> NOT x OR NOT y
 73    """
 74    if isinstance(expression, exp.Not):
 75        if is_null(expression.this):
 76            return exp.null()
 77        if isinstance(expression.this, exp.Paren):
 78            condition = expression.this.unnest()
 79            if isinstance(condition, exp.And):
 80                return exp.or_(
 81                    exp.not_(condition.left, copy=False),
 82                    exp.not_(condition.right, copy=False),
 83                    copy=False,
 84                )
 85            if isinstance(condition, exp.Or):
 86                return exp.and_(
 87                    exp.not_(condition.left, copy=False),
 88                    exp.not_(condition.right, copy=False),
 89                    copy=False,
 90                )
 91            if is_null(condition):
 92                return exp.null()
 93        if always_true(expression.this):
 94            return exp.false()
 95        if is_false(expression.this):
 96            return exp.true()
 97        if isinstance(expression.this, exp.Not):
 98            # double negation
 99            # NOT NOT x -> x
100            return expression.this.this
101    return expression
102
103
104def flatten(expression):
105    """
106    A AND (B AND C) -> A AND B AND C
107    A OR (B OR C) -> A OR B OR C
108    """
109    if isinstance(expression, exp.Connector):
110        for node in expression.args.values():
111            child = node.unnest()
112            if isinstance(child, expression.__class__):
113                node.replace(child)
114    return expression
115
116
117def simplify_connectors(expression, root=True):
118    def _simplify_connectors(expression, left, right):
119        if left == right:
120            return left
121        if isinstance(expression, exp.And):
122            if is_false(left) or is_false(right):
123                return exp.false()
124            if is_null(left) or is_null(right):
125                return exp.null()
126            if always_true(left) and always_true(right):
127                return exp.true()
128            if always_true(left):
129                return right
130            if always_true(right):
131                return left
132            return _simplify_comparison(expression, left, right)
133        elif isinstance(expression, exp.Or):
134            if always_true(left) or always_true(right):
135                return exp.true()
136            if is_false(left) and is_false(right):
137                return exp.false()
138            if (
139                (is_null(left) and is_null(right))
140                or (is_null(left) and is_false(right))
141                or (is_false(left) and is_null(right))
142            ):
143                return exp.null()
144            if is_false(left):
145                return right
146            if is_false(right):
147                return left
148            return _simplify_comparison(expression, left, right, or_=True)
149
150    if isinstance(expression, exp.Connector):
151        return _flat_simplify(expression, _simplify_connectors, root)
152    return expression
153
154
155LT_LTE = (exp.LT, exp.LTE)
156GT_GTE = (exp.GT, exp.GTE)
157
158COMPARISONS = (
159    *LT_LTE,
160    *GT_GTE,
161    exp.EQ,
162    exp.NEQ,
163)
164
165INVERSE_COMPARISONS = {
166    exp.LT: exp.GT,
167    exp.GT: exp.LT,
168    exp.LTE: exp.GTE,
169    exp.GTE: exp.LTE,
170}
171
172
173def _simplify_comparison(expression, left, right, or_=False):
174    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
175        ll, lr = left.args.values()
176        rl, rr = right.args.values()
177
178        largs = {ll, lr}
179        rargs = {rl, rr}
180
181        matching = largs & rargs
182        columns = {m for m in matching if isinstance(m, exp.Column)}
183
184        if matching and columns:
185            try:
186                l = first(largs - columns)
187                r = first(rargs - columns)
188            except StopIteration:
189                return expression
190
191            # make sure the comparison is always of the form x > 1 instead of 1 < x
192            if left.__class__ in INVERSE_COMPARISONS and l == ll:
193                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
194            if right.__class__ in INVERSE_COMPARISONS and r == rl:
195                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
196
197            if l.is_number and r.is_number:
198                l = float(l.name)
199                r = float(r.name)
200            elif l.is_string and r.is_string:
201                l = l.name
202                r = r.name
203            else:
204                return None
205
206            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
207                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
208                    return left if (av > bv if or_ else av <= bv) else right
209                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
210                    return left if (av < bv if or_ else av >= bv) else right
211
212                # we can't ever shortcut to true because the column could be null
213                if not or_:
214                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
215                        if av <= bv:
216                            return exp.false()
217                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
218                        if av >= bv:
219                            return exp.false()
220                    elif isinstance(a, exp.EQ):
221                        if isinstance(b, exp.LT):
222                            return exp.false() if av >= bv else a
223                        if isinstance(b, exp.LTE):
224                            return exp.false() if av > bv else a
225                        if isinstance(b, exp.GT):
226                            return exp.false() if av <= bv else a
227                        if isinstance(b, exp.GTE):
228                            return exp.false() if av < bv else a
229                        if isinstance(b, exp.NEQ):
230                            return exp.false() if av == bv else a
231    return None
232
233
234def remove_compliments(expression, root=True):
235    """
236    Removing compliments.
237
238    A AND NOT A -> FALSE
239    A OR NOT A -> TRUE
240    """
241    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
242        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
243
244        for a, b in itertools.permutations(expression.flatten(), 2):
245            if is_complement(a, b):
246                return compliment
247    return expression
248
249
250def uniq_sort(expression, cache=None, root=True):
251    """
252    Uniq and sort a connector.
253
254    C AND A AND B AND B -> A AND B AND C
255    """
256    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
257        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
258        flattened = tuple(expression.flatten())
259        deduped = {GENERATOR.generate(e, cache): e for e in flattened}
260        arr = tuple(deduped.items())
261
262        # check if the operands are already sorted, if not sort them
263        # A AND C AND B -> A AND B AND C
264        for i, (sql, e) in enumerate(arr[1:]):
265            if sql < arr[i][0]:
266                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
267                break
268        else:
269            # we didn't have to sort but maybe we need to dedup
270            if len(deduped) < len(flattened):
271                expression = result_func(*deduped.values(), copy=False)
272
273    return expression
274
275
276def absorb_and_eliminate(expression, root=True):
277    """
278    absorption:
279        A AND (A OR B) -> A
280        A OR (A AND B) -> A
281        A AND (NOT A OR B) -> A AND B
282        A OR (NOT A AND B) -> A OR B
283    elimination:
284        (A AND B) OR (A AND NOT B) -> A
285        (A OR B) AND (A OR NOT B) -> A
286    """
287    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
288        kind = exp.Or if isinstance(expression, exp.And) else exp.And
289
290        for a, b in itertools.permutations(expression.flatten(), 2):
291            if isinstance(a, kind):
292                aa, ab = a.unnest_operands()
293
294                # absorb
295                if is_complement(b, aa):
296                    aa.replace(exp.true() if kind == exp.And else exp.false())
297                elif is_complement(b, ab):
298                    ab.replace(exp.true() if kind == exp.And else exp.false())
299                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
300                    a.replace(exp.false() if kind == exp.And else exp.true())
301                elif isinstance(b, kind):
302                    # eliminate
303                    rhs = b.unnest_operands()
304                    ba, bb = rhs
305
306                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
307                        a.replace(aa)
308                        b.replace(aa)
309                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
310                        a.replace(ab)
311                        b.replace(ab)
312
313    return expression
314
315
316def simplify_literals(expression, root=True):
317    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
318        return _flat_simplify(expression, _simplify_binary, root)
319    elif isinstance(expression, exp.Neg):
320        this = expression.this
321        if this.is_number:
322            value = this.name
323            if value[0] == "-":
324                return exp.Literal.number(value[1:])
325            return exp.Literal.number(f"-{value}")
326
327    return expression
328
329
330def _simplify_binary(expression, a, b):
331    if isinstance(expression, exp.Is):
332        if isinstance(b, exp.Not):
333            c = b.this
334            not_ = True
335        else:
336            c = b
337            not_ = False
338
339        if is_null(c):
340            if isinstance(a, exp.Literal):
341                return exp.true() if not_ else exp.false()
342            if is_null(a):
343                return exp.false() if not_ else exp.true()
344    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
345        return None
346    elif is_null(a) or is_null(b):
347        return exp.null()
348
349    if a.is_number and b.is_number:
350        a = int(a.name) if a.is_int else Decimal(a.name)
351        b = int(b.name) if b.is_int else Decimal(b.name)
352
353        if isinstance(expression, exp.Add):
354            return exp.Literal.number(a + b)
355        if isinstance(expression, exp.Sub):
356            return exp.Literal.number(a - b)
357        if isinstance(expression, exp.Mul):
358            return exp.Literal.number(a * b)
359        if isinstance(expression, exp.Div):
360            # engines have differing int div behavior so intdiv is not safe
361            if isinstance(a, int) and isinstance(b, int):
362                return None
363            return exp.Literal.number(a / b)
364
365        boolean = eval_boolean(expression, a, b)
366
367        if boolean:
368            return boolean
369    elif a.is_string and b.is_string:
370        boolean = eval_boolean(expression, a.this, b.this)
371
372        if boolean:
373            return boolean
374    elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
375        a, b = extract_date(a), extract_interval(b)
376        if a and b:
377            if isinstance(expression, exp.Add):
378                return date_literal(a + b)
379            if isinstance(expression, exp.Sub):
380                return date_literal(a - b)
381    elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
382        a, b = extract_interval(a), extract_date(b)
383        # you cannot subtract a date from an interval
384        if a and b and isinstance(expression, exp.Add):
385            return date_literal(a + b)
386
387    return None
388
389
390def simplify_parens(expression):
391    if (
392        isinstance(expression, exp.Paren)
393        and not isinstance(expression.this, exp.Select)
394        and (
395            not isinstance(expression.parent, (exp.Condition, exp.Binary))
396            or isinstance(expression.this, exp.Predicate)
397            or not isinstance(expression.this, exp.Binary)
398        )
399    ):
400        return expression.this
401    return expression
402
403
404def remove_where_true(expression):
405    for where in expression.find_all(exp.Where):
406        if always_true(where.this):
407            where.parent.set("where", None)
408    for join in expression.find_all(exp.Join):
409        if always_true(join.args.get("on")):
410            join.set("kind", "CROSS")
411            join.set("on", None)
412
413
414def always_true(expression):
415    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
416        expression, exp.Literal
417    )
418
419
420def is_complement(a, b):
421    return isinstance(b, exp.Not) and b.this == a
422
423
424def is_false(a: exp.Expression) -> bool:
425    return type(a) is exp.Boolean and not a.this
426
427
428def is_null(a: exp.Expression) -> bool:
429    return type(a) is exp.Null
430
431
432def eval_boolean(expression, a, b):
433    if isinstance(expression, (exp.EQ, exp.Is)):
434        return boolean_literal(a == b)
435    if isinstance(expression, exp.NEQ):
436        return boolean_literal(a != b)
437    if isinstance(expression, exp.GT):
438        return boolean_literal(a > b)
439    if isinstance(expression, exp.GTE):
440        return boolean_literal(a >= b)
441    if isinstance(expression, exp.LT):
442        return boolean_literal(a < b)
443    if isinstance(expression, exp.LTE):
444        return boolean_literal(a <= b)
445    return None
446
447
448def extract_date(cast):
449    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
450    # so in that case we can't extract the date.
451    try:
452        if cast.args["to"].this == exp.DataType.Type.DATE:
453            return datetime.date.fromisoformat(cast.name)
454        if cast.args["to"].this == exp.DataType.Type.DATETIME:
455            return datetime.datetime.fromisoformat(cast.name)
456    except ValueError:
457        return None
458
459
460def extract_interval(interval):
461    try:
462        from dateutil.relativedelta import relativedelta  # type: ignore
463    except ModuleNotFoundError:
464        return None
465
466    n = int(interval.name)
467    unit = interval.text("unit").lower()
468
469    if unit == "year":
470        return relativedelta(years=n)
471    if unit == "month":
472        return relativedelta(months=n)
473    if unit == "week":
474        return relativedelta(weeks=n)
475    if unit == "day":
476        return relativedelta(days=n)
477    return None
478
479
480def date_literal(date):
481    return exp.cast(
482        exp.Literal.string(date),
483        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
484    )
485
486
487def boolean_literal(condition):
488    return exp.true() if condition else exp.false()
489
490
491def _flat_simplify(expression, simplifier, root=True):
492    if root or not expression.same_parent:
493        operands = []
494        queue = deque(expression.flatten(unnest=False))
495        size = len(queue)
496
497        while queue:
498            a = queue.popleft()
499
500            for b in queue:
501                result = simplifier(expression, a, b)
502
503                if result:
504                    queue.remove(b)
505                    queue.appendleft(result)
506                    break
507            else:
508                operands.append(a)
509
510        if len(operands) < size:
511            return functools.reduce(
512                lambda a, b: expression.__class__(this=a, expression=b), operands
513            )
514    return expression
def simplify(expression):
15def simplify(expression):
16    """
17    Rewrite sqlglot AST to simplify expressions.
18
19    Example:
20        >>> import sqlglot
21        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
22        >>> simplify(expression).sql()
23        'TRUE'
24
25    Args:
26        expression (sqlglot.Expression): expression to simplify
27    Returns:
28        sqlglot.Expression: simplified expression
29    """
30
31    cache = {}
32
33    def _simplify(expression, root=True):
34        node = expression
35        node = rewrite_between(node)
36        node = uniq_sort(node, cache, root)
37        node = absorb_and_eliminate(node, root)
38        exp.replace_children(node, lambda e: _simplify(e, False))
39        node = simplify_not(node)
40        node = flatten(node)
41        node = simplify_connectors(node, root)
42        node = remove_compliments(node, root)
43        node.parent = expression.parent
44        node = simplify_literals(node, root)
45        node = simplify_parens(node)
46        if root:
47            expression.replace(node)
48        return node
49
50    expression = while_changing(expression, _simplify)
51    remove_where_true(expression)
52    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
55def rewrite_between(expression: exp.Expression) -> exp.Expression:
56    """Rewrite x between y and z to x >= y AND x <= z.
57
58    This is done because comparison simplification is only done on lt/lte/gt/gte.
59    """
60    if isinstance(expression, exp.Between):
61        return exp.and_(
62            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
63            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
64            copy=False,
65        )
66    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
 69def simplify_not(expression):
 70    """
 71    Demorgan's Law
 72    NOT (x OR y) -> NOT x AND NOT y
 73    NOT (x AND y) -> NOT x OR NOT y
 74    """
 75    if isinstance(expression, exp.Not):
 76        if is_null(expression.this):
 77            return exp.null()
 78        if isinstance(expression.this, exp.Paren):
 79            condition = expression.this.unnest()
 80            if isinstance(condition, exp.And):
 81                return exp.or_(
 82                    exp.not_(condition.left, copy=False),
 83                    exp.not_(condition.right, copy=False),
 84                    copy=False,
 85                )
 86            if isinstance(condition, exp.Or):
 87                return exp.and_(
 88                    exp.not_(condition.left, copy=False),
 89                    exp.not_(condition.right, copy=False),
 90                    copy=False,
 91                )
 92            if is_null(condition):
 93                return exp.null()
 94        if always_true(expression.this):
 95            return exp.false()
 96        if is_false(expression.this):
 97            return exp.true()
 98        if isinstance(expression.this, exp.Not):
 99            # double negation
100            # NOT NOT x -> x
101            return expression.this.this
102    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
105def flatten(expression):
106    """
107    A AND (B AND C) -> A AND B AND C
108    A OR (B OR C) -> A OR B OR C
109    """
110    if isinstance(expression, exp.Connector):
111        for node in expression.args.values():
112            child = node.unnest()
113            if isinstance(child, expression.__class__):
114                node.replace(child)
115    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
118def simplify_connectors(expression, root=True):
119    def _simplify_connectors(expression, left, right):
120        if left == right:
121            return left
122        if isinstance(expression, exp.And):
123            if is_false(left) or is_false(right):
124                return exp.false()
125            if is_null(left) or is_null(right):
126                return exp.null()
127            if always_true(left) and always_true(right):
128                return exp.true()
129            if always_true(left):
130                return right
131            if always_true(right):
132                return left
133            return _simplify_comparison(expression, left, right)
134        elif isinstance(expression, exp.Or):
135            if always_true(left) or always_true(right):
136                return exp.true()
137            if is_false(left) and is_false(right):
138                return exp.false()
139            if (
140                (is_null(left) and is_null(right))
141                or (is_null(left) and is_false(right))
142                or (is_false(left) and is_null(right))
143            ):
144                return exp.null()
145            if is_false(left):
146                return right
147            if is_false(right):
148                return left
149            return _simplify_comparison(expression, left, right, or_=True)
150
151    if isinstance(expression, exp.Connector):
152        return _flat_simplify(expression, _simplify_connectors, root)
153    return expression
def remove_compliments(expression, root=True):
235def remove_compliments(expression, root=True):
236    """
237    Removing compliments.
238
239    A AND NOT A -> FALSE
240    A OR NOT A -> TRUE
241    """
242    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
243        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
244
245        for a, b in itertools.permutations(expression.flatten(), 2):
246            if is_complement(a, b):
247                return compliment
248    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, cache=None, root=True):
251def uniq_sort(expression, cache=None, root=True):
252    """
253    Uniq and sort a connector.
254
255    C AND A AND B AND B -> A AND B AND C
256    """
257    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
258        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
259        flattened = tuple(expression.flatten())
260        deduped = {GENERATOR.generate(e, cache): e for e in flattened}
261        arr = tuple(deduped.items())
262
263        # check if the operands are already sorted, if not sort them
264        # A AND C AND B -> A AND B AND C
265        for i, (sql, e) in enumerate(arr[1:]):
266            if sql < arr[i][0]:
267                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
268                break
269        else:
270            # we didn't have to sort but maybe we need to dedup
271            if len(deduped) < len(flattened):
272                expression = result_func(*deduped.values(), copy=False)
273
274    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
277def absorb_and_eliminate(expression, root=True):
278    """
279    absorption:
280        A AND (A OR B) -> A
281        A OR (A AND B) -> A
282        A AND (NOT A OR B) -> A AND B
283        A OR (NOT A AND B) -> A OR B
284    elimination:
285        (A AND B) OR (A AND NOT B) -> A
286        (A OR B) AND (A OR NOT B) -> A
287    """
288    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
289        kind = exp.Or if isinstance(expression, exp.And) else exp.And
290
291        for a, b in itertools.permutations(expression.flatten(), 2):
292            if isinstance(a, kind):
293                aa, ab = a.unnest_operands()
294
295                # absorb
296                if is_complement(b, aa):
297                    aa.replace(exp.true() if kind == exp.And else exp.false())
298                elif is_complement(b, ab):
299                    ab.replace(exp.true() if kind == exp.And else exp.false())
300                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
301                    a.replace(exp.false() if kind == exp.And else exp.true())
302                elif isinstance(b, kind):
303                    # eliminate
304                    rhs = b.unnest_operands()
305                    ba, bb = rhs
306
307                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
308                        a.replace(aa)
309                        b.replace(aa)
310                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
311                        a.replace(ab)
312                        b.replace(ab)
313
314    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_literals(expression, root=True):
317def simplify_literals(expression, root=True):
318    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
319        return _flat_simplify(expression, _simplify_binary, root)
320    elif isinstance(expression, exp.Neg):
321        this = expression.this
322        if this.is_number:
323            value = this.name
324            if value[0] == "-":
325                return exp.Literal.number(value[1:])
326            return exp.Literal.number(f"-{value}")
327
328    return expression
def simplify_parens(expression):
391def simplify_parens(expression):
392    if (
393        isinstance(expression, exp.Paren)
394        and not isinstance(expression.this, exp.Select)
395        and (
396            not isinstance(expression.parent, (exp.Condition, exp.Binary))
397            or isinstance(expression.this, exp.Predicate)
398            or not isinstance(expression.this, exp.Binary)
399        )
400    ):
401        return expression.this
402    return expression
def remove_where_true(expression):
405def remove_where_true(expression):
406    for where in expression.find_all(exp.Where):
407        if always_true(where.this):
408            where.parent.set("where", None)
409    for join in expression.find_all(exp.Join):
410        if always_true(join.args.get("on")):
411            join.set("kind", "CROSS")
412            join.set("on", None)
def always_true(expression):
415def always_true(expression):
416    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
417        expression, exp.Literal
418    )
def is_complement(a, b):
421def is_complement(a, b):
422    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
425def is_false(a: exp.Expression) -> bool:
426    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
429def is_null(a: exp.Expression) -> bool:
430    return type(a) is exp.Null
def eval_boolean(expression, a, b):
433def eval_boolean(expression, a, b):
434    if isinstance(expression, (exp.EQ, exp.Is)):
435        return boolean_literal(a == b)
436    if isinstance(expression, exp.NEQ):
437        return boolean_literal(a != b)
438    if isinstance(expression, exp.GT):
439        return boolean_literal(a > b)
440    if isinstance(expression, exp.GTE):
441        return boolean_literal(a >= b)
442    if isinstance(expression, exp.LT):
443        return boolean_literal(a < b)
444    if isinstance(expression, exp.LTE):
445        return boolean_literal(a <= b)
446    return None
def extract_date(cast):
449def extract_date(cast):
450    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
451    # so in that case we can't extract the date.
452    try:
453        if cast.args["to"].this == exp.DataType.Type.DATE:
454            return datetime.date.fromisoformat(cast.name)
455        if cast.args["to"].this == exp.DataType.Type.DATETIME:
456            return datetime.datetime.fromisoformat(cast.name)
457    except ValueError:
458        return None
def extract_interval(interval):
461def extract_interval(interval):
462    try:
463        from dateutil.relativedelta import relativedelta  # type: ignore
464    except ModuleNotFoundError:
465        return None
466
467    n = int(interval.name)
468    unit = interval.text("unit").lower()
469
470    if unit == "year":
471        return relativedelta(years=n)
472    if unit == "month":
473        return relativedelta(months=n)
474    if unit == "week":
475        return relativedelta(weeks=n)
476    if unit == "day":
477        return relativedelta(days=n)
478    return None
def date_literal(date):
481def date_literal(date):
482    return exp.cast(
483        exp.Literal.string(date),
484        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
485    )
def boolean_literal(condition):
488def boolean_literal(condition):
489    return exp.true() if condition else exp.false()