Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4from collections import deque
  5from decimal import Decimal
  6
  7from sqlglot import exp
  8from sqlglot.generator import Generator
  9from sqlglot.helper import first, while_changing
 10
 11GENERATOR = Generator(normalize=True, identify="safe")
 12
 13
 14def simplify(expression):
 15    """
 16    Rewrite sqlglot AST to simplify expressions.
 17
 18    Example:
 19        >>> import sqlglot
 20        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 21        >>> simplify(expression).sql()
 22        'TRUE'
 23
 24    Args:
 25        expression (sqlglot.Expression): expression to simplify
 26    Returns:
 27        sqlglot.Expression: simplified expression
 28    """
 29
 30    cache = {}
 31
 32    def _simplify(expression, root=True):
 33        node = expression
 34        node = rewrite_between(node)
 35        node = uniq_sort(node, cache, root)
 36        node = absorb_and_eliminate(node, root)
 37        exp.replace_children(node, lambda e: _simplify(e, False))
 38        node = simplify_not(node)
 39        node = flatten(node)
 40        node = simplify_connectors(node, root)
 41        node = remove_compliments(node, root)
 42        node.parent = expression.parent
 43        node = simplify_literals(node, root)
 44        node = simplify_parens(node)
 45        if root:
 46            expression.replace(node)
 47        return node
 48
 49    expression = while_changing(expression, _simplify)
 50    remove_where_true(expression)
 51    return expression
 52
 53
 54def rewrite_between(expression: exp.Expression) -> exp.Expression:
 55    """Rewrite x between y and z to x >= y AND x <= z.
 56
 57    This is done because comparison simplification is only done on lt/lte/gt/gte.
 58    """
 59    if isinstance(expression, exp.Between):
 60        return exp.and_(
 61            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 62            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 63        )
 64    return expression
 65
 66
 67def simplify_not(expression):
 68    """
 69    Demorgan's Law
 70    NOT (x OR y) -> NOT x AND NOT y
 71    NOT (x AND y) -> NOT x OR NOT y
 72    """
 73    if isinstance(expression, exp.Not):
 74        if is_null(expression.this):
 75            return exp.null()
 76        if isinstance(expression.this, exp.Paren):
 77            condition = expression.this.unnest()
 78            if isinstance(condition, exp.And):
 79                return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
 80            if isinstance(condition, exp.Or):
 81                return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
 82            if is_null(condition):
 83                return exp.null()
 84        if always_true(expression.this):
 85            return exp.false()
 86        if is_false(expression.this):
 87            return exp.true()
 88        if isinstance(expression.this, exp.Not):
 89            # double negation
 90            # NOT NOT x -> x
 91            return expression.this.this
 92    return expression
 93
 94
 95def flatten(expression):
 96    """
 97    A AND (B AND C) -> A AND B AND C
 98    A OR (B OR C) -> A OR B OR C
 99    """
100    if isinstance(expression, exp.Connector):
101        for node in expression.args.values():
102            child = node.unnest()
103            if isinstance(child, expression.__class__):
104                node.replace(child)
105    return expression
106
107
108def simplify_connectors(expression, root=True):
109    def _simplify_connectors(expression, left, right):
110        if left == right:
111            return left
112        if isinstance(expression, exp.And):
113            if is_false(left) or is_false(right):
114                return exp.false()
115            if is_null(left) or is_null(right):
116                return exp.null()
117            if always_true(left) and always_true(right):
118                return exp.true()
119            if always_true(left):
120                return right
121            if always_true(right):
122                return left
123            return _simplify_comparison(expression, left, right)
124        elif isinstance(expression, exp.Or):
125            if always_true(left) or always_true(right):
126                return exp.true()
127            if is_false(left) and is_false(right):
128                return exp.false()
129            if (
130                (is_null(left) and is_null(right))
131                or (is_null(left) and is_false(right))
132                or (is_false(left) and is_null(right))
133            ):
134                return exp.null()
135            if is_false(left):
136                return right
137            if is_false(right):
138                return left
139            return _simplify_comparison(expression, left, right, or_=True)
140
141    if isinstance(expression, exp.Connector):
142        return _flat_simplify(expression, _simplify_connectors, root)
143    return expression
144
145
146LT_LTE = (exp.LT, exp.LTE)
147GT_GTE = (exp.GT, exp.GTE)
148
149COMPARISONS = (
150    *LT_LTE,
151    *GT_GTE,
152    exp.EQ,
153    exp.NEQ,
154)
155
156INVERSE_COMPARISONS = {
157    exp.LT: exp.GT,
158    exp.GT: exp.LT,
159    exp.LTE: exp.GTE,
160    exp.GTE: exp.LTE,
161}
162
163
164def _simplify_comparison(expression, left, right, or_=False):
165    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
166        ll, lr = left.args.values()
167        rl, rr = right.args.values()
168
169        largs = {ll, lr}
170        rargs = {rl, rr}
171
172        matching = largs & rargs
173        columns = {m for m in matching if isinstance(m, exp.Column)}
174
175        if matching and columns:
176            try:
177                l = first(largs - columns)
178                r = first(rargs - columns)
179            except StopIteration:
180                return expression
181
182            # make sure the comparison is always of the form x > 1 instead of 1 < x
183            if left.__class__ in INVERSE_COMPARISONS and l == ll:
184                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
185            if right.__class__ in INVERSE_COMPARISONS and r == rl:
186                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
187
188            if l.is_number and r.is_number:
189                l = float(l.name)
190                r = float(r.name)
191            elif l.is_string and r.is_string:
192                l = l.name
193                r = r.name
194            else:
195                return None
196
197            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
198                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
199                    return left if (av > bv if or_ else av <= bv) else right
200                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
201                    return left if (av < bv if or_ else av >= bv) else right
202
203                # we can't ever shortcut to true because the column could be null
204                if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
205                    if not or_ and av <= bv:
206                        return exp.false()
207                elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
208                    if not or_ and av >= bv:
209                        return exp.false()
210                elif isinstance(a, exp.EQ):
211                    if isinstance(b, exp.LT):
212                        return exp.false() if av >= bv else a
213                    if isinstance(b, exp.LTE):
214                        return exp.false() if av > bv else a
215                    if isinstance(b, exp.GT):
216                        return exp.false() if av <= bv else a
217                    if isinstance(b, exp.GTE):
218                        return exp.false() if av < bv else a
219                    if isinstance(b, exp.NEQ):
220                        return exp.false() if av == bv else a
221    return None
222
223
224def remove_compliments(expression, root=True):
225    """
226    Removing compliments.
227
228    A AND NOT A -> FALSE
229    A OR NOT A -> TRUE
230    """
231    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
232        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
233
234        for a, b in itertools.permutations(expression.flatten(), 2):
235            if is_complement(a, b):
236                return compliment
237    return expression
238
239
240def uniq_sort(expression, cache=None, root=True):
241    """
242    Uniq and sort a connector.
243
244    C AND A AND B AND B -> A AND B AND C
245    """
246    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
247        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
248        flattened = tuple(expression.flatten())
249        deduped = {GENERATOR.generate(e, cache): e for e in flattened}
250        arr = tuple(deduped.items())
251
252        # check if the operands are already sorted, if not sort them
253        # A AND C AND B -> A AND B AND C
254        for i, (sql, e) in enumerate(arr[1:]):
255            if sql < arr[i][0]:
256                expression = result_func(*(e for _, e in sorted(arr)))
257                break
258        else:
259            # we didn't have to sort but maybe we need to dedup
260            if len(deduped) < len(flattened):
261                expression = result_func(*deduped.values())
262
263    return expression
264
265
266def absorb_and_eliminate(expression, root=True):
267    """
268    absorption:
269        A AND (A OR B) -> A
270        A OR (A AND B) -> A
271        A AND (NOT A OR B) -> A AND B
272        A OR (NOT A AND B) -> A OR B
273    elimination:
274        (A AND B) OR (A AND NOT B) -> A
275        (A OR B) AND (A OR NOT B) -> A
276    """
277    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
278        kind = exp.Or if isinstance(expression, exp.And) else exp.And
279
280        for a, b in itertools.permutations(expression.flatten(), 2):
281            if isinstance(a, kind):
282                aa, ab = a.unnest_operands()
283
284                # absorb
285                if is_complement(b, aa):
286                    aa.replace(exp.true() if kind == exp.And else exp.false())
287                elif is_complement(b, ab):
288                    ab.replace(exp.true() if kind == exp.And else exp.false())
289                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
290                    a.replace(exp.false() if kind == exp.And else exp.true())
291                elif isinstance(b, kind):
292                    # eliminate
293                    rhs = b.unnest_operands()
294                    ba, bb = rhs
295
296                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
297                        a.replace(aa)
298                        b.replace(aa)
299                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
300                        a.replace(ab)
301                        b.replace(ab)
302
303    return expression
304
305
306def simplify_literals(expression, root=True):
307    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
308        return _flat_simplify(expression, _simplify_binary, root)
309    elif isinstance(expression, exp.Neg):
310        this = expression.this
311        if this.is_number:
312            value = this.name
313            if value[0] == "-":
314                return exp.Literal.number(value[1:])
315            return exp.Literal.number(f"-{value}")
316
317    return expression
318
319
320def _simplify_binary(expression, a, b):
321    if isinstance(expression, exp.Is):
322        if isinstance(b, exp.Not):
323            c = b.this
324            not_ = True
325        else:
326            c = b
327            not_ = False
328
329        if is_null(c):
330            if isinstance(a, exp.Literal):
331                return exp.true() if not_ else exp.false()
332            if is_null(a):
333                return exp.false() if not_ else exp.true()
334    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
335        return None
336    elif is_null(a) or is_null(b):
337        return exp.null()
338
339    if a.is_number and b.is_number:
340        a = int(a.name) if a.is_int else Decimal(a.name)
341        b = int(b.name) if b.is_int else Decimal(b.name)
342
343        if isinstance(expression, exp.Add):
344            return exp.Literal.number(a + b)
345        if isinstance(expression, exp.Sub):
346            return exp.Literal.number(a - b)
347        if isinstance(expression, exp.Mul):
348            return exp.Literal.number(a * b)
349        if isinstance(expression, exp.Div):
350            if isinstance(a, int) and isinstance(b, int):
351                return exp.Literal.number(a // b)
352            return exp.Literal.number(a / b)
353
354        boolean = eval_boolean(expression, a, b)
355
356        if boolean:
357            return boolean
358    elif a.is_string and b.is_string:
359        boolean = eval_boolean(expression, a, b)
360
361        if boolean:
362            return boolean
363    elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
364        a, b = extract_date(a), extract_interval(b)
365        if a and b:
366            if isinstance(expression, exp.Add):
367                return date_literal(a + b)
368            if isinstance(expression, exp.Sub):
369                return date_literal(a - b)
370    elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
371        a, b = extract_interval(a), extract_date(b)
372        # you cannot subtract a date from an interval
373        if a and b and isinstance(expression, exp.Add):
374            return date_literal(a + b)
375
376    return None
377
378
379def simplify_parens(expression):
380    if (
381        isinstance(expression, exp.Paren)
382        and not isinstance(expression.this, exp.Select)
383        and (
384            not isinstance(expression.parent, (exp.Condition, exp.Binary))
385            or isinstance(expression.this, exp.Predicate)
386            or not isinstance(expression.this, exp.Binary)
387        )
388    ):
389        return expression.this
390    return expression
391
392
393def remove_where_true(expression):
394    for where in expression.find_all(exp.Where):
395        if always_true(where.this):
396            where.parent.set("where", None)
397    for join in expression.find_all(exp.Join):
398        if always_true(join.args.get("on")):
399            join.set("kind", "CROSS")
400            join.set("on", None)
401
402
403def always_true(expression):
404    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
405        expression, exp.Literal
406    )
407
408
409def is_complement(a, b):
410    return isinstance(b, exp.Not) and b.this == a
411
412
413def is_false(a: exp.Expression) -> bool:
414    return type(a) is exp.Boolean and not a.this
415
416
417def is_null(a: exp.Expression) -> bool:
418    return type(a) is exp.Null
419
420
421def eval_boolean(expression, a, b):
422    if isinstance(expression, (exp.EQ, exp.Is)):
423        return boolean_literal(a == b)
424    if isinstance(expression, exp.NEQ):
425        return boolean_literal(a != b)
426    if isinstance(expression, exp.GT):
427        return boolean_literal(a > b)
428    if isinstance(expression, exp.GTE):
429        return boolean_literal(a >= b)
430    if isinstance(expression, exp.LT):
431        return boolean_literal(a < b)
432    if isinstance(expression, exp.LTE):
433        return boolean_literal(a <= b)
434    return None
435
436
437def extract_date(cast):
438    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
439    # so in that case we can't extract the date.
440    try:
441        if cast.args["to"].this == exp.DataType.Type.DATE:
442            return datetime.date.fromisoformat(cast.name)
443        if cast.args["to"].this == exp.DataType.Type.DATETIME:
444            return datetime.datetime.fromisoformat(cast.name)
445    except ValueError:
446        return None
447
448
449def extract_interval(interval):
450    try:
451        from dateutil.relativedelta import relativedelta  # type: ignore
452    except ModuleNotFoundError:
453        return None
454
455    n = int(interval.name)
456    unit = interval.text("unit").lower()
457
458    if unit == "year":
459        return relativedelta(years=n)
460    if unit == "month":
461        return relativedelta(months=n)
462    if unit == "week":
463        return relativedelta(weeks=n)
464    if unit == "day":
465        return relativedelta(days=n)
466    return None
467
468
469def date_literal(date):
470    return exp.cast(
471        exp.Literal.string(date),
472        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
473    )
474
475
476def boolean_literal(condition):
477    return exp.true() if condition else exp.false()
478
479
480def _flat_simplify(expression, simplifier, root=True):
481    if root or not expression.same_parent:
482        operands = []
483        queue = deque(expression.flatten(unnest=False))
484        size = len(queue)
485
486        while queue:
487            a = queue.popleft()
488
489            for b in queue:
490                result = simplifier(expression, a, b)
491
492                if result:
493                    queue.remove(b)
494                    queue.append(result)
495                    break
496            else:
497                operands.append(a)
498
499        if len(operands) < size:
500            return functools.reduce(
501                lambda a, b: expression.__class__(this=a, expression=b), operands
502            )
503    return expression
def simplify(expression):
15def simplify(expression):
16    """
17    Rewrite sqlglot AST to simplify expressions.
18
19    Example:
20        >>> import sqlglot
21        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
22        >>> simplify(expression).sql()
23        'TRUE'
24
25    Args:
26        expression (sqlglot.Expression): expression to simplify
27    Returns:
28        sqlglot.Expression: simplified expression
29    """
30
31    cache = {}
32
33    def _simplify(expression, root=True):
34        node = expression
35        node = rewrite_between(node)
36        node = uniq_sort(node, cache, root)
37        node = absorb_and_eliminate(node, root)
38        exp.replace_children(node, lambda e: _simplify(e, False))
39        node = simplify_not(node)
40        node = flatten(node)
41        node = simplify_connectors(node, root)
42        node = remove_compliments(node, root)
43        node.parent = expression.parent
44        node = simplify_literals(node, root)
45        node = simplify_parens(node)
46        if root:
47            expression.replace(node)
48        return node
49
50    expression = while_changing(expression, _simplify)
51    remove_where_true(expression)
52    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
55def rewrite_between(expression: exp.Expression) -> exp.Expression:
56    """Rewrite x between y and z to x >= y AND x <= z.
57
58    This is done because comparison simplification is only done on lt/lte/gt/gte.
59    """
60    if isinstance(expression, exp.Between):
61        return exp.and_(
62            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
63            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
64        )
65    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
68def simplify_not(expression):
69    """
70    Demorgan's Law
71    NOT (x OR y) -> NOT x AND NOT y
72    NOT (x AND y) -> NOT x OR NOT y
73    """
74    if isinstance(expression, exp.Not):
75        if is_null(expression.this):
76            return exp.null()
77        if isinstance(expression.this, exp.Paren):
78            condition = expression.this.unnest()
79            if isinstance(condition, exp.And):
80                return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
81            if isinstance(condition, exp.Or):
82                return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
83            if is_null(condition):
84                return exp.null()
85        if always_true(expression.this):
86            return exp.false()
87        if is_false(expression.this):
88            return exp.true()
89        if isinstance(expression.this, exp.Not):
90            # double negation
91            # NOT NOT x -> x
92            return expression.this.this
93    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
 96def flatten(expression):
 97    """
 98    A AND (B AND C) -> A AND B AND C
 99    A OR (B OR C) -> A OR B OR C
100    """
101    if isinstance(expression, exp.Connector):
102        for node in expression.args.values():
103            child = node.unnest()
104            if isinstance(child, expression.__class__):
105                node.replace(child)
106    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
109def simplify_connectors(expression, root=True):
110    def _simplify_connectors(expression, left, right):
111        if left == right:
112            return left
113        if isinstance(expression, exp.And):
114            if is_false(left) or is_false(right):
115                return exp.false()
116            if is_null(left) or is_null(right):
117                return exp.null()
118            if always_true(left) and always_true(right):
119                return exp.true()
120            if always_true(left):
121                return right
122            if always_true(right):
123                return left
124            return _simplify_comparison(expression, left, right)
125        elif isinstance(expression, exp.Or):
126            if always_true(left) or always_true(right):
127                return exp.true()
128            if is_false(left) and is_false(right):
129                return exp.false()
130            if (
131                (is_null(left) and is_null(right))
132                or (is_null(left) and is_false(right))
133                or (is_false(left) and is_null(right))
134            ):
135                return exp.null()
136            if is_false(left):
137                return right
138            if is_false(right):
139                return left
140            return _simplify_comparison(expression, left, right, or_=True)
141
142    if isinstance(expression, exp.Connector):
143        return _flat_simplify(expression, _simplify_connectors, root)
144    return expression
def remove_compliments(expression, root=True):
225def remove_compliments(expression, root=True):
226    """
227    Removing compliments.
228
229    A AND NOT A -> FALSE
230    A OR NOT A -> TRUE
231    """
232    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
233        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
234
235        for a, b in itertools.permutations(expression.flatten(), 2):
236            if is_complement(a, b):
237                return compliment
238    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, cache=None, root=True):
241def uniq_sort(expression, cache=None, root=True):
242    """
243    Uniq and sort a connector.
244
245    C AND A AND B AND B -> A AND B AND C
246    """
247    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
248        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
249        flattened = tuple(expression.flatten())
250        deduped = {GENERATOR.generate(e, cache): e for e in flattened}
251        arr = tuple(deduped.items())
252
253        # check if the operands are already sorted, if not sort them
254        # A AND C AND B -> A AND B AND C
255        for i, (sql, e) in enumerate(arr[1:]):
256            if sql < arr[i][0]:
257                expression = result_func(*(e for _, e in sorted(arr)))
258                break
259        else:
260            # we didn't have to sort but maybe we need to dedup
261            if len(deduped) < len(flattened):
262                expression = result_func(*deduped.values())
263
264    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
267def absorb_and_eliminate(expression, root=True):
268    """
269    absorption:
270        A AND (A OR B) -> A
271        A OR (A AND B) -> A
272        A AND (NOT A OR B) -> A AND B
273        A OR (NOT A AND B) -> A OR B
274    elimination:
275        (A AND B) OR (A AND NOT B) -> A
276        (A OR B) AND (A OR NOT B) -> A
277    """
278    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
279        kind = exp.Or if isinstance(expression, exp.And) else exp.And
280
281        for a, b in itertools.permutations(expression.flatten(), 2):
282            if isinstance(a, kind):
283                aa, ab = a.unnest_operands()
284
285                # absorb
286                if is_complement(b, aa):
287                    aa.replace(exp.true() if kind == exp.And else exp.false())
288                elif is_complement(b, ab):
289                    ab.replace(exp.true() if kind == exp.And else exp.false())
290                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
291                    a.replace(exp.false() if kind == exp.And else exp.true())
292                elif isinstance(b, kind):
293                    # eliminate
294                    rhs = b.unnest_operands()
295                    ba, bb = rhs
296
297                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
298                        a.replace(aa)
299                        b.replace(aa)
300                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
301                        a.replace(ab)
302                        b.replace(ab)
303
304    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_literals(expression, root=True):
307def simplify_literals(expression, root=True):
308    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
309        return _flat_simplify(expression, _simplify_binary, root)
310    elif isinstance(expression, exp.Neg):
311        this = expression.this
312        if this.is_number:
313            value = this.name
314            if value[0] == "-":
315                return exp.Literal.number(value[1:])
316            return exp.Literal.number(f"-{value}")
317
318    return expression
def simplify_parens(expression):
380def simplify_parens(expression):
381    if (
382        isinstance(expression, exp.Paren)
383        and not isinstance(expression.this, exp.Select)
384        and (
385            not isinstance(expression.parent, (exp.Condition, exp.Binary))
386            or isinstance(expression.this, exp.Predicate)
387            or not isinstance(expression.this, exp.Binary)
388        )
389    ):
390        return expression.this
391    return expression
def remove_where_true(expression):
394def remove_where_true(expression):
395    for where in expression.find_all(exp.Where):
396        if always_true(where.this):
397            where.parent.set("where", None)
398    for join in expression.find_all(exp.Join):
399        if always_true(join.args.get("on")):
400            join.set("kind", "CROSS")
401            join.set("on", None)
def always_true(expression):
404def always_true(expression):
405    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
406        expression, exp.Literal
407    )
def is_complement(a, b):
410def is_complement(a, b):
411    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
414def is_false(a: exp.Expression) -> bool:
415    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
418def is_null(a: exp.Expression) -> bool:
419    return type(a) is exp.Null
def eval_boolean(expression, a, b):
422def eval_boolean(expression, a, b):
423    if isinstance(expression, (exp.EQ, exp.Is)):
424        return boolean_literal(a == b)
425    if isinstance(expression, exp.NEQ):
426        return boolean_literal(a != b)
427    if isinstance(expression, exp.GT):
428        return boolean_literal(a > b)
429    if isinstance(expression, exp.GTE):
430        return boolean_literal(a >= b)
431    if isinstance(expression, exp.LT):
432        return boolean_literal(a < b)
433    if isinstance(expression, exp.LTE):
434        return boolean_literal(a <= b)
435    return None
def extract_date(cast):
438def extract_date(cast):
439    # The "fromisoformat" conversion could fail if the cast is used on an identifier,
440    # so in that case we can't extract the date.
441    try:
442        if cast.args["to"].this == exp.DataType.Type.DATE:
443            return datetime.date.fromisoformat(cast.name)
444        if cast.args["to"].this == exp.DataType.Type.DATETIME:
445            return datetime.datetime.fromisoformat(cast.name)
446    except ValueError:
447        return None
def extract_interval(interval):
450def extract_interval(interval):
451    try:
452        from dateutil.relativedelta import relativedelta  # type: ignore
453    except ModuleNotFoundError:
454        return None
455
456    n = int(interval.name)
457    unit = interval.text("unit").lower()
458
459    if unit == "year":
460        return relativedelta(years=n)
461    if unit == "month":
462        return relativedelta(months=n)
463    if unit == "week":
464        return relativedelta(weeks=n)
465    if unit == "day":
466        return relativedelta(days=n)
467    return None
def date_literal(date):
470def date_literal(date):
471    return exp.cast(
472        exp.Literal.string(date),
473        "DATETIME" if isinstance(date, datetime.datetime) else "DATE",
474    )
def boolean_literal(condition):
477def boolean_literal(condition):
478    return exp.true() if condition else exp.false()