Edit on GitHub

sqlglot.optimizer.annotate_types

  1from __future__ import annotations
  2
  3import typing as t
  4
  5from sqlglot import exp
  6from sqlglot._typing import E
  7from sqlglot.helper import ensure_list, subclasses
  8from sqlglot.optimizer.scope import Scope, traverse_scope
  9from sqlglot.schema import Schema, ensure_schema
 10
 11if t.TYPE_CHECKING:
 12    B = t.TypeVar("B", bound=exp.Binary)
 13
 14
 15def annotate_types(
 16    expression: E,
 17    schema: t.Optional[t.Dict | Schema] = None,
 18    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
 19    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
 20) -> E:
 21    """
 22    Infers the types of an expression, annotating its AST accordingly.
 23
 24    Example:
 25        >>> import sqlglot
 26        >>> schema = {"y": {"cola": "SMALLINT"}}
 27        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 28        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 29        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 30        <Type.DOUBLE: 'DOUBLE'>
 31
 32    Args:
 33        expression: Expression to annotate.
 34        schema: Database schema.
 35        annotators: Maps expression type to corresponding annotation function.
 36        coerces_to: Maps expression type to set of types that it can be coerced into.
 37
 38    Returns:
 39        The expression annotated with types.
 40    """
 41
 42    schema = ensure_schema(schema)
 43
 44    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 45
 46
 47def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
 48    return lambda self, e: self._annotate_with_type(e, data_type)
 49
 50
 51class _TypeAnnotator(type):
 52    def __new__(cls, clsname, bases, attrs):
 53        klass = super().__new__(cls, clsname, bases, attrs)
 54
 55        # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
 56        # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
 57        text_precedence = (
 58            exp.DataType.Type.TEXT,
 59            exp.DataType.Type.NVARCHAR,
 60            exp.DataType.Type.VARCHAR,
 61            exp.DataType.Type.NCHAR,
 62            exp.DataType.Type.CHAR,
 63        )
 64        numeric_precedence = (
 65            exp.DataType.Type.DOUBLE,
 66            exp.DataType.Type.FLOAT,
 67            exp.DataType.Type.DECIMAL,
 68            exp.DataType.Type.BIGINT,
 69            exp.DataType.Type.INT,
 70            exp.DataType.Type.SMALLINT,
 71            exp.DataType.Type.TINYINT,
 72        )
 73        timelike_precedence = (
 74            exp.DataType.Type.TIMESTAMPLTZ,
 75            exp.DataType.Type.TIMESTAMPTZ,
 76            exp.DataType.Type.TIMESTAMP,
 77            exp.DataType.Type.DATETIME,
 78            exp.DataType.Type.DATE,
 79        )
 80
 81        for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
 82            coerces_to = set()
 83            for data_type in type_precedence:
 84                klass.COERCES_TO[data_type] = coerces_to.copy()
 85                coerces_to |= {data_type}
 86
 87        return klass
 88
 89
 90class TypeAnnotator(metaclass=_TypeAnnotator):
 91    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
 92        exp.DataType.Type.BIGINT: {
 93            exp.ApproxDistinct,
 94            exp.ArraySize,
 95            exp.Count,
 96            exp.Length,
 97        },
 98        exp.DataType.Type.BOOLEAN: {
 99            exp.Between,
100            exp.Boolean,
101            exp.In,
102            exp.RegexpLike,
103        },
104        exp.DataType.Type.DATE: {
105            exp.CurrentDate,
106            exp.Date,
107            exp.DateAdd,
108            exp.DateFromParts,
109            exp.DateStrToDate,
110            exp.DateSub,
111            exp.DateTrunc,
112            exp.DiToDate,
113            exp.StrToDate,
114            exp.TimeStrToDate,
115            exp.TsOrDsToDate,
116        },
117        exp.DataType.Type.DATETIME: {
118            exp.CurrentDatetime,
119            exp.DatetimeAdd,
120            exp.DatetimeSub,
121        },
122        exp.DataType.Type.DOUBLE: {
123            exp.ApproxQuantile,
124            exp.Avg,
125            exp.Exp,
126            exp.Ln,
127            exp.Log,
128            exp.Log2,
129            exp.Log10,
130            exp.Pow,
131            exp.Quantile,
132            exp.Round,
133            exp.SafeDivide,
134            exp.Sqrt,
135            exp.Stddev,
136            exp.StddevPop,
137            exp.StddevSamp,
138            exp.Variance,
139            exp.VariancePop,
140        },
141        exp.DataType.Type.INT: {
142            exp.Ceil,
143            exp.DateDiff,
144            exp.DatetimeDiff,
145            exp.Extract,
146            exp.TimestampDiff,
147            exp.TimeDiff,
148            exp.DateToDi,
149            exp.Floor,
150            exp.Levenshtein,
151            exp.StrPosition,
152            exp.TsOrDiToDi,
153        },
154        exp.DataType.Type.TIMESTAMP: {
155            exp.CurrentTime,
156            exp.CurrentTimestamp,
157            exp.StrToTime,
158            exp.TimeAdd,
159            exp.TimeStrToTime,
160            exp.TimeSub,
161            exp.TimestampAdd,
162            exp.TimestampSub,
163            exp.UnixToTime,
164        },
165        exp.DataType.Type.TINYINT: {
166            exp.Day,
167            exp.Month,
168            exp.Week,
169            exp.Year,
170        },
171        exp.DataType.Type.VARCHAR: {
172            exp.ArrayConcat,
173            exp.Concat,
174            exp.ConcatWs,
175            exp.DateToDateStr,
176            exp.GroupConcat,
177            exp.Initcap,
178            exp.Lower,
179            exp.SafeConcat,
180            exp.Substring,
181            exp.TimeToStr,
182            exp.TimeToTimeStr,
183            exp.Trim,
184            exp.TsOrDsToDateStr,
185            exp.UnixToStr,
186            exp.UnixToTimeStr,
187            exp.Upper,
188        },
189    }
190
191    ANNOTATORS: t.Dict = {
192        **{
193            expr_type: lambda self, e: self._annotate_unary(e)
194            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
195        },
196        **{
197            expr_type: lambda self, e: self._annotate_binary(e)
198            for expr_type in subclasses(exp.__name__, exp.Binary)
199        },
200        **{
201            expr_type: _annotate_with_type_lambda(data_type)
202            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
203            for expr_type in expressions
204        },
205        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
206        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
207        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
208        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
209        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
210        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
211        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
212        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
213        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
214        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
215        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
216        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
217        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
218        exp.Literal: lambda self, e: self._annotate_literal(e),
219        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
220        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
221        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
222        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
223        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
224        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
225        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
226    }
227
228    NESTED_TYPES = {
229        exp.DataType.Type.ARRAY,
230    }
231
232    # Specifies what types a given type can be coerced into (autofilled)
233    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
234
235    def __init__(
236        self,
237        schema: Schema,
238        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
239        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
240    ) -> None:
241        self.schema = schema
242        self.annotators = annotators or self.ANNOTATORS
243        self.coerces_to = coerces_to or self.COERCES_TO
244
245    def annotate(self, expression: E) -> E:
246        for scope in traverse_scope(expression):
247            selects = {}
248            for name, source in scope.sources.items():
249                if not isinstance(source, Scope):
250                    continue
251                if isinstance(source.expression, exp.UDTF):
252                    values = []
253
254                    if isinstance(source.expression, exp.Lateral):
255                        if isinstance(source.expression.this, exp.Explode):
256                            values = [source.expression.this.this]
257                    else:
258                        values = source.expression.expressions[0].expressions
259
260                    if not values:
261                        continue
262
263                    selects[name] = {
264                        alias: column
265                        for alias, column in zip(
266                            source.expression.alias_column_names,
267                            values,
268                        )
269                    }
270                else:
271                    selects[name] = {
272                        select.alias_or_name: select for select in source.expression.selects
273                    }
274
275            # First annotate the current scope's column references
276            for col in scope.columns:
277                if not col.table:
278                    continue
279
280                source = scope.sources.get(col.table)
281                if isinstance(source, exp.Table):
282                    col.type = self.schema.get_column_type(source, col)
283                elif source and col.table in selects and col.name in selects[col.table]:
284                    col.type = selects[col.table][col.name].type
285
286            # Then (possibly) annotate the remaining expressions in the scope
287            self._maybe_annotate(scope.expression)
288
289        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
290
291    def _maybe_annotate(self, expression: E) -> E:
292        if expression.type:
293            return expression  # We've already inferred the expression's type
294
295        annotator = self.annotators.get(expression.__class__)
296
297        return (
298            annotator(self, expression)
299            if annotator
300            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
301        )
302
303    def _annotate_args(self, expression: E) -> E:
304        for _, value in expression.iter_expressions():
305            self._maybe_annotate(value)
306
307        return expression
308
309    def _maybe_coerce(
310        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
311    ) -> exp.DataType | exp.DataType.Type:
312        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
313        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
314
315        # We propagate the NULL / UNKNOWN types upwards if found
316        if exp.DataType.Type.NULL in (type1_value, type2_value):
317            return exp.DataType.Type.NULL
318        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
319            return exp.DataType.Type.UNKNOWN
320
321        if type1_value in self.NESTED_TYPES:
322            return type1
323        if type2_value in self.NESTED_TYPES:
324            return type2
325
326        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
327
328    # Note: the following "no_type_check" decorators were added because mypy was yelling due
329    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
330    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
331
332    @t.no_type_check
333    def _annotate_binary(self, expression: B) -> B:
334        self._annotate_args(expression)
335
336        left_type = expression.left.type.this
337        right_type = expression.right.type.this
338
339        if isinstance(expression, exp.Connector):
340            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
341                expression.type = exp.DataType.Type.NULL
342            elif exp.DataType.Type.NULL in (left_type, right_type):
343                expression.type = exp.DataType.build(
344                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
345                )
346            else:
347                expression.type = exp.DataType.Type.BOOLEAN
348        elif isinstance(expression, exp.Predicate):
349            expression.type = exp.DataType.Type.BOOLEAN
350        else:
351            expression.type = self._maybe_coerce(left_type, right_type)
352
353        return expression
354
355    @t.no_type_check
356    def _annotate_unary(self, expression: E) -> E:
357        self._annotate_args(expression)
358
359        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
360            expression.type = exp.DataType.Type.BOOLEAN
361        else:
362            expression.type = expression.this.type
363
364        return expression
365
366    @t.no_type_check
367    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
368        if expression.is_string:
369            expression.type = exp.DataType.Type.VARCHAR
370        elif expression.is_int:
371            expression.type = exp.DataType.Type.INT
372        else:
373            expression.type = exp.DataType.Type.DOUBLE
374
375        return expression
376
377    @t.no_type_check
378    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
379        expression.type = target_type
380        return self._annotate_args(expression)
381
382    @t.no_type_check
383    def _annotate_by_args(
384        self, expression: E, *args: str, promote: bool = False, array: bool = False
385    ) -> E:
386        self._annotate_args(expression)
387
388        expressions: t.List[exp.Expression] = []
389        for arg in args:
390            arg_expr = expression.args.get(arg)
391            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
392
393        last_datatype = None
394        for expr in expressions:
395            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
396
397        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
398
399        if promote:
400            if expression.type.this in exp.DataType.INTEGER_TYPES:
401                expression.type = exp.DataType.Type.BIGINT
402            elif expression.type.this in exp.DataType.FLOAT_TYPES:
403                expression.type = exp.DataType.Type.DOUBLE
404
405        if array:
406            expression.type = exp.DataType(
407                this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
408            )
409
410        return expression
def annotate_types( expression: ~E, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None) -> ~E:
16def annotate_types(
17    expression: E,
18    schema: t.Optional[t.Dict | Schema] = None,
19    annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
20    coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
21) -> E:
22    """
23    Infers the types of an expression, annotating its AST accordingly.
24
25    Example:
26        >>> import sqlglot
27        >>> schema = {"y": {"cola": "SMALLINT"}}
28        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
29        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
30        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
31        <Type.DOUBLE: 'DOUBLE'>
32
33    Args:
34        expression: Expression to annotate.
35        schema: Database schema.
36        annotators: Maps expression type to corresponding annotation function.
37        coerces_to: Maps expression type to set of types that it can be coerced into.
38
39    Returns:
40        The expression annotated with types.
41    """
42
43    schema = ensure_schema(schema)
44
45    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Infers the types of an expression, annotating its AST accordingly.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression: Expression to annotate.
  • schema: Database schema.
  • annotators: Maps expression type to corresponding annotation function.
  • coerces_to: Maps expression type to set of types that it can be coerced into.
Returns:

The expression annotated with types.

class TypeAnnotator:
 91class TypeAnnotator(metaclass=_TypeAnnotator):
 92    TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
 93        exp.DataType.Type.BIGINT: {
 94            exp.ApproxDistinct,
 95            exp.ArraySize,
 96            exp.Count,
 97            exp.Length,
 98        },
 99        exp.DataType.Type.BOOLEAN: {
100            exp.Between,
101            exp.Boolean,
102            exp.In,
103            exp.RegexpLike,
104        },
105        exp.DataType.Type.DATE: {
106            exp.CurrentDate,
107            exp.Date,
108            exp.DateAdd,
109            exp.DateFromParts,
110            exp.DateStrToDate,
111            exp.DateSub,
112            exp.DateTrunc,
113            exp.DiToDate,
114            exp.StrToDate,
115            exp.TimeStrToDate,
116            exp.TsOrDsToDate,
117        },
118        exp.DataType.Type.DATETIME: {
119            exp.CurrentDatetime,
120            exp.DatetimeAdd,
121            exp.DatetimeSub,
122        },
123        exp.DataType.Type.DOUBLE: {
124            exp.ApproxQuantile,
125            exp.Avg,
126            exp.Exp,
127            exp.Ln,
128            exp.Log,
129            exp.Log2,
130            exp.Log10,
131            exp.Pow,
132            exp.Quantile,
133            exp.Round,
134            exp.SafeDivide,
135            exp.Sqrt,
136            exp.Stddev,
137            exp.StddevPop,
138            exp.StddevSamp,
139            exp.Variance,
140            exp.VariancePop,
141        },
142        exp.DataType.Type.INT: {
143            exp.Ceil,
144            exp.DateDiff,
145            exp.DatetimeDiff,
146            exp.Extract,
147            exp.TimestampDiff,
148            exp.TimeDiff,
149            exp.DateToDi,
150            exp.Floor,
151            exp.Levenshtein,
152            exp.StrPosition,
153            exp.TsOrDiToDi,
154        },
155        exp.DataType.Type.TIMESTAMP: {
156            exp.CurrentTime,
157            exp.CurrentTimestamp,
158            exp.StrToTime,
159            exp.TimeAdd,
160            exp.TimeStrToTime,
161            exp.TimeSub,
162            exp.TimestampAdd,
163            exp.TimestampSub,
164            exp.UnixToTime,
165        },
166        exp.DataType.Type.TINYINT: {
167            exp.Day,
168            exp.Month,
169            exp.Week,
170            exp.Year,
171        },
172        exp.DataType.Type.VARCHAR: {
173            exp.ArrayConcat,
174            exp.Concat,
175            exp.ConcatWs,
176            exp.DateToDateStr,
177            exp.GroupConcat,
178            exp.Initcap,
179            exp.Lower,
180            exp.SafeConcat,
181            exp.Substring,
182            exp.TimeToStr,
183            exp.TimeToTimeStr,
184            exp.Trim,
185            exp.TsOrDsToDateStr,
186            exp.UnixToStr,
187            exp.UnixToTimeStr,
188            exp.Upper,
189        },
190    }
191
192    ANNOTATORS: t.Dict = {
193        **{
194            expr_type: lambda self, e: self._annotate_unary(e)
195            for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
196        },
197        **{
198            expr_type: lambda self, e: self._annotate_binary(e)
199            for expr_type in subclasses(exp.__name__, exp.Binary)
200        },
201        **{
202            expr_type: _annotate_with_type_lambda(data_type)
203            for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
204            for expr_type in expressions
205        },
206        exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
207        exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
208        exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
209        exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
210        exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
211        exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
212        exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
213        exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
214        exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
215        exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
216        exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
217        exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
218        exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
219        exp.Literal: lambda self, e: self._annotate_literal(e),
220        exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
221        exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
222        exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
223        exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
224        exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
225        exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
226        exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
227    }
228
229    NESTED_TYPES = {
230        exp.DataType.Type.ARRAY,
231    }
232
233    # Specifies what types a given type can be coerced into (autofilled)
234    COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
235
236    def __init__(
237        self,
238        schema: Schema,
239        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
240        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
241    ) -> None:
242        self.schema = schema
243        self.annotators = annotators or self.ANNOTATORS
244        self.coerces_to = coerces_to or self.COERCES_TO
245
246    def annotate(self, expression: E) -> E:
247        for scope in traverse_scope(expression):
248            selects = {}
249            for name, source in scope.sources.items():
250                if not isinstance(source, Scope):
251                    continue
252                if isinstance(source.expression, exp.UDTF):
253                    values = []
254
255                    if isinstance(source.expression, exp.Lateral):
256                        if isinstance(source.expression.this, exp.Explode):
257                            values = [source.expression.this.this]
258                    else:
259                        values = source.expression.expressions[0].expressions
260
261                    if not values:
262                        continue
263
264                    selects[name] = {
265                        alias: column
266                        for alias, column in zip(
267                            source.expression.alias_column_names,
268                            values,
269                        )
270                    }
271                else:
272                    selects[name] = {
273                        select.alias_or_name: select for select in source.expression.selects
274                    }
275
276            # First annotate the current scope's column references
277            for col in scope.columns:
278                if not col.table:
279                    continue
280
281                source = scope.sources.get(col.table)
282                if isinstance(source, exp.Table):
283                    col.type = self.schema.get_column_type(source, col)
284                elif source and col.table in selects and col.name in selects[col.table]:
285                    col.type = selects[col.table][col.name].type
286
287            # Then (possibly) annotate the remaining expressions in the scope
288            self._maybe_annotate(scope.expression)
289
290        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
291
292    def _maybe_annotate(self, expression: E) -> E:
293        if expression.type:
294            return expression  # We've already inferred the expression's type
295
296        annotator = self.annotators.get(expression.__class__)
297
298        return (
299            annotator(self, expression)
300            if annotator
301            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
302        )
303
304    def _annotate_args(self, expression: E) -> E:
305        for _, value in expression.iter_expressions():
306            self._maybe_annotate(value)
307
308        return expression
309
310    def _maybe_coerce(
311        self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
312    ) -> exp.DataType | exp.DataType.Type:
313        type1_value = type1.this if isinstance(type1, exp.DataType) else type1
314        type2_value = type2.this if isinstance(type2, exp.DataType) else type2
315
316        # We propagate the NULL / UNKNOWN types upwards if found
317        if exp.DataType.Type.NULL in (type1_value, type2_value):
318            return exp.DataType.Type.NULL
319        if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
320            return exp.DataType.Type.UNKNOWN
321
322        if type1_value in self.NESTED_TYPES:
323            return type1
324        if type2_value in self.NESTED_TYPES:
325            return type2
326
327        return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value  # type: ignore
328
329    # Note: the following "no_type_check" decorators were added because mypy was yelling due
330    # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
331    # This is a known mypy issue: https://github.com/python/mypy/issues/3004
332
333    @t.no_type_check
334    def _annotate_binary(self, expression: B) -> B:
335        self._annotate_args(expression)
336
337        left_type = expression.left.type.this
338        right_type = expression.right.type.this
339
340        if isinstance(expression, exp.Connector):
341            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
342                expression.type = exp.DataType.Type.NULL
343            elif exp.DataType.Type.NULL in (left_type, right_type):
344                expression.type = exp.DataType.build(
345                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
346                )
347            else:
348                expression.type = exp.DataType.Type.BOOLEAN
349        elif isinstance(expression, exp.Predicate):
350            expression.type = exp.DataType.Type.BOOLEAN
351        else:
352            expression.type = self._maybe_coerce(left_type, right_type)
353
354        return expression
355
356    @t.no_type_check
357    def _annotate_unary(self, expression: E) -> E:
358        self._annotate_args(expression)
359
360        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
361            expression.type = exp.DataType.Type.BOOLEAN
362        else:
363            expression.type = expression.this.type
364
365        return expression
366
367    @t.no_type_check
368    def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
369        if expression.is_string:
370            expression.type = exp.DataType.Type.VARCHAR
371        elif expression.is_int:
372            expression.type = exp.DataType.Type.INT
373        else:
374            expression.type = exp.DataType.Type.DOUBLE
375
376        return expression
377
378    @t.no_type_check
379    def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
380        expression.type = target_type
381        return self._annotate_args(expression)
382
383    @t.no_type_check
384    def _annotate_by_args(
385        self, expression: E, *args: str, promote: bool = False, array: bool = False
386    ) -> E:
387        self._annotate_args(expression)
388
389        expressions: t.List[exp.Expression] = []
390        for arg in args:
391            arg_expr = expression.args.get(arg)
392            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
393
394        last_datatype = None
395        for expr in expressions:
396            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
397
398        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
399
400        if promote:
401            if expression.type.this in exp.DataType.INTEGER_TYPES:
402                expression.type = exp.DataType.Type.BIGINT
403            elif expression.type.this in exp.DataType.FLOAT_TYPES:
404                expression.type = exp.DataType.Type.DOUBLE
405
406        if array:
407            expression.type = exp.DataType(
408                this=exp.DataType.Type.ARRAY, expressions=[expression.type], nested=True
409            )
410
411        return expression
TypeAnnotator( schema: sqlglot.schema.Schema, annotators: Optional[Dict[Type[~E], Callable[[TypeAnnotator, ~E], ~E]]] = None, coerces_to: Optional[Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]]] = None)
236    def __init__(
237        self,
238        schema: Schema,
239        annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
240        coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
241    ) -> None:
242        self.schema = schema
243        self.annotators = annotators or self.ANNOTATORS
244        self.coerces_to = coerces_to or self.COERCES_TO
TYPE_TO_EXPRESSIONS: Dict[sqlglot.expressions.DataType.Type, Set[Type[sqlglot.expressions.Expression]]] = {<Type.BIGINT: 'BIGINT'>: {<class 'sqlglot.expressions.ArraySize'>, <class 'sqlglot.expressions.ApproxDistinct'>, <class 'sqlglot.expressions.Length'>, <class 'sqlglot.expressions.Count'>}, <Type.BOOLEAN: 'BOOLEAN'>: {<class 'sqlglot.expressions.Between'>, <class 'sqlglot.expressions.In'>, <class 'sqlglot.expressions.Boolean'>, <class 'sqlglot.expressions.RegexpLike'>}, <Type.DATE: 'DATE'>: {<class 'sqlglot.expressions.DiToDate'>, <class 'sqlglot.expressions.Date'>, <class 'sqlglot.expressions.StrToDate'>, <class 'sqlglot.expressions.CurrentDate'>, <class 'sqlglot.expressions.DateTrunc'>, <class 'sqlglot.expressions.TimeStrToDate'>, <class 'sqlglot.expressions.DateSub'>, <class 'sqlglot.expressions.TsOrDsToDate'>, <class 'sqlglot.expressions.DateStrToDate'>, <class 'sqlglot.expressions.DateAdd'>, <class 'sqlglot.expressions.DateFromParts'>}, <Type.DATETIME: 'DATETIME'>: {<class 'sqlglot.expressions.CurrentDatetime'>, <class 'sqlglot.expressions.DatetimeAdd'>, <class 'sqlglot.expressions.DatetimeSub'>}, <Type.DOUBLE: 'DOUBLE'>: {<class 'sqlglot.expressions.Pow'>, <class 'sqlglot.expressions.Log2'>, <class 'sqlglot.expressions.Stddev'>, <class 'sqlglot.expressions.Log'>, <class 'sqlglot.expressions.Sqrt'>, <class 'sqlglot.expressions.Avg'>, <class 'sqlglot.expressions.StddevPop'>, <class 'sqlglot.expressions.VariancePop'>, <class 'sqlglot.expressions.SafeDivide'>, <class 'sqlglot.expressions.ApproxQuantile'>, <class 'sqlglot.expressions.Quantile'>, <class 'sqlglot.expressions.Exp'>, <class 'sqlglot.expressions.StddevSamp'>, <class 'sqlglot.expressions.Log10'>, <class 'sqlglot.expressions.Variance'>, <class 'sqlglot.expressions.Ln'>, <class 'sqlglot.expressions.Round'>}, <Type.INT: 'INT'>: {<class 'sqlglot.expressions.Extract'>, <class 'sqlglot.expressions.Levenshtein'>, <class 'sqlglot.expressions.DatetimeDiff'>, <class 'sqlglot.expressions.Ceil'>, <class 'sqlglot.expressions.TimeDiff'>, <class 'sqlglot.expressions.DateToDi'>, <class 'sqlglot.expressions.Floor'>, <class 'sqlglot.expressions.DateDiff'>, <class 'sqlglot.expressions.StrPosition'>, <class 'sqlglot.expressions.TimestampDiff'>, <class 'sqlglot.expressions.TsOrDiToDi'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<class 'sqlglot.expressions.UnixToTime'>, <class 'sqlglot.expressions.CurrentTimestamp'>, <class 'sqlglot.expressions.CurrentTime'>, <class 'sqlglot.expressions.TimeSub'>, <class 'sqlglot.expressions.StrToTime'>, <class 'sqlglot.expressions.TimeStrToTime'>, <class 'sqlglot.expressions.TimeAdd'>, <class 'sqlglot.expressions.TimestampAdd'>, <class 'sqlglot.expressions.TimestampSub'>}, <Type.TINYINT: 'TINYINT'>: {<class 'sqlglot.expressions.Day'>, <class 'sqlglot.expressions.Month'>, <class 'sqlglot.expressions.Year'>, <class 'sqlglot.expressions.Week'>}, <Type.VARCHAR: 'VARCHAR'>: {<class 'sqlglot.expressions.GroupConcat'>, <class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.Trim'>, <class 'sqlglot.expressions.UnixToTimeStr'>, <class 'sqlglot.expressions.UnixToStr'>, <class 'sqlglot.expressions.ArrayConcat'>, <class 'sqlglot.expressions.Lower'>, <class 'sqlglot.expressions.DateToDateStr'>, <class 'sqlglot.expressions.TsOrDsToDateStr'>, <class 'sqlglot.expressions.TimeToTimeStr'>, <class 'sqlglot.expressions.Initcap'>, <class 'sqlglot.expressions.Upper'>, <class 'sqlglot.expressions.ConcatWs'>, <class 'sqlglot.expressions.TimeToStr'>, <class 'sqlglot.expressions.Substring'>, <class 'sqlglot.expressions.SafeConcat'>}}
ANNOTATORS: Dict = {<class 'sqlglot.expressions.Alias'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseNot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Neg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Not'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Paren'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Unary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Add'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.And'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContained'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArrayOverlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Binary'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseAnd'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseLeftShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseOr'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseRightShift'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.BitwiseXor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Collate'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Connector'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.DPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Distance'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Div'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Dot'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.EQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Escape'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.GTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Glob'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILike'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ILikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.IntDiv'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Is'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONArrayContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBContains'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONBExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtract'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.JSONExtractScalar'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Kwarg'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LT'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LTE'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Like'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.LikeAny'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mod'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Mul'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.NullSafeNEQ'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Or'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Overlaps'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Pow'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.RegexpLike'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDPipe'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.SimilarTo'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Slice'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Sub'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.Xor'>: <function TypeAnnotator.<dictcomp>.<lambda>>, <class 'sqlglot.expressions.ArraySize'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxDistinct'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Length'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Count'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Between'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.In'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Boolean'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DiToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Date'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateTrunc'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateStrToDate'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateFromParts'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentDatetime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log2'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Stddev'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Sqrt'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Avg'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevPop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.VariancePop'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeDivide'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ApproxQuantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Quantile'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Exp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StddevSamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Log10'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Variance'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ln'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Round'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Extract'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Levenshtein'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DatetimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Ceil'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Floor'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrPosition'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampDiff'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDiToDi'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTimestamp'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.CurrentTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.StrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeStrToTime'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampAdd'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimestampSub'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Day'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Month'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Year'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Week'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.GroupConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Concat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Trim'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.UnixToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ArrayConcat'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Lower'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.DateToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TsOrDsToDateStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToTimeStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Initcap'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Upper'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.ConcatWs'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.TimeToStr'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Substring'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.SafeConcat'>: <function _annotate_with_type_lambda.<locals>.<lambda>>, <class 'sqlglot.expressions.Anonymous'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Array'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.ArrayAgg'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Cast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Case'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Coalesce'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.DataType'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Distinct'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Filter'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.If'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Interval'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Least'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Literal'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Map'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Max'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Min'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Null'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.Sum'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.TryCast'>: <function TypeAnnotator.<lambda>>, <class 'sqlglot.expressions.VarMap'>: <function TypeAnnotator.<lambda>>}
NESTED_TYPES = {<Type.ARRAY: 'ARRAY'>}
COERCES_TO: Dict[sqlglot.expressions.DataType.Type, Set[sqlglot.expressions.DataType.Type]] = {<Type.TEXT: 'TEXT'>: set(), <Type.NVARCHAR: 'NVARCHAR'>: {<Type.TEXT: 'TEXT'>}, <Type.VARCHAR: 'VARCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>}, <Type.NCHAR: 'NCHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.CHAR: 'CHAR'>: {<Type.NVARCHAR: 'NVARCHAR'>, <Type.TEXT: 'TEXT'>, <Type.NCHAR: 'NCHAR'>, <Type.VARCHAR: 'VARCHAR'>}, <Type.DOUBLE: 'DOUBLE'>: set(), <Type.FLOAT: 'FLOAT'>: {<Type.DOUBLE: 'DOUBLE'>}, <Type.DECIMAL: 'DECIMAL'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>}, <Type.BIGINT: 'BIGINT'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>}, <Type.INT: 'INT'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.BIGINT: 'BIGINT'>}, <Type.SMALLINT: 'SMALLINT'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.INT: 'INT'>, <Type.BIGINT: 'BIGINT'>}, <Type.TINYINT: 'TINYINT'>: {<Type.FLOAT: 'FLOAT'>, <Type.DOUBLE: 'DOUBLE'>, <Type.DECIMAL: 'DECIMAL'>, <Type.INT: 'INT'>, <Type.BIGINT: 'BIGINT'>, <Type.SMALLINT: 'SMALLINT'>}, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>: set(), <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>: {<Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.TIMESTAMP: 'TIMESTAMP'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>}, <Type.DATETIME: 'DATETIME'>: {<Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>}, <Type.DATE: 'DATE'>: {<Type.DATETIME: 'DATETIME'>, <Type.TIMESTAMPTZ: 'TIMESTAMPTZ'>, <Type.TIMESTAMPLTZ: 'TIMESTAMPLTZ'>, <Type.TIMESTAMP: 'TIMESTAMP'>}}
schema
annotators
coerces_to
def annotate(self, expression: ~E) -> ~E:
246    def annotate(self, expression: E) -> E:
247        for scope in traverse_scope(expression):
248            selects = {}
249            for name, source in scope.sources.items():
250                if not isinstance(source, Scope):
251                    continue
252                if isinstance(source.expression, exp.UDTF):
253                    values = []
254
255                    if isinstance(source.expression, exp.Lateral):
256                        if isinstance(source.expression.this, exp.Explode):
257                            values = [source.expression.this.this]
258                    else:
259                        values = source.expression.expressions[0].expressions
260
261                    if not values:
262                        continue
263
264                    selects[name] = {
265                        alias: column
266                        for alias, column in zip(
267                            source.expression.alias_column_names,
268                            values,
269                        )
270                    }
271                else:
272                    selects[name] = {
273                        select.alias_or_name: select for select in source.expression.selects
274                    }
275
276            # First annotate the current scope's column references
277            for col in scope.columns:
278                if not col.table:
279                    continue
280
281                source = scope.sources.get(col.table)
282                if isinstance(source, exp.Table):
283                    col.type = self.schema.get_column_type(source, col)
284                elif source and col.table in selects and col.name in selects[col.table]:
285                    col.type = selects[col.table][col.name].type
286
287            # Then (possibly) annotate the remaining expressions in the scope
288            self._maybe_annotate(scope.expression)
289
290        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions