diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 115 |
1 files changed, 109 insertions, 6 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index afc6995..17af6ac 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,5 +1,7 @@ from __future__ import annotations +import datetime +import functools import typing as t from sqlglot import exp @@ -11,6 +13,16 @@ from sqlglot.schema import Schema, ensure_schema if t.TYPE_CHECKING: B = t.TypeVar("B", bound=exp.Binary) + BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type] + BinaryCoercions = t.Dict[ + t.Tuple[exp.DataType.Type, exp.DataType.Type], + BinaryCoercionFunc, + ] + + +# Interval units that operate on date components +DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} + def annotate_types( expression: E, @@ -48,6 +60,59 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type return lambda self, e: self._annotate_with_type(e, data_type) +def _is_iso_date(text: str) -> bool: + try: + datetime.date.fromisoformat(text) + return True + except ValueError: + return False + + +def _is_iso_datetime(text: str) -> bool: + try: + datetime.datetime.fromisoformat(text) + return True + except ValueError: + return False + + +def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: + date_text = l.name + unit = r.text("unit").lower() + + is_iso_date = _is_iso_date(date_text) + + if is_iso_date and unit in DATE_UNITS: + l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE)) + return exp.DataType.Type.DATE + + # An ISO date is also an ISO datetime, but not vice versa + if is_iso_date or _is_iso_datetime(date_text): + l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME)) + return exp.DataType.Type.DATETIME + + return exp.DataType.Type.UNKNOWN + + +def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: + unit = r.text("unit").lower() + if unit not in DATE_UNITS: + return exp.DataType.Type.DATETIME + return l.type.this if l.type else exp.DataType.Type.UNKNOWN + + +def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc: + @functools.wraps(func) + def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: + return func(r, l) + + return _swapped + + +def swap_all(coercions: BinaryCoercions) -> BinaryCoercions: + return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}} + + class _TypeAnnotator(type): def __new__(cls, clsname, bases, attrs): klass = super().__new__(cls, clsname, bases, attrs) @@ -104,10 +169,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.DataType.Type.DATE: { exp.CurrentDate, exp.Date, - exp.DateAdd, exp.DateFromParts, exp.DateStrToDate, - exp.DateSub, exp.DateTrunc, exp.DiToDate, exp.StrToDate, @@ -212,6 +275,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), + exp.DateAdd: lambda self, e: self._annotate_dateadd(e), + exp.DateSub: lambda self, e: self._annotate_dateadd(e), exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), @@ -234,21 +299,41 @@ class TypeAnnotator(metaclass=_TypeAnnotator): # Specifies what types a given type can be coerced into (autofilled) COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} + # Coercion functions for binary operations. + # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type. + BINARY_COERCIONS: BinaryCoercions = { + **swap_all( + { + (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval + for t in exp.DataType.TEXT_TYPES + } + ), + **swap_all( + { + (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval, + } + ), + } + def __init__( self, schema: Schema, annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, + binary_coercions: t.Optional[BinaryCoercions] = None, ) -> None: self.schema = schema self.annotators = annotators or self.ANNOTATORS self.coerces_to = coerces_to or self.COERCES_TO + self.binary_coercions = binary_coercions or self.BINARY_COERCIONS # Caches the ids of annotated sub-Expressions, to ensure we only visit them once self._visited: t.Set[int] = set() - def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None: - expression.type = target_type + def _set_type( + self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type + ) -> None: + expression.type = target_type # type: ignore self._visited.add(id(expression)) def annotate(self, expression: E) -> E: @@ -342,8 +427,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): def _annotate_binary(self, expression: B) -> B: self._annotate_args(expression) - left_type = expression.left.type.this - right_type = expression.right.type.this + left, right = expression.left, expression.right + left_type, right_type = left.type.this, right.type.this if isinstance(expression, exp.Connector): if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: @@ -357,6 +442,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._set_type(expression, exp.DataType.Type.BOOLEAN) elif isinstance(expression, exp.Predicate): self._set_type(expression, exp.DataType.Type.BOOLEAN) + elif (left_type, right_type) in self.binary_coercions: + self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right)) else: self._set_type(expression, self._maybe_coerce(left_type, right_type)) @@ -421,3 +508,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator): ) return expression + + def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp: + self._annotate_args(expression) + + if expression.this.type.this in exp.DataType.TEXT_TYPES: + datatype = _coerce_literal_and_interval(expression.this, expression.interval()) + elif ( + expression.this.type.is_type(exp.DataType.Type.DATE) + and expression.text("unit").lower() not in DATE_UNITS + ): + datatype = exp.DataType.Type.DATETIME + else: + datatype = expression.this.type + + self._set_type(expression, datatype) + return expression |