diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/simplify.py | 97 |
1 files changed, 70 insertions, 27 deletions
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index d08c692..51214c4 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -387,10 +387,6 @@ def _is_number(expression: exp.Expression) -> bool: return expression.is_number -def _is_date(expression: exp.Expression) -> bool: - return isinstance(expression, exp.Cast) and extract_date(expression) is not None - - def _is_interval(expression: exp.Expression) -> bool: return isinstance(expression, exp.Interval) and extract_interval(expression) is not None @@ -422,18 +418,15 @@ def simplify_equality(expression: exp.Expression) -> exp.Expression: if r.is_number: a_predicate = _is_number b_predicate = _is_number - elif _is_date(r): - a_predicate = _is_date + elif _is_date_literal(r): + a_predicate = _is_date_literal b_predicate = _is_interval else: return expression if l.__class__ in INVERSE_DATE_OPS: a = l.this - b = exp.Interval( - this=l.expression.copy(), - unit=l.unit.copy(), - ) + b = l.interval() else: a, b = l.left, l.right @@ -509,14 +502,14 @@ def _simplify_binary(expression, a, b): if boolean: return boolean - elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): + elif _is_date_literal(a) and isinstance(b, exp.Interval): a, b = extract_date(a), extract_interval(b) if a and b: if isinstance(expression, exp.Add): return date_literal(a + b) if isinstance(expression, exp.Sub): return date_literal(a - b) - elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): + elif isinstance(a, exp.Interval) and _is_date_literal(b): a, b = extract_interval(a), extract_date(b) # you cannot subtract a date from an interval if a and b and isinstance(expression, exp.Add): @@ -702,11 +695,7 @@ DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: - return ( - isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) - and isinstance(right, exp.Cast) - and right.is_type(*exp.DataType.TEMPORAL_TYPES) - ) + return isinstance(left, (exp.DateTrunc, exp.TimestampTrunc)) and _is_date_literal(right) @catch(ModuleNotFoundError, UnsupportedUnit) @@ -731,15 +720,26 @@ def simplify_datetrunc_predicate(expression: exp.Expression) -> exp.Expression: unit = l.unit.name.lower() date = extract_date(r) + if not date: + return expression + return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit) or expression elif isinstance(expression, exp.In): l = expression.this rs = expression.expressions - if all(_is_datetrunc_predicate(l, r) for r in rs): + if rs and all(_is_datetrunc_predicate(l, r) for r in rs): unit = l.unit.name.lower() - ranges = [r for r in [_datetrunc_range(extract_date(r), unit) for r in rs] if r] + ranges = [] + for r in rs: + date = extract_date(r) + if not date: + return expression + drange = _datetrunc_range(date, unit) + if drange: + ranges.append(drange) + if not ranges: return expression @@ -811,18 +811,59 @@ def eval_boolean(expression, a, b): return None -def extract_date(cast): - # The "fromisoformat" conversion could fail if the cast is used on an identifier, - # so in that case we can't extract the date. +def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: + if isinstance(value, datetime.datetime): + return value.date() + if isinstance(value, datetime.date): + return value try: - if cast.args["to"].this == exp.DataType.Type.DATE: - return datetime.date.fromisoformat(cast.name) - if cast.args["to"].this == exp.DataType.Type.DATETIME: - return datetime.datetime.fromisoformat(cast.name) + return datetime.datetime.fromisoformat(value).date() except ValueError: return None +def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: + if isinstance(value, datetime.datetime): + return value + if isinstance(value, datetime.date): + return datetime.datetime(year=value.year, month=value.month, day=value.day) + try: + return datetime.datetime.fromisoformat(value) + except ValueError: + return None + + +def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: + if not value: + return None + if to.is_type(exp.DataType.Type.DATE): + return cast_as_date(value) + if to.is_type(*exp.DataType.TEMPORAL_TYPES): + return cast_as_datetime(value) + return None + + +def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: + if isinstance(cast, exp.Cast): + to = cast.to + elif isinstance(cast, exp.TsOrDsToDate): + to = exp.DataType.build(exp.DataType.Type.DATE) + else: + return None + + if isinstance(cast.this, exp.Literal): + value: t.Any = cast.this.name + elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): + value = extract_date(cast.this) + else: + return None + return cast_value(value, to) + + +def _is_date_literal(expression: exp.Expression) -> bool: + return extract_date(expression) is not None + + def extract_interval(expression): n = int(expression.name) unit = expression.text("unit").lower() @@ -836,7 +877,9 @@ def extract_interval(expression): def date_literal(date): return exp.cast( exp.Literal.string(date), - "DATETIME" if isinstance(date, datetime.datetime) else "DATE", + exp.DataType.Type.DATETIME + if isinstance(date, datetime.datetime) + else exp.DataType.Type.DATE, ) |