Edit on GitHub

sqlglot.optimizer.annotate_types

  1from sqlglot import exp
  2from sqlglot.helper import ensure_collection, ensure_list, subclasses
  3from sqlglot.optimizer.scope import Scope, traverse_scope
  4from sqlglot.schema import ensure_schema
  5
  6
  7def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
  8    """
  9    Recursively infer & annotate types in an expression syntax tree against a schema.
 10    Assumes that we've already executed the optimizer's qualify_columns step.
 11
 12    Example:
 13        >>> import sqlglot
 14        >>> schema = {"y": {"cola": "SMALLINT"}}
 15        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 16        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 17        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 18        <Type.DOUBLE: 'DOUBLE'>
 19
 20    Args:
 21        expression (sqlglot.Expression): Expression to annotate.
 22        schema (dict|sqlglot.optimizer.Schema): Database schema.
 23        annotators (dict): Maps expression type to corresponding annotation function.
 24        coerces_to (dict): Maps expression type to set of types that it can be coerced into.
 25    Returns:
 26        sqlglot.Expression: expression annotated with types
 27    """
 28
 29    schema = ensure_schema(schema)
 30
 31    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 32
 33
 34class TypeAnnotator:
 35    ANNOTATORS = {
 36        **{
 37            expr_type: lambda self, expr: self._annotate_unary(expr)
 38            for expr_type in subclasses(exp.__name__, exp.Unary)
 39        },
 40        **{
 41            expr_type: lambda self, expr: self._annotate_binary(expr)
 42            for expr_type in subclasses(exp.__name__, exp.Binary)
 43        },
 44        exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 45        exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 46        exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
 47        exp.Alias: lambda self, expr: self._annotate_unary(expr),
 48        exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 49        exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 50        exp.Literal: lambda self, expr: self._annotate_literal(expr),
 51        exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 52        exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
 53        exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
 54        exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
 55            expr, exp.DataType.Type.BIGINT
 56        ),
 57        exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
 58        exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 59        exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 60        exp.Sum: lambda self, expr: self._annotate_by_args(
 61            expr, "this", "expressions", promote=True
 62        ),
 63        exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 64        exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
 65        exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 66        exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
 67            expr, exp.DataType.Type.DATETIME
 68        ),
 69        exp.CurrentTime: lambda self, expr: self._annotate_with_type(
 70            expr, exp.DataType.Type.TIMESTAMP
 71        ),
 72        exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
 73            expr, exp.DataType.Type.TIMESTAMP
 74        ),
 75        exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 76        exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 77        exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 78        exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
 79            expr, exp.DataType.Type.DATETIME
 80        ),
 81        exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
 82            expr, exp.DataType.Type.DATETIME
 83        ),
 84        exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 85        exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 86        exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
 87            expr, exp.DataType.Type.TIMESTAMP
 88        ),
 89        exp.TimestampSub: lambda self, expr: self._annotate_with_type(
 90            expr, exp.DataType.Type.TIMESTAMP
 91        ),
 92        exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 93        exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 94        exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 95        exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 96        exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
 97            expr, exp.DataType.Type.DATE
 98        ),
 99        exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
100            expr, exp.DataType.Type.VARCHAR
101        ),
102        exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
103        exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
104        exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
105        exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
106        exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
107        exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
108        exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
109        exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
110        exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
111        exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
112        exp.GroupConcat: lambda self, expr: self._annotate_with_type(
113            expr, exp.DataType.Type.VARCHAR
114        ),
115        exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
116            expr, exp.DataType.Type.VARCHAR
117        ),
118        exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
119        exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
120        exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
121        exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
122        exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
123        exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
124        exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
125        exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
126        exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
127        exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
128        exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
129        exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
130        exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
131            expr, exp.DataType.Type.DOUBLE
132        ),
133        exp.RegexpLike: lambda self, expr: self._annotate_with_type(
134            expr, exp.DataType.Type.BOOLEAN
135        ),
136        exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
137        exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
138        exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
139        exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
140        exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
141        exp.StrToTime: lambda self, expr: self._annotate_with_type(
142            expr, exp.DataType.Type.TIMESTAMP
143        ),
144        exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
145        exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
146        exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
147        exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
148        exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
149        exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
150            expr, exp.DataType.Type.VARCHAR
151        ),
152        exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
153            expr, exp.DataType.Type.DATE
154        ),
155        exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
156            expr, exp.DataType.Type.TIMESTAMP
157        ),
158        exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
159        exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
160            expr, exp.DataType.Type.VARCHAR
161        ),
162        exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
163        exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
164        exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
165        exp.UnixToTime: lambda self, expr: self._annotate_with_type(
166            expr, exp.DataType.Type.TIMESTAMP
167        ),
168        exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
169            expr, exp.DataType.Type.VARCHAR
170        ),
171        exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
172        exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
173        exp.VariancePop: lambda self, expr: self._annotate_with_type(
174            expr, exp.DataType.Type.DOUBLE
175        ),
176        exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
177        exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
178    }
179
180    # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
181    COERCES_TO = {
182        # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
183        exp.DataType.Type.TEXT: set(),
184        exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
185        exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
186        exp.DataType.Type.NCHAR: {
187            exp.DataType.Type.VARCHAR,
188            exp.DataType.Type.NVARCHAR,
189            exp.DataType.Type.TEXT,
190        },
191        exp.DataType.Type.CHAR: {
192            exp.DataType.Type.NCHAR,
193            exp.DataType.Type.VARCHAR,
194            exp.DataType.Type.NVARCHAR,
195            exp.DataType.Type.TEXT,
196        },
197        # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
198        exp.DataType.Type.DOUBLE: set(),
199        exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
200        exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
201        exp.DataType.Type.BIGINT: {
202            exp.DataType.Type.DECIMAL,
203            exp.DataType.Type.FLOAT,
204            exp.DataType.Type.DOUBLE,
205        },
206        exp.DataType.Type.INT: {
207            exp.DataType.Type.BIGINT,
208            exp.DataType.Type.DECIMAL,
209            exp.DataType.Type.FLOAT,
210            exp.DataType.Type.DOUBLE,
211        },
212        exp.DataType.Type.SMALLINT: {
213            exp.DataType.Type.INT,
214            exp.DataType.Type.BIGINT,
215            exp.DataType.Type.DECIMAL,
216            exp.DataType.Type.FLOAT,
217            exp.DataType.Type.DOUBLE,
218        },
219        exp.DataType.Type.TINYINT: {
220            exp.DataType.Type.SMALLINT,
221            exp.DataType.Type.INT,
222            exp.DataType.Type.BIGINT,
223            exp.DataType.Type.DECIMAL,
224            exp.DataType.Type.FLOAT,
225            exp.DataType.Type.DOUBLE,
226        },
227        # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
228        exp.DataType.Type.TIMESTAMPLTZ: set(),
229        exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
230        exp.DataType.Type.TIMESTAMP: {
231            exp.DataType.Type.TIMESTAMPTZ,
232            exp.DataType.Type.TIMESTAMPLTZ,
233        },
234        exp.DataType.Type.DATETIME: {
235            exp.DataType.Type.TIMESTAMP,
236            exp.DataType.Type.TIMESTAMPTZ,
237            exp.DataType.Type.TIMESTAMPLTZ,
238        },
239        exp.DataType.Type.DATE: {
240            exp.DataType.Type.DATETIME,
241            exp.DataType.Type.TIMESTAMP,
242            exp.DataType.Type.TIMESTAMPTZ,
243            exp.DataType.Type.TIMESTAMPLTZ,
244        },
245    }
246
247    TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
248
249    def __init__(self, schema=None, annotators=None, coerces_to=None):
250        self.schema = schema
251        self.annotators = annotators or self.ANNOTATORS
252        self.coerces_to = coerces_to or self.COERCES_TO
253
254    def annotate(self, expression):
255        if isinstance(expression, self.TRAVERSABLES):
256            for scope in traverse_scope(expression):
257                selects = {}
258                for name, source in scope.sources.items():
259                    if not isinstance(source, Scope):
260                        continue
261                    if isinstance(source.expression, exp.UDTF):
262                        values = []
263
264                        if isinstance(source.expression, exp.Lateral):
265                            if isinstance(source.expression.this, exp.Explode):
266                                values = [source.expression.this.this]
267                        else:
268                            values = source.expression.expressions[0].expressions
269
270                        if not values:
271                            continue
272
273                        selects[name] = {
274                            alias: column
275                            for alias, column in zip(
276                                source.expression.alias_column_names,
277                                values,
278                            )
279                        }
280                    else:
281                        selects[name] = {
282                            select.alias_or_name: select for select in source.expression.selects
283                        }
284                # First annotate the current scope's column references
285                for col in scope.columns:
286                    if not col.table:
287                        continue
288
289                    source = scope.sources.get(col.table)
290                    if isinstance(source, exp.Table):
291                        col.type = self.schema.get_column_type(source, col)
292                    elif source and col.table in selects and col.name in selects[col.table]:
293                        col.type = selects[col.table][col.name].type
294                # Then (possibly) annotate the remaining expressions in the scope
295                self._maybe_annotate(scope.expression)
296        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
297
298    def _maybe_annotate(self, expression):
299        if not isinstance(expression, exp.Expression):
300            return None
301
302        if expression.type:
303            return expression  # We've already inferred the expression's type
304
305        annotator = self.annotators.get(expression.__class__)
306
307        return (
308            annotator(self, expression)
309            if annotator
310            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
311        )
312
313    def _annotate_args(self, expression):
314        for value in expression.args.values():
315            for v in ensure_collection(value):
316                self._maybe_annotate(v)
317
318        return expression
319
320    def _maybe_coerce(self, type1, type2):
321        # We propagate the NULL / UNKNOWN types upwards if found
322        if isinstance(type1, exp.DataType):
323            type1 = type1.this
324        if isinstance(type2, exp.DataType):
325            type2 = type2.this
326
327        if exp.DataType.Type.NULL in (type1, type2):
328            return exp.DataType.Type.NULL
329        if exp.DataType.Type.UNKNOWN in (type1, type2):
330            return exp.DataType.Type.UNKNOWN
331
332        return type2 if type2 in self.coerces_to.get(type1, {}) else type1
333
334    def _annotate_binary(self, expression):
335        self._annotate_args(expression)
336
337        left_type = expression.left.type.this
338        right_type = expression.right.type.this
339
340        if isinstance(expression, (exp.And, exp.Or)):
341            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
342                expression.type = exp.DataType.Type.NULL
343            elif exp.DataType.Type.NULL in (left_type, right_type):
344                expression.type = exp.DataType.build(
345                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
346                )
347            else:
348                expression.type = exp.DataType.Type.BOOLEAN
349        elif isinstance(expression, (exp.Condition, exp.Predicate)):
350            expression.type = exp.DataType.Type.BOOLEAN
351        else:
352            expression.type = self._maybe_coerce(left_type, right_type)
353
354        return expression
355
356    def _annotate_unary(self, expression):
357        self._annotate_args(expression)
358
359        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
360            expression.type = exp.DataType.Type.BOOLEAN
361        else:
362            expression.type = expression.this.type
363
364        return expression
365
366    def _annotate_literal(self, expression):
367        if expression.is_string:
368            expression.type = exp.DataType.Type.VARCHAR
369        elif expression.is_int:
370            expression.type = exp.DataType.Type.INT
371        else:
372            expression.type = exp.DataType.Type.DOUBLE
373
374        return expression
375
376    def _annotate_with_type(self, expression, target_type):
377        expression.type = target_type
378        return self._annotate_args(expression)
379
380    def _annotate_by_args(self, expression, *args, promote=False):
381        self._annotate_args(expression)
382        expressions = []
383        for arg in args:
384            arg_expr = expression.args.get(arg)
385            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
386
387        last_datatype = None
388        for expr in expressions:
389            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
390
391        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
392
393        if promote:
394            if expression.type.this in exp.DataType.INTEGER_TYPES:
395                expression.type = exp.DataType.Type.BIGINT
396            elif expression.type.this in exp.DataType.FLOAT_TYPES:
397                expression.type = exp.DataType.Type.DOUBLE
398
399        return expression
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
 8def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
 9    """
10    Recursively infer & annotate types in an expression syntax tree against a schema.
11    Assumes that we've already executed the optimizer's qualify_columns step.
12
13    Example:
14        >>> import sqlglot
15        >>> schema = {"y": {"cola": "SMALLINT"}}
16        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
17        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
18        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
19        <Type.DOUBLE: 'DOUBLE'>
20
21    Args:
22        expression (sqlglot.Expression): Expression to annotate.
23        schema (dict|sqlglot.optimizer.Schema): Database schema.
24        annotators (dict): Maps expression type to corresponding annotation function.
25        coerces_to (dict): Maps expression type to set of types that it can be coerced into.
26    Returns:
27        sqlglot.Expression: expression annotated with types
28    """
29
30    schema = ensure_schema(schema)
31
32    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Recursively infer & annotate types in an expression syntax tree against a schema. Assumes that we've already executed the optimizer's qualify_columns step.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression (sqlglot.Expression): Expression to annotate.
  • schema (dict|sqlglot.optimizer.Schema): Database schema.
  • annotators (dict): Maps expression type to corresponding annotation function.
  • coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:

sqlglot.Expression: expression annotated with types

class TypeAnnotator:
 35class TypeAnnotator:
 36    ANNOTATORS = {
 37        **{
 38            expr_type: lambda self, expr: self._annotate_unary(expr)
 39            for expr_type in subclasses(exp.__name__, exp.Unary)
 40        },
 41        **{
 42            expr_type: lambda self, expr: self._annotate_binary(expr)
 43            for expr_type in subclasses(exp.__name__, exp.Binary)
 44        },
 45        exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 46        exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 47        exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
 48        exp.Alias: lambda self, expr: self._annotate_unary(expr),
 49        exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 50        exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 51        exp.Literal: lambda self, expr: self._annotate_literal(expr),
 52        exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 53        exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
 54        exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
 55        exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
 56            expr, exp.DataType.Type.BIGINT
 57        ),
 58        exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
 59        exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 60        exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 61        exp.Sum: lambda self, expr: self._annotate_by_args(
 62            expr, "this", "expressions", promote=True
 63        ),
 64        exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 65        exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
 66        exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 67        exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
 68            expr, exp.DataType.Type.DATETIME
 69        ),
 70        exp.CurrentTime: lambda self, expr: self._annotate_with_type(
 71            expr, exp.DataType.Type.TIMESTAMP
 72        ),
 73        exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
 74            expr, exp.DataType.Type.TIMESTAMP
 75        ),
 76        exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 77        exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 78        exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 79        exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
 80            expr, exp.DataType.Type.DATETIME
 81        ),
 82        exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
 83            expr, exp.DataType.Type.DATETIME
 84        ),
 85        exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 86        exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 87        exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
 88            expr, exp.DataType.Type.TIMESTAMP
 89        ),
 90        exp.TimestampSub: lambda self, expr: self._annotate_with_type(
 91            expr, exp.DataType.Type.TIMESTAMP
 92        ),
 93        exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 94        exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 95        exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 96        exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 97        exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
 98            expr, exp.DataType.Type.DATE
 99        ),
100        exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
101            expr, exp.DataType.Type.VARCHAR
102        ),
103        exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
104        exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
105        exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
106        exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
107        exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
108        exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
109        exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
110        exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
111        exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
112        exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
113        exp.GroupConcat: lambda self, expr: self._annotate_with_type(
114            expr, exp.DataType.Type.VARCHAR
115        ),
116        exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
117            expr, exp.DataType.Type.VARCHAR
118        ),
119        exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
120        exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
121        exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
122        exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
123        exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
124        exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
125        exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
126        exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
127        exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
128        exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
129        exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
130        exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
131        exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
132            expr, exp.DataType.Type.DOUBLE
133        ),
134        exp.RegexpLike: lambda self, expr: self._annotate_with_type(
135            expr, exp.DataType.Type.BOOLEAN
136        ),
137        exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
138        exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
139        exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
140        exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
141        exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
142        exp.StrToTime: lambda self, expr: self._annotate_with_type(
143            expr, exp.DataType.Type.TIMESTAMP
144        ),
145        exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
146        exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
147        exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
148        exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
149        exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
150        exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
151            expr, exp.DataType.Type.VARCHAR
152        ),
153        exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
154            expr, exp.DataType.Type.DATE
155        ),
156        exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
157            expr, exp.DataType.Type.TIMESTAMP
158        ),
159        exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
160        exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
161            expr, exp.DataType.Type.VARCHAR
162        ),
163        exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
164        exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
165        exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
166        exp.UnixToTime: lambda self, expr: self._annotate_with_type(
167            expr, exp.DataType.Type.TIMESTAMP
168        ),
169        exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
170            expr, exp.DataType.Type.VARCHAR
171        ),
172        exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
173        exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
174        exp.VariancePop: lambda self, expr: self._annotate_with_type(
175            expr, exp.DataType.Type.DOUBLE
176        ),
177        exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
178        exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
179    }
180
181    # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
182    COERCES_TO = {
183        # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
184        exp.DataType.Type.TEXT: set(),
185        exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
186        exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
187        exp.DataType.Type.NCHAR: {
188            exp.DataType.Type.VARCHAR,
189            exp.DataType.Type.NVARCHAR,
190            exp.DataType.Type.TEXT,
191        },
192        exp.DataType.Type.CHAR: {
193            exp.DataType.Type.NCHAR,
194            exp.DataType.Type.VARCHAR,
195            exp.DataType.Type.NVARCHAR,
196            exp.DataType.Type.TEXT,
197        },
198        # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
199        exp.DataType.Type.DOUBLE: set(),
200        exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
201        exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
202        exp.DataType.Type.BIGINT: {
203            exp.DataType.Type.DECIMAL,
204            exp.DataType.Type.FLOAT,
205            exp.DataType.Type.DOUBLE,
206        },
207        exp.DataType.Type.INT: {
208            exp.DataType.Type.BIGINT,
209            exp.DataType.Type.DECIMAL,
210            exp.DataType.Type.FLOAT,
211            exp.DataType.Type.DOUBLE,
212        },
213        exp.DataType.Type.SMALLINT: {
214            exp.DataType.Type.INT,
215            exp.DataType.Type.BIGINT,
216            exp.DataType.Type.DECIMAL,
217            exp.DataType.Type.FLOAT,
218            exp.DataType.Type.DOUBLE,
219        },
220        exp.DataType.Type.TINYINT: {
221            exp.DataType.Type.SMALLINT,
222            exp.DataType.Type.INT,
223            exp.DataType.Type.BIGINT,
224            exp.DataType.Type.DECIMAL,
225            exp.DataType.Type.FLOAT,
226            exp.DataType.Type.DOUBLE,
227        },
228        # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
229        exp.DataType.Type.TIMESTAMPLTZ: set(),
230        exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
231        exp.DataType.Type.TIMESTAMP: {
232            exp.DataType.Type.TIMESTAMPTZ,
233            exp.DataType.Type.TIMESTAMPLTZ,
234        },
235        exp.DataType.Type.DATETIME: {
236            exp.DataType.Type.TIMESTAMP,
237            exp.DataType.Type.TIMESTAMPTZ,
238            exp.DataType.Type.TIMESTAMPLTZ,
239        },
240        exp.DataType.Type.DATE: {
241            exp.DataType.Type.DATETIME,
242            exp.DataType.Type.TIMESTAMP,
243            exp.DataType.Type.TIMESTAMPTZ,
244            exp.DataType.Type.TIMESTAMPLTZ,
245        },
246    }
247
248    TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
249
250    def __init__(self, schema=None, annotators=None, coerces_to=None):
251        self.schema = schema
252        self.annotators = annotators or self.ANNOTATORS
253        self.coerces_to = coerces_to or self.COERCES_TO
254
255    def annotate(self, expression):
256        if isinstance(expression, self.TRAVERSABLES):
257            for scope in traverse_scope(expression):
258                selects = {}
259                for name, source in scope.sources.items():
260                    if not isinstance(source, Scope):
261                        continue
262                    if isinstance(source.expression, exp.UDTF):
263                        values = []
264
265                        if isinstance(source.expression, exp.Lateral):
266                            if isinstance(source.expression.this, exp.Explode):
267                                values = [source.expression.this.this]
268                        else:
269                            values = source.expression.expressions[0].expressions
270
271                        if not values:
272                            continue
273
274                        selects[name] = {
275                            alias: column
276                            for alias, column in zip(
277                                source.expression.alias_column_names,
278                                values,
279                            )
280                        }
281                    else:
282                        selects[name] = {
283                            select.alias_or_name: select for select in source.expression.selects
284                        }
285                # First annotate the current scope's column references
286                for col in scope.columns:
287                    if not col.table:
288                        continue
289
290                    source = scope.sources.get(col.table)
291                    if isinstance(source, exp.Table):
292                        col.type = self.schema.get_column_type(source, col)
293                    elif source and col.table in selects and col.name in selects[col.table]:
294                        col.type = selects[col.table][col.name].type
295                # Then (possibly) annotate the remaining expressions in the scope
296                self._maybe_annotate(scope.expression)
297        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
298
299    def _maybe_annotate(self, expression):
300        if not isinstance(expression, exp.Expression):
301            return None
302
303        if expression.type:
304            return expression  # We've already inferred the expression's type
305
306        annotator = self.annotators.get(expression.__class__)
307
308        return (
309            annotator(self, expression)
310            if annotator
311            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
312        )
313
314    def _annotate_args(self, expression):
315        for value in expression.args.values():
316            for v in ensure_collection(value):
317                self._maybe_annotate(v)
318
319        return expression
320
321    def _maybe_coerce(self, type1, type2):
322        # We propagate the NULL / UNKNOWN types upwards if found
323        if isinstance(type1, exp.DataType):
324            type1 = type1.this
325        if isinstance(type2, exp.DataType):
326            type2 = type2.this
327
328        if exp.DataType.Type.NULL in (type1, type2):
329            return exp.DataType.Type.NULL
330        if exp.DataType.Type.UNKNOWN in (type1, type2):
331            return exp.DataType.Type.UNKNOWN
332
333        return type2 if type2 in self.coerces_to.get(type1, {}) else type1
334
335    def _annotate_binary(self, expression):
336        self._annotate_args(expression)
337
338        left_type = expression.left.type.this
339        right_type = expression.right.type.this
340
341        if isinstance(expression, (exp.And, exp.Or)):
342            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
343                expression.type = exp.DataType.Type.NULL
344            elif exp.DataType.Type.NULL in (left_type, right_type):
345                expression.type = exp.DataType.build(
346                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
347                )
348            else:
349                expression.type = exp.DataType.Type.BOOLEAN
350        elif isinstance(expression, (exp.Condition, exp.Predicate)):
351            expression.type = exp.DataType.Type.BOOLEAN
352        else:
353            expression.type = self._maybe_coerce(left_type, right_type)
354
355        return expression
356
357    def _annotate_unary(self, expression):
358        self._annotate_args(expression)
359
360        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
361            expression.type = exp.DataType.Type.BOOLEAN
362        else:
363            expression.type = expression.this.type
364
365        return expression
366
367    def _annotate_literal(self, expression):
368        if expression.is_string:
369            expression.type = exp.DataType.Type.VARCHAR
370        elif expression.is_int:
371            expression.type = exp.DataType.Type.INT
372        else:
373            expression.type = exp.DataType.Type.DOUBLE
374
375        return expression
376
377    def _annotate_with_type(self, expression, target_type):
378        expression.type = target_type
379        return self._annotate_args(expression)
380
381    def _annotate_by_args(self, expression, *args, promote=False):
382        self._annotate_args(expression)
383        expressions = []
384        for arg in args:
385            arg_expr = expression.args.get(arg)
386            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
387
388        last_datatype = None
389        for expr in expressions:
390            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
391
392        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
393
394        if promote:
395            if expression.type.this in exp.DataType.INTEGER_TYPES:
396                expression.type = exp.DataType.Type.BIGINT
397            elif expression.type.this in exp.DataType.FLOAT_TYPES:
398                expression.type = exp.DataType.Type.DOUBLE
399
400        return expression
TypeAnnotator(schema=None, annotators=None, coerces_to=None)
250    def __init__(self, schema=None, annotators=None, coerces_to=None):
251        self.schema = schema
252        self.annotators = annotators or self.ANNOTATORS
253        self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression):
255    def annotate(self, expression):
256        if isinstance(expression, self.TRAVERSABLES):
257            for scope in traverse_scope(expression):
258                selects = {}
259                for name, source in scope.sources.items():
260                    if not isinstance(source, Scope):
261                        continue
262                    if isinstance(source.expression, exp.UDTF):
263                        values = []
264
265                        if isinstance(source.expression, exp.Lateral):
266                            if isinstance(source.expression.this, exp.Explode):
267                                values = [source.expression.this.this]
268                        else:
269                            values = source.expression.expressions[0].expressions
270
271                        if not values:
272                            continue
273
274                        selects[name] = {
275                            alias: column
276                            for alias, column in zip(
277                                source.expression.alias_column_names,
278                                values,
279                            )
280                        }
281                    else:
282                        selects[name] = {
283                            select.alias_or_name: select for select in source.expression.selects
284                        }
285                # First annotate the current scope's column references
286                for col in scope.columns:
287                    if not col.table:
288                        continue
289
290                    source = scope.sources.get(col.table)
291                    if isinstance(source, exp.Table):
292                        col.type = self.schema.get_column_type(source, col)
293                    elif source and col.table in selects and col.name in selects[col.table]:
294                        col.type = selects[col.table][col.name].type
295                # Then (possibly) annotate the remaining expressions in the scope
296                self._maybe_annotate(scope.expression)
297        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions