Edit on GitHub

sqlglot.optimizer.annotate_types

  1from sqlglot import exp
  2from sqlglot.helper import ensure_collection, ensure_list, subclasses
  3from sqlglot.optimizer.scope import Scope, traverse_scope
  4from sqlglot.schema import ensure_schema
  5
  6
  7def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
  8    """
  9    Recursively infer & annotate types in an expression syntax tree against a schema.
 10    Assumes that we've already executed the optimizer's qualify_columns step.
 11
 12    Example:
 13        >>> import sqlglot
 14        >>> schema = {"y": {"cola": "SMALLINT"}}
 15        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 16        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 17        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 18        <Type.DOUBLE: 'DOUBLE'>
 19
 20    Args:
 21        expression (sqlglot.Expression): Expression to annotate.
 22        schema (dict|sqlglot.optimizer.Schema): Database schema.
 23        annotators (dict): Maps expression type to corresponding annotation function.
 24        coerces_to (dict): Maps expression type to set of types that it can be coerced into.
 25    Returns:
 26        sqlglot.Expression: expression annotated with types
 27    """
 28
 29    schema = ensure_schema(schema)
 30
 31    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 32
 33
 34class TypeAnnotator:
 35    ANNOTATORS = {
 36        **{
 37            expr_type: lambda self, expr: self._annotate_unary(expr)
 38            for expr_type in subclasses(exp.__name__, exp.Unary)
 39        },
 40        **{
 41            expr_type: lambda self, expr: self._annotate_binary(expr)
 42            for expr_type in subclasses(exp.__name__, exp.Binary)
 43        },
 44        exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 45        exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 46        exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
 47        exp.Alias: lambda self, expr: self._annotate_unary(expr),
 48        exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 49        exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 50        exp.Literal: lambda self, expr: self._annotate_literal(expr),
 51        exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 52        exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
 53        exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
 54        exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
 55            expr, exp.DataType.Type.BIGINT
 56        ),
 57        exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
 58        exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"),
 59        exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"),
 60        exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True),
 61        exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 62        exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
 63        exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 64        exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
 65            expr, exp.DataType.Type.DATETIME
 66        ),
 67        exp.CurrentTime: lambda self, expr: self._annotate_with_type(
 68            expr, exp.DataType.Type.TIMESTAMP
 69        ),
 70        exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
 71            expr, exp.DataType.Type.TIMESTAMP
 72        ),
 73        exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 74        exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 75        exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 76        exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
 77            expr, exp.DataType.Type.DATETIME
 78        ),
 79        exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
 80            expr, exp.DataType.Type.DATETIME
 81        ),
 82        exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 83        exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 84        exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
 85            expr, exp.DataType.Type.TIMESTAMP
 86        ),
 87        exp.TimestampSub: lambda self, expr: self._annotate_with_type(
 88            expr, exp.DataType.Type.TIMESTAMP
 89        ),
 90        exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 91        exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 92        exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 93        exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 94        exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
 95            expr, exp.DataType.Type.DATE
 96        ),
 97        exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
 98            expr, exp.DataType.Type.VARCHAR
 99        ),
100        exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
101        exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
102        exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
103        exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
104        exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
105        exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
106        exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
107        exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
108        exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
109        exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
110        exp.GroupConcat: lambda self, expr: self._annotate_with_type(
111            expr, exp.DataType.Type.VARCHAR
112        ),
113        exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
114            expr, exp.DataType.Type.VARCHAR
115        ),
116        exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
117        exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
118        exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
119        exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
120        exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
121        exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
122        exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
123        exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
124        exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
125        exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
126        exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
127        exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
128            expr, exp.DataType.Type.DOUBLE
129        ),
130        exp.RegexpLike: lambda self, expr: self._annotate_with_type(
131            expr, exp.DataType.Type.BOOLEAN
132        ),
133        exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
134        exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
135        exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
136        exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
137        exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
138        exp.StrToTime: lambda self, expr: self._annotate_with_type(
139            expr, exp.DataType.Type.TIMESTAMP
140        ),
141        exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
142        exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
143        exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
144        exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
145        exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
146        exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
147            expr, exp.DataType.Type.VARCHAR
148        ),
149        exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
150            expr, exp.DataType.Type.DATE
151        ),
152        exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
153            expr, exp.DataType.Type.TIMESTAMP
154        ),
155        exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
156        exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
157            expr, exp.DataType.Type.VARCHAR
158        ),
159        exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
160        exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
161        exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
162        exp.UnixToTime: lambda self, expr: self._annotate_with_type(
163            expr, exp.DataType.Type.TIMESTAMP
164        ),
165        exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
166            expr, exp.DataType.Type.VARCHAR
167        ),
168        exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
169        exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
170        exp.VariancePop: lambda self, expr: self._annotate_with_type(
171            expr, exp.DataType.Type.DOUBLE
172        ),
173        exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
174        exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
175    }
176
177    # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
178    COERCES_TO = {
179        # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
180        exp.DataType.Type.TEXT: set(),
181        exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
182        exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
183        exp.DataType.Type.NCHAR: {
184            exp.DataType.Type.VARCHAR,
185            exp.DataType.Type.NVARCHAR,
186            exp.DataType.Type.TEXT,
187        },
188        exp.DataType.Type.CHAR: {
189            exp.DataType.Type.NCHAR,
190            exp.DataType.Type.VARCHAR,
191            exp.DataType.Type.NVARCHAR,
192            exp.DataType.Type.TEXT,
193        },
194        # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
195        exp.DataType.Type.DOUBLE: set(),
196        exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
197        exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
198        exp.DataType.Type.BIGINT: {
199            exp.DataType.Type.DECIMAL,
200            exp.DataType.Type.FLOAT,
201            exp.DataType.Type.DOUBLE,
202        },
203        exp.DataType.Type.INT: {
204            exp.DataType.Type.BIGINT,
205            exp.DataType.Type.DECIMAL,
206            exp.DataType.Type.FLOAT,
207            exp.DataType.Type.DOUBLE,
208        },
209        exp.DataType.Type.SMALLINT: {
210            exp.DataType.Type.INT,
211            exp.DataType.Type.BIGINT,
212            exp.DataType.Type.DECIMAL,
213            exp.DataType.Type.FLOAT,
214            exp.DataType.Type.DOUBLE,
215        },
216        exp.DataType.Type.TINYINT: {
217            exp.DataType.Type.SMALLINT,
218            exp.DataType.Type.INT,
219            exp.DataType.Type.BIGINT,
220            exp.DataType.Type.DECIMAL,
221            exp.DataType.Type.FLOAT,
222            exp.DataType.Type.DOUBLE,
223        },
224        # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
225        exp.DataType.Type.TIMESTAMPLTZ: set(),
226        exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
227        exp.DataType.Type.TIMESTAMP: {
228            exp.DataType.Type.TIMESTAMPTZ,
229            exp.DataType.Type.TIMESTAMPLTZ,
230        },
231        exp.DataType.Type.DATETIME: {
232            exp.DataType.Type.TIMESTAMP,
233            exp.DataType.Type.TIMESTAMPTZ,
234            exp.DataType.Type.TIMESTAMPLTZ,
235        },
236        exp.DataType.Type.DATE: {
237            exp.DataType.Type.DATETIME,
238            exp.DataType.Type.TIMESTAMP,
239            exp.DataType.Type.TIMESTAMPTZ,
240            exp.DataType.Type.TIMESTAMPLTZ,
241        },
242    }
243
244    TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
245
246    def __init__(self, schema=None, annotators=None, coerces_to=None):
247        self.schema = schema
248        self.annotators = annotators or self.ANNOTATORS
249        self.coerces_to = coerces_to or self.COERCES_TO
250
251    def annotate(self, expression):
252        if isinstance(expression, self.TRAVERSABLES):
253            for scope in traverse_scope(expression):
254                selects = {}
255                for name, source in scope.sources.items():
256                    if not isinstance(source, Scope):
257                        continue
258                    if isinstance(source.expression, exp.UDTF):
259                        values = []
260
261                        if isinstance(source.expression, exp.Lateral):
262                            if isinstance(source.expression.this, exp.Explode):
263                                values = [source.expression.this.this]
264                        else:
265                            values = source.expression.expressions[0].expressions
266
267                        if not values:
268                            continue
269
270                        selects[name] = {
271                            alias: column
272                            for alias, column in zip(
273                                source.expression.alias_column_names,
274                                values,
275                            )
276                        }
277                    else:
278                        selects[name] = {
279                            select.alias_or_name: select for select in source.expression.selects
280                        }
281                # First annotate the current scope's column references
282                for col in scope.columns:
283                    source = scope.sources.get(col.table)
284                    if isinstance(source, exp.Table):
285                        col.type = self.schema.get_column_type(source, col)
286                    elif source and col.table in selects:
287                        col.type = selects[col.table][col.name].type
288                # Then (possibly) annotate the remaining expressions in the scope
289                self._maybe_annotate(scope.expression)
290        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
291
292    def _maybe_annotate(self, expression):
293        if not isinstance(expression, exp.Expression):
294            return None
295
296        if expression.type:
297            return expression  # We've already inferred the expression's type
298
299        annotator = self.annotators.get(expression.__class__)
300
301        return (
302            annotator(self, expression)
303            if annotator
304            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
305        )
306
307    def _annotate_args(self, expression):
308        for value in expression.args.values():
309            for v in ensure_collection(value):
310                self._maybe_annotate(v)
311
312        return expression
313
314    def _maybe_coerce(self, type1, type2):
315        # We propagate the NULL / UNKNOWN types upwards if found
316        if isinstance(type1, exp.DataType):
317            type1 = type1.this
318        if isinstance(type2, exp.DataType):
319            type2 = type2.this
320
321        if exp.DataType.Type.NULL in (type1, type2):
322            return exp.DataType.Type.NULL
323        if exp.DataType.Type.UNKNOWN in (type1, type2):
324            return exp.DataType.Type.UNKNOWN
325
326        return type2 if type2 in self.coerces_to.get(type1, {}) else type1
327
328    def _annotate_binary(self, expression):
329        self._annotate_args(expression)
330
331        left_type = expression.left.type.this
332        right_type = expression.right.type.this
333
334        if isinstance(expression, (exp.And, exp.Or)):
335            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
336                expression.type = exp.DataType.Type.NULL
337            elif exp.DataType.Type.NULL in (left_type, right_type):
338                expression.type = exp.DataType.build(
339                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
340                )
341            else:
342                expression.type = exp.DataType.Type.BOOLEAN
343        elif isinstance(expression, (exp.Condition, exp.Predicate)):
344            expression.type = exp.DataType.Type.BOOLEAN
345        else:
346            expression.type = self._maybe_coerce(left_type, right_type)
347
348        return expression
349
350    def _annotate_unary(self, expression):
351        self._annotate_args(expression)
352
353        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
354            expression.type = exp.DataType.Type.BOOLEAN
355        else:
356            expression.type = expression.this.type
357
358        return expression
359
360    def _annotate_literal(self, expression):
361        if expression.is_string:
362            expression.type = exp.DataType.Type.VARCHAR
363        elif expression.is_int:
364            expression.type = exp.DataType.Type.INT
365        else:
366            expression.type = exp.DataType.Type.DOUBLE
367
368        return expression
369
370    def _annotate_with_type(self, expression, target_type):
371        expression.type = target_type
372        return self._annotate_args(expression)
373
374    def _annotate_by_args(self, expression, *args, promote=False):
375        self._annotate_args(expression)
376        expressions = []
377        for arg in args:
378            arg_expr = expression.args.get(arg)
379            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
380
381        last_datatype = None
382        for expr in expressions:
383            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
384
385        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
386
387        if promote:
388            if expression.type.this in exp.DataType.INTEGER_TYPES:
389                expression.type = exp.DataType.Type.BIGINT
390            elif expression.type.this in exp.DataType.FLOAT_TYPES:
391                expression.type = exp.DataType.Type.DOUBLE
392
393        return expression
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
 8def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
 9    """
10    Recursively infer & annotate types in an expression syntax tree against a schema.
11    Assumes that we've already executed the optimizer's qualify_columns step.
12
13    Example:
14        >>> import sqlglot
15        >>> schema = {"y": {"cola": "SMALLINT"}}
16        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
17        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
18        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
19        <Type.DOUBLE: 'DOUBLE'>
20
21    Args:
22        expression (sqlglot.Expression): Expression to annotate.
23        schema (dict|sqlglot.optimizer.Schema): Database schema.
24        annotators (dict): Maps expression type to corresponding annotation function.
25        coerces_to (dict): Maps expression type to set of types that it can be coerced into.
26    Returns:
27        sqlglot.Expression: expression annotated with types
28    """
29
30    schema = ensure_schema(schema)
31
32    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Recursively infer & annotate types in an expression syntax tree against a schema. Assumes that we've already executed the optimizer's qualify_columns step.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression (sqlglot.Expression): Expression to annotate.
  • schema (dict|sqlglot.optimizer.Schema): Database schema.
  • annotators (dict): Maps expression type to corresponding annotation function.
  • coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:

sqlglot.Expression: expression annotated with types

class TypeAnnotator:
 35class TypeAnnotator:
 36    ANNOTATORS = {
 37        **{
 38            expr_type: lambda self, expr: self._annotate_unary(expr)
 39            for expr_type in subclasses(exp.__name__, exp.Unary)
 40        },
 41        **{
 42            expr_type: lambda self, expr: self._annotate_binary(expr)
 43            for expr_type in subclasses(exp.__name__, exp.Binary)
 44        },
 45        exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 46        exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 47        exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
 48        exp.Alias: lambda self, expr: self._annotate_unary(expr),
 49        exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 50        exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 51        exp.Literal: lambda self, expr: self._annotate_literal(expr),
 52        exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 53        exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
 54        exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
 55        exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
 56            expr, exp.DataType.Type.BIGINT
 57        ),
 58        exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
 59        exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"),
 60        exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"),
 61        exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True),
 62        exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 63        exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
 64        exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 65        exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
 66            expr, exp.DataType.Type.DATETIME
 67        ),
 68        exp.CurrentTime: lambda self, expr: self._annotate_with_type(
 69            expr, exp.DataType.Type.TIMESTAMP
 70        ),
 71        exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
 72            expr, exp.DataType.Type.TIMESTAMP
 73        ),
 74        exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 75        exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 76        exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 77        exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
 78            expr, exp.DataType.Type.DATETIME
 79        ),
 80        exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
 81            expr, exp.DataType.Type.DATETIME
 82        ),
 83        exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 84        exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 85        exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
 86            expr, exp.DataType.Type.TIMESTAMP
 87        ),
 88        exp.TimestampSub: lambda self, expr: self._annotate_with_type(
 89            expr, exp.DataType.Type.TIMESTAMP
 90        ),
 91        exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 92        exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 93        exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 94        exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 95        exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
 96            expr, exp.DataType.Type.DATE
 97        ),
 98        exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
 99            expr, exp.DataType.Type.VARCHAR
100        ),
101        exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
102        exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
103        exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
104        exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
105        exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
106        exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
107        exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
108        exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
109        exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
110        exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
111        exp.GroupConcat: lambda self, expr: self._annotate_with_type(
112            expr, exp.DataType.Type.VARCHAR
113        ),
114        exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
115            expr, exp.DataType.Type.VARCHAR
116        ),
117        exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
118        exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
119        exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
120        exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
121        exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
122        exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
123        exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
124        exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
125        exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
126        exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
127        exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
128        exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
129            expr, exp.DataType.Type.DOUBLE
130        ),
131        exp.RegexpLike: lambda self, expr: self._annotate_with_type(
132            expr, exp.DataType.Type.BOOLEAN
133        ),
134        exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
135        exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
136        exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
137        exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
138        exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
139        exp.StrToTime: lambda self, expr: self._annotate_with_type(
140            expr, exp.DataType.Type.TIMESTAMP
141        ),
142        exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
143        exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
144        exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
145        exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
146        exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
147        exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
148            expr, exp.DataType.Type.VARCHAR
149        ),
150        exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
151            expr, exp.DataType.Type.DATE
152        ),
153        exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
154            expr, exp.DataType.Type.TIMESTAMP
155        ),
156        exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
157        exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
158            expr, exp.DataType.Type.VARCHAR
159        ),
160        exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
161        exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
162        exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
163        exp.UnixToTime: lambda self, expr: self._annotate_with_type(
164            expr, exp.DataType.Type.TIMESTAMP
165        ),
166        exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
167            expr, exp.DataType.Type.VARCHAR
168        ),
169        exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
170        exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
171        exp.VariancePop: lambda self, expr: self._annotate_with_type(
172            expr, exp.DataType.Type.DOUBLE
173        ),
174        exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
175        exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
176    }
177
178    # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
179    COERCES_TO = {
180        # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
181        exp.DataType.Type.TEXT: set(),
182        exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
183        exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
184        exp.DataType.Type.NCHAR: {
185            exp.DataType.Type.VARCHAR,
186            exp.DataType.Type.NVARCHAR,
187            exp.DataType.Type.TEXT,
188        },
189        exp.DataType.Type.CHAR: {
190            exp.DataType.Type.NCHAR,
191            exp.DataType.Type.VARCHAR,
192            exp.DataType.Type.NVARCHAR,
193            exp.DataType.Type.TEXT,
194        },
195        # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
196        exp.DataType.Type.DOUBLE: set(),
197        exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
198        exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
199        exp.DataType.Type.BIGINT: {
200            exp.DataType.Type.DECIMAL,
201            exp.DataType.Type.FLOAT,
202            exp.DataType.Type.DOUBLE,
203        },
204        exp.DataType.Type.INT: {
205            exp.DataType.Type.BIGINT,
206            exp.DataType.Type.DECIMAL,
207            exp.DataType.Type.FLOAT,
208            exp.DataType.Type.DOUBLE,
209        },
210        exp.DataType.Type.SMALLINT: {
211            exp.DataType.Type.INT,
212            exp.DataType.Type.BIGINT,
213            exp.DataType.Type.DECIMAL,
214            exp.DataType.Type.FLOAT,
215            exp.DataType.Type.DOUBLE,
216        },
217        exp.DataType.Type.TINYINT: {
218            exp.DataType.Type.SMALLINT,
219            exp.DataType.Type.INT,
220            exp.DataType.Type.BIGINT,
221            exp.DataType.Type.DECIMAL,
222            exp.DataType.Type.FLOAT,
223            exp.DataType.Type.DOUBLE,
224        },
225        # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
226        exp.DataType.Type.TIMESTAMPLTZ: set(),
227        exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
228        exp.DataType.Type.TIMESTAMP: {
229            exp.DataType.Type.TIMESTAMPTZ,
230            exp.DataType.Type.TIMESTAMPLTZ,
231        },
232        exp.DataType.Type.DATETIME: {
233            exp.DataType.Type.TIMESTAMP,
234            exp.DataType.Type.TIMESTAMPTZ,
235            exp.DataType.Type.TIMESTAMPLTZ,
236        },
237        exp.DataType.Type.DATE: {
238            exp.DataType.Type.DATETIME,
239            exp.DataType.Type.TIMESTAMP,
240            exp.DataType.Type.TIMESTAMPTZ,
241            exp.DataType.Type.TIMESTAMPLTZ,
242        },
243    }
244
245    TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
246
247    def __init__(self, schema=None, annotators=None, coerces_to=None):
248        self.schema = schema
249        self.annotators = annotators or self.ANNOTATORS
250        self.coerces_to = coerces_to or self.COERCES_TO
251
252    def annotate(self, expression):
253        if isinstance(expression, self.TRAVERSABLES):
254            for scope in traverse_scope(expression):
255                selects = {}
256                for name, source in scope.sources.items():
257                    if not isinstance(source, Scope):
258                        continue
259                    if isinstance(source.expression, exp.UDTF):
260                        values = []
261
262                        if isinstance(source.expression, exp.Lateral):
263                            if isinstance(source.expression.this, exp.Explode):
264                                values = [source.expression.this.this]
265                        else:
266                            values = source.expression.expressions[0].expressions
267
268                        if not values:
269                            continue
270
271                        selects[name] = {
272                            alias: column
273                            for alias, column in zip(
274                                source.expression.alias_column_names,
275                                values,
276                            )
277                        }
278                    else:
279                        selects[name] = {
280                            select.alias_or_name: select for select in source.expression.selects
281                        }
282                # First annotate the current scope's column references
283                for col in scope.columns:
284                    source = scope.sources.get(col.table)
285                    if isinstance(source, exp.Table):
286                        col.type = self.schema.get_column_type(source, col)
287                    elif source and col.table in selects:
288                        col.type = selects[col.table][col.name].type
289                # Then (possibly) annotate the remaining expressions in the scope
290                self._maybe_annotate(scope.expression)
291        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
292
293    def _maybe_annotate(self, expression):
294        if not isinstance(expression, exp.Expression):
295            return None
296
297        if expression.type:
298            return expression  # We've already inferred the expression's type
299
300        annotator = self.annotators.get(expression.__class__)
301
302        return (
303            annotator(self, expression)
304            if annotator
305            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
306        )
307
308    def _annotate_args(self, expression):
309        for value in expression.args.values():
310            for v in ensure_collection(value):
311                self._maybe_annotate(v)
312
313        return expression
314
315    def _maybe_coerce(self, type1, type2):
316        # We propagate the NULL / UNKNOWN types upwards if found
317        if isinstance(type1, exp.DataType):
318            type1 = type1.this
319        if isinstance(type2, exp.DataType):
320            type2 = type2.this
321
322        if exp.DataType.Type.NULL in (type1, type2):
323            return exp.DataType.Type.NULL
324        if exp.DataType.Type.UNKNOWN in (type1, type2):
325            return exp.DataType.Type.UNKNOWN
326
327        return type2 if type2 in self.coerces_to.get(type1, {}) else type1
328
329    def _annotate_binary(self, expression):
330        self._annotate_args(expression)
331
332        left_type = expression.left.type.this
333        right_type = expression.right.type.this
334
335        if isinstance(expression, (exp.And, exp.Or)):
336            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
337                expression.type = exp.DataType.Type.NULL
338            elif exp.DataType.Type.NULL in (left_type, right_type):
339                expression.type = exp.DataType.build(
340                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
341                )
342            else:
343                expression.type = exp.DataType.Type.BOOLEAN
344        elif isinstance(expression, (exp.Condition, exp.Predicate)):
345            expression.type = exp.DataType.Type.BOOLEAN
346        else:
347            expression.type = self._maybe_coerce(left_type, right_type)
348
349        return expression
350
351    def _annotate_unary(self, expression):
352        self._annotate_args(expression)
353
354        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
355            expression.type = exp.DataType.Type.BOOLEAN
356        else:
357            expression.type = expression.this.type
358
359        return expression
360
361    def _annotate_literal(self, expression):
362        if expression.is_string:
363            expression.type = exp.DataType.Type.VARCHAR
364        elif expression.is_int:
365            expression.type = exp.DataType.Type.INT
366        else:
367            expression.type = exp.DataType.Type.DOUBLE
368
369        return expression
370
371    def _annotate_with_type(self, expression, target_type):
372        expression.type = target_type
373        return self._annotate_args(expression)
374
375    def _annotate_by_args(self, expression, *args, promote=False):
376        self._annotate_args(expression)
377        expressions = []
378        for arg in args:
379            arg_expr = expression.args.get(arg)
380            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
381
382        last_datatype = None
383        for expr in expressions:
384            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
385
386        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
387
388        if promote:
389            if expression.type.this in exp.DataType.INTEGER_TYPES:
390                expression.type = exp.DataType.Type.BIGINT
391            elif expression.type.this in exp.DataType.FLOAT_TYPES:
392                expression.type = exp.DataType.Type.DOUBLE
393
394        return expression
TypeAnnotator(schema=None, annotators=None, coerces_to=None)
247    def __init__(self, schema=None, annotators=None, coerces_to=None):
248        self.schema = schema
249        self.annotators = annotators or self.ANNOTATORS
250        self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression):
252    def annotate(self, expression):
253        if isinstance(expression, self.TRAVERSABLES):
254            for scope in traverse_scope(expression):
255                selects = {}
256                for name, source in scope.sources.items():
257                    if not isinstance(source, Scope):
258                        continue
259                    if isinstance(source.expression, exp.UDTF):
260                        values = []
261
262                        if isinstance(source.expression, exp.Lateral):
263                            if isinstance(source.expression.this, exp.Explode):
264                                values = [source.expression.this.this]
265                        else:
266                            values = source.expression.expressions[0].expressions
267
268                        if not values:
269                            continue
270
271                        selects[name] = {
272                            alias: column
273                            for alias, column in zip(
274                                source.expression.alias_column_names,
275                                values,
276                            )
277                        }
278                    else:
279                        selects[name] = {
280                            select.alias_or_name: select for select in source.expression.selects
281                        }
282                # First annotate the current scope's column references
283                for col in scope.columns:
284                    source = scope.sources.get(col.table)
285                    if isinstance(source, exp.Table):
286                        col.type = self.schema.get_column_type(source, col)
287                    elif source and col.table in selects:
288                        col.type = selects[col.table][col.name].type
289                # Then (possibly) annotate the remaining expressions in the scope
290                self._maybe_annotate(scope.expression)
291        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions