Edit on GitHub

sqlglot.optimizer.annotate_types

  1from sqlglot import exp
  2from sqlglot.helper import ensure_list, subclasses
  3from sqlglot.optimizer.scope import Scope, traverse_scope
  4from sqlglot.schema import ensure_schema
  5
  6
  7def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
  8    """
  9    Recursively infer & annotate types in an expression syntax tree against a schema.
 10    Assumes that we've already executed the optimizer's qualify_columns step.
 11
 12    Example:
 13        >>> import sqlglot
 14        >>> schema = {"y": {"cola": "SMALLINT"}}
 15        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 16        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 17        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 18        <Type.DOUBLE: 'DOUBLE'>
 19
 20    Args:
 21        expression (sqlglot.Expression): Expression to annotate.
 22        schema (dict|sqlglot.optimizer.Schema): Database schema.
 23        annotators (dict): Maps expression type to corresponding annotation function.
 24        coerces_to (dict): Maps expression type to set of types that it can be coerced into.
 25    Returns:
 26        sqlglot.Expression: expression annotated with types
 27    """
 28
 29    schema = ensure_schema(schema)
 30
 31    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 32
 33
 34class TypeAnnotator:
 35    ANNOTATORS = {
 36        **{
 37            expr_type: lambda self, expr: self._annotate_unary(expr)
 38            for expr_type in subclasses(exp.__name__, exp.Unary)
 39        },
 40        **{
 41            expr_type: lambda self, expr: self._annotate_binary(expr)
 42            for expr_type in subclasses(exp.__name__, exp.Binary)
 43        },
 44        exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 45        exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 46        exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
 47        exp.Alias: lambda self, expr: self._annotate_unary(expr),
 48        exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 49        exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 50        exp.Literal: lambda self, expr: self._annotate_literal(expr),
 51        exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 52        exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
 53        exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
 54        exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
 55            expr, exp.DataType.Type.BIGINT
 56        ),
 57        exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
 58        exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 59        exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 60        exp.Sum: lambda self, expr: self._annotate_by_args(
 61            expr, "this", "expressions", promote=True
 62        ),
 63        exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 64        exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
 65        exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 66        exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
 67            expr, exp.DataType.Type.DATETIME
 68        ),
 69        exp.CurrentTime: lambda self, expr: self._annotate_with_type(
 70            expr, exp.DataType.Type.TIMESTAMP
 71        ),
 72        exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
 73            expr, exp.DataType.Type.TIMESTAMP
 74        ),
 75        exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 76        exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 77        exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 78        exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
 79            expr, exp.DataType.Type.DATETIME
 80        ),
 81        exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
 82            expr, exp.DataType.Type.DATETIME
 83        ),
 84        exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 85        exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 86        exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
 87            expr, exp.DataType.Type.TIMESTAMP
 88        ),
 89        exp.TimestampSub: lambda self, expr: self._annotate_with_type(
 90            expr, exp.DataType.Type.TIMESTAMP
 91        ),
 92        exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 93        exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 94        exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 95        exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 96        exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
 97            expr, exp.DataType.Type.DATE
 98        ),
 99        exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
100            expr, exp.DataType.Type.VARCHAR
101        ),
102        exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
103        exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
104        exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
105        exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
106        exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
107        exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
108        exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
109        exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
110        exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
111        exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
112        exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
113        exp.GroupConcat: lambda self, expr: self._annotate_with_type(
114            expr, exp.DataType.Type.VARCHAR
115        ),
116        exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
117            expr, exp.DataType.Type.VARCHAR
118        ),
119        exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
120        exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
121        exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
122        exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
123        exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
124        exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
125        exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
126        exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
127        exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
128        exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
129        exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
130        exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
131        exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
132        exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
133            expr, exp.DataType.Type.DOUBLE
134        ),
135        exp.RegexpLike: lambda self, expr: self._annotate_with_type(
136            expr, exp.DataType.Type.BOOLEAN
137        ),
138        exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
139        exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
140        exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
141        exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
142        exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
143        exp.StrToTime: lambda self, expr: self._annotate_with_type(
144            expr, exp.DataType.Type.TIMESTAMP
145        ),
146        exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
147        exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
148        exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
149        exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
150        exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
151        exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
152            expr, exp.DataType.Type.VARCHAR
153        ),
154        exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
155            expr, exp.DataType.Type.DATE
156        ),
157        exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
158            expr, exp.DataType.Type.TIMESTAMP
159        ),
160        exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
161        exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
162            expr, exp.DataType.Type.VARCHAR
163        ),
164        exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
165        exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
166        exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
167        exp.UnixToTime: lambda self, expr: self._annotate_with_type(
168            expr, exp.DataType.Type.TIMESTAMP
169        ),
170        exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
171            expr, exp.DataType.Type.VARCHAR
172        ),
173        exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
174        exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
175        exp.VariancePop: lambda self, expr: self._annotate_with_type(
176            expr, exp.DataType.Type.DOUBLE
177        ),
178        exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
179        exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
180    }
181
182    # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
183    COERCES_TO = {
184        # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
185        exp.DataType.Type.TEXT: set(),
186        exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
187        exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
188        exp.DataType.Type.NCHAR: {
189            exp.DataType.Type.VARCHAR,
190            exp.DataType.Type.NVARCHAR,
191            exp.DataType.Type.TEXT,
192        },
193        exp.DataType.Type.CHAR: {
194            exp.DataType.Type.NCHAR,
195            exp.DataType.Type.VARCHAR,
196            exp.DataType.Type.NVARCHAR,
197            exp.DataType.Type.TEXT,
198        },
199        # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
200        exp.DataType.Type.DOUBLE: set(),
201        exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
202        exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
203        exp.DataType.Type.BIGINT: {
204            exp.DataType.Type.DECIMAL,
205            exp.DataType.Type.FLOAT,
206            exp.DataType.Type.DOUBLE,
207        },
208        exp.DataType.Type.INT: {
209            exp.DataType.Type.BIGINT,
210            exp.DataType.Type.DECIMAL,
211            exp.DataType.Type.FLOAT,
212            exp.DataType.Type.DOUBLE,
213        },
214        exp.DataType.Type.SMALLINT: {
215            exp.DataType.Type.INT,
216            exp.DataType.Type.BIGINT,
217            exp.DataType.Type.DECIMAL,
218            exp.DataType.Type.FLOAT,
219            exp.DataType.Type.DOUBLE,
220        },
221        exp.DataType.Type.TINYINT: {
222            exp.DataType.Type.SMALLINT,
223            exp.DataType.Type.INT,
224            exp.DataType.Type.BIGINT,
225            exp.DataType.Type.DECIMAL,
226            exp.DataType.Type.FLOAT,
227            exp.DataType.Type.DOUBLE,
228        },
229        # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
230        exp.DataType.Type.TIMESTAMPLTZ: set(),
231        exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
232        exp.DataType.Type.TIMESTAMP: {
233            exp.DataType.Type.TIMESTAMPTZ,
234            exp.DataType.Type.TIMESTAMPLTZ,
235        },
236        exp.DataType.Type.DATETIME: {
237            exp.DataType.Type.TIMESTAMP,
238            exp.DataType.Type.TIMESTAMPTZ,
239            exp.DataType.Type.TIMESTAMPLTZ,
240        },
241        exp.DataType.Type.DATE: {
242            exp.DataType.Type.DATETIME,
243            exp.DataType.Type.TIMESTAMP,
244            exp.DataType.Type.TIMESTAMPTZ,
245            exp.DataType.Type.TIMESTAMPLTZ,
246        },
247    }
248
249    TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
250
251    def __init__(self, schema=None, annotators=None, coerces_to=None):
252        self.schema = schema
253        self.annotators = annotators or self.ANNOTATORS
254        self.coerces_to = coerces_to or self.COERCES_TO
255
256    def annotate(self, expression):
257        if isinstance(expression, self.TRAVERSABLES):
258            for scope in traverse_scope(expression):
259                selects = {}
260                for name, source in scope.sources.items():
261                    if not isinstance(source, Scope):
262                        continue
263                    if isinstance(source.expression, exp.UDTF):
264                        values = []
265
266                        if isinstance(source.expression, exp.Lateral):
267                            if isinstance(source.expression.this, exp.Explode):
268                                values = [source.expression.this.this]
269                        else:
270                            values = source.expression.expressions[0].expressions
271
272                        if not values:
273                            continue
274
275                        selects[name] = {
276                            alias: column
277                            for alias, column in zip(
278                                source.expression.alias_column_names,
279                                values,
280                            )
281                        }
282                    else:
283                        selects[name] = {
284                            select.alias_or_name: select for select in source.expression.selects
285                        }
286                # First annotate the current scope's column references
287                for col in scope.columns:
288                    if not col.table:
289                        continue
290
291                    source = scope.sources.get(col.table)
292                    if isinstance(source, exp.Table):
293                        col.type = self.schema.get_column_type(source, col)
294                    elif source and col.table in selects and col.name in selects[col.table]:
295                        col.type = selects[col.table][col.name].type
296                # Then (possibly) annotate the remaining expressions in the scope
297                self._maybe_annotate(scope.expression)
298        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
299
300    def _maybe_annotate(self, expression):
301        if expression.type:
302            return expression  # We've already inferred the expression's type
303
304        annotator = self.annotators.get(expression.__class__)
305
306        return (
307            annotator(self, expression)
308            if annotator
309            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
310        )
311
312    def _annotate_args(self, expression):
313        for _, value in expression.iter_expressions():
314            self._maybe_annotate(value)
315
316        return expression
317
318    def _maybe_coerce(self, type1, type2):
319        # We propagate the NULL / UNKNOWN types upwards if found
320        if isinstance(type1, exp.DataType):
321            type1 = type1.this
322        if isinstance(type2, exp.DataType):
323            type2 = type2.this
324
325        if exp.DataType.Type.NULL in (type1, type2):
326            return exp.DataType.Type.NULL
327        if exp.DataType.Type.UNKNOWN in (type1, type2):
328            return exp.DataType.Type.UNKNOWN
329
330        return type2 if type2 in self.coerces_to.get(type1, {}) else type1
331
332    def _annotate_binary(self, expression):
333        self._annotate_args(expression)
334
335        left_type = expression.left.type.this
336        right_type = expression.right.type.this
337
338        if isinstance(expression, (exp.And, exp.Or)):
339            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
340                expression.type = exp.DataType.Type.NULL
341            elif exp.DataType.Type.NULL in (left_type, right_type):
342                expression.type = exp.DataType.build(
343                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
344                )
345            else:
346                expression.type = exp.DataType.Type.BOOLEAN
347        elif isinstance(expression, (exp.Condition, exp.Predicate)):
348            expression.type = exp.DataType.Type.BOOLEAN
349        else:
350            expression.type = self._maybe_coerce(left_type, right_type)
351
352        return expression
353
354    def _annotate_unary(self, expression):
355        self._annotate_args(expression)
356
357        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
358            expression.type = exp.DataType.Type.BOOLEAN
359        else:
360            expression.type = expression.this.type
361
362        return expression
363
364    def _annotate_literal(self, expression):
365        if expression.is_string:
366            expression.type = exp.DataType.Type.VARCHAR
367        elif expression.is_int:
368            expression.type = exp.DataType.Type.INT
369        else:
370            expression.type = exp.DataType.Type.DOUBLE
371
372        return expression
373
374    def _annotate_with_type(self, expression, target_type):
375        expression.type = target_type
376        return self._annotate_args(expression)
377
378    def _annotate_by_args(self, expression, *args, promote=False):
379        self._annotate_args(expression)
380        expressions = []
381        for arg in args:
382            arg_expr = expression.args.get(arg)
383            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
384
385        last_datatype = None
386        for expr in expressions:
387            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
388
389        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
390
391        if promote:
392            if expression.type.this in exp.DataType.INTEGER_TYPES:
393                expression.type = exp.DataType.Type.BIGINT
394            elif expression.type.this in exp.DataType.FLOAT_TYPES:
395                expression.type = exp.DataType.Type.DOUBLE
396
397        return expression
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
 8def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
 9    """
10    Recursively infer & annotate types in an expression syntax tree against a schema.
11    Assumes that we've already executed the optimizer's qualify_columns step.
12
13    Example:
14        >>> import sqlglot
15        >>> schema = {"y": {"cola": "SMALLINT"}}
16        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
17        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
18        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
19        <Type.DOUBLE: 'DOUBLE'>
20
21    Args:
22        expression (sqlglot.Expression): Expression to annotate.
23        schema (dict|sqlglot.optimizer.Schema): Database schema.
24        annotators (dict): Maps expression type to corresponding annotation function.
25        coerces_to (dict): Maps expression type to set of types that it can be coerced into.
26    Returns:
27        sqlglot.Expression: expression annotated with types
28    """
29
30    schema = ensure_schema(schema)
31
32    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Recursively infer & annotate types in an expression syntax tree against a schema. Assumes that we've already executed the optimizer's qualify_columns step.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression (sqlglot.Expression): Expression to annotate.
  • schema (dict|sqlglot.optimizer.Schema): Database schema.
  • annotators (dict): Maps expression type to corresponding annotation function.
  • coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:

sqlglot.Expression: expression annotated with types

class TypeAnnotator:
 35class TypeAnnotator:
 36    ANNOTATORS = {
 37        **{
 38            expr_type: lambda self, expr: self._annotate_unary(expr)
 39            for expr_type in subclasses(exp.__name__, exp.Unary)
 40        },
 41        **{
 42            expr_type: lambda self, expr: self._annotate_binary(expr)
 43            for expr_type in subclasses(exp.__name__, exp.Binary)
 44        },
 45        exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 46        exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 47        exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
 48        exp.Alias: lambda self, expr: self._annotate_unary(expr),
 49        exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 50        exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 51        exp.Literal: lambda self, expr: self._annotate_literal(expr),
 52        exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 53        exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
 54        exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
 55        exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
 56            expr, exp.DataType.Type.BIGINT
 57        ),
 58        exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
 59        exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 60        exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 61        exp.Sum: lambda self, expr: self._annotate_by_args(
 62            expr, "this", "expressions", promote=True
 63        ),
 64        exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 65        exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
 66        exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 67        exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
 68            expr, exp.DataType.Type.DATETIME
 69        ),
 70        exp.CurrentTime: lambda self, expr: self._annotate_with_type(
 71            expr, exp.DataType.Type.TIMESTAMP
 72        ),
 73        exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
 74            expr, exp.DataType.Type.TIMESTAMP
 75        ),
 76        exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 77        exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 78        exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 79        exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
 80            expr, exp.DataType.Type.DATETIME
 81        ),
 82        exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
 83            expr, exp.DataType.Type.DATETIME
 84        ),
 85        exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 86        exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 87        exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
 88            expr, exp.DataType.Type.TIMESTAMP
 89        ),
 90        exp.TimestampSub: lambda self, expr: self._annotate_with_type(
 91            expr, exp.DataType.Type.TIMESTAMP
 92        ),
 93        exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 94        exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 95        exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 96        exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 97        exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
 98            expr, exp.DataType.Type.DATE
 99        ),
100        exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
101            expr, exp.DataType.Type.VARCHAR
102        ),
103        exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
104        exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
105        exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
106        exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
107        exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
108        exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
109        exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
110        exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
111        exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
112        exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
113        exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
114        exp.GroupConcat: lambda self, expr: self._annotate_with_type(
115            expr, exp.DataType.Type.VARCHAR
116        ),
117        exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
118            expr, exp.DataType.Type.VARCHAR
119        ),
120        exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
121        exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
122        exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
123        exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
124        exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
125        exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
126        exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
127        exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
128        exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
129        exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
130        exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
131        exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
132        exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
133        exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
134            expr, exp.DataType.Type.DOUBLE
135        ),
136        exp.RegexpLike: lambda self, expr: self._annotate_with_type(
137            expr, exp.DataType.Type.BOOLEAN
138        ),
139        exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
140        exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
141        exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
142        exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
143        exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
144        exp.StrToTime: lambda self, expr: self._annotate_with_type(
145            expr, exp.DataType.Type.TIMESTAMP
146        ),
147        exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
148        exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
149        exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
150        exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
151        exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
152        exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
153            expr, exp.DataType.Type.VARCHAR
154        ),
155        exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
156            expr, exp.DataType.Type.DATE
157        ),
158        exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
159            expr, exp.DataType.Type.TIMESTAMP
160        ),
161        exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
162        exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
163            expr, exp.DataType.Type.VARCHAR
164        ),
165        exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
166        exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
167        exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
168        exp.UnixToTime: lambda self, expr: self._annotate_with_type(
169            expr, exp.DataType.Type.TIMESTAMP
170        ),
171        exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
172            expr, exp.DataType.Type.VARCHAR
173        ),
174        exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
175        exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
176        exp.VariancePop: lambda self, expr: self._annotate_with_type(
177            expr, exp.DataType.Type.DOUBLE
178        ),
179        exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
180        exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
181    }
182
183    # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
184    COERCES_TO = {
185        # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
186        exp.DataType.Type.TEXT: set(),
187        exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
188        exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
189        exp.DataType.Type.NCHAR: {
190            exp.DataType.Type.VARCHAR,
191            exp.DataType.Type.NVARCHAR,
192            exp.DataType.Type.TEXT,
193        },
194        exp.DataType.Type.CHAR: {
195            exp.DataType.Type.NCHAR,
196            exp.DataType.Type.VARCHAR,
197            exp.DataType.Type.NVARCHAR,
198            exp.DataType.Type.TEXT,
199        },
200        # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
201        exp.DataType.Type.DOUBLE: set(),
202        exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
203        exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
204        exp.DataType.Type.BIGINT: {
205            exp.DataType.Type.DECIMAL,
206            exp.DataType.Type.FLOAT,
207            exp.DataType.Type.DOUBLE,
208        },
209        exp.DataType.Type.INT: {
210            exp.DataType.Type.BIGINT,
211            exp.DataType.Type.DECIMAL,
212            exp.DataType.Type.FLOAT,
213            exp.DataType.Type.DOUBLE,
214        },
215        exp.DataType.Type.SMALLINT: {
216            exp.DataType.Type.INT,
217            exp.DataType.Type.BIGINT,
218            exp.DataType.Type.DECIMAL,
219            exp.DataType.Type.FLOAT,
220            exp.DataType.Type.DOUBLE,
221        },
222        exp.DataType.Type.TINYINT: {
223            exp.DataType.Type.SMALLINT,
224            exp.DataType.Type.INT,
225            exp.DataType.Type.BIGINT,
226            exp.DataType.Type.DECIMAL,
227            exp.DataType.Type.FLOAT,
228            exp.DataType.Type.DOUBLE,
229        },
230        # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
231        exp.DataType.Type.TIMESTAMPLTZ: set(),
232        exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
233        exp.DataType.Type.TIMESTAMP: {
234            exp.DataType.Type.TIMESTAMPTZ,
235            exp.DataType.Type.TIMESTAMPLTZ,
236        },
237        exp.DataType.Type.DATETIME: {
238            exp.DataType.Type.TIMESTAMP,
239            exp.DataType.Type.TIMESTAMPTZ,
240            exp.DataType.Type.TIMESTAMPLTZ,
241        },
242        exp.DataType.Type.DATE: {
243            exp.DataType.Type.DATETIME,
244            exp.DataType.Type.TIMESTAMP,
245            exp.DataType.Type.TIMESTAMPTZ,
246            exp.DataType.Type.TIMESTAMPLTZ,
247        },
248    }
249
250    TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
251
252    def __init__(self, schema=None, annotators=None, coerces_to=None):
253        self.schema = schema
254        self.annotators = annotators or self.ANNOTATORS
255        self.coerces_to = coerces_to or self.COERCES_TO
256
257    def annotate(self, expression):
258        if isinstance(expression, self.TRAVERSABLES):
259            for scope in traverse_scope(expression):
260                selects = {}
261                for name, source in scope.sources.items():
262                    if not isinstance(source, Scope):
263                        continue
264                    if isinstance(source.expression, exp.UDTF):
265                        values = []
266
267                        if isinstance(source.expression, exp.Lateral):
268                            if isinstance(source.expression.this, exp.Explode):
269                                values = [source.expression.this.this]
270                        else:
271                            values = source.expression.expressions[0].expressions
272
273                        if not values:
274                            continue
275
276                        selects[name] = {
277                            alias: column
278                            for alias, column in zip(
279                                source.expression.alias_column_names,
280                                values,
281                            )
282                        }
283                    else:
284                        selects[name] = {
285                            select.alias_or_name: select for select in source.expression.selects
286                        }
287                # First annotate the current scope's column references
288                for col in scope.columns:
289                    if not col.table:
290                        continue
291
292                    source = scope.sources.get(col.table)
293                    if isinstance(source, exp.Table):
294                        col.type = self.schema.get_column_type(source, col)
295                    elif source and col.table in selects and col.name in selects[col.table]:
296                        col.type = selects[col.table][col.name].type
297                # Then (possibly) annotate the remaining expressions in the scope
298                self._maybe_annotate(scope.expression)
299        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
300
301    def _maybe_annotate(self, expression):
302        if expression.type:
303            return expression  # We've already inferred the expression's type
304
305        annotator = self.annotators.get(expression.__class__)
306
307        return (
308            annotator(self, expression)
309            if annotator
310            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
311        )
312
313    def _annotate_args(self, expression):
314        for _, value in expression.iter_expressions():
315            self._maybe_annotate(value)
316
317        return expression
318
319    def _maybe_coerce(self, type1, type2):
320        # We propagate the NULL / UNKNOWN types upwards if found
321        if isinstance(type1, exp.DataType):
322            type1 = type1.this
323        if isinstance(type2, exp.DataType):
324            type2 = type2.this
325
326        if exp.DataType.Type.NULL in (type1, type2):
327            return exp.DataType.Type.NULL
328        if exp.DataType.Type.UNKNOWN in (type1, type2):
329            return exp.DataType.Type.UNKNOWN
330
331        return type2 if type2 in self.coerces_to.get(type1, {}) else type1
332
333    def _annotate_binary(self, expression):
334        self._annotate_args(expression)
335
336        left_type = expression.left.type.this
337        right_type = expression.right.type.this
338
339        if isinstance(expression, (exp.And, exp.Or)):
340            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
341                expression.type = exp.DataType.Type.NULL
342            elif exp.DataType.Type.NULL in (left_type, right_type):
343                expression.type = exp.DataType.build(
344                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
345                )
346            else:
347                expression.type = exp.DataType.Type.BOOLEAN
348        elif isinstance(expression, (exp.Condition, exp.Predicate)):
349            expression.type = exp.DataType.Type.BOOLEAN
350        else:
351            expression.type = self._maybe_coerce(left_type, right_type)
352
353        return expression
354
355    def _annotate_unary(self, expression):
356        self._annotate_args(expression)
357
358        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
359            expression.type = exp.DataType.Type.BOOLEAN
360        else:
361            expression.type = expression.this.type
362
363        return expression
364
365    def _annotate_literal(self, expression):
366        if expression.is_string:
367            expression.type = exp.DataType.Type.VARCHAR
368        elif expression.is_int:
369            expression.type = exp.DataType.Type.INT
370        else:
371            expression.type = exp.DataType.Type.DOUBLE
372
373        return expression
374
375    def _annotate_with_type(self, expression, target_type):
376        expression.type = target_type
377        return self._annotate_args(expression)
378
379    def _annotate_by_args(self, expression, *args, promote=False):
380        self._annotate_args(expression)
381        expressions = []
382        for arg in args:
383            arg_expr = expression.args.get(arg)
384            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
385
386        last_datatype = None
387        for expr in expressions:
388            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
389
390        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
391
392        if promote:
393            if expression.type.this in exp.DataType.INTEGER_TYPES:
394                expression.type = exp.DataType.Type.BIGINT
395            elif expression.type.this in exp.DataType.FLOAT_TYPES:
396                expression.type = exp.DataType.Type.DOUBLE
397
398        return expression
TypeAnnotator(schema=None, annotators=None, coerces_to=None)
252    def __init__(self, schema=None, annotators=None, coerces_to=None):
253        self.schema = schema
254        self.annotators = annotators or self.ANNOTATORS
255        self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression):
257    def annotate(self, expression):
258        if isinstance(expression, self.TRAVERSABLES):
259            for scope in traverse_scope(expression):
260                selects = {}
261                for name, source in scope.sources.items():
262                    if not isinstance(source, Scope):
263                        continue
264                    if isinstance(source.expression, exp.UDTF):
265                        values = []
266
267                        if isinstance(source.expression, exp.Lateral):
268                            if isinstance(source.expression.this, exp.Explode):
269                                values = [source.expression.this.this]
270                        else:
271                            values = source.expression.expressions[0].expressions
272
273                        if not values:
274                            continue
275
276                        selects[name] = {
277                            alias: column
278                            for alias, column in zip(
279                                source.expression.alias_column_names,
280                                values,
281                            )
282                        }
283                    else:
284                        selects[name] = {
285                            select.alias_or_name: select for select in source.expression.selects
286                        }
287                # First annotate the current scope's column references
288                for col in scope.columns:
289                    if not col.table:
290                        continue
291
292                    source = scope.sources.get(col.table)
293                    if isinstance(source, exp.Table):
294                        col.type = self.schema.get_column_type(source, col)
295                    elif source and col.table in selects and col.name in selects[col.table]:
296                        col.type = selects[col.table][col.name].type
297                # Then (possibly) annotate the remaining expressions in the scope
298                self._maybe_annotate(scope.expression)
299        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions