summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/annotate_types.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r--sqlglot/optimizer/annotate_types.py110
1 files changed, 61 insertions, 49 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 69d4567..7b990f1 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,12 +1,18 @@
from __future__ import annotations
-import datetime
import functools
import typing as t
from sqlglot import exp
from sqlglot._typing import E
-from sqlglot.helper import ensure_list, seq_get, subclasses
+from sqlglot.helper import (
+ ensure_list,
+ is_date_unit,
+ is_iso_date,
+ is_iso_datetime,
+ seq_get,
+ subclasses,
+)
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema, ensure_schema
@@ -20,10 +26,6 @@ if t.TYPE_CHECKING:
]
-# Interval units that operate on date components
-DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"}
-
-
def annotate_types(
expression: E,
schema: t.Optional[t.Dict | Schema] = None,
@@ -60,43 +62,22 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type
return lambda self, e: self._annotate_with_type(e, data_type)
-def _is_iso_date(text: str) -> bool:
- try:
- datetime.date.fromisoformat(text)
- return True
- except ValueError:
- return False
-
-
-def _is_iso_datetime(text: str) -> bool:
- try:
- datetime.datetime.fromisoformat(text)
- return True
- except ValueError:
- return False
-
-
-def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
+def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
date_text = l.name
- unit = r.text("unit").lower()
-
- is_iso_date = _is_iso_date(date_text)
+ is_iso_date_ = is_iso_date(date_text)
- if is_iso_date and unit in DATE_UNITS:
- l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE))
+ if is_iso_date_ and is_date_unit(unit):
return exp.DataType.Type.DATE
# An ISO date is also an ISO datetime, but not vice versa
- if is_iso_date or _is_iso_datetime(date_text):
- l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME))
+ if is_iso_date_ or is_iso_datetime(date_text):
return exp.DataType.Type.DATETIME
return exp.DataType.Type.UNKNOWN
-def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type:
- unit = r.text("unit").lower()
- if unit not in DATE_UNITS:
+def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
+ if not is_date_unit(unit):
return exp.DataType.Type.DATETIME
return l.type.this if l.type else exp.DataType.Type.UNKNOWN
@@ -171,7 +152,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Date,
exp.DateFromParts,
exp.DateStrToDate,
- exp.DateTrunc,
exp.DiToDate,
exp.StrToDate,
exp.TimeStrToDate,
@@ -185,6 +165,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DataType.Type.DOUBLE: {
exp.ApproxQuantile,
exp.Avg,
+ exp.Div,
exp.Exp,
exp.Ln,
exp.Log,
@@ -203,8 +184,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
},
exp.DataType.Type.INT: {
exp.Ceil,
- exp.DateDiff,
exp.DatetimeDiff,
+ exp.DateDiff,
exp.Extract,
exp.TimestampDiff,
exp.TimeDiff,
@@ -240,8 +221,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.GroupConcat,
exp.Initcap,
exp.Lower,
- exp.SafeConcat,
- exp.SafeDPipe,
exp.Substring,
exp.TimeToStr,
exp.TimeToTimeStr,
@@ -267,6 +246,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
for expr_type in expressions
},
+ exp.Abs: lambda self, e: self._annotate_by_args(e, "this"),
exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
@@ -276,9 +256,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
- exp.DateAdd: lambda self, e: self._annotate_dateadd(e),
- exp.DateSub: lambda self, e: self._annotate_dateadd(e),
+ exp.DateAdd: lambda self, e: self._annotate_timeunit(e),
+ exp.DateSub: lambda self, e: self._annotate_timeunit(e),
+ exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
+ exp.Div: lambda self, e: self._annotate_div(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
@@ -288,6 +270,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
+ exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
@@ -306,13 +289,27 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
BINARY_COERCIONS: BinaryCoercions = {
**swap_all(
{
- (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval
+ (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal(
+ l, r.args.get("unit")
+ )
for t in exp.DataType.TEXT_TYPES
}
),
**swap_all(
{
- (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval,
+ # text + numeric will yield the numeric type to match most dialects' semantics
+ (text, numeric): lambda l, r: t.cast(
+ exp.DataType.Type, l.type if l.type in exp.DataType.NUMERIC_TYPES else r.type
+ )
+ for text in exp.DataType.TEXT_TYPES
+ for numeric in exp.DataType.NUMERIC_TYPES
+ }
+ ),
+ **swap_all(
+ {
+ (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date(
+ l, r.args.get("unit")
+ ),
}
),
}
@@ -511,18 +508,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return expression
- def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp:
+ def _annotate_timeunit(
+ self, expression: exp.TimeUnit | exp.DateTrunc
+ ) -> exp.TimeUnit | exp.DateTrunc:
self._annotate_args(expression)
if expression.this.type.this in exp.DataType.TEXT_TYPES:
- datatype = _coerce_literal_and_interval(expression.this, expression.interval())
- elif (
- expression.this.type.is_type(exp.DataType.Type.DATE)
- and expression.text("unit").lower() not in DATE_UNITS
- ):
- datatype = exp.DataType.Type.DATETIME
+ datatype = _coerce_date_literal(expression.this, expression.unit)
+ elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES:
+ datatype = _coerce_date(expression.this, expression.unit)
else:
- datatype = expression.this.type
+ datatype = exp.DataType.Type.UNKNOWN
self._set_type(expression, datatype)
return expression
@@ -547,3 +543,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, exp.DataType.Type.UNKNOWN)
return expression
+
+ def _annotate_div(self, expression: exp.Div) -> exp.Div:
+ self._annotate_args(expression)
+
+ left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore
+
+ if (
+ expression.args.get("typed")
+ and left_type in exp.DataType.INTEGER_TYPES
+ and right_type in exp.DataType.INTEGER_TYPES
+ ):
+ self._set_type(expression, exp.DataType.Type.BIGINT)
+ else:
+ self._set_type(expression, self._maybe_coerce(left_type, right_type))
+
+ return expression