diff options
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 110 |
1 files changed, 61 insertions, 49 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 69d4567..7b990f1 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,12 +1,18 @@ from __future__ import annotations -import datetime import functools import typing as t from sqlglot import exp from sqlglot._typing import E -from sqlglot.helper import ensure_list, seq_get, subclasses +from sqlglot.helper import ( + ensure_list, + is_date_unit, + is_iso_date, + is_iso_datetime, + seq_get, + subclasses, +) from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import Schema, ensure_schema @@ -20,10 +26,6 @@ if t.TYPE_CHECKING: ] -# Interval units that operate on date components -DATE_UNITS = {"day", "week", "month", "quarter", "year", "year_month"} - - def annotate_types( expression: E, schema: t.Optional[t.Dict | Schema] = None, @@ -60,43 +62,22 @@ def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[Type return lambda self, e: self._annotate_with_type(e, data_type) -def _is_iso_date(text: str) -> bool: - try: - datetime.date.fromisoformat(text) - return True - except ValueError: - return False - - -def _is_iso_datetime(text: str) -> bool: - try: - datetime.datetime.fromisoformat(text) - return True - except ValueError: - return False - - -def _coerce_literal_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: +def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type: date_text = l.name - unit = r.text("unit").lower() - - is_iso_date = _is_iso_date(date_text) + is_iso_date_ = is_iso_date(date_text) - if is_iso_date and unit in DATE_UNITS: - l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATE)) + if is_iso_date_ and is_date_unit(unit): return exp.DataType.Type.DATE # An ISO date is also an ISO datetime, but not vice versa - if is_iso_date or _is_iso_datetime(date_text): - l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DATETIME)) + if is_iso_date_ or is_iso_datetime(date_text): return exp.DataType.Type.DATETIME return exp.DataType.Type.UNKNOWN -def _coerce_date_and_interval(l: exp.Expression, r: exp.Expression) -> exp.DataType.Type: - unit = r.text("unit").lower() - if unit not in DATE_UNITS: +def _coerce_date(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type: + if not is_date_unit(unit): return exp.DataType.Type.DATETIME return l.type.this if l.type else exp.DataType.Type.UNKNOWN @@ -171,7 +152,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Date, exp.DateFromParts, exp.DateStrToDate, - exp.DateTrunc, exp.DiToDate, exp.StrToDate, exp.TimeStrToDate, @@ -185,6 +165,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.DataType.Type.DOUBLE: { exp.ApproxQuantile, exp.Avg, + exp.Div, exp.Exp, exp.Ln, exp.Log, @@ -203,8 +184,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): }, exp.DataType.Type.INT: { exp.Ceil, - exp.DateDiff, exp.DatetimeDiff, + exp.DateDiff, exp.Extract, exp.TimestampDiff, exp.TimeDiff, @@ -240,8 +221,6 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.GroupConcat, exp.Initcap, exp.Lower, - exp.SafeConcat, - exp.SafeDPipe, exp.Substring, exp.TimeToStr, exp.TimeToTimeStr, @@ -267,6 +246,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): for data_type, expressions in TYPE_TO_EXPRESSIONS.items() for expr_type in expressions }, + exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), @@ -276,9 +256,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), - exp.DateAdd: lambda self, e: self._annotate_dateadd(e), - exp.DateSub: lambda self, e: self._annotate_dateadd(e), + exp.DateAdd: lambda self, e: self._annotate_timeunit(e), + exp.DateSub: lambda self, e: self._annotate_timeunit(e), + exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), + exp.Div: lambda self, e: self._annotate_div(e), exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), @@ -288,6 +270,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), + exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), @@ -306,13 +289,27 @@ class TypeAnnotator(metaclass=_TypeAnnotator): BINARY_COERCIONS: BinaryCoercions = { **swap_all( { - (t, exp.DataType.Type.INTERVAL): _coerce_literal_and_interval + (t, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date_literal( + l, r.args.get("unit") + ) for t in exp.DataType.TEXT_TYPES } ), **swap_all( { - (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): _coerce_date_and_interval, + # text + numeric will yield the numeric type to match most dialects' semantics + (text, numeric): lambda l, r: t.cast( + exp.DataType.Type, l.type if l.type in exp.DataType.NUMERIC_TYPES else r.type + ) + for text in exp.DataType.TEXT_TYPES + for numeric in exp.DataType.NUMERIC_TYPES + } + ), + **swap_all( + { + (exp.DataType.Type.DATE, exp.DataType.Type.INTERVAL): lambda l, r: _coerce_date( + l, r.args.get("unit") + ), } ), } @@ -511,18 +508,17 @@ class TypeAnnotator(metaclass=_TypeAnnotator): return expression - def _annotate_dateadd(self, expression: exp.IntervalOp) -> exp.IntervalOp: + def _annotate_timeunit( + self, expression: exp.TimeUnit | exp.DateTrunc + ) -> exp.TimeUnit | exp.DateTrunc: self._annotate_args(expression) if expression.this.type.this in exp.DataType.TEXT_TYPES: - datatype = _coerce_literal_and_interval(expression.this, expression.interval()) - elif ( - expression.this.type.is_type(exp.DataType.Type.DATE) - and expression.text("unit").lower() not in DATE_UNITS - ): - datatype = exp.DataType.Type.DATETIME + datatype = _coerce_date_literal(expression.this, expression.unit) + elif expression.this.type.this in exp.DataType.TEMPORAL_TYPES: + datatype = _coerce_date(expression.this, expression.unit) else: - datatype = expression.this.type + datatype = exp.DataType.Type.UNKNOWN self._set_type(expression, datatype) return expression @@ -547,3 +543,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._set_type(expression, exp.DataType.Type.UNKNOWN) return expression + + def _annotate_div(self, expression: exp.Div) -> exp.Div: + self._annotate_args(expression) + + left_type, right_type = expression.left.type.this, expression.right.type.this # type: ignore + + if ( + expression.args.get("typed") + and left_type in exp.DataType.INTEGER_TYPES + and right_type in exp.DataType.INTEGER_TYPES + ): + self._set_type(expression, exp.DataType.Type.BIGINT) + else: + self._set_type(expression, self._maybe_coerce(left_type, right_type)) + + return expression |