Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4from collections import deque
  5from decimal import Decimal
  6
  7from sqlglot import exp
  8from sqlglot.generator import Generator
  9from sqlglot.helper import first, while_changing
 10
 11GENERATOR = Generator(normalize=True, identify="safe")
 12
 13
 14def simplify(expression):
 15    """
 16    Rewrite sqlglot AST to simplify expressions.
 17
 18    Example:
 19        >>> import sqlglot
 20        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 21        >>> simplify(expression).sql()
 22        'TRUE'
 23
 24    Args:
 25        expression (sqlglot.Expression): expression to simplify
 26    Returns:
 27        sqlglot.Expression: simplified expression
 28    """
 29
 30    cache = {}
 31
 32    def _simplify(expression, root=True):
 33        node = expression
 34        node = rewrite_between(node)
 35        node = uniq_sort(node, cache, root)
 36        node = absorb_and_eliminate(node, root)
 37        exp.replace_children(node, lambda e: _simplify(e, False))
 38        node = simplify_not(node)
 39        node = flatten(node)
 40        node = simplify_connectors(node, root)
 41        node = remove_compliments(node, root)
 42        node.parent = expression.parent
 43        node = simplify_literals(node, root)
 44        node = simplify_parens(node)
 45        if root:
 46            expression.replace(node)
 47        return node
 48
 49    expression = while_changing(expression, _simplify)
 50    remove_where_true(expression)
 51    return expression
 52
 53
 54def rewrite_between(expression: exp.Expression) -> exp.Expression:
 55    """Rewrite x between y and z to x >= y AND x <= z.
 56
 57    This is done because comparison simplification is only done on lt/lte/gt/gte.
 58    """
 59    if isinstance(expression, exp.Between):
 60        return exp.and_(
 61            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 62            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 63        )
 64    return expression
 65
 66
 67def simplify_not(expression):
 68    """
 69    Demorgan's Law
 70    NOT (x OR y) -> NOT x AND NOT y
 71    NOT (x AND y) -> NOT x OR NOT y
 72    """
 73    if isinstance(expression, exp.Not):
 74        if is_null(expression.this):
 75            return exp.null()
 76        if isinstance(expression.this, exp.Paren):
 77            condition = expression.this.unnest()
 78            if isinstance(condition, exp.And):
 79                return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
 80            if isinstance(condition, exp.Or):
 81                return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
 82            if is_null(condition):
 83                return exp.null()
 84        if always_true(expression.this):
 85            return exp.false()
 86        if is_false(expression.this):
 87            return exp.true()
 88        if isinstance(expression.this, exp.Not):
 89            # double negation
 90            # NOT NOT x -> x
 91            return expression.this.this
 92    return expression
 93
 94
 95def flatten(expression):
 96    """
 97    A AND (B AND C) -> A AND B AND C
 98    A OR (B OR C) -> A OR B OR C
 99    """
100    if isinstance(expression, exp.Connector):
101        for node in expression.args.values():
102            child = node.unnest()
103            if isinstance(child, expression.__class__):
104                node.replace(child)
105    return expression
106
107
108def simplify_connectors(expression, root=True):
109    def _simplify_connectors(expression, left, right):
110        if left == right:
111            return left
112        if isinstance(expression, exp.And):
113            if is_false(left) or is_false(right):
114                return exp.false()
115            if is_null(left) or is_null(right):
116                return exp.null()
117            if always_true(left) and always_true(right):
118                return exp.true()
119            if always_true(left):
120                return right
121            if always_true(right):
122                return left
123            return _simplify_comparison(expression, left, right)
124        elif isinstance(expression, exp.Or):
125            if always_true(left) or always_true(right):
126                return exp.true()
127            if is_false(left) and is_false(right):
128                return exp.false()
129            if (
130                (is_null(left) and is_null(right))
131                or (is_null(left) and is_false(right))
132                or (is_false(left) and is_null(right))
133            ):
134                return exp.null()
135            if is_false(left):
136                return right
137            if is_false(right):
138                return left
139            return _simplify_comparison(expression, left, right, or_=True)
140
141    if isinstance(expression, exp.Connector):
142        return _flat_simplify(expression, _simplify_connectors, root)
143    return expression
144
145
146LT_LTE = (exp.LT, exp.LTE)
147GT_GTE = (exp.GT, exp.GTE)
148
149COMPARISONS = (
150    *LT_LTE,
151    *GT_GTE,
152    exp.EQ,
153    exp.NEQ,
154)
155
156INVERSE_COMPARISONS = {
157    exp.LT: exp.GT,
158    exp.GT: exp.LT,
159    exp.LTE: exp.GTE,
160    exp.GTE: exp.LTE,
161}
162
163
164def _simplify_comparison(expression, left, right, or_=False):
165    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
166        ll, lr = left.args.values()
167        rl, rr = right.args.values()
168
169        largs = {ll, lr}
170        rargs = {rl, rr}
171
172        matching = largs & rargs
173        columns = {m for m in matching if isinstance(m, exp.Column)}
174
175        if matching and columns:
176            try:
177                l = first(largs - columns)
178                r = first(rargs - columns)
179            except StopIteration:
180                return expression
181
182            # make sure the comparison is always of the form x > 1 instead of 1 < x
183            if left.__class__ in INVERSE_COMPARISONS and l == ll:
184                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
185            if right.__class__ in INVERSE_COMPARISONS and r == rl:
186                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
187
188            if l.is_number and r.is_number:
189                l = float(l.name)
190                r = float(r.name)
191            elif l.is_string and r.is_string:
192                l = l.name
193                r = r.name
194            else:
195                return None
196
197            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
198                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
199                    return left if (av > bv if or_ else av <= bv) else right
200                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
201                    return left if (av < bv if or_ else av >= bv) else right
202
203                # we can't ever shortcut to true because the column could be null
204                if not or_:
205                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
206                        if av <= bv:
207                            return exp.false()
208                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
209                        if av >= bv:
210                            return exp.false()
211                    elif isinstance(a, exp.EQ):
212                        if isinstance(b, exp.LT):
213                            return exp.false() if av >= bv else a
214                        if isinstance(b, exp.LTE):
215                            return exp.false() if av > bv else a
216                        if isinstance(b, exp.GT):
217                            return exp.false() if av <= bv else a
218                        if isinstance(b, exp.GTE):
219                            return exp.false() if av < bv else a
220                        if isinstance(b, exp.NEQ):
221                            return exp.false() if av == bv else a
222    return None
223
224
225def remove_compliments(expression, root=True):
226    """
227    Removing compliments.
228
229    A AND NOT A -> FALSE
230    A OR NOT A -> TRUE
231    """
232    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
233        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
234
235        for a, b in itertools.permutations(expression.flatten(), 2):
236            if is_complement(a, b):
237                return compliment
238    return expression
239
240
241def uniq_sort(expression, cache=None, root=True):
242    """
243    Uniq and sort a connector.
244
245    C AND A AND B AND B -> A AND B AND C
246    """
247    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
248        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
249        flattened = tuple(expression.flatten())
250        deduped = {GENERATOR.generate(e, cache): e for e in flattened}
251        arr = tuple(deduped.items())
252
253        # check if the operands are already sorted, if not sort them
254        # A AND C AND B -> A AND B AND C
255        for i, (sql, e) in enumerate(arr[1:]):
256            if sql < arr[i][0]:
257                expression = result_func(*(e for _, e in sorted(arr)))
258                break
259        else:
260            # we didn't have to sort but maybe we need to dedup
261            if len(deduped) < len(flattened):
262                expression = result_func(*deduped.values())
263
264    return expression
265
266
267def absorb_and_eliminate(expression, root=True):
268    """
269    absorption:
270        A AND (A OR B) -> A
271        A OR (A AND B) -> A
272        A AND (NOT A OR B) -> A AND B
273        A OR (NOT A AND B) -> A OR B
274    elimination:
275        (A AND B) OR (A AND NOT B) -> A
276        (A OR B) AND (A OR NOT B) -> A
277    """
278    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
279        kind = exp.Or if isinstance(expression, exp.And) else exp.And
280
281        for a, b in itertools.permutations(expression.flatten(), 2):
282            if isinstance(a, kind):
283                aa, ab = a.unnest_operands()
284
285                # absorb
286                if is_complement(b, aa):
287                    aa.replace(exp.true() if kind == exp.And else exp.false())
288                elif is_complement(b, ab):
289                    ab.replace(exp.true() if kind == exp.And else exp.false())
290                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
291                    a.replace(exp.false() if kind == exp.And else exp.true())
292                elif isinstance(b, kind):
293                    # eliminate
294                    rhs = b.unnest_operands()
295                    ba, bb = rhs
296
297                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
298                        a.replace(aa)
299                        b.replace(aa)
300                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
301                        a.replace(ab)
302                        b.replace(ab)
303
304    return expression
305
306
307def simplify_literals(expression, root=True):
308    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
309        return _flat_simplify(expression, _simplify_binary, root)
310    elif isinstance(expression, exp.Neg):
311        this = expression.this
312        if this.is_number:
313            value = this.name
314            if value[0] == "-":
315                return exp.Literal.number(value[1:])
316            return exp.Literal.number(f"-{value}")
317
318    return expression
319
320
321def _simplify_binary(expression, a, b):
322    if isinstance(expression, exp.Is):
323        if isinstance(b, exp.Not):
324            c = b.this
325            not_ = True
326        else:
327            c = b
328            not_ = False
329
330        if is_null(c):
331            if isinstance(a, exp.Literal):
332                return exp.true() if not_ else exp.false()
333            if is_null(a):
334                return exp.false() if not_ else exp.true()
335    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
336        return None
337    elif is_null(a) or is_null(b):
338        return exp.null()
339
340    if a.is_number and b.is_number:
341        a = int(a.name) if a.is_int else Decimal(a.name)
342        b = int(b.name) if b.is_int else Decimal(b.name)
343
344        if isinstance(expression, exp.Add):
345            return exp.Literal.number(a + b)
346        if isinstance(expression, exp.Sub):
347            return exp.Literal.number(a - b)
348        if isinstance(expression, exp.Mul):
349            return exp.Literal.number(a * b)
350        if isinstance(expression, exp.Div):
351            # engines have differing int div behavior so intdiv is not safe
352            if isinstance(a, int) and isinstance(b, int):
353                return None
354            return exp.Literal.number(a / b)
355
356        boolean = eval_boolean(expression, a, b)
357
358        if boolean:
359            return boolean
360    elif a.is_string and b.is_string:
361        boolean = eval_boolean(expression, a.this, b.this)
362
363        if boolean:
364            return boolean
365    elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
366        a, b = extract_date(a), extract_interval(b)
367        if a and b:
368            if isinstance(expression, exp.Add):
369                return date_literal(a + b)
370            if isinstance(expression, exp.Sub):
371                return date_literal(a - b)
372    elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
373        a, b = extract_interval(a), extract_date(b)
374        # you cannot subtract a date from an interval
375        if a and b and isinstance(expression, exp.Add):
376            return date_literal(a + b)
377
378    return None
379
380
381def simplify_parens(expression):
382    if (
383        isinstance(expression, exp.Paren)
384        and not isinstance(expression.this, exp.Select)
385        and (
386            not isinstance(expression.parent, (exp.Condition, exp.Binary))
387            or isinstance(expression.this, exp.Predicate)
388            or not isinstance(expression.this, exp.Binary)
389        )
390    ):
391        return expression.this
392    return expression
393
394
395def remove_where_true(expression):
396    for where in expression.find_all(exp.Where):
397        if always_true(where.this):
398            where.parent.set("where", None)
399    for join in expression.find_all(exp.Join):
400        if always_true(join.args.get("on")):
401            join.set("kind", "CROSS")
402            join.set("on", None)
403
404
405def always_true(expression):
406    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
407        expression, exp.Literal
408    )
409
410
411def is_complement(a, b):
412    return isinstance(b, exp.Not) and b.this == a
413
414
415def is_false(a: exp.Expression) -> bool:
416    return type(a) is exp.Boolean and not a.this
417
418
419def is_null(a: exp.Expression) -> bool:
420    return type(a) is exp.Null
421
422
423def eval_boolean(expression, a, b):
424    if isinstance(expression, (exp.EQ, exp.Is)):
425        return boolean_literal(a == b)
426    if isinstance(expression, exp.NEQ):
427        return boolean_literal(a != b)
428    if isinstance(expression, exp.GT):
429        return boolean_literal(a > b)
430    if isinstance(expression, exp.GTE):
431        return boolean_literal(a >= b)
432    if isinstance(expression, exp.LT):
433        return boolean_literal(a < b)
434    if isinstance(expression, exp.LTE):
435        return boolean_literal(a <= b)
436    return None
437
438
439def extract_date(cast):
440    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
441    # so in that case we can't extract the date.
442    try:
443        if cast.args["to"].this == exp.DataType.Type.DATE:
444            return datetime.date.fromisoformat(cast.name)
445        if cast.args["to"].this == exp.DataType.Type.DATETIME:
446            return datetime.datetime.fromisoformat(cast.name)
447    except ValueError:
448        return None
449
450
451def extract_interval(interval):
452    try:
453        from dateutil.relativedelta import relativedelta  # type: ignore
454    except ModuleNotFoundError:
455        return None
456
457    n = int(interval.name)
458    unit = interval.text("unit").lower()
459
460    if unit == "year":
461        return relativedelta(years=n)
462    if unit == "month":
463        return relativedelta(months=n)
464    if unit == "week":
465        return relativedelta(weeks=n)
466    if unit == "day":
467        return relativedelta(days=n)
468    return None
469
470
471def date_literal(date):
472    return exp.cast(
473        exp.Literal.string(date),
474        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
475    )
476
477
478def boolean_literal(condition):
479    return exp.true() if condition else exp.false()
480
481
482def _flat_simplify(expression, simplifier, root=True):
483    if root or not expression.same_parent:
484        operands = []
485        queue = deque(expression.flatten(unnest=False))
486        size = len(queue)
487
488        while queue:
489            a = queue.popleft()
490
491            for b in queue:
492                result = simplifier(expression, a, b)
493
494                if result:
495                    queue.remove(b)
496                    queue.appendleft(result)
497                    break
498            else:
499                operands.append(a)
500
501        if len(operands) < size:
502            return functools.reduce(
503                lambda a, b: expression.__class__(this=a, expression=b), operands
504            )
505    return expression
def simplify(expression):
15def simplify(expression):
16    """
17    Rewrite sqlglot AST to simplify expressions.
18
19    Example:
20        >>> import sqlglot
21        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
22        >>> simplify(expression).sql()
23        'TRUE'
24
25    Args:
26        expression (sqlglot.Expression): expression to simplify
27    Returns:
28        sqlglot.Expression: simplified expression
29    """
30
31    cache = {}
32
33    def _simplify(expression, root=True):
34        node = expression
35        node = rewrite_between(node)
36        node = uniq_sort(node, cache, root)
37        node = absorb_and_eliminate(node, root)
38        exp.replace_children(node, lambda e: _simplify(e, False))
39        node = simplify_not(node)
40        node = flatten(node)
41        node = simplify_connectors(node, root)
42        node = remove_compliments(node, root)
43        node.parent = expression.parent
44        node = simplify_literals(node, root)
45        node = simplify_parens(node)
46        if root:
47            expression.replace(node)
48        return node
49
50    expression = while_changing(expression, _simplify)
51    remove_where_true(expression)
52    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
55def rewrite_between(expression: exp.Expression) -> exp.Expression:
56    """Rewrite x between y and z to x >= y AND x <= z.
57
58    This is done because comparison simplification is only done on lt/lte/gt/gte.
59    """
60    if isinstance(expression, exp.Between):
61        return exp.and_(
62            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
63            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
64        )
65    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
68def simplify_not(expression):
69    """
70    Demorgan's Law
71    NOT (x OR y) -> NOT x AND NOT y
72    NOT (x AND y) -> NOT x OR NOT y
73    """
74    if isinstance(expression, exp.Not):
75        if is_null(expression.this):
76            return exp.null()
77        if isinstance(expression.this, exp.Paren):
78            condition = expression.this.unnest()
79            if isinstance(condition, exp.And):
80                return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
81            if isinstance(condition, exp.Or):
82                return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
83            if is_null(condition):
84                return exp.null()
85        if always_true(expression.this):
86            return exp.false()
87        if is_false(expression.this):
88            return exp.true()
89        if isinstance(expression.this, exp.Not):
90            # double negation
91            # NOT NOT x -> x
92            return expression.this.this
93    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
 96def flatten(expression):
 97    """
 98    A AND (B AND C) -> A AND B AND C
 99    A OR (B OR C) -> A OR B OR C
100    """
101    if isinstance(expression, exp.Connector):
102        for node in expression.args.values():
103            child = node.unnest()
104            if isinstance(child, expression.__class__):
105                node.replace(child)
106    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
109def simplify_connectors(expression, root=True):
110    def _simplify_connectors(expression, left, right):
111        if left == right:
112            return left
113        if isinstance(expression, exp.And):
114            if is_false(left) or is_false(right):
115                return exp.false()
116            if is_null(left) or is_null(right):
117                return exp.null()
118            if always_true(left) and always_true(right):
119                return exp.true()
120            if always_true(left):
121                return right
122            if always_true(right):
123                return left
124            return _simplify_comparison(expression, left, right)
125        elif isinstance(expression, exp.Or):
126            if always_true(left) or always_true(right):
127                return exp.true()
128            if is_false(left) and is_false(right):
129                return exp.false()
130            if (
131                (is_null(left) and is_null(right))
132                or (is_null(left) and is_false(right))
133                or (is_false(left) and is_null(right))
134            ):
135                return exp.null()
136            if is_false(left):
137                return right
138            if is_false(right):
139                return left
140            return _simplify_comparison(expression, left, right, or_=True)
141
142    if isinstance(expression, exp.Connector):
143        return _flat_simplify(expression, _simplify_connectors, root)
144    return expression
def remove_compliments(expression, root=True):
226def remove_compliments(expression, root=True):
227    """
228    Removing compliments.
229
230    A AND NOT A -> FALSE
231    A OR NOT A -> TRUE
232    """
233    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
234        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
235
236        for a, b in itertools.permutations(expression.flatten(), 2):
237            if is_complement(a, b):
238                return compliment
239    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, cache=None, root=True):
242def uniq_sort(expression, cache=None, root=True):
243    """
244    Uniq and sort a connector.
245
246    C AND A AND B AND B -> A AND B AND C
247    """
248    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
249        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
250        flattened = tuple(expression.flatten())
251        deduped = {GENERATOR.generate(e, cache): e for e in flattened}
252        arr = tuple(deduped.items())
253
254        # check if the operands are already sorted, if not sort them
255        # A AND C AND B -> A AND B AND C
256        for i, (sql, e) in enumerate(arr[1:]):
257            if sql < arr[i][0]:
258                expression = result_func(*(e for _, e in sorted(arr)))
259                break
260        else:
261            # we didn't have to sort but maybe we need to dedup
262            if len(deduped) < len(flattened):
263                expression = result_func(*deduped.values())
264
265    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
268def absorb_and_eliminate(expression, root=True):
269    """
270    absorption:
271        A AND (A OR B) -> A
272        A OR (A AND B) -> A
273        A AND (NOT A OR B) -> A AND B
274        A OR (NOT A AND B) -> A OR B
275    elimination:
276        (A AND B) OR (A AND NOT B) -> A
277        (A OR B) AND (A OR NOT B) -> A
278    """
279    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
280        kind = exp.Or if isinstance(expression, exp.And) else exp.And
281
282        for a, b in itertools.permutations(expression.flatten(), 2):
283            if isinstance(a, kind):
284                aa, ab = a.unnest_operands()
285
286                # absorb
287                if is_complement(b, aa):
288                    aa.replace(exp.true() if kind == exp.And else exp.false())
289                elif is_complement(b, ab):
290                    ab.replace(exp.true() if kind == exp.And else exp.false())
291                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
292                    a.replace(exp.false() if kind == exp.And else exp.true())
293                elif isinstance(b, kind):
294                    # eliminate
295                    rhs = b.unnest_operands()
296                    ba, bb = rhs
297
298                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
299                        a.replace(aa)
300                        b.replace(aa)
301                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
302                        a.replace(ab)
303                        b.replace(ab)
304
305    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_literals(expression, root=True):
308def simplify_literals(expression, root=True):
309    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
310        return _flat_simplify(expression, _simplify_binary, root)
311    elif isinstance(expression, exp.Neg):
312        this = expression.this
313        if this.is_number:
314            value = this.name
315            if value[0] == "-":
316                return exp.Literal.number(value[1:])
317            return exp.Literal.number(f"-{value}")
318
319    return expression
def simplify_parens(expression):
382def simplify_parens(expression):
383    if (
384        isinstance(expression, exp.Paren)
385        and not isinstance(expression.this, exp.Select)
386        and (
387            not isinstance(expression.parent, (exp.Condition, exp.Binary))
388            or isinstance(expression.this, exp.Predicate)
389            or not isinstance(expression.this, exp.Binary)
390        )
391    ):
392        return expression.this
393    return expression
def remove_where_true(expression):
396def remove_where_true(expression):
397    for where in expression.find_all(exp.Where):
398        if always_true(where.this):
399            where.parent.set("where", None)
400    for join in expression.find_all(exp.Join):
401        if always_true(join.args.get("on")):
402            join.set("kind", "CROSS")
403            join.set("on", None)
def always_true(expression):
406def always_true(expression):
407    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
408        expression, exp.Literal
409    )
def is_complement(a, b):
412def is_complement(a, b):
413    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
416def is_false(a: exp.Expression) -> bool:
417    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
420def is_null(a: exp.Expression) -> bool:
421    return type(a) is exp.Null
def eval_boolean(expression, a, b):
424def eval_boolean(expression, a, b):
425    if isinstance(expression, (exp.EQ, exp.Is)):
426        return boolean_literal(a == b)
427    if isinstance(expression, exp.NEQ):
428        return boolean_literal(a != b)
429    if isinstance(expression, exp.GT):
430        return boolean_literal(a > b)
431    if isinstance(expression, exp.GTE):
432        return boolean_literal(a >= b)
433    if isinstance(expression, exp.LT):
434        return boolean_literal(a < b)
435    if isinstance(expression, exp.LTE):
436        return boolean_literal(a <= b)
437    return None
def extract_date(cast):
440def extract_date(cast):
441    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
442    # so in that case we can't extract the date.
443    try:
444        if cast.args["to"].this == exp.DataType.Type.DATE:
445            return datetime.date.fromisoformat(cast.name)
446        if cast.args["to"].this == exp.DataType.Type.DATETIME:
447            return datetime.datetime.fromisoformat(cast.name)
448    except ValueError:
449        return None
def extract_interval(interval):
452def extract_interval(interval):
453    try:
454        from dateutil.relativedelta import relativedelta  # type: ignore
455    except ModuleNotFoundError:
456        return None
457
458    n = int(interval.name)
459    unit = interval.text("unit").lower()
460
461    if unit == "year":
462        return relativedelta(years=n)
463    if unit == "month":
464        return relativedelta(months=n)
465    if unit == "week":
466        return relativedelta(weeks=n)
467    if unit == "day":
468        return relativedelta(days=n)
469    return None
def date_literal(date):
472def date_literal(date):
473    return exp.cast(
474        exp.Literal.string(date),
475        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
476    )
def boolean_literal(condition):
479def boolean_literal(condition):
480    return exp.true() if condition else exp.false()