Edit on GitHub

sqlglot.optimizer.annotate_types

  1from sqlglot import exp
  2from sqlglot.helper import ensure_collection, ensure_list, subclasses
  3from sqlglot.optimizer.scope import Scope, traverse_scope
  4from sqlglot.schema import ensure_schema
  5
  6
  7def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
  8    """
  9    Recursively infer & annotate types in an expression syntax tree against a schema.
 10    Assumes that we've already executed the optimizer's qualify_columns step.
 11
 12    Example:
 13        >>> import sqlglot
 14        >>> schema = {"y": {"cola": "SMALLINT"}}
 15        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 16        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 17        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 18        <Type.DOUBLE: 'DOUBLE'>
 19
 20    Args:
 21        expression (sqlglot.Expression): Expression to annotate.
 22        schema (dict|sqlglot.optimizer.Schema): Database schema.
 23        annotators (dict): Maps expression type to corresponding annotation function.
 24        coerces_to (dict): Maps expression type to set of types that it can be coerced into.
 25    Returns:
 26        sqlglot.Expression: expression annotated with types
 27    """
 28
 29    schema = ensure_schema(schema)
 30
 31    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 32
 33
 34class TypeAnnotator:
 35    ANNOTATORS = {
 36        **{
 37            expr_type: lambda self, expr: self._annotate_unary(expr)
 38            for expr_type in subclasses(exp.__name__, exp.Unary)
 39        },
 40        **{
 41            expr_type: lambda self, expr: self._annotate_binary(expr)
 42            for expr_type in subclasses(exp.__name__, exp.Binary)
 43        },
 44        exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 45        exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 46        exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
 47        exp.Alias: lambda self, expr: self._annotate_unary(expr),
 48        exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 49        exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 50        exp.Literal: lambda self, expr: self._annotate_literal(expr),
 51        exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 52        exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
 53        exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
 54        exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
 55            expr, exp.DataType.Type.BIGINT
 56        ),
 57        exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
 58        exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"),
 59        exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"),
 60        exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True),
 61        exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 62        exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
 63        exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 64        exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
 65            expr, exp.DataType.Type.DATETIME
 66        ),
 67        exp.CurrentTime: lambda self, expr: self._annotate_with_type(
 68            expr, exp.DataType.Type.TIMESTAMP
 69        ),
 70        exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
 71            expr, exp.DataType.Type.TIMESTAMP
 72        ),
 73        exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 74        exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 75        exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 76        exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
 77            expr, exp.DataType.Type.DATETIME
 78        ),
 79        exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
 80            expr, exp.DataType.Type.DATETIME
 81        ),
 82        exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 83        exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 84        exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
 85            expr, exp.DataType.Type.TIMESTAMP
 86        ),
 87        exp.TimestampSub: lambda self, expr: self._annotate_with_type(
 88            expr, exp.DataType.Type.TIMESTAMP
 89        ),
 90        exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 91        exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 92        exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 93        exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 94        exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
 95            expr, exp.DataType.Type.DATE
 96        ),
 97        exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
 98            expr, exp.DataType.Type.VARCHAR
 99        ),
100        exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
101        exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
102        exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
103        exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
104        exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
105        exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
106        exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
107        exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
108        exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
109        exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
110        exp.GroupConcat: lambda self, expr: self._annotate_with_type(
111            expr, exp.DataType.Type.VARCHAR
112        ),
113        exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
114            expr, exp.DataType.Type.VARCHAR
115        ),
116        exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
117        exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
118        exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
119        exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
120        exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
121        exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
122        exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
123        exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
124        exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
125        exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
126        exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
127        exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
128            expr, exp.DataType.Type.DOUBLE
129        ),
130        exp.RegexpLike: lambda self, expr: self._annotate_with_type(
131            expr, exp.DataType.Type.BOOLEAN
132        ),
133        exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
134        exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
135        exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
136        exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
137        exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
138        exp.StrToTime: lambda self, expr: self._annotate_with_type(
139            expr, exp.DataType.Type.TIMESTAMP
140        ),
141        exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
142        exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
143        exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
144        exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
145        exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
146        exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
147            expr, exp.DataType.Type.VARCHAR
148        ),
149        exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
150            expr, exp.DataType.Type.DATE
151        ),
152        exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
153            expr, exp.DataType.Type.TIMESTAMP
154        ),
155        exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
156        exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
157            expr, exp.DataType.Type.VARCHAR
158        ),
159        exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
160        exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
161        exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
162        exp.UnixToTime: lambda self, expr: self._annotate_with_type(
163            expr, exp.DataType.Type.TIMESTAMP
164        ),
165        exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
166            expr, exp.DataType.Type.VARCHAR
167        ),
168        exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
169        exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
170        exp.VariancePop: lambda self, expr: self._annotate_with_type(
171            expr, exp.DataType.Type.DOUBLE
172        ),
173        exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
174        exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
175    }
176
177    # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
178    COERCES_TO = {
179        # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
180        exp.DataType.Type.TEXT: set(),
181        exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
182        exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
183        exp.DataType.Type.NCHAR: {
184            exp.DataType.Type.VARCHAR,
185            exp.DataType.Type.NVARCHAR,
186            exp.DataType.Type.TEXT,
187        },
188        exp.DataType.Type.CHAR: {
189            exp.DataType.Type.NCHAR,
190            exp.DataType.Type.VARCHAR,
191            exp.DataType.Type.NVARCHAR,
192            exp.DataType.Type.TEXT,
193        },
194        # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
195        exp.DataType.Type.DOUBLE: set(),
196        exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
197        exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
198        exp.DataType.Type.BIGINT: {
199            exp.DataType.Type.DECIMAL,
200            exp.DataType.Type.FLOAT,
201            exp.DataType.Type.DOUBLE,
202        },
203        exp.DataType.Type.INT: {
204            exp.DataType.Type.BIGINT,
205            exp.DataType.Type.DECIMAL,
206            exp.DataType.Type.FLOAT,
207            exp.DataType.Type.DOUBLE,
208        },
209        exp.DataType.Type.SMALLINT: {
210            exp.DataType.Type.INT,
211            exp.DataType.Type.BIGINT,
212            exp.DataType.Type.DECIMAL,
213            exp.DataType.Type.FLOAT,
214            exp.DataType.Type.DOUBLE,
215        },
216        exp.DataType.Type.TINYINT: {
217            exp.DataType.Type.SMALLINT,
218            exp.DataType.Type.INT,
219            exp.DataType.Type.BIGINT,
220            exp.DataType.Type.DECIMAL,
221            exp.DataType.Type.FLOAT,
222            exp.DataType.Type.DOUBLE,
223        },
224        # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
225        exp.DataType.Type.TIMESTAMPLTZ: set(),
226        exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
227        exp.DataType.Type.TIMESTAMP: {
228            exp.DataType.Type.TIMESTAMPTZ,
229            exp.DataType.Type.TIMESTAMPLTZ,
230        },
231        exp.DataType.Type.DATETIME: {
232            exp.DataType.Type.TIMESTAMP,
233            exp.DataType.Type.TIMESTAMPTZ,
234            exp.DataType.Type.TIMESTAMPLTZ,
235        },
236        exp.DataType.Type.DATE: {
237            exp.DataType.Type.DATETIME,
238            exp.DataType.Type.TIMESTAMP,
239            exp.DataType.Type.TIMESTAMPTZ,
240            exp.DataType.Type.TIMESTAMPLTZ,
241        },
242    }
243
244    TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
245
246    def __init__(self, schema=None, annotators=None, coerces_to=None):
247        self.schema = schema
248        self.annotators = annotators or self.ANNOTATORS
249        self.coerces_to = coerces_to or self.COERCES_TO
250
251    def annotate(self, expression):
252        if isinstance(expression, self.TRAVERSABLES):
253            for scope in traverse_scope(expression):
254                selects = {}
255                for name, source in scope.sources.items():
256                    if not isinstance(source, Scope):
257                        continue
258                    if isinstance(source.expression, exp.UDTF):
259                        values = []
260
261                        if isinstance(source.expression, exp.Lateral):
262                            if isinstance(source.expression.this, exp.Explode):
263                                values = [source.expression.this.this]
264                        else:
265                            values = source.expression.expressions[0].expressions
266
267                        if not values:
268                            continue
269
270                        selects[name] = {
271                            alias: column
272                            for alias, column in zip(
273                                source.expression.alias_column_names,
274                                values,
275                            )
276                        }
277                    else:
278                        selects[name] = {
279                            select.alias_or_name: select for select in source.expression.selects
280                        }
281                # First annotate the current scope's column references
282                for col in scope.columns:
283                    if not col.table:
284                        continue
285
286                    source = scope.sources.get(col.table)
287                    if isinstance(source, exp.Table):
288                        col.type = self.schema.get_column_type(source, col)
289                    elif source and col.table in selects:
290                        col.type = selects[col.table][col.name].type
291                # Then (possibly) annotate the remaining expressions in the scope
292                self._maybe_annotate(scope.expression)
293        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
294
295    def _maybe_annotate(self, expression):
296        if not isinstance(expression, exp.Expression):
297            return None
298
299        if expression.type:
300            return expression  # We've already inferred the expression's type
301
302        annotator = self.annotators.get(expression.__class__)
303
304        return (
305            annotator(self, expression)
306            if annotator
307            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
308        )
309
310    def _annotate_args(self, expression):
311        for value in expression.args.values():
312            for v in ensure_collection(value):
313                self._maybe_annotate(v)
314
315        return expression
316
317    def _maybe_coerce(self, type1, type2):
318        # We propagate the NULL / UNKNOWN types upwards if found
319        if isinstance(type1, exp.DataType):
320            type1 = type1.this
321        if isinstance(type2, exp.DataType):
322            type2 = type2.this
323
324        if exp.DataType.Type.NULL in (type1, type2):
325            return exp.DataType.Type.NULL
326        if exp.DataType.Type.UNKNOWN in (type1, type2):
327            return exp.DataType.Type.UNKNOWN
328
329        return type2 if type2 in self.coerces_to.get(type1, {}) else type1
330
331    def _annotate_binary(self, expression):
332        self._annotate_args(expression)
333
334        left_type = expression.left.type.this
335        right_type = expression.right.type.this
336
337        if isinstance(expression, (exp.And, exp.Or)):
338            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
339                expression.type = exp.DataType.Type.NULL
340            elif exp.DataType.Type.NULL in (left_type, right_type):
341                expression.type = exp.DataType.build(
342                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
343                )
344            else:
345                expression.type = exp.DataType.Type.BOOLEAN
346        elif isinstance(expression, (exp.Condition, exp.Predicate)):
347            expression.type = exp.DataType.Type.BOOLEAN
348        else:
349            expression.type = self._maybe_coerce(left_type, right_type)
350
351        return expression
352
353    def _annotate_unary(self, expression):
354        self._annotate_args(expression)
355
356        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
357            expression.type = exp.DataType.Type.BOOLEAN
358        else:
359            expression.type = expression.this.type
360
361        return expression
362
363    def _annotate_literal(self, expression):
364        if expression.is_string:
365            expression.type = exp.DataType.Type.VARCHAR
366        elif expression.is_int:
367            expression.type = exp.DataType.Type.INT
368        else:
369            expression.type = exp.DataType.Type.DOUBLE
370
371        return expression
372
373    def _annotate_with_type(self, expression, target_type):
374        expression.type = target_type
375        return self._annotate_args(expression)
376
377    def _annotate_by_args(self, expression, *args, promote=False):
378        self._annotate_args(expression)
379        expressions = []
380        for arg in args:
381            arg_expr = expression.args.get(arg)
382            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
383
384        last_datatype = None
385        for expr in expressions:
386            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
387
388        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
389
390        if promote:
391            if expression.type.this in exp.DataType.INTEGER_TYPES:
392                expression.type = exp.DataType.Type.BIGINT
393            elif expression.type.this in exp.DataType.FLOAT_TYPES:
394                expression.type = exp.DataType.Type.DOUBLE
395
396        return expression
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
 8def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
 9    """
10    Recursively infer & annotate types in an expression syntax tree against a schema.
11    Assumes that we've already executed the optimizer's qualify_columns step.
12
13    Example:
14        >>> import sqlglot
15        >>> schema = {"y": {"cola": "SMALLINT"}}
16        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
17        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
18        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
19        <Type.DOUBLE: 'DOUBLE'>
20
21    Args:
22        expression (sqlglot.Expression): Expression to annotate.
23        schema (dict|sqlglot.optimizer.Schema): Database schema.
24        annotators (dict): Maps expression type to corresponding annotation function.
25        coerces_to (dict): Maps expression type to set of types that it can be coerced into.
26    Returns:
27        sqlglot.Expression: expression annotated with types
28    """
29
30    schema = ensure_schema(schema)
31
32    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Recursively infer & annotate types in an expression syntax tree against a schema. Assumes that we've already executed the optimizer's qualify_columns step.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression (sqlglot.Expression): Expression to annotate.
  • schema (dict|sqlglot.optimizer.Schema): Database schema.
  • annotators (dict): Maps expression type to corresponding annotation function.
  • coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:

sqlglot.Expression: expression annotated with types

class TypeAnnotator:
 35class TypeAnnotator:
 36    ANNOTATORS = {
 37        **{
 38            expr_type: lambda self, expr: self._annotate_unary(expr)
 39            for expr_type in subclasses(exp.__name__, exp.Unary)
 40        },
 41        **{
 42            expr_type: lambda self, expr: self._annotate_binary(expr)
 43            for expr_type in subclasses(exp.__name__, exp.Binary)
 44        },
 45        exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 46        exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 47        exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
 48        exp.Alias: lambda self, expr: self._annotate_unary(expr),
 49        exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 50        exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 51        exp.Literal: lambda self, expr: self._annotate_literal(expr),
 52        exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 53        exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
 54        exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
 55        exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
 56            expr, exp.DataType.Type.BIGINT
 57        ),
 58        exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
 59        exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"),
 60        exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"),
 61        exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True),
 62        exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 63        exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
 64        exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 65        exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
 66            expr, exp.DataType.Type.DATETIME
 67        ),
 68        exp.CurrentTime: lambda self, expr: self._annotate_with_type(
 69            expr, exp.DataType.Type.TIMESTAMP
 70        ),
 71        exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
 72            expr, exp.DataType.Type.TIMESTAMP
 73        ),
 74        exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 75        exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 76        exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 77        exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
 78            expr, exp.DataType.Type.DATETIME
 79        ),
 80        exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
 81            expr, exp.DataType.Type.DATETIME
 82        ),
 83        exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 84        exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 85        exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
 86            expr, exp.DataType.Type.TIMESTAMP
 87        ),
 88        exp.TimestampSub: lambda self, expr: self._annotate_with_type(
 89            expr, exp.DataType.Type.TIMESTAMP
 90        ),
 91        exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 92        exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 93        exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 94        exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 95        exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
 96            expr, exp.DataType.Type.DATE
 97        ),
 98        exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
 99            expr, exp.DataType.Type.VARCHAR
100        ),
101        exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
102        exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
103        exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
104        exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
105        exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
106        exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
107        exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
108        exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
109        exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
110        exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
111        exp.GroupConcat: lambda self, expr: self._annotate_with_type(
112            expr, exp.DataType.Type.VARCHAR
113        ),
114        exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
115            expr, exp.DataType.Type.VARCHAR
116        ),
117        exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
118        exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
119        exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
120        exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
121        exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
122        exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
123        exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
124        exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
125        exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
126        exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
127        exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
128        exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
129            expr, exp.DataType.Type.DOUBLE
130        ),
131        exp.RegexpLike: lambda self, expr: self._annotate_with_type(
132            expr, exp.DataType.Type.BOOLEAN
133        ),
134        exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
135        exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
136        exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
137        exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
138        exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
139        exp.StrToTime: lambda self, expr: self._annotate_with_type(
140            expr, exp.DataType.Type.TIMESTAMP
141        ),
142        exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
143        exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
144        exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
145        exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
146        exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
147        exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
148            expr, exp.DataType.Type.VARCHAR
149        ),
150        exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
151            expr, exp.DataType.Type.DATE
152        ),
153        exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
154            expr, exp.DataType.Type.TIMESTAMP
155        ),
156        exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
157        exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
158            expr, exp.DataType.Type.VARCHAR
159        ),
160        exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
161        exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
162        exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
163        exp.UnixToTime: lambda self, expr: self._annotate_with_type(
164            expr, exp.DataType.Type.TIMESTAMP
165        ),
166        exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
167            expr, exp.DataType.Type.VARCHAR
168        ),
169        exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
170        exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
171        exp.VariancePop: lambda self, expr: self._annotate_with_type(
172            expr, exp.DataType.Type.DOUBLE
173        ),
174        exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
175        exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
176    }
177
178    # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
179    COERCES_TO = {
180        # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
181        exp.DataType.Type.TEXT: set(),
182        exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
183        exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
184        exp.DataType.Type.NCHAR: {
185            exp.DataType.Type.VARCHAR,
186            exp.DataType.Type.NVARCHAR,
187            exp.DataType.Type.TEXT,
188        },
189        exp.DataType.Type.CHAR: {
190            exp.DataType.Type.NCHAR,
191            exp.DataType.Type.VARCHAR,
192            exp.DataType.Type.NVARCHAR,
193            exp.DataType.Type.TEXT,
194        },
195        # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
196        exp.DataType.Type.DOUBLE: set(),
197        exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
198        exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
199        exp.DataType.Type.BIGINT: {
200            exp.DataType.Type.DECIMAL,
201            exp.DataType.Type.FLOAT,
202            exp.DataType.Type.DOUBLE,
203        },
204        exp.DataType.Type.INT: {
205            exp.DataType.Type.BIGINT,
206            exp.DataType.Type.DECIMAL,
207            exp.DataType.Type.FLOAT,
208            exp.DataType.Type.DOUBLE,
209        },
210        exp.DataType.Type.SMALLINT: {
211            exp.DataType.Type.INT,
212            exp.DataType.Type.BIGINT,
213            exp.DataType.Type.DECIMAL,
214            exp.DataType.Type.FLOAT,
215            exp.DataType.Type.DOUBLE,
216        },
217        exp.DataType.Type.TINYINT: {
218            exp.DataType.Type.SMALLINT,
219            exp.DataType.Type.INT,
220            exp.DataType.Type.BIGINT,
221            exp.DataType.Type.DECIMAL,
222            exp.DataType.Type.FLOAT,
223            exp.DataType.Type.DOUBLE,
224        },
225        # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
226        exp.DataType.Type.TIMESTAMPLTZ: set(),
227        exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
228        exp.DataType.Type.TIMESTAMP: {
229            exp.DataType.Type.TIMESTAMPTZ,
230            exp.DataType.Type.TIMESTAMPLTZ,
231        },
232        exp.DataType.Type.DATETIME: {
233            exp.DataType.Type.TIMESTAMP,
234            exp.DataType.Type.TIMESTAMPTZ,
235            exp.DataType.Type.TIMESTAMPLTZ,
236        },
237        exp.DataType.Type.DATE: {
238            exp.DataType.Type.DATETIME,
239            exp.DataType.Type.TIMESTAMP,
240            exp.DataType.Type.TIMESTAMPTZ,
241            exp.DataType.Type.TIMESTAMPLTZ,
242        },
243    }
244
245    TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
246
247    def __init__(self, schema=None, annotators=None, coerces_to=None):
248        self.schema = schema
249        self.annotators = annotators or self.ANNOTATORS
250        self.coerces_to = coerces_to or self.COERCES_TO
251
252    def annotate(self, expression):
253        if isinstance(expression, self.TRAVERSABLES):
254            for scope in traverse_scope(expression):
255                selects = {}
256                for name, source in scope.sources.items():
257                    if not isinstance(source, Scope):
258                        continue
259                    if isinstance(source.expression, exp.UDTF):
260                        values = []
261
262                        if isinstance(source.expression, exp.Lateral):
263                            if isinstance(source.expression.this, exp.Explode):
264                                values = [source.expression.this.this]
265                        else:
266                            values = source.expression.expressions[0].expressions
267
268                        if not values:
269                            continue
270
271                        selects[name] = {
272                            alias: column
273                            for alias, column in zip(
274                                source.expression.alias_column_names,
275                                values,
276                            )
277                        }
278                    else:
279                        selects[name] = {
280                            select.alias_or_name: select for select in source.expression.selects
281                        }
282                # First annotate the current scope's column references
283                for col in scope.columns:
284                    if not col.table:
285                        continue
286
287                    source = scope.sources.get(col.table)
288                    if isinstance(source, exp.Table):
289                        col.type = self.schema.get_column_type(source, col)
290                    elif source and col.table in selects:
291                        col.type = selects[col.table][col.name].type
292                # Then (possibly) annotate the remaining expressions in the scope
293                self._maybe_annotate(scope.expression)
294        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
295
296    def _maybe_annotate(self, expression):
297        if not isinstance(expression, exp.Expression):
298            return None
299
300        if expression.type:
301            return expression  # We've already inferred the expression's type
302
303        annotator = self.annotators.get(expression.__class__)
304
305        return (
306            annotator(self, expression)
307            if annotator
308            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
309        )
310
311    def _annotate_args(self, expression):
312        for value in expression.args.values():
313            for v in ensure_collection(value):
314                self._maybe_annotate(v)
315
316        return expression
317
318    def _maybe_coerce(self, type1, type2):
319        # We propagate the NULL / UNKNOWN types upwards if found
320        if isinstance(type1, exp.DataType):
321            type1 = type1.this
322        if isinstance(type2, exp.DataType):
323            type2 = type2.this
324
325        if exp.DataType.Type.NULL in (type1, type2):
326            return exp.DataType.Type.NULL
327        if exp.DataType.Type.UNKNOWN in (type1, type2):
328            return exp.DataType.Type.UNKNOWN
329
330        return type2 if type2 in self.coerces_to.get(type1, {}) else type1
331
332    def _annotate_binary(self, expression):
333        self._annotate_args(expression)
334
335        left_type = expression.left.type.this
336        right_type = expression.right.type.this
337
338        if isinstance(expression, (exp.And, exp.Or)):
339            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
340                expression.type = exp.DataType.Type.NULL
341            elif exp.DataType.Type.NULL in (left_type, right_type):
342                expression.type = exp.DataType.build(
343                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
344                )
345            else:
346                expression.type = exp.DataType.Type.BOOLEAN
347        elif isinstance(expression, (exp.Condition, exp.Predicate)):
348            expression.type = exp.DataType.Type.BOOLEAN
349        else:
350            expression.type = self._maybe_coerce(left_type, right_type)
351
352        return expression
353
354    def _annotate_unary(self, expression):
355        self._annotate_args(expression)
356
357        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
358            expression.type = exp.DataType.Type.BOOLEAN
359        else:
360            expression.type = expression.this.type
361
362        return expression
363
364    def _annotate_literal(self, expression):
365        if expression.is_string:
366            expression.type = exp.DataType.Type.VARCHAR
367        elif expression.is_int:
368            expression.type = exp.DataType.Type.INT
369        else:
370            expression.type = exp.DataType.Type.DOUBLE
371
372        return expression
373
374    def _annotate_with_type(self, expression, target_type):
375        expression.type = target_type
376        return self._annotate_args(expression)
377
378    def _annotate_by_args(self, expression, *args, promote=False):
379        self._annotate_args(expression)
380        expressions = []
381        for arg in args:
382            arg_expr = expression.args.get(arg)
383            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
384
385        last_datatype = None
386        for expr in expressions:
387            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
388
389        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
390
391        if promote:
392            if expression.type.this in exp.DataType.INTEGER_TYPES:
393                expression.type = exp.DataType.Type.BIGINT
394            elif expression.type.this in exp.DataType.FLOAT_TYPES:
395                expression.type = exp.DataType.Type.DOUBLE
396
397        return expression
TypeAnnotator(schema=None, annotators=None, coerces_to=None)
247    def __init__(self, schema=None, annotators=None, coerces_to=None):
248        self.schema = schema
249        self.annotators = annotators or self.ANNOTATORS
250        self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression):
252    def annotate(self, expression):
253        if isinstance(expression, self.TRAVERSABLES):
254            for scope in traverse_scope(expression):
255                selects = {}
256                for name, source in scope.sources.items():
257                    if not isinstance(source, Scope):
258                        continue
259                    if isinstance(source.expression, exp.UDTF):
260                        values = []
261
262                        if isinstance(source.expression, exp.Lateral):
263                            if isinstance(source.expression.this, exp.Explode):
264                                values = [source.expression.this.this]
265                        else:
266                            values = source.expression.expressions[0].expressions
267
268                        if not values:
269                            continue
270
271                        selects[name] = {
272                            alias: column
273                            for alias, column in zip(
274                                source.expression.alias_column_names,
275                                values,
276                            )
277                        }
278                    else:
279                        selects[name] = {
280                            select.alias_or_name: select for select in source.expression.selects
281                        }
282                # First annotate the current scope's column references
283                for col in scope.columns:
284                    if not col.table:
285                        continue
286
287                    source = scope.sources.get(col.table)
288                    if isinstance(source, exp.Table):
289                        col.type = self.schema.get_column_type(source, col)
290                    elif source and col.table in selects:
291                        col.type = selects[col.table][col.name].type
292                # Then (possibly) annotate the remaining expressions in the scope
293                self._maybe_annotate(scope.expression)
294        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions