summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/annotate_types.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/optimizer/annotate_types.py115
1 files changed, 109 insertions, 6 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index afc6995..17af6ac 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import datetime
+import functools
import typing as t
from sqlglot import exp
@@ -11,6 +13,16 @@ from sqlglot.schema import Schema, ensure_schema
if t.TYPE_CHECKING:
B = t.TypeVar("B", bound=exp.Binary)
+ BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
+ BinaryCoercions = t.Dict[
+ t.Tuple[exp.DataType.Type, exp.DataType.Type],
+ BinaryCoercionFunc,
+ ]
+
+
+# Interval units that operate on date components
+DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
+
def annotate_types(
expression: E,
@@ -48,6 +60,59 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type
return lambda self, e: self._annotate_with_type(e, data_type)
+def _is_iso_date(text: str) -> bool:
+ try:
+ datetime.date.fromisoformat(text)
+ return True
+ except ValueError:
+ return False
+
+
+def _is_iso_datetime(text: str) -> bool:
+ try:
+ datetime.datetime.fromisoformat(text)
+ return True
+ except ValueError:
+ return False
+
+
+def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
+ date_text = l.name
+ unit = r.text("unit").lower()
+
+ is_iso_date = _is_iso_date(date_text)
+
+ if is_iso_date and unit in DATE_UNITS:
+ l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE))
+ return exp.DataType.Type.DATE
+
+ # An ISO date is also an ISO datetime, but not vice versa
+ if is_iso_date or _is_iso_datetime(date_text):
+ l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME))
+ return exp.DataType.Type.DATETIME
+
+ return exp.DataType.Type.UNKNOWN
+
+
+def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
+ unit = r.text("unit").lower()
+ if unit not in DATE_UNITS:
+ return exp.DataType.Type.DATETIME
+ return l.type.this if l.type else exp.DataType.Type.UNKNOWN
+
+
+def swap_args(func: BinaryCoercionFunc) -> BinaryCoercionFunc:
+ @functools.wraps(func)
+ def _swapped(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
+ return func(r, l)
+
+ return _swapped
+
+
+def swap_all(coercions: BinaryCoercions) -> BinaryCoercions:
+ return {**coercions, **{(b, a): swap_args(func) for (a, b), func in coercions.items()}}
+
+
class _TypeAnnotator(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
@@ -104,10 +169,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DataType.Type.DATE: {
exp.CurrentDate,
exp.Date,
- exp.DateAdd,
exp.DateFromParts,
exp.DateStrToDate,
- exp.DateSub,
exp.DateTrunc,
exp.DiToDate,
exp.StrToDate,
@@ -212,6 +275,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
+ exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
+ exp.DateSub: lambda self, e: self._annotate_dateadd(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
@@ -234,21 +299,41 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
# Specifies what types a given type can be coerced into (autofilled)
COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
+ # Coercion functions for binary operations.
+ # Map of type pairs to a callable that takes both sides of the binary operation and returns the resulting type.
+ BINARY_COERCIONS: BinaryCoercions = {
+ **swap_all(
+ {
+ (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
+ for t in exp.DataType.TEXT_TYPES
+ }
+ ),
+ **swap_all(
+ {
+ (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
+ }
+ ),
+ }
+
def __init__(
self,
schema: Schema,
annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
+ binary_coercions: t.Optional[BinaryCoercions] = None,
) -> None:
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
+ self.binary_coercions = binary_coercions or self.BINARY_COERCIONS
# Caches the ids of annotated sub-Expressions, to ensure we only visit them once
self._visited: t.Set[int] = set()
- def _set_type(self, expression: exp.Expression, target_type: exp.DataType) -> None:
- expression.type = target_type
+ def _set_type(
+ self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
+ ) -> None:
+ expression.type = target_type # type: ignore
self._visited.add(id(expression))
def annotate(self, expression: E) -> E:
@@ -342,8 +427,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)
- left_type = expression.left.type.this
- right_type = expression.right.type.this
+ left, right = expression.left, expression.right
+ left_type, right_type = left.type.this, right.type.this
if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
@@ -357,6 +442,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, exp.DataType.Type.BOOLEAN)
elif isinstance(expression, exp.Predicate):
self._set_type(expression, exp.DataType.Type.BOOLEAN)
+ elif (left_type, right_type) in self.binary_coercions:
+ self._set_type(expression, self.binary_coercions[(left_type, right_type)](left, right))
else:
self._set_type(expression, self._maybe_coerce(left_type, right_type))
@@ -421,3 +508,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
)
return expression
+
+ def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
+ self._annotate_args(expression)
+
+ if expression.this.type.this in exp.DataType.TEXT_TYPES:
+ datatype = _coerce_literal_and_interval(expression.this, expression.interval())
+ elif (
+ expression.this.type.is_type(exp.DataType.Type.DATE)
+ and expression.text("unit").lower() not in DATE_UNITS
+ ):
+ datatype = exp.DataType.Type.DATETIME
+ else:
+ datatype = expression.this.type
+
+ self._set_type(expression, datatype)
+ return expression