Edit on GitHub

sqlglot.optimizer.simplify

  1import datetime
  2import functools
  3import itertools
  4import typing as t
  5from collections import deque
  6from decimal import Decimal
  7
  8from sqlglot import exp
  9from sqlglot.generator import cached_generator
 10from sqlglot.helper import first, merge_ranges, while_changing
 11
 12# Final means that an expression should not be simplified
 13FINAL = "final"
 14
 15
 16class UnsupportedUnit(Exception):
 17    pass
 18
 19
 20def simplify(expression):
 21    """
 22    Rewrite sqlglot AST to simplify expressions.
 23
 24    Example:
 25        >>> import sqlglot
 26        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 27        >>> simplify(expression).sql()
 28        'TRUE'
 29
 30    Args:
 31        expression (sqlglot.Expression): expression to simplify
 32    Returns:
 33        sqlglot.Expression: simplified expression
 34    """
 35
 36    generate = cached_generator()
 37
 38    # group by expressions cannot be simplified, for example
 39    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 40    # the projection must exactly match the group by key
 41    for group in expression.find_all(exp.Group):
 42        select = group.parent
 43        groups = set(group.expressions)
 44        group.meta[FINAL] = True
 45
 46        for e in select.selects:
 47            for node, *_ in e.walk():
 48                if node in groups:
 49                    e.meta[FINAL] = True
 50                    break
 51
 52        having = select.args.get("having")
 53        if having:
 54            for node, *_ in having.walk():
 55                if node in groups:
 56                    having.meta[FINAL] = True
 57                    break
 58
 59    def _simplify(expression, root=True):
 60        if expression.meta.get(FINAL):
 61            return expression
 62
 63        # Pre-order transformations
 64        node = expression
 65        node = rewrite_between(node)
 66        node = uniq_sort(node, generate, root)
 67        node = absorb_and_eliminate(node, root)
 68        node = simplify_concat(node)
 69
 70        exp.replace_children(node, lambda e: _simplify(e, False))
 71
 72        # Post-order transformations
 73        node = simplify_not(node)
 74        node = flatten(node)
 75        node = simplify_connectors(node, root)
 76        node = remove_compliments(node, root)
 77        node = simplify_coalesce(node)
 78        node.parent = expression.parent
 79        node = simplify_literals(node, root)
 80        node = simplify_equality(node)
 81        node = simplify_parens(node)
 82        node = simplify_datetrunc_predicate(node)
 83
 84        if root:
 85            expression.replace(node)
 86
 87        return node
 88
 89    expression = while_changing(expression, _simplify)
 90    remove_where_true(expression)
 91    return expression
 92
 93
 94def catch(*exceptions):
 95    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 96
 97    def decorator(func):
 98        def wrapped(expression, *args, **kwargs):
 99            try:
100                return func(expression, *args, **kwargs)
101            except exceptions:
102                return expression
103
104        return wrapped
105
106    return decorator
107
108
109def rewrite_between(expression: exp.Expression) -> exp.Expression:
110    """Rewrite x between y and z to x >= y AND x <= z.
111
112    This is done because comparison simplification is only done on lt/lte/gt/gte.
113    """
114    if isinstance(expression, exp.Between):
115        return exp.and_(
116            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
117            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
118            copy=False,
119        )
120    return expression
121
122
123def simplify_not(expression):
124    """
125    Demorgan's Law
126    NOT (x OR y) -> NOT x AND NOT y
127    NOT (x AND y) -> NOT x OR NOT y
128    """
129    if isinstance(expression, exp.Not):
130        if is_null(expression.this):
131            return exp.null()
132        if isinstance(expression.this, exp.Paren):
133            condition = expression.this.unnest()
134            if isinstance(condition, exp.And):
135                return exp.or_(
136                    exp.not_(condition.left, copy=False),
137                    exp.not_(condition.right, copy=False),
138                    copy=False,
139                )
140            if isinstance(condition, exp.Or):
141                return exp.and_(
142                    exp.not_(condition.left, copy=False),
143                    exp.not_(condition.right, copy=False),
144                    copy=False,
145                )
146            if is_null(condition):
147                return exp.null()
148        if always_true(expression.this):
149            return exp.false()
150        if is_false(expression.this):
151            return exp.true()
152        if isinstance(expression.this, exp.Not):
153            # double negation
154            # NOT NOT x -> x
155            return expression.this.this
156    return expression
157
158
159def flatten(expression):
160    """
161    A AND (B AND C) -> A AND B AND C
162    A OR (B OR C) -> A OR B OR C
163    """
164    if isinstance(expression, exp.Connector):
165        for node in expression.args.values():
166            child = node.unnest()
167            if isinstance(child, expression.__class__):
168                node.replace(child)
169    return expression
170
171
172def simplify_connectors(expression, root=True):
173    def _simplify_connectors(expression, left, right):
174        if left == right:
175            return left
176        if isinstance(expression, exp.And):
177            if is_false(left) or is_false(right):
178                return exp.false()
179            if is_null(left) or is_null(right):
180                return exp.null()
181            if always_true(left) and always_true(right):
182                return exp.true()
183            if always_true(left):
184                return right
185            if always_true(right):
186                return left
187            return _simplify_comparison(expression, left, right)
188        elif isinstance(expression, exp.Or):
189            if always_true(left) or always_true(right):
190                return exp.true()
191            if is_false(left) and is_false(right):
192                return exp.false()
193            if (
194                (is_null(left) and is_null(right))
195                or (is_null(left) and is_false(right))
196                or (is_false(left) and is_null(right))
197            ):
198                return exp.null()
199            if is_false(left):
200                return right
201            if is_false(right):
202                return left
203            return _simplify_comparison(expression, left, right, or_=True)
204
205    if isinstance(expression, exp.Connector):
206        return _flat_simplify(expression, _simplify_connectors, root)
207    return expression
208
209
210LT_LTE = (exp.LT, exp.LTE)
211GT_GTE = (exp.GT, exp.GTE)
212
213COMPARISONS = (
214    *LT_LTE,
215    *GT_GTE,
216    exp.EQ,
217    exp.NEQ,
218    exp.Is,
219)
220
221INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
222    exp.LT: exp.GT,
223    exp.GT: exp.LT,
224    exp.LTE: exp.GTE,
225    exp.GTE: exp.LTE,
226}
227
228
229def _simplify_comparison(expression, left, right, or_=False):
230    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
231        ll, lr = left.args.values()
232        rl, rr = right.args.values()
233
234        largs = {ll, lr}
235        rargs = {rl, rr}
236
237        matching = largs & rargs
238        columns = {m for m in matching if isinstance(m, exp.Column)}
239
240        if matching and columns:
241            try:
242                l = first(largs - columns)
243                r = first(rargs - columns)
244            except StopIteration:
245                return expression
246
247            # make sure the comparison is always of the form x > 1 instead of 1 < x
248            if left.__class__ in INVERSE_COMPARISONS and l == ll:
249                left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
250            if right.__class__ in INVERSE_COMPARISONS and r == rl:
251                right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
252
253            if l.is_number and r.is_number:
254                l = float(l.name)
255                r = float(r.name)
256            elif l.is_string and r.is_string:
257                l = l.name
258                r = r.name
259            else:
260                return None
261
262            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
263                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
264                    return left if (av > bv if or_ else av <= bv) else right
265                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
266                    return left if (av < bv if or_ else av >= bv) else right
267
268                # we can't ever shortcut to true because the column could be null
269                if not or_:
270                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
271                        if av <= bv:
272                            return exp.false()
273                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
274                        if av >= bv:
275                            return exp.false()
276                    elif isinstance(a, exp.EQ):
277                        if isinstance(b, exp.LT):
278                            return exp.false() if av >= bv else a
279                        if isinstance(b, exp.LTE):
280                            return exp.false() if av > bv else a
281                        if isinstance(b, exp.GT):
282                            return exp.false() if av <= bv else a
283                        if isinstance(b, exp.GTE):
284                            return exp.false() if av < bv else a
285                        if isinstance(b, exp.NEQ):
286                            return exp.false() if av == bv else a
287    return None
288
289
290def remove_compliments(expression, root=True):
291    """
292    Removing compliments.
293
294    A AND NOT A -> FALSE
295    A OR NOT A -> TRUE
296    """
297    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
298        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
299
300        for a, b in itertools.permutations(expression.flatten(), 2):
301            if is_complement(a, b):
302                return compliment
303    return expression
304
305
306def uniq_sort(expression, generate, root=True):
307    """
308    Uniq and sort a connector.
309
310    C AND A AND B AND B -> A AND B AND C
311    """
312    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
313        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
314        flattened = tuple(expression.flatten())
315        deduped = {generate(e): e for e in flattened}
316        arr = tuple(deduped.items())
317
318        # check if the operands are already sorted, if not sort them
319        # A AND C AND B -> A AND B AND C
320        for i, (sql, e) in enumerate(arr[1:]):
321            if sql < arr[i][0]:
322                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
323                break
324        else:
325            # we didn't have to sort but maybe we need to dedup
326            if len(deduped) < len(flattened):
327                expression = result_func(*deduped.values(), copy=False)
328
329    return expression
330
331
332def absorb_and_eliminate(expression, root=True):
333    """
334    absorption:
335        A AND (A OR B) -> A
336        A OR (A AND B) -> A
337        A AND (NOT A OR B) -> A AND B
338        A OR (NOT A AND B) -> A OR B
339    elimination:
340        (A AND B) OR (A AND NOT B) -> A
341        (A OR B) AND (A OR NOT B) -> A
342    """
343    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
344        kind = exp.Or if isinstance(expression, exp.And) else exp.And
345
346        for a, b in itertools.permutations(expression.flatten(), 2):
347            if isinstance(a, kind):
348                aa, ab = a.unnest_operands()
349
350                # absorb
351                if is_complement(b, aa):
352                    aa.replace(exp.true() if kind == exp.And else exp.false())
353                elif is_complement(b, ab):
354                    ab.replace(exp.true() if kind == exp.And else exp.false())
355                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
356                    a.replace(exp.false() if kind == exp.And else exp.true())
357                elif isinstance(b, kind):
358                    # eliminate
359                    rhs = b.unnest_operands()
360                    ba, bb = rhs
361
362                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
363                        a.replace(aa)
364                        b.replace(aa)
365                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
366                        a.replace(ab)
367                        b.replace(ab)
368
369    return expression
370
371
372INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
373    exp.DateAdd: exp.Sub,
374    exp.DateSub: exp.Add,
375    exp.DatetimeAdd: exp.Sub,
376    exp.DatetimeSub: exp.Add,
377}
378
379INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
380    **INVERSE_DATE_OPS,
381    exp.Add: exp.Sub,
382    exp.Sub: exp.Add,
383}
384
385
386def _is_number(expression: exp.Expression) -> bool:
387    return expression.is_number
388
389
390def _is_interval(expression: exp.Expression) -> bool:
391    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
392
393
394@catch(ModuleNotFoundError, UnsupportedUnit)
395def simplify_equality(expression: exp.Expression) -> exp.Expression:
396    """
397    Use the subtraction and addition properties of equality to simplify expressions:
398
399        x + 1 = 3 becomes x = 2
400
401    There are two binary operations in the above expression: + and =
402    Here's how we reference all the operands in the code below:
403
404          l     r
405        x + 1 = 3
406        a   b
407    """
408    if isinstance(expression, COMPARISONS):
409        l, r = expression.left, expression.right
410
411        if l.__class__ in INVERSE_OPS:
412            pass
413        elif r.__class__ in INVERSE_OPS:
414            l, r = r, l
415        else:
416            return expression
417
418        if r.is_number:
419            a_predicate = _is_number
420            b_predicate = _is_number
421        elif _is_date_literal(r):
422            a_predicate = _is_date_literal
423            b_predicate = _is_interval
424        else:
425            return expression
426
427        if l.__class__ in INVERSE_DATE_OPS:
428            a = l.this
429            b = l.interval()
430        else:
431            a, b = l.left, l.right
432
433        if not a_predicate(a) and b_predicate(b):
434            pass
435        elif not a_predicate(b) and b_predicate(a):
436            a, b = b, a
437        else:
438            return expression
439
440        return expression.__class__(
441            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
442        )
443    return expression
444
445
446def simplify_literals(expression, root=True):
447    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
448        return _flat_simplify(expression, _simplify_binary, root)
449
450    if isinstance(expression, exp.Neg):
451        this = expression.this
452        if this.is_number:
453            value = this.name
454            if value[0] == "-":
455                return exp.Literal.number(value[1:])
456            return exp.Literal.number(f"-{value}")
457
458    return expression
459
460
461def _simplify_binary(expression, a, b):
462    if isinstance(expression, exp.Is):
463        if isinstance(b, exp.Not):
464            c = b.this
465            not_ = True
466        else:
467            c = b
468            not_ = False
469
470        if is_null(c):
471            if isinstance(a, exp.Literal):
472                return exp.true() if not_ else exp.false()
473            if is_null(a):
474                return exp.false() if not_ else exp.true()
475    elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
476        return None
477    elif is_null(a) or is_null(b):
478        return exp.null()
479
480    if a.is_number and b.is_number:
481        a = int(a.name) if a.is_int else Decimal(a.name)
482        b = int(b.name) if b.is_int else Decimal(b.name)
483
484        if isinstance(expression, exp.Add):
485            return exp.Literal.number(a + b)
486        if isinstance(expression, exp.Sub):
487            return exp.Literal.number(a - b)
488        if isinstance(expression, exp.Mul):
489            return exp.Literal.number(a * b)
490        if isinstance(expression, exp.Div):
491            # engines have differing int div behavior so intdiv is not safe
492            if isinstance(a, int) and isinstance(b, int):
493                return None
494            return exp.Literal.number(a / b)
495
496        boolean = eval_boolean(expression, a, b)
497
498        if boolean:
499            return boolean
500    elif a.is_string and b.is_string:
501        boolean = eval_boolean(expression, a.this, b.this)
502
503        if boolean:
504            return boolean
505    elif _is_date_literal(a) and isinstance(b, exp.Interval):
506        a, b = extract_date(a), extract_interval(b)
507        if a and b:
508            if isinstance(expression, exp.Add):
509                return date_literal(a + b)
510            if isinstance(expression, exp.Sub):
511                return date_literal(a - b)
512    elif isinstance(a, exp.Interval) and _is_date_literal(b):
513        a, b = extract_interval(a), extract_date(b)
514        # you cannot subtract a date from an interval
515        if a and b and isinstance(expression, exp.Add):
516            return date_literal(a + b)
517
518    return None
519
520
521def simplify_parens(expression):
522    if not isinstance(expression, exp.Paren):
523        return expression
524
525    this = expression.this
526    parent = expression.parent
527
528    if not isinstance(this, exp.Select) and (
529        not isinstance(parent, (exp.Condition, exp.Binary))
530        or isinstance(parent, exp.Paren)
531        or not isinstance(this, exp.Binary)
532        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
533        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
534        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
535        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
536    ):
537        return this
538    return expression
539
540
541CONSTANTS = (
542    exp.Literal,
543    exp.Boolean,
544    exp.Null,
545)
546
547
548def simplify_coalesce(expression):
549    # COALESCE(x) -> x
550    if (
551        isinstance(expression, exp.Coalesce)
552        and not expression.expressions
553        # COALESCE is also used as a Spark partitioning hint
554        and not isinstance(expression.parent, exp.Hint)
555    ):
556        return expression.this
557
558    if not isinstance(expression, COMPARISONS):
559        return expression
560
561    if isinstance(expression.left, exp.Coalesce):
562        coalesce = expression.left
563        other = expression.right
564    elif isinstance(expression.right, exp.Coalesce):
565        coalesce = expression.right
566        other = expression.left
567    else:
568        return expression
569
570    # This transformation is valid for non-constants,
571    # but it really only does anything if they are both constants.
572    if not isinstance(other, CONSTANTS):
573        return expression
574
575    # Find the first constant arg
576    for arg_index, arg in enumerate(coalesce.expressions):
577        if isinstance(arg, CONSTANTS):
578            break
579    else:
580        return expression
581
582    coalesce.set("expressions", coalesce.expressions[:arg_index])
583
584    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
585    # since we already remove COALESCE at the top of this function.
586    coalesce = coalesce if coalesce.expressions else coalesce.this
587
588    # This expression is more complex than when we started, but it will get simplified further
589    return exp.paren(
590        exp.or_(
591            exp.and_(
592                coalesce.is_(exp.null()).not_(copy=False),
593                expression.copy(),
594                copy=False,
595            ),
596            exp.and_(
597                coalesce.is_(exp.null()),
598                type(expression)(this=arg.copy(), expression=other.copy()),
599                copy=False,
600            ),
601            copy=False,
602        )
603    )
604
605
606CONCATS = (exp.Concat, exp.DPipe)
607SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe)
608
609
610def simplify_concat(expression):
611    """Reduces all groups that contain string literals by concatenating them."""
612    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
613        return expression
614
615    new_args = []
616    for is_string_group, group in itertools.groupby(
617        expression.expressions or expression.flatten(), lambda e: e.is_string
618    ):
619        if is_string_group:
620            new_args.append(exp.Literal.string("".join(string.name for string in group)))
621        else:
622            new_args.extend(group)
623
624    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
625    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
626    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)
627
628
629DateRange = t.Tuple[datetime.date, datetime.date]
630
631
632def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]:
633    """
634    Get the date range for a DATE_TRUNC equality comparison:
635
636    Example:
637        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
638    Returns:
639        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
640    """
641    floor = date_floor(date, unit)
642
643    if date != floor:
644        # This will always be False, except for NULL values.
645        return None
646
647    return floor, floor + interval(unit)
648
649
650def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
651    """Get the logical expression for a date range"""
652    return exp.and_(
653        left >= date_literal(drange[0]),
654        left < date_literal(drange[1]),
655        copy=False,
656    )
657
658
659def _datetrunc_eq(
660    left: exp.Expression, date: datetime.date, unit: str
661) -> t.Optional[exp.Expression]:
662    drange = _datetrunc_range(date, unit)
663    if not drange:
664        return None
665
666    return _datetrunc_eq_expression(left, drange)
667
668
669def _datetrunc_neq(
670    left: exp.Expression, date: datetime.date, unit: str
671) -> t.Optional[exp.Expression]:
672    drange = _datetrunc_range(date, unit)
673    if not drange:
674        return None
675
676    return exp.and_(
677        left < date_literal(drange[0]),
678        left >= date_literal(drange[1]),
679        copy=False,
680    )
681
682
683DateTruncBinaryTransform = t.Callable[
684    [exp.Expression, datetime.date, str], t.Optional[exp.Expression]
685]
686DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
687    exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)),
688    exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)),
689    exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)),
690    exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)),
691    exp.EQ: _datetrunc_eq,
692    exp.NEQ: _datetrunc_neq,
693}
694DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
695
696
697def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
698    return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right)
699
700
701@catch(ModuleNotFoundError, UnsupportedUnit)
702def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression:
703    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
704    comparison = expression.__class__
705
706    if comparison not in DATETRUNC_COMPARISONS:
707        return expression
708
709    if isinstance(expression, exp.Binary):
710        l, r = expression.left, expression.right
711
712        if _is_datetrunc_predicate(l, r):
713            pass
714        elif _is_datetrunc_predicate(r, l):
715            comparison = INVERSE_COMPARISONS.get(comparison, comparison)
716            l, r = r, l
717        else:
718            return expression
719
720        unit = l.unit.name.lower()
721        date = extract_date(r)
722
723        if not date:
724            return expression
725
726        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression
727    elif isinstance(expression, exp.In):
728        l = expression.this
729        rs = expression.expressions
730
731        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
732            unit = l.unit.name.lower()
733
734            ranges = []
735            for r in rs:
736                date = extract_date(r)
737                if not date:
738                    return expression
739                drange = _datetrunc_range(date, unit)
740                if drange:
741                    ranges.append(drange)
742
743            if not ranges:
744                return expression
745
746            ranges = merge_ranges(ranges)
747
748            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
749
750    return expression
751
752
753# CROSS joins result in an empty table if the right table is empty.
754# So we can only simplify certain types of joins to CROSS.
755# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
756JOINS = {
757    ("", ""),
758    ("", "INNER"),
759    ("RIGHT", ""),
760    ("RIGHT", "OUTER"),
761}
762
763
764def remove_where_true(expression):
765    for where in expression.find_all(exp.Where):
766        if always_true(where.this):
767            where.parent.set("where", None)
768    for join in expression.find_all(exp.Join):
769        if (
770            always_true(join.args.get("on"))
771            and not join.args.get("using")
772            and not join.args.get("method")
773            and (join.side, join.kind) in JOINS
774        ):
775            join.set("on", None)
776            join.set("side", None)
777            join.set("kind", "CROSS")
778
779
780def always_true(expression):
781    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
782        expression, exp.Literal
783    )
784
785
786def is_complement(a, b):
787    return isinstance(b, exp.Not) and b.this == a
788
789
790def is_false(a: exp.Expression) -> bool:
791    return type(a) is exp.Boolean and not a.this
792
793
794def is_null(a: exp.Expression) -> bool:
795    return type(a) is exp.Null
796
797
798def eval_boolean(expression, a, b):
799    if isinstance(expression, (exp.EQ, exp.Is)):
800        return boolean_literal(a == b)
801    if isinstance(expression, exp.NEQ):
802        return boolean_literal(a != b)
803    if isinstance(expression, exp.GT):
804        return boolean_literal(a > b)
805    if isinstance(expression, exp.GTE):
806        return boolean_literal(a >= b)
807    if isinstance(expression, exp.LT):
808        return boolean_literal(a < b)
809    if isinstance(expression, exp.LTE):
810        return boolean_literal(a <= b)
811    return None
812
813
814def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
815    if isinstance(value, datetime.datetime):
816        return value.date()
817    if isinstance(value, datetime.date):
818        return value
819    try:
820        return datetime.datetime.fromisoformat(value).date()
821    except ValueError:
822        return None
823
824
825def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
826    if isinstance(value, datetime.datetime):
827        return value
828    if isinstance(value, datetime.date):
829        return datetime.datetime(year=value.year, month=value.month, day=value.day)
830    try:
831        return datetime.datetime.fromisoformat(value)
832    except ValueError:
833        return None
834
835
836def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
837    if not value:
838        return None
839    if to.is_type(exp.DataType.Type.DATE):
840        return cast_as_date(value)
841    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
842        return cast_as_datetime(value)
843    return None
844
845
846def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
847    if isinstance(cast, exp.Cast):
848        to = cast.to
849    elif isinstance(cast, exp.TsOrDsToDate):
850        to = exp.DataType.build(exp.DataType.Type.DATE)
851    else:
852        return None
853
854    if isinstance(cast.this, exp.Literal):
855        value: t.Any = cast.this.name
856    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
857        value = extract_date(cast.this)
858    else:
859        return None
860    return cast_value(value, to)
861
862
863def _is_date_literal(expression: exp.Expression) -> bool:
864    return extract_date(expression) is not None
865
866
867def extract_interval(expression):
868    n = int(expression.name)
869    unit = expression.text("unit").lower()
870
871    try:
872        return interval(unit, n)
873    except (UnsupportedUnit, ModuleNotFoundError):
874        return None
875
876
877def date_literal(date):
878    return exp.cast(
879        exp.Literal.string(date),
880        exp.DataType.Type.DATETIME
881        if isinstance(date, datetime.datetime)
882        else exp.DataType.Type.DATE,
883    )
884
885
886def interval(unit: str, n: int = 1):
887    from dateutil.relativedelta import relativedelta
888
889    if unit == "year":
890        return relativedelta(years=1 * n)
891    if unit == "quarter":
892        return relativedelta(months=3 * n)
893    if unit == "month":
894        return relativedelta(months=1 * n)
895    if unit == "week":
896        return relativedelta(weeks=1 * n)
897    if unit == "day":
898        return relativedelta(days=1 * n)
899    if unit == "hour":
900        return relativedelta(hours=1 * n)
901    if unit == "minute":
902        return relativedelta(minutes=1 * n)
903    if unit == "second":
904        return relativedelta(seconds=1 * n)
905
906    raise UnsupportedUnit(f"Unsupported unit: {unit}")
907
908
909def date_floor(d: datetime.date, unit: str) -> datetime.date:
910    if unit == "year":
911        return d.replace(month=1, day=1)
912    if unit == "quarter":
913        if d.month <= 3:
914            return d.replace(month=1, day=1)
915        elif d.month <= 6:
916            return d.replace(month=4, day=1)
917        elif d.month <= 9:
918            return d.replace(month=7, day=1)
919        else:
920            return d.replace(month=10, day=1)
921    if unit == "month":
922        return d.replace(month=d.month, day=1)
923    if unit == "week":
924        # Assuming week starts on Monday (0) and ends on Sunday (6)
925        return d - datetime.timedelta(days=d.weekday())
926    if unit == "day":
927        return d
928
929    raise UnsupportedUnit(f"Unsupported unit: {unit}")
930
931
932def date_ceil(d: datetime.date, unit: str) -> datetime.date:
933    floor = date_floor(d, unit)
934
935    if floor == d:
936        return d
937
938    return floor + interval(unit)
939
940
941def boolean_literal(condition):
942    return exp.true() if condition else exp.false()
943
944
945def _flat_simplify(expression, simplifier, root=True):
946    if root or not expression.same_parent:
947        operands = []
948        queue = deque(expression.flatten(unnest=False))
949        size = len(queue)
950
951        while queue:
952            a = queue.popleft()
953
954            for b in queue:
955                result = simplifier(expression, a, b)
956
957                if result and result is not expression:
958                    queue.remove(b)
959                    queue.appendleft(result)
960                    break
961            else:
962                operands.append(a)
963
964        if len(operands) < size:
965            return functools.reduce(
966                lambda a, b: expression.__class__(this=a, expression=b), operands
967            )
968    return expression
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
17class UnsupportedUnit(Exception):
18    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify(expression):
21def simplify(expression):
22    """
23    Rewrite sqlglot AST to simplify expressions.
24
25    Example:
26        >>> import sqlglot
27        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
28        >>> simplify(expression).sql()
29        'TRUE'
30
31    Args:
32        expression (sqlglot.Expression): expression to simplify
33    Returns:
34        sqlglot.Expression: simplified expression
35    """
36
37    generate = cached_generator()
38
39    # group by expressions cannot be simplified, for example
40    # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
41    # the projection must exactly match the group by key
42    for group in expression.find_all(exp.Group):
43        select = group.parent
44        groups = set(group.expressions)
45        group.meta[FINAL] = True
46
47        for e in select.selects:
48            for node, *_ in e.walk():
49                if node in groups:
50                    e.meta[FINAL] = True
51                    break
52
53        having = select.args.get("having")
54        if having:
55            for node, *_ in having.walk():
56                if node in groups:
57                    having.meta[FINAL] = True
58                    break
59
60    def _simplify(expression, root=True):
61        if expression.meta.get(FINAL):
62            return expression
63
64        # Pre-order transformations
65        node = expression
66        node = rewrite_between(node)
67        node = uniq_sort(node, generate, root)
68        node = absorb_and_eliminate(node, root)
69        node = simplify_concat(node)
70
71        exp.replace_children(node, lambda e: _simplify(e, False))
72
73        # Post-order transformations
74        node = simplify_not(node)
75        node = flatten(node)
76        node = simplify_connectors(node, root)
77        node = remove_compliments(node, root)
78        node = simplify_coalesce(node)
79        node.parent = expression.parent
80        node = simplify_literals(node, root)
81        node = simplify_equality(node)
82        node = simplify_parens(node)
83        node = simplify_datetrunc_predicate(node)
84
85        if root:
86            expression.replace(node)
87
88        return node
89
90    expression = while_changing(expression, _simplify)
91    remove_where_true(expression)
92    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
 95def catch(*exceptions):
 96    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 97
 98    def decorator(func):
 99        def wrapped(expression, *args, **kwargs):
100            try:
101                return func(expression, *args, **kwargs)
102            except exceptions:
103                return expression
104
105        return wrapped
106
107    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
110def rewrite_between(expression: exp.Expression) -> exp.Expression:
111    """Rewrite x between y and z to x >= y AND x <= z.
112
113    This is done because comparison simplification is only done on lt/lte/gt/gte.
114    """
115    if isinstance(expression, exp.Between):
116        return exp.and_(
117            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
118            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
119            copy=False,
120        )
121    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
124def simplify_not(expression):
125    """
126    Demorgan's Law
127    NOT (x OR y) -> NOT x AND NOT y
128    NOT (x AND y) -> NOT x OR NOT y
129    """
130    if isinstance(expression, exp.Not):
131        if is_null(expression.this):
132            return exp.null()
133        if isinstance(expression.this, exp.Paren):
134            condition = expression.this.unnest()
135            if isinstance(condition, exp.And):
136                return exp.or_(
137                    exp.not_(condition.left, copy=False),
138                    exp.not_(condition.right, copy=False),
139                    copy=False,
140                )
141            if isinstance(condition, exp.Or):
142                return exp.and_(
143                    exp.not_(condition.left, copy=False),
144                    exp.not_(condition.right, copy=False),
145                    copy=False,
146                )
147            if is_null(condition):
148                return exp.null()
149        if always_true(expression.this):
150            return exp.false()
151        if is_false(expression.this):
152            return exp.true()
153        if isinstance(expression.this, exp.Not):
154            # double negation
155            # NOT NOT x -> x
156            return expression.this.this
157    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
160def flatten(expression):
161    """
162    A AND (B AND C) -> A AND B AND C
163    A OR (B OR C) -> A OR B OR C
164    """
165    if isinstance(expression, exp.Connector):
166        for node in expression.args.values():
167            child = node.unnest()
168            if isinstance(child, expression.__class__):
169                node.replace(child)
170    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
173def simplify_connectors(expression, root=True):
174    def _simplify_connectors(expression, left, right):
175        if left == right:
176            return left
177        if isinstance(expression, exp.And):
178            if is_false(left) or is_false(right):
179                return exp.false()
180            if is_null(left) or is_null(right):
181                return exp.null()
182            if always_true(left) and always_true(right):
183                return exp.true()
184            if always_true(left):
185                return right
186            if always_true(right):
187                return left
188            return _simplify_comparison(expression, left, right)
189        elif isinstance(expression, exp.Or):
190            if always_true(left) or always_true(right):
191                return exp.true()
192            if is_false(left) and is_false(right):
193                return exp.false()
194            if (
195                (is_null(left) and is_null(right))
196                or (is_null(left) and is_false(right))
197                or (is_false(left) and is_null(right))
198            ):
199                return exp.null()
200            if is_false(left):
201                return right
202            if is_false(right):
203                return left
204            return _simplify_comparison(expression, left, right, or_=True)
205
206    if isinstance(expression, exp.Connector):
207        return _flat_simplify(expression, _simplify_connectors, root)
208    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
def remove_compliments(expression, root=True):
291def remove_compliments(expression, root=True):
292    """
293    Removing compliments.
294
295    A AND NOT A -> FALSE
296    A OR NOT A -> TRUE
297    """
298    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
299        compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
300
301        for a, b in itertools.permutations(expression.flatten(), 2):
302            if is_complement(a, b):
303                return compliment
304    return expression

Removing compliments.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, generate, root=True):
307def uniq_sort(expression, generate, root=True):
308    """
309    Uniq and sort a connector.
310
311    C AND A AND B AND B -> A AND B AND C
312    """
313    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
314        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
315        flattened = tuple(expression.flatten())
316        deduped = {generate(e): e for e in flattened}
317        arr = tuple(deduped.items())
318
319        # check if the operands are already sorted, if not sort them
320        # A AND C AND B -> A AND B AND C
321        for i, (sql, e) in enumerate(arr[1:]):
322            if sql < arr[i][0]:
323                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
324                break
325        else:
326            # we didn't have to sort but maybe we need to dedup
327            if len(deduped) < len(flattened):
328                expression = result_func(*deduped.values(), copy=False)
329
330    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
333def absorb_and_eliminate(expression, root=True):
334    """
335    absorption:
336        A AND (A OR B) -> A
337        A OR (A AND B) -> A
338        A AND (NOT A OR B) -> A AND B
339        A OR (NOT A AND B) -> A OR B
340    elimination:
341        (A AND B) OR (A AND NOT B) -> A
342        (A OR B) AND (A OR NOT B) -> A
343    """
344    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
345        kind = exp.Or if isinstance(expression, exp.And) else exp.And
346
347        for a, b in itertools.permutations(expression.flatten(), 2):
348            if isinstance(a, kind):
349                aa, ab = a.unnest_operands()
350
351                # absorb
352                if is_complement(b, aa):
353                    aa.replace(exp.true() if kind == exp.And else exp.false())
354                elif is_complement(b, ab):
355                    ab.replace(exp.true() if kind == exp.And else exp.false())
356                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
357                    a.replace(exp.false() if kind == exp.And else exp.true())
358                elif isinstance(b, kind):
359                    # eliminate
360                    rhs = b.unnest_operands()
361                    ba, bb = rhs
362
363                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
364                        a.replace(aa)
365                        b.replace(aa)
366                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
367                        a.replace(ab)
368                        b.replace(ab)
369
370    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def simplify_equality(expression, *args, **kwargs):
 99        def wrapped(expression, *args, **kwargs):
100            try:
101                return func(expression, *args, **kwargs)
102            except exceptions:
103                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
447def simplify_literals(expression, root=True):
448    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
449        return _flat_simplify(expression, _simplify_binary, root)
450
451    if isinstance(expression, exp.Neg):
452        this = expression.this
453        if this.is_number:
454            value = this.name
455            if value[0] == "-":
456                return exp.Literal.number(value[1:])
457            return exp.Literal.number(f"-{value}")
458
459    return expression
def simplify_parens(expression):
522def simplify_parens(expression):
523    if not isinstance(expression, exp.Paren):
524        return expression
525
526    this = expression.this
527    parent = expression.parent
528
529    if not isinstance(this, exp.Select) and (
530        not isinstance(parent, (exp.Condition, exp.Binary))
531        or isinstance(parent, exp.Paren)
532        or not isinstance(this, exp.Binary)
533        or (isinstance(this, exp.Predicate) and not isinstance(parent, exp.Predicate))
534        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
535        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
536        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
537    ):
538        return this
539    return expression
def simplify_coalesce(expression):
549def simplify_coalesce(expression):
550    # COALESCE(x) -> x
551    if (
552        isinstance(expression, exp.Coalesce)
553        and not expression.expressions
554        # COALESCE is also used as a Spark partitioning hint
555        and not isinstance(expression.parent, exp.Hint)
556    ):
557        return expression.this
558
559    if not isinstance(expression, COMPARISONS):
560        return expression
561
562    if isinstance(expression.left, exp.Coalesce):
563        coalesce = expression.left
564        other = expression.right
565    elif isinstance(expression.right, exp.Coalesce):
566        coalesce = expression.right
567        other = expression.left
568    else:
569        return expression
570
571    # This transformation is valid for non-constants,
572    # but it really only does anything if they are both constants.
573    if not isinstance(other, CONSTANTS):
574        return expression
575
576    # Find the first constant arg
577    for arg_index, arg in enumerate(coalesce.expressions):
578        if isinstance(arg, CONSTANTS):
579            break
580    else:
581        return expression
582
583    coalesce.set("expressions", coalesce.expressions[:arg_index])
584
585    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
586    # since we already remove COALESCE at the top of this function.
587    coalesce = coalesce if coalesce.expressions else coalesce.this
588
589    # This expression is more complex than when we started, but it will get simplified further
590    return exp.paren(
591        exp.or_(
592            exp.and_(
593                coalesce.is_(exp.null()).not_(copy=False),
594                expression.copy(),
595                copy=False,
596            ),
597            exp.and_(
598                coalesce.is_(exp.null()),
599                type(expression)(this=arg.copy(), expression=other.copy()),
600                copy=False,
601            ),
602            copy=False,
603        )
604    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
SAFE_CONCATS = (<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.SafeDPipe'>)
def simplify_concat(expression):
611def simplify_concat(expression):
612    """Reduces all groups that contain string literals by concatenating them."""
613    if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs):
614        return expression
615
616    new_args = []
617    for is_string_group, group in itertools.groupby(
618        expression.expressions or expression.flatten(), lambda e: e.is_string
619    ):
620        if is_string_group:
621            new_args.append(exp.Literal.string("".join(string.name for string in group)))
622        else:
623            new_args.extend(group)
624
625    # Ensures we preserve the right concat type, i.e. whether it's "safe" or not
626    concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat
627    return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args)

Reduces all groups that contain string literals by concatenating them.

DateRange = typing.Tuple[datetime.date, datetime.date]
DateTruncBinaryTransform = typing.Callable[[sqlglot.expressions.Expression, datetime.date, str], typing.Optional[sqlglot.expressions.Expression]]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>}
def simplify_datetrunc_predicate(expression, *args, **kwargs):
 99        def wrapped(expression, *args, **kwargs):
100            try:
101                return func(expression, *args, **kwargs)
102            except exceptions:
103                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

JOINS = {('', ''), ('', 'INNER'), ('RIGHT', ''), ('RIGHT', 'OUTER')}
def remove_where_true(expression):
765def remove_where_true(expression):
766    for where in expression.find_all(exp.Where):
767        if always_true(where.this):
768            where.parent.set("where", None)
769    for join in expression.find_all(exp.Join):
770        if (
771            always_true(join.args.get("on"))
772            and not join.args.get("using")
773            and not join.args.get("method")
774            and (join.side, join.kind) in JOINS
775        ):
776            join.set("on", None)
777            join.set("side", None)
778            join.set("kind", "CROSS")
def always_true(expression):
781def always_true(expression):
782    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
783        expression, exp.Literal
784    )
def is_complement(a, b):
787def is_complement(a, b):
788    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
791def is_false(a: exp.Expression) -> bool:
792    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
795def is_null(a: exp.Expression) -> bool:
796    return type(a) is exp.Null
def eval_boolean(expression, a, b):
799def eval_boolean(expression, a, b):
800    if isinstance(expression, (exp.EQ, exp.Is)):
801        return boolean_literal(a == b)
802    if isinstance(expression, exp.NEQ):
803        return boolean_literal(a != b)
804    if isinstance(expression, exp.GT):
805        return boolean_literal(a > b)
806    if isinstance(expression, exp.GTE):
807        return boolean_literal(a >= b)
808    if isinstance(expression, exp.LT):
809        return boolean_literal(a < b)
810    if isinstance(expression, exp.LTE):
811        return boolean_literal(a <= b)
812    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
815def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
816    if isinstance(value, datetime.datetime):
817        return value.date()
818    if isinstance(value, datetime.date):
819        return value
820    try:
821        return datetime.datetime.fromisoformat(value).date()
822    except ValueError:
823        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
826def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
827    if isinstance(value, datetime.datetime):
828        return value
829    if isinstance(value, datetime.date):
830        return datetime.datetime(year=value.year, month=value.month, day=value.day)
831    try:
832        return datetime.datetime.fromisoformat(value)
833    except ValueError:
834        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
837def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
838    if not value:
839        return None
840    if to.is_type(exp.DataType.Type.DATE):
841        return cast_as_date(value)
842    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
843        return cast_as_datetime(value)
844    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
847def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
848    if isinstance(cast, exp.Cast):
849        to = cast.to
850    elif isinstance(cast, exp.TsOrDsToDate):
851        to = exp.DataType.build(exp.DataType.Type.DATE)
852    else:
853        return None
854
855    if isinstance(cast.this, exp.Literal):
856        value: t.Any = cast.this.name
857    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
858        value = extract_date(cast.this)
859    else:
860        return None
861    return cast_value(value, to)
def extract_interval(expression):
868def extract_interval(expression):
869    n = int(expression.name)
870    unit = expression.text("unit").lower()
871
872    try:
873        return interval(unit, n)
874    except (UnsupportedUnit, ModuleNotFoundError):
875        return None
def date_literal(date):
878def date_literal(date):
879    return exp.cast(
880        exp.Literal.string(date),
881        exp.DataType.Type.DATETIME
882        if isinstance(date, datetime.datetime)
883        else exp.DataType.Type.DATE,
884    )
def interval(unit: str, n: int = 1):
887def interval(unit: str, n: int = 1):
888    from dateutil.relativedelta import relativedelta
889
890    if unit == "year":
891        return relativedelta(years=1 * n)
892    if unit == "quarter":
893        return relativedelta(months=3 * n)
894    if unit == "month":
895        return relativedelta(months=1 * n)
896    if unit == "week":
897        return relativedelta(weeks=1 * n)
898    if unit == "day":
899        return relativedelta(days=1 * n)
900    if unit == "hour":
901        return relativedelta(hours=1 * n)
902    if unit == "minute":
903        return relativedelta(minutes=1 * n)
904    if unit == "second":
905        return relativedelta(seconds=1 * n)
906
907    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor(d: datetime.date, unit: str) -> datetime.date:
910def date_floor(d: datetime.date, unit: str) -> datetime.date:
911    if unit == "year":
912        return d.replace(month=1, day=1)
913    if unit == "quarter":
914        if d.month <= 3:
915            return d.replace(month=1, day=1)
916        elif d.month <= 6:
917            return d.replace(month=4, day=1)
918        elif d.month <= 9:
919            return d.replace(month=7, day=1)
920        else:
921            return d.replace(month=10, day=1)
922    if unit == "month":
923        return d.replace(month=d.month, day=1)
924    if unit == "week":
925        # Assuming week starts on Monday (0) and ends on Sunday (6)
926        return d - datetime.timedelta(days=d.weekday())
927    if unit == "day":
928        return d
929
930    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil(d: datetime.date, unit: str) -> datetime.date:
933def date_ceil(d: datetime.date, unit: str) -> datetime.date:
934    floor = date_floor(d, unit)
935
936    if floor == d:
937        return d
938
939    return floor + interval(unit)
def boolean_literal(condition):
942def boolean_literal(condition):
943    return exp.true() if condition else exp.false()