diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-25 08:20:09 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-09-25 08:20:09 +0000 |
commit | 4554ab4c7d6b2bbbaa6f4d0b810bf477d1a505a6 (patch) | |
tree | 8f4f60a82ab9cd6dcd41397e4ecb2960c332b209 /sqlglot/optimizer/simplify.py | |
parent | Releasing debian version 18.5.1-1. (diff) | |
download | sqlglot-4554ab4c7d6b2bbbaa6f4d0b810bf477d1a505a6.tar.xz sqlglot-4554ab4c7d6b2bbbaa6f4d0b810bf477d1a505a6.zip |
Merging upstream version 18.7.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/optimizer/simplify.py')
-rw-r--r-- | sqlglot/optimizer/simplify.py | 299 |
1 files changed, 282 insertions, 17 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 3974ea4..d08c692 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -1,17 +1,22 @@ import datetime import functools import itertools +import typing as t from collections import deque from decimal import Decimal from sqlglot import exp from sqlglot.generator import cached_generator -from sqlglot.helper import first, while_changing +from sqlglot.helper import first, merge_ranges, while_changing # Final means that an expression should not be simplified FINAL = "final" +class UnsupportedUnit(Exception): + pass + + def simplify(expression): """ Rewrite sqlglot AST to simplify expressions. @@ -72,7 +77,9 @@ def simplify(expression): node = simplify_coalesce(node) node.parent = expression.parent node = simplify_literals(node, root) + node = simplify_equality(node) node = simplify_parens(node) + node = simplify_datetrunc_predicate(node) if root: expression.replace(node) @@ -84,6 +91,21 @@ def simplify(expression): return expression +def catch(*exceptions): + """Decorator that ignores a simplification function if any of `exceptions` are raised""" + + def decorator(func): + def wrapped(expression, *args, **kwargs): + try: + return func(expression, *args, **kwargs) + except exceptions: + return expression + + return wrapped + + return decorator + + def rewrite_between(expression: exp.Expression) -> exp.Expression: """Rewrite x between y and z to x >= y AND x <= z. @@ -196,7 +218,7 @@ COMPARISONS = ( exp.Is, ) -INVERSE_COMPARISONS = { +INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { exp.LT: exp.GT, exp.GT: exp.LT, exp.LTE: exp.GTE, @@ -347,6 +369,87 @@ def absorb_and_eliminate(expression, root=True): return expression +INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { + exp.DateAdd: exp.Sub, + exp.DateSub: exp.Add, + exp.DatetimeAdd: exp.Sub, + exp.DatetimeSub: exp.Add, +} + +INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { + **INVERSE_DATE_OPS, + exp.Add: exp.Sub, + exp.Sub: exp.Add, +} + + +def _is_number(expression: exp.Expression) -> bool: + return expression.is_number + + +def _is_date(expression: exp.Expression) -> bool: + return isinstance(expression, exp.Cast) and extract_date(expression) is not None + + +def _is_interval(expression: exp.Expression) -> bool: + return isinstance(expression, exp.Interval) and extract_interval(expression) is not None + + +@catch(ModuleNotFoundError, UnsupportedUnit) +def simplify_equality(expression: exp.Expression) -> exp.Expression: + """ + Use the subtraction and addition properties of equality to simplify expressions: + + x + 1 = 3 becomes x = 2 + + There are two binary operations in the above expression: + and = + Here's how we reference all the operands in the code below: + + l r + x + 1 = 3 + a b + """ + if isinstance(expression, COMPARISONS): + l, r = expression.left, expression.right + + if l.__class__ in INVERSE_OPS: + pass + elif r.__class__ in INVERSE_OPS: + l, r = r, l + else: + return expression + + if r.is_number: + a_predicate = _is_number + b_predicate = _is_number + elif _is_date(r): + a_predicate = _is_date + b_predicate = _is_interval + else: + return expression + + if l.__class__ in INVERSE_DATE_OPS: + a = l.this + b = exp.Interval( + this=l.expression.copy(), + unit=l.unit.copy(), + ) + else: + a, b = l.left, l.right + + if not a_predicate(a) and b_predicate(b): + pass + elif not a_predicate(b) and b_predicate(a): + a, b = b, a + else: + return expression + + return expression.__class__( + this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) + ) + return expression + + def simplify_literals(expression, root=True): if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): return _flat_simplify(expression, _simplify_binary, root) @@ -530,6 +633,123 @@ def simplify_concat(expression): return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args) +DateRange = t.Tuple[datetime.date, datetime.date] + + +def _datetrunc_range(date: datetime.date, unit: str) -> t.Optional[DateRange]: + """ + Get the date range for a DATE_TRUNC equality comparison: + + Example: + _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) + Returns: + tuple of [min, max) or None if a value can never be equal to `date` for `unit` + """ + floor = date_floor(date, unit) + + if date != floor: + # This will always be False, except for NULL values. + return None + + return floor, floor + interval(unit) + + +def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: + """Get the logical expression for a date range""" + return exp.and_( + left >= date_literal(drange[0]), + left < date_literal(drange[1]), + copy=False, + ) + + +def _datetrunc_eq( + left: exp.Expression, date: datetime.date, unit: str +) -> t.Optional[exp.Expression]: + drange = _datetrunc_range(date, unit) + if not drange: + return None + + return _datetrunc_eq_expression(left, drange) + + +def _datetrunc_neq( + left: exp.Expression, date: datetime.date, unit: str +) -> t.Optional[exp.Expression]: + drange = _datetrunc_range(date, unit) + if not drange: + return None + + return exp.and_( + left < date_literal(drange[0]), + left >= date_literal(drange[1]), + copy=False, + ) + + +DateTruncBinaryTransform = t.Callable[ + [exp.Expression, datetime.date, str], t.Optional[exp.Expression] +] +DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { + exp.LT: lambda l, d, u: l < date_literal(date_floor(d, u)), + exp.GT: lambda l, d, u: l >= date_literal(date_floor(d, u) + interval(u)), + exp.LTE: lambda l, d, u: l < date_literal(date_floor(d, u) + interval(u)), + exp.GTE: lambda l, d, u: l >= date_literal(date_ceil(d, u)), + exp.EQ: _datetrunc_eq, + exp.NEQ: _datetrunc_neq, +} +DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} + + +def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: + return ( + isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) + and isinstance(right, exp.Cast) + and right.is_type(*exp.DataType.TEMPORAL_TYPES) + ) + + +@catch(ModuleNotFoundError, UnsupportedUnit) +def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: + """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" + comparison = expression.__class__ + + if comparison not in DATETRUNC_COMPARISONS: + return expression + + if isinstance(expression, exp.Binary): + l, r = expression.left, expression.right + + if _is_datetrunc_predicate(l, r): + pass + elif _is_datetrunc_predicate(r, l): + comparison = INVERSE_COMPARISONS.get(comparison, comparison) + l, r = r, l + else: + return expression + + unit = l.unit.name.lower() + date = extract_date(r) + + return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression + elif isinstance(expression, exp.In): + l = expression.this + rs = expression.expressions + + if all(_is_datetrunc_predicate(l, r) for r in rs): + unit = l.unit.name.lower() + + ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r] + if not ranges: + return expression + + ranges = merge_ranges(ranges) + + return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) + + return expression + + # CROSS joins result in an empty table if the right table is empty. # So we can only simplify certain types of joins to CROSS. # Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x @@ -603,31 +823,76 @@ def extract_date(cast): return None -def extract_interval(interval): +def extract_interval(expression): + n = int(expression.name) + unit = expression.text("unit").lower() + try: - from dateutil.relativedelta import relativedelta # type: ignore - except ModuleNotFoundError: + return interval(unit, n) + except (UnsupportedUnit, ModuleNotFoundError): return None - n = int(interval.name) - unit = interval.text("unit").lower() + +def date_literal(date): + return exp.cast( + exp.Literal.string(date), + "DATETIME" if isinstance(date, datetime.datetime) else "DATE", + ) + + +def interval(unit: str, n: int = 1): + from dateutil.relativedelta import relativedelta if unit == "year": - return relativedelta(years=n) + return relativedelta(years=1 * n) + if unit == "quarter": + return relativedelta(months=3 * n) if unit == "month": - return relativedelta(months=n) + return relativedelta(months=1 * n) if unit == "week": - return relativedelta(weeks=n) + return relativedelta(weeks=1 * n) if unit == "day": - return relativedelta(days=n) - return None + return relativedelta(days=1 * n) + if unit == "hour": + return relativedelta(hours=1 * n) + if unit == "minute": + return relativedelta(minutes=1 * n) + if unit == "second": + return relativedelta(seconds=1 * n) + raise UnsupportedUnit(f"Unsupported unit: {unit}") -def date_literal(date): - return exp.cast( - exp.Literal.string(date), - "DATETIME" if isinstance(date, datetime.datetime) else "DATE", - ) + +def date_floor(d: datetime.date, unit: str) -> datetime.date: + if unit == "year": + return d.replace(month=1, day=1) + if unit == "quarter": + if d.month <= 3: + return d.replace(month=1, day=1) + elif d.month <= 6: + return d.replace(month=4, day=1) + elif d.month <= 9: + return d.replace(month=7, day=1) + else: + return d.replace(month=10, day=1) + if unit == "month": + return d.replace(month=d.month, day=1) + if unit == "week": + # Assuming week starts on Monday (0) and ends on Sunday (6) + return d - datetime.timedelta(days=d.weekday()) + if unit == "day": + return d + + raise UnsupportedUnit(f"Unsupported unit: {unit}") + + +def date_ceil(d: datetime.date, unit: str) -> datetime.date: + floor = date_floor(d, unit) + + if floor == d: + return d + + return floor + interval(unit) def boolean_literal(condition): |