Edit on GitHub

sqlglot.optimizer.annotate_types

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import exp
  6from sqlglot._typing import E
  7from sqlglot.helper import ensure_list, subclasses
  8from sqlglot.optimizer.scope import Scope, traverse_scope
  9from sqlglot.schema import Schema, ensure_schema
 10
 11if t.TYPE_CHECKING:
 12    B = t.TypeVar("B", bound=exp.Binary)
 13
 14
 15def annotate_types(
 16    expression: E,
 17    schema: t.Optional[t.Dict | Schema] = None,
 18    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
 19    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
 20) -> E:
 21    """
 22    Infers the types of an expression, annotating its AST accordingly.
 23
 24    Example:
 25        >>> import sqlglot
 26        >>> schema = {"y": {"cola": "SMALLINT"}}
 27        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 28        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 29        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 30        <Type.DOUBLE: 'DOUBLE'>
 31
 32    Args:
 33        expression: Expression to annotate.
 34        schema: Database schema.
 35        annotators: Maps expression type to corresponding annotation function.
 36        coerces_to: Maps expression type to set of types that it can be coerced into.
 37
 38    Returns:
 39        The expression annotated with types.
 40    """
 41
 42    schema = ensure_schema(schema)
 43
 44    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 45
 46
 47def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
 48    return lambda self, e: self._annotate_with_type(e, data_type)
 49
 50
 51class _TypeAnnotator(type):
 52    def __new__(cls, clsname, bases, attrs):
 53        klass = super().__new__(cls, clsname, bases, attrs)
 54
 55        # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
 56        # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
 57        text_precedence = (
 58            exp.DataType.Type.TEXT,
 59            exp.DataType.Type.NVARCHAR,
 60            exp.DataType.Type.VARCHAR,
 61            exp.DataType.Type.NCHAR,
 62            exp.DataType.Type.CHAR,
 63        )
 64        numeric_precedence = (
 65            exp.DataType.Type.DOUBLE,
 66            exp.DataType.Type.FLOAT,
 67            exp.DataType.Type.DECIMAL,
 68            exp.DataType.Type.BIGINT,
 69            exp.DataType.Type.INT,
 70            exp.DataType.Type.SMALLINT,
 71            exp.DataType.Type.TINYINT,
 72        )
 73        timelike_precedence = (
 74            exp.DataType.Type.TIMESTAMPLTZ,
 75            exp.DataType.Type.TIMESTAMPTZ,
 76            exp.DataType.Type.TIMESTAMP,
 77            exp.DataType.Type.DATETIME,
 78            exp.DataType.Type.DATE,
 79        )
 80
 81        for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
 82            coerces_to = set()
 83            for data_type in type_precedence:
 84                klass.COERCES_TO[data_type] = coerces_to.copy()
 85                coerces_to |= {data_type}
 86
 87        return klass
 88
 89
 90class TypeAnnotator(metaclass=_TypeAnnotator):
 91    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
 92        exp.DataType.Type.BIGINT: {
 93            exp.ApproxDistinct,
 94            exp.ArraySize,
 95            exp.Count,
 96            exp.Length,
 97        },
 98        exp.DataType.Type.BOOLEAN: {
 99            exp.Between,
100            exp.Boolean,
101            exp.In,
102            exp.RegexpLike,
103        },
104        exp.DataType.Type.DATE: {
105            exp.CurrentDate,
106            exp.Date,
107            exp.DateAdd,
108            exp.DateFromParts,
109            exp.DateStrToDate,
110            exp.DateSub,
111            exp.DateTrunc,
112            exp.DiToDate,
113            exp.StrToDate,
114            exp.TimeStrToDate,
115            exp.TsOrDsToDate,
116        },
117        exp.DataType.Type.DATETIME: {
118            exp.CurrentDatetime,
119            exp.DatetimeAdd,
120            exp.DatetimeSub,
121        },
122        exp.DataType.Type.DOUBLE: {
123            exp.ApproxQuantile,
124            exp.Avg,
125            exp.Exp,
126            exp.Ln,
127            exp.Log,
128            exp.Log2,
129            exp.Log10,
130            exp.Pow,
131            exp.Quantile,
132            exp.Round,
133            exp.SafeDivide,
134            exp.Sqrt,
135            exp.Stddev,
136            exp.StddevPop,
137            exp.StddevSamp,
138            exp.Variance,
139            exp.VariancePop,
140        },
141        exp.DataType.Type.INT: {
142            exp.Ceil,
143            exp.DateDiff,
144            exp.DatetimeDiff,
145            exp.Extract,
146            exp.TimestampDiff,
147            exp.TimeDiff,
148            exp.DateToDi,
149            exp.Floor,
150            exp.Levenshtein,
151            exp.StrPosition,
152            exp.TsOrDiToDi,
153        },
154        exp.DataType.Type.TIMESTAMP: {
155            exp.CurrentTime,
156            exp.CurrentTimestamp,
157            exp.StrToTime,
158            exp.TimeAdd,
159            exp.TimeStrToTime,
160            exp.TimeSub,
161            exp.TimestampAdd,
162            exp.TimestampSub,
163            exp.UnixToTime,
164        },
165        exp.DataType.Type.TINYINT: {
166            exp.Day,
167            exp.Month,
168            exp.Week,
169            exp.Year,
170        },
171        exp.DataType.Type.VARCHAR: {
172            exp.ArrayConcat,
173            exp.Concat,
174            exp.ConcatWs,
175            exp.DateToDateStr,
176            exp.GroupConcat,
177            exp.Initcap,
178            exp.Lower,
179            exp.SafeConcat,
180            exp.Substring,
181            exp.TimeToStr,
182            exp.TimeToTimeStr,
183            exp.Trim,
184            exp.TsOrDsToDateStr,
185            exp.UnixToStr,
186            exp.UnixToTimeStr,
187            exp.Upper,
188        },
189    }
190
191    ANNOTATORS: t.Dict = {
192        **{
193            expr_type: lambda self, e: self._annotate_unary(e)
194            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
195        },
196        **{
197            expr_type: lambda self, e: self._annotate_binary(e)
198            for expr_type in subclasses(exp.__name__, exp.Binary)
199        },
200        **{
201            expr_type: _annotate_with_type_lambda(data_type)
202            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
203            for expr_type in expressions
204        },
205        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
206        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
207        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
208        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
209        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
210        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
211        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
212        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
213        exp.Literal: lambda self, e: self._annotate_literal(e),
214        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
215        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
216        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
217        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
218        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
219        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
220        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
221    }
222
223    # Specifies what types a given type can be coerced into (autofilled)
224    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
225
226    def __init__(
227        self,
228        schema: Schema,
229        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
230        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
231    ) -> None:
232        self.schema = schema
233        self.annotators = annotators or self.ANNOTATORS
234        self.coerces_to = coerces_to or self.COERCES_TO
235
236    def annotate(self, expression: E) -> E:
237        for scope in traverse_scope(expression):
238            selects = {}
239            for name, source in scope.sources.items():
240                if not isinstance(source, Scope):
241                    continue
242                if isinstance(source.expression, exp.UDTF):
243                    values = []
244
245                    if isinstance(source.expression, exp.Lateral):
246                        if isinstance(source.expression.this, exp.Explode):
247                            values = [source.expression.this.this]
248                    else:
249                        values = source.expression.expressions[0].expressions
250
251                    if not values:
252                        continue
253
254                    selects[name] = {
255                        alias: column
256                        for alias, column in zip(
257                            source.expression.alias_column_names,
258                            values,
259                        )
260                    }
261                else:
262                    selects[name] = {
263                        select.alias_or_name: select for select in source.expression.selects
264                    }
265
266            # First annotate the current scope's column references
267            for col in scope.columns:
268                if not col.table:
269                    continue
270
271                source = scope.sources.get(col.table)
272                if isinstance(source, exp.Table):
273                    col.type = self.schema.get_column_type(source, col)
274                elif source and col.table in selects and col.name in selects[col.table]:
275                    col.type = selects[col.table][col.name].type
276
277            # Then (possibly) annotate the remaining expressions in the scope
278            self._maybe_annotate(scope.expression)
279
280        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
281
282    def _maybe_annotate(self, expression: E) -> E:
283        if expression.type:
284            return expression  # We've already inferred the expression's type
285
286        annotator = self.annotators.get(expression.__class__)
287
288        return (
289            annotator(self, expression)
290            if annotator
291            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
292        )
293
294    def _annotate_args(self, expression: E) -> E:
295        for _, value in expression.iter_expressions():
296            self._maybe_annotate(value)
297
298        return expression
299
300    def _maybe_coerce(
301        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
302    ) -> exp.DataType.Type:
303        # We propagate the NULL / UNKNOWN types upwards if found
304        if isinstance(type1, exp.DataType):
305            type1 = type1.this
306        if isinstance(type2, exp.DataType):
307            type2 = type2.this
308
309        if exp.DataType.Type.NULL in (type1, type2):
310            return exp.DataType.Type.NULL
311        if exp.DataType.Type.UNKNOWN in (type1, type2):
312            return exp.DataType.Type.UNKNOWN
313
314        return type2 if type2 in self.coerces_to.get(type1, {}) else type1  # type: ignore
315
316    # Note: the following "no_type_check" decorators were added because mypy was yelling due
317    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
318    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
319
320    @t.no_type_check
321    def _annotate_binary(self, expression: B) -> B:
322        self._annotate_args(expression)
323
324        left_type = expression.left.type.this
325        right_type = expression.right.type.this
326
327        if isinstance(expression, exp.Connector):
328            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
329                expression.type = exp.DataType.Type.NULL
330            elif exp.DataType.Type.NULL in (left_type, right_type):
331                expression.type = exp.DataType.build(
332                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
333                )
334            else:
335                expression.type = exp.DataType.Type.BOOLEAN
336        elif isinstance(expression, exp.Predicate):
337            expression.type = exp.DataType.Type.BOOLEAN
338        else:
339            expression.type = self._maybe_coerce(left_type, right_type)
340
341        return expression
342
343    @t.no_type_check
344    def _annotate_unary(self, expression: E) -> E:
345        self._annotate_args(expression)
346
347        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
348            expression.type = exp.DataType.Type.BOOLEAN
349        else:
350            expression.type = expression.this.type
351
352        return expression
353
354    @t.no_type_check
355    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
356        if expression.is_string:
357            expression.type = exp.DataType.Type.VARCHAR
358        elif expression.is_int:
359            expression.type = exp.DataType.Type.INT
360        else:
361            expression.type = exp.DataType.Type.DOUBLE
362
363        return expression
364
365    @t.no_type_check
366    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
367        expression.type = target_type
368        return self._annotate_args(expression)
369
370    @t.no_type_check
371    def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
372        self._annotate_args(expression)
373
374        expressions: t.List[exp.Expression] = []
375        for arg in args:
376            arg_expr = expression.args.get(arg)
377            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
378
379        last_datatype = None
380        for expr in expressions:
381            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
382
383        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
384
385        if promote:
386            if expression.type.this in exp.DataType.INTEGER_TYPES:
387                expression.type = exp.DataType.Type.BIGINT
388            elif expression.type.this in exp.DataType.FLOAT_TYPES:
389                expression.type = exp.DataType.Type.DOUBLE
390
391        return expression
def annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[sqlglot.optimizer.annotate_types.TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
16def annotate_types(
17    expression: E,
18    schema: t.Optional[t.Dict | Schema] = None,
19    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
20    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
21) -> E:
22    """
23    Infers the types of an expression, annotating its AST accordingly.
24
25    Example:
26        >>> import sqlglot
27        >>> schema = {"y": {"cola": "SMALLINT"}}
28        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
29        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
30        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
31        <Type.DOUBLE: 'DOUBLE'>
32
33    Args:
34        expression: Expression to annotate.
35        schema: Database schema.
36        annotators: Maps expression type to corresponding annotation function.
37        coerces_to: Maps expression type to set of types that it can be coerced into.
38
39    Returns:
40        The expression annotated with types.
41    """
42
43    schema = ensure_schema(schema)
44
45    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Infers the types of an expression, annotating its AST accordingly.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression: Expression to annotate.
  • schema: Database schema.
  • annotators: Maps expression type to corresponding annotation function.
  • coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:

The expression annotated with types.

class TypeAnnotator:
 91class TypeAnnotator(metaclass=_TypeAnnotator):
 92    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
 93        exp.DataType.Type.BIGINT: {
 94            exp.ApproxDistinct,
 95            exp.ArraySize,
 96            exp.Count,
 97            exp.Length,
 98        },
 99        exp.DataType.Type.BOOLEAN: {
100            exp.Between,
101            exp.Boolean,
102            exp.In,
103            exp.RegexpLike,
104        },
105        exp.DataType.Type.DATE: {
106            exp.CurrentDate,
107            exp.Date,
108            exp.DateAdd,
109            exp.DateFromParts,
110            exp.DateStrToDate,
111            exp.DateSub,
112            exp.DateTrunc,
113            exp.DiToDate,
114            exp.StrToDate,
115            exp.TimeStrToDate,
116            exp.TsOrDsToDate,
117        },
118        exp.DataType.Type.DATETIME: {
119            exp.CurrentDatetime,
120            exp.DatetimeAdd,
121            exp.DatetimeSub,
122        },
123        exp.DataType.Type.DOUBLE: {
124            exp.ApproxQuantile,
125            exp.Avg,
126            exp.Exp,
127            exp.Ln,
128            exp.Log,
129            exp.Log2,
130            exp.Log10,
131            exp.Pow,
132            exp.Quantile,
133            exp.Round,
134            exp.SafeDivide,
135            exp.Sqrt,
136            exp.Stddev,
137            exp.StddevPop,
138            exp.StddevSamp,
139            exp.Variance,
140            exp.VariancePop,
141        },
142        exp.DataType.Type.INT: {
143            exp.Ceil,
144            exp.DateDiff,
145            exp.DatetimeDiff,
146            exp.Extract,
147            exp.TimestampDiff,
148            exp.TimeDiff,
149            exp.DateToDi,
150            exp.Floor,
151            exp.Levenshtein,
152            exp.StrPosition,
153            exp.TsOrDiToDi,
154        },
155        exp.DataType.Type.TIMESTAMP: {
156            exp.CurrentTime,
157            exp.CurrentTimestamp,
158            exp.StrToTime,
159            exp.TimeAdd,
160            exp.TimeStrToTime,
161            exp.TimeSub,
162            exp.TimestampAdd,
163            exp.TimestampSub,
164            exp.UnixToTime,
165        },
166        exp.DataType.Type.TINYINT: {
167            exp.Day,
168            exp.Month,
169            exp.Week,
170            exp.Year,
171        },
172        exp.DataType.Type.VARCHAR: {
173            exp.ArrayConcat,
174            exp.Concat,
175            exp.ConcatWs,
176            exp.DateToDateStr,
177            exp.GroupConcat,
178            exp.Initcap,
179            exp.Lower,
180            exp.SafeConcat,
181            exp.Substring,
182            exp.TimeToStr,
183            exp.TimeToTimeStr,
184            exp.Trim,
185            exp.TsOrDsToDateStr,
186            exp.UnixToStr,
187            exp.UnixToTimeStr,
188            exp.Upper,
189        },
190    }
191
192    ANNOTATORS: t.Dict = {
193        **{
194            expr_type: lambda self, e: self._annotate_unary(e)
195            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
196        },
197        **{
198            expr_type: lambda self, e: self._annotate_binary(e)
199            for expr_type in subclasses(exp.__name__, exp.Binary)
200        },
201        **{
202            expr_type: _annotate_with_type_lambda(data_type)
203            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
204            for expr_type in expressions
205        },
206        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
207        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
208        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
209        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
210        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
211        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
212        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
213        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
214        exp.Literal: lambda self, e: self._annotate_literal(e),
215        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
216        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
217        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
218        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
219        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
220        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
221        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
222    }
223
224    # Specifies what types a given type can be coerced into (autofilled)
225    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
226
227    def __init__(
228        self,
229        schema: Schema,
230        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
231        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
232    ) -> None:
233        self.schema = schema
234        self.annotators = annotators or self.ANNOTATORS
235        self.coerces_to = coerces_to or self.COERCES_TO
236
237    def annotate(self, expression: E) -> E:
238        for scope in traverse_scope(expression):
239            selects = {}
240            for name, source in scope.sources.items():
241                if not isinstance(source, Scope):
242                    continue
243                if isinstance(source.expression, exp.UDTF):
244                    values = []
245
246                    if isinstance(source.expression, exp.Lateral):
247                        if isinstance(source.expression.this, exp.Explode):
248                            values = [source.expression.this.this]
249                    else:
250                        values = source.expression.expressions[0].expressions
251
252                    if not values:
253                        continue
254
255                    selects[name] = {
256                        alias: column
257                        for alias, column in zip(
258                            source.expression.alias_column_names,
259                            values,
260                        )
261                    }
262                else:
263                    selects[name] = {
264                        select.alias_or_name: select for select in source.expression.selects
265                    }
266
267            # First annotate the current scope's column references
268            for col in scope.columns:
269                if not col.table:
270                    continue
271
272                source = scope.sources.get(col.table)
273                if isinstance(source, exp.Table):
274                    col.type = self.schema.get_column_type(source, col)
275                elif source and col.table in selects and col.name in selects[col.table]:
276                    col.type = selects[col.table][col.name].type
277
278            # Then (possibly) annotate the remaining expressions in the scope
279            self._maybe_annotate(scope.expression)
280
281        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
282
283    def _maybe_annotate(self, expression: E) -> E:
284        if expression.type:
285            return expression  # We've already inferred the expression's type
286
287        annotator = self.annotators.get(expression.__class__)
288
289        return (
290            annotator(self, expression)
291            if annotator
292            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
293        )
294
295    def _annotate_args(self, expression: E) -> E:
296        for _, value in expression.iter_expressions():
297            self._maybe_annotate(value)
298
299        return expression
300
301    def _maybe_coerce(
302        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
303    ) -> exp.DataType.Type:
304        # We propagate the NULL / UNKNOWN types upwards if found
305        if isinstance(type1, exp.DataType):
306            type1 = type1.this
307        if isinstance(type2, exp.DataType):
308            type2 = type2.this
309
310        if exp.DataType.Type.NULL in (type1, type2):
311            return exp.DataType.Type.NULL
312        if exp.DataType.Type.UNKNOWN in (type1, type2):
313            return exp.DataType.Type.UNKNOWN
314
315        return type2 if type2 in self.coerces_to.get(type1, {}) else type1  # type: ignore
316
317    # Note: the following "no_type_check" decorators were added because mypy was yelling due
318    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
319    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
320
321    @t.no_type_check
322    def _annotate_binary(self, expression: B) -> B:
323        self._annotate_args(expression)
324
325        left_type = expression.left.type.this
326        right_type = expression.right.type.this
327
328        if isinstance(expression, exp.Connector):
329            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
330                expression.type = exp.DataType.Type.NULL
331            elif exp.DataType.Type.NULL in (left_type, right_type):
332                expression.type = exp.DataType.build(
333                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
334                )
335            else:
336                expression.type = exp.DataType.Type.BOOLEAN
337        elif isinstance(expression, exp.Predicate):
338            expression.type = exp.DataType.Type.BOOLEAN
339        else:
340            expression.type = self._maybe_coerce(left_type, right_type)
341
342        return expression
343
344    @t.no_type_check
345    def _annotate_unary(self, expression: E) -> E:
346        self._annotate_args(expression)
347
348        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
349            expression.type = exp.DataType.Type.BOOLEAN
350        else:
351            expression.type = expression.this.type
352
353        return expression
354
355    @t.no_type_check
356    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
357        if expression.is_string:
358            expression.type = exp.DataType.Type.VARCHAR
359        elif expression.is_int:
360            expression.type = exp.DataType.Type.INT
361        else:
362            expression.type = exp.DataType.Type.DOUBLE
363
364        return expression
365
366    @t.no_type_check
367    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
368        expression.type = target_type
369        return self._annotate_args(expression)
370
371    @t.no_type_check
372    def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
373        self._annotate_args(expression)
374
375        expressions: t.List[exp.Expression] = []
376        for arg in args:
377            arg_expr = expression.args.get(arg)
378            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
379
380        last_datatype = None
381        for expr in expressions:
382            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
383
384        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
385
386        if promote:
387            if expression.type.this in exp.DataType.INTEGER_TYPES:
388                expression.type = exp.DataType.Type.BIGINT
389            elif expression.type.this in exp.DataType.FLOAT_TYPES:
390                expression.type = exp.DataType.Type.DOUBLE
391
392        return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[sqlglot.optimizer.annotate_types.TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None)
227    def __init__(
228        self,
229        schema: Schema,
230        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
231        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
232    ) -> None:
233        self.schema = schema
234        self.annotators = annotators or self.ANNOTATORS
235        self.coerces_to = coerces_to or self.COERCES_TO
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] = {<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.Count'>, <class 'sqlglot.expressions.ArraySize'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.Boolean'>, <class 'sqlglot.expressions.RegexpLike'>, <class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.In'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.DateSub'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.DateAdd'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.DateTrunc'>, <class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.DateFromParts'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.CurrentDatetime'>, <class 'sqlglot.expressions.DatetimeSub'>, <class 'sqlglot.expressions.DatetimeAdd'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.Round'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.Pow'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.Levenshtein'>, <class 'sqlglot.expressions.TsOrDiToDi'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.TimestampSub'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.TimeStrToTime'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Week'>, <class 'sqlglot.expressions.Day'>, <class 'sqlglot.expressions.Year'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.SafeConcat'>, <class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.Initcap'>}}
ANNOTATORS: Dict = {<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] = {<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.NCHAR: 'NCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.CHAR: 'CHAR'>: {<Type.NCHAR: 'NCHAR'>, <Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.BIGINT: 'BIGINT'>: {<Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.INT: 'INT'>: {<Type.DECIMAL: 'DECIMAL'>, <Type.BIGINT: 'BIGINT'>, <Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.BIGINT: 'BIGINT'>, <Type.INT: 'INT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.TINYINT: 'TINYINT'>: {<Type.SMALLINT: 'SMALLINT'>, <Type.BIGINT: 'BIGINT'>, <Type.INT: 'INT'>, <Type.DECIMAL: 'DECIMAL'>, <Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>}, <Type.DATE: 'DATE'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>}}
schema
annotators
coerces_to
def annotate(self, expression: ~E) -> ~E:
237    def annotate(self, expression: E) -> E:
238        for scope in traverse_scope(expression):
239            selects = {}
240            for name, source in scope.sources.items():
241                if not isinstance(source, Scope):
242                    continue
243                if isinstance(source.expression, exp.UDTF):
244                    values = []
245
246                    if isinstance(source.expression, exp.Lateral):
247                        if isinstance(source.expression.this, exp.Explode):
248                            values = [source.expression.this.this]
249                    else:
250                        values = source.expression.expressions[0].expressions
251
252                    if not values:
253                        continue
254
255                    selects[name] = {
256                        alias: column
257                        for alias, column in zip(
258                            source.expression.alias_column_names,
259                            values,
260                        )
261                    }
262                else:
263                    selects[name] = {
264                        select.alias_or_name: select for select in source.expression.selects
265                    }
266
267            # First annotate the current scope's column references
268            for col in scope.columns:
269                if not col.table:
270                    continue
271
272                source = scope.sources.get(col.table)
273                if isinstance(source, exp.Table):
274                    col.type = self.schema.get_column_type(source, col)
275                elif source and col.table in selects and col.name in selects[col.table]:
276                    col.type = selects[col.table][col.name].type
277
278            # Then (possibly) annotate the remaining expressions in the scope
279            self._maybe_annotate(scope.expression)
280
281        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions