Edit on GitHub

sqlglot.optimizer.annotate_types

  1from sqlglot import exp
  2from sqlglot.helper import ensure_list, subclasses
  3from sqlglot.optimizer.scope import Scope, traverse_scope
  4from sqlglot.schema import ensure_schema
  5
  6
  7def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
  8    """
  9    Recursively infer & annotate types in an expression syntax tree against a schema.
 10    Assumes that we've already executed the optimizer's qualify_columns step.
 11
 12    Example:
 13        >>> import sqlglot
 14        >>> schema = {"y": {"cola": "SMALLINT"}}
 15        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
 16        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
 17        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
 18        <Type.DOUBLE: 'DOUBLE'>
 19
 20    Args:
 21        expression (sqlglot.Expression): Expression to annotate.
 22        schema (dict|sqlglot.optimizer.Schema): Database schema.
 23        annotators (dict): Maps expression type to corresponding annotation function.
 24        coerces_to (dict): Maps expression type to set of types that it can be coerced into.
 25    Returns:
 26        sqlglot.Expression: expression annotated with types
 27    """
 28
 29    schema = ensure_schema(schema)
 30
 31    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
 32
 33
 34class TypeAnnotator:
 35    ANNOTATORS = {
 36        **{
 37            expr_type: lambda self, expr: self._annotate_unary(expr)
 38            for expr_type in subclasses(exp.__name__, exp.Unary)
 39        },
 40        **{
 41            expr_type: lambda self, expr: self._annotate_binary(expr)
 42            for expr_type in subclasses(exp.__name__, exp.Binary)
 43        },
 44        exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 45        exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 46        exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
 47        exp.Alias: lambda self, expr: self._annotate_unary(expr),
 48        exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 49        exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 50        exp.Literal: lambda self, expr: self._annotate_literal(expr),
 51        exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 52        exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
 53        exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
 54        exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
 55            expr, exp.DataType.Type.BIGINT
 56        ),
 57        exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
 58        exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 59        exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 60        exp.Sum: lambda self, expr: self._annotate_by_args(
 61            expr, "this", "expressions", promote=True
 62        ),
 63        exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 64        exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
 65        exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 66        exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
 67            expr, exp.DataType.Type.DATETIME
 68        ),
 69        exp.CurrentTime: lambda self, expr: self._annotate_with_type(
 70            expr, exp.DataType.Type.TIMESTAMP
 71        ),
 72        exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
 73            expr, exp.DataType.Type.TIMESTAMP
 74        ),
 75        exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 76        exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 77        exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 78        exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
 79            expr, exp.DataType.Type.DATETIME
 80        ),
 81        exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
 82            expr, exp.DataType.Type.DATETIME
 83        ),
 84        exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 85        exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 86        exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
 87            expr, exp.DataType.Type.TIMESTAMP
 88        ),
 89        exp.TimestampSub: lambda self, expr: self._annotate_with_type(
 90            expr, exp.DataType.Type.TIMESTAMP
 91        ),
 92        exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 93        exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 94        exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 95        exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 96        exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
 97            expr, exp.DataType.Type.DATE
 98        ),
 99        exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
100            expr, exp.DataType.Type.VARCHAR
101        ),
102        exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
103        exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
104        exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
105        exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
106        exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
107        exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
108        exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
109        exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
110        exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
111        exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
112        exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
113        exp.GroupConcat: lambda self, expr: self._annotate_with_type(
114            expr, exp.DataType.Type.VARCHAR
115        ),
116        exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
117            expr, exp.DataType.Type.VARCHAR
118        ),
119        exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
120        exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
121        exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
122        exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
123        exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
124        exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
125        exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
126        exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
127        exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
128        exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
129        exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
130        exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
131        exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
132        exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
133        exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
134        exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
135        exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
136            expr, exp.DataType.Type.DOUBLE
137        ),
138        exp.RegexpLike: lambda self, expr: self._annotate_with_type(
139            expr, exp.DataType.Type.BOOLEAN
140        ),
141        exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
142        exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
143        exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
144        exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
145        exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
146        exp.StrToTime: lambda self, expr: self._annotate_with_type(
147            expr, exp.DataType.Type.TIMESTAMP
148        ),
149        exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
150        exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
151        exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
152        exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
153        exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
154        exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
155            expr, exp.DataType.Type.VARCHAR
156        ),
157        exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
158            expr, exp.DataType.Type.DATE
159        ),
160        exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
161            expr, exp.DataType.Type.TIMESTAMP
162        ),
163        exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
164        exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
165            expr, exp.DataType.Type.VARCHAR
166        ),
167        exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
168        exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
169        exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
170        exp.UnixToTime: lambda self, expr: self._annotate_with_type(
171            expr, exp.DataType.Type.TIMESTAMP
172        ),
173        exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
174            expr, exp.DataType.Type.VARCHAR
175        ),
176        exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
177        exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
178        exp.VariancePop: lambda self, expr: self._annotate_with_type(
179            expr, exp.DataType.Type.DOUBLE
180        ),
181        exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
182        exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
183    }
184
185    # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
186    COERCES_TO = {
187        # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
188        exp.DataType.Type.TEXT: set(),
189        exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
190        exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
191        exp.DataType.Type.NCHAR: {
192            exp.DataType.Type.VARCHAR,
193            exp.DataType.Type.NVARCHAR,
194            exp.DataType.Type.TEXT,
195        },
196        exp.DataType.Type.CHAR: {
197            exp.DataType.Type.NCHAR,
198            exp.DataType.Type.VARCHAR,
199            exp.DataType.Type.NVARCHAR,
200            exp.DataType.Type.TEXT,
201        },
202        # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
203        exp.DataType.Type.DOUBLE: set(),
204        exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
205        exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
206        exp.DataType.Type.BIGINT: {
207            exp.DataType.Type.DECIMAL,
208            exp.DataType.Type.FLOAT,
209            exp.DataType.Type.DOUBLE,
210        },
211        exp.DataType.Type.INT: {
212            exp.DataType.Type.BIGINT,
213            exp.DataType.Type.DECIMAL,
214            exp.DataType.Type.FLOAT,
215            exp.DataType.Type.DOUBLE,
216        },
217        exp.DataType.Type.SMALLINT: {
218            exp.DataType.Type.INT,
219            exp.DataType.Type.BIGINT,
220            exp.DataType.Type.DECIMAL,
221            exp.DataType.Type.FLOAT,
222            exp.DataType.Type.DOUBLE,
223        },
224        exp.DataType.Type.TINYINT: {
225            exp.DataType.Type.SMALLINT,
226            exp.DataType.Type.INT,
227            exp.DataType.Type.BIGINT,
228            exp.DataType.Type.DECIMAL,
229            exp.DataType.Type.FLOAT,
230            exp.DataType.Type.DOUBLE,
231        },
232        # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
233        exp.DataType.Type.TIMESTAMPLTZ: set(),
234        exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
235        exp.DataType.Type.TIMESTAMP: {
236            exp.DataType.Type.TIMESTAMPTZ,
237            exp.DataType.Type.TIMESTAMPLTZ,
238        },
239        exp.DataType.Type.DATETIME: {
240            exp.DataType.Type.TIMESTAMP,
241            exp.DataType.Type.TIMESTAMPTZ,
242            exp.DataType.Type.TIMESTAMPLTZ,
243        },
244        exp.DataType.Type.DATE: {
245            exp.DataType.Type.DATETIME,
246            exp.DataType.Type.TIMESTAMP,
247            exp.DataType.Type.TIMESTAMPTZ,
248            exp.DataType.Type.TIMESTAMPLTZ,
249        },
250    }
251
252    TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
253
254    def __init__(self, schema=None, annotators=None, coerces_to=None):
255        self.schema = schema
256        self.annotators = annotators or self.ANNOTATORS
257        self.coerces_to = coerces_to or self.COERCES_TO
258
259    def annotate(self, expression):
260        if isinstance(expression, self.TRAVERSABLES):
261            for scope in traverse_scope(expression):
262                selects = {}
263                for name, source in scope.sources.items():
264                    if not isinstance(source, Scope):
265                        continue
266                    if isinstance(source.expression, exp.UDTF):
267                        values = []
268
269                        if isinstance(source.expression, exp.Lateral):
270                            if isinstance(source.expression.this, exp.Explode):
271                                values = [source.expression.this.this]
272                        else:
273                            values = source.expression.expressions[0].expressions
274
275                        if not values:
276                            continue
277
278                        selects[name] = {
279                            alias: column
280                            for alias, column in zip(
281                                source.expression.alias_column_names,
282                                values,
283                            )
284                        }
285                    else:
286                        selects[name] = {
287                            select.alias_or_name: select for select in source.expression.selects
288                        }
289                # First annotate the current scope's column references
290                for col in scope.columns:
291                    if not col.table:
292                        continue
293
294                    source = scope.sources.get(col.table)
295                    if isinstance(source, exp.Table):
296                        col.type = self.schema.get_column_type(source, col)
297                    elif source and col.table in selects and col.name in selects[col.table]:
298                        col.type = selects[col.table][col.name].type
299                # Then (possibly) annotate the remaining expressions in the scope
300                self._maybe_annotate(scope.expression)
301        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
302
303    def _maybe_annotate(self, expression):
304        if expression.type:
305            return expression  # We've already inferred the expression's type
306
307        annotator = self.annotators.get(expression.__class__)
308
309        return (
310            annotator(self, expression)
311            if annotator
312            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
313        )
314
315    def _annotate_args(self, expression):
316        for _, value in expression.iter_expressions():
317            self._maybe_annotate(value)
318
319        return expression
320
321    def _maybe_coerce(self, type1, type2):
322        # We propagate the NULL / UNKNOWN types upwards if found
323        if isinstance(type1, exp.DataType):
324            type1 = type1.this
325        if isinstance(type2, exp.DataType):
326            type2 = type2.this
327
328        if exp.DataType.Type.NULL in (type1, type2):
329            return exp.DataType.Type.NULL
330        if exp.DataType.Type.UNKNOWN in (type1, type2):
331            return exp.DataType.Type.UNKNOWN
332
333        return type2 if type2 in self.coerces_to.get(type1, {}) else type1
334
335    def _annotate_binary(self, expression):
336        self._annotate_args(expression)
337
338        left_type = expression.left.type.this
339        right_type = expression.right.type.this
340
341        if isinstance(expression, (exp.And, exp.Or)):
342            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
343                expression.type = exp.DataType.Type.NULL
344            elif exp.DataType.Type.NULL in (left_type, right_type):
345                expression.type = exp.DataType.build(
346                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
347                )
348            else:
349                expression.type = exp.DataType.Type.BOOLEAN
350        elif isinstance(expression, (exp.Condition, exp.Predicate)):
351            expression.type = exp.DataType.Type.BOOLEAN
352        else:
353            expression.type = self._maybe_coerce(left_type, right_type)
354
355        return expression
356
357    def _annotate_unary(self, expression):
358        self._annotate_args(expression)
359
360        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
361            expression.type = exp.DataType.Type.BOOLEAN
362        else:
363            expression.type = expression.this.type
364
365        return expression
366
367    def _annotate_literal(self, expression):
368        if expression.is_string:
369            expression.type = exp.DataType.Type.VARCHAR
370        elif expression.is_int:
371            expression.type = exp.DataType.Type.INT
372        else:
373            expression.type = exp.DataType.Type.DOUBLE
374
375        return expression
376
377    def _annotate_with_type(self, expression, target_type):
378        expression.type = target_type
379        return self._annotate_args(expression)
380
381    def _annotate_by_args(self, expression, *args, promote=False):
382        self._annotate_args(expression)
383        expressions = []
384        for arg in args:
385            arg_expr = expression.args.get(arg)
386            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
387
388        last_datatype = None
389        for expr in expressions:
390            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
391
392        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
393
394        if promote:
395            if expression.type.this in exp.DataType.INTEGER_TYPES:
396                expression.type = exp.DataType.Type.BIGINT
397            elif expression.type.this in exp.DataType.FLOAT_TYPES:
398                expression.type = exp.DataType.Type.DOUBLE
399
400        return expression
def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
 8def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
 9    """
10    Recursively infer & annotate types in an expression syntax tree against a schema.
11    Assumes that we've already executed the optimizer's qualify_columns step.
12
13    Example:
14        >>> import sqlglot
15        >>> schema = {"y": {"cola": "SMALLINT"}}
16        >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
17        >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
18        >>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
19        <Type.DOUBLE: 'DOUBLE'>
20
21    Args:
22        expression (sqlglot.Expression): Expression to annotate.
23        schema (dict|sqlglot.optimizer.Schema): Database schema.
24        annotators (dict): Maps expression type to corresponding annotation function.
25        coerces_to (dict): Maps expression type to set of types that it can be coerced into.
26    Returns:
27        sqlglot.Expression: expression annotated with types
28    """
29
30    schema = ensure_schema(schema)
31
32    return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)

Recursively infer & annotate types in an expression syntax tree against a schema. Assumes that we've already executed the optimizer's qualify_columns step.

Example:
>>> import sqlglot
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
>>> annotated_expr.expressions[0].type.this  # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Arguments:
  • expression (sqlglot.Expression): Expression to annotate.
  • schema (dict|sqlglot.optimizer.Schema): Database schema.
  • annotators (dict): Maps expression type to corresponding annotation function.
  • coerces_to (dict): Maps expression type to set of types that it can be coerced into.
Returns:

sqlglot.Expression: expression annotated with types

class TypeAnnotator:
 35class TypeAnnotator:
 36    ANNOTATORS = {
 37        **{
 38            expr_type: lambda self, expr: self._annotate_unary(expr)
 39            for expr_type in subclasses(exp.__name__, exp.Unary)
 40        },
 41        **{
 42            expr_type: lambda self, expr: self._annotate_binary(expr)
 43            for expr_type in subclasses(exp.__name__, exp.Binary)
 44        },
 45        exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 46        exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
 47        exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
 48        exp.Alias: lambda self, expr: self._annotate_unary(expr),
 49        exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 50        exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 51        exp.Literal: lambda self, expr: self._annotate_literal(expr),
 52        exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
 53        exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
 54        exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
 55        exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
 56            expr, exp.DataType.Type.BIGINT
 57        ),
 58        exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
 59        exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 60        exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
 61        exp.Sum: lambda self, expr: self._annotate_by_args(
 62            expr, "this", "expressions", promote=True
 63        ),
 64        exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 65        exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
 66        exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 67        exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
 68            expr, exp.DataType.Type.DATETIME
 69        ),
 70        exp.CurrentTime: lambda self, expr: self._annotate_with_type(
 71            expr, exp.DataType.Type.TIMESTAMP
 72        ),
 73        exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
 74            expr, exp.DataType.Type.TIMESTAMP
 75        ),
 76        exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 77        exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
 78        exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 79        exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
 80            expr, exp.DataType.Type.DATETIME
 81        ),
 82        exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
 83            expr, exp.DataType.Type.DATETIME
 84        ),
 85        exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 86        exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 87        exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
 88            expr, exp.DataType.Type.TIMESTAMP
 89        ),
 90        exp.TimestampSub: lambda self, expr: self._annotate_with_type(
 91            expr, exp.DataType.Type.TIMESTAMP
 92        ),
 93        exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 94        exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 95        exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
 96        exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
 97        exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
 98            expr, exp.DataType.Type.DATE
 99        ),
100        exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
101            expr, exp.DataType.Type.VARCHAR
102        ),
103        exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
104        exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
105        exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
106        exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
107        exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
108        exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
109        exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
110        exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
111        exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
112        exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
113        exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
114        exp.GroupConcat: lambda self, expr: self._annotate_with_type(
115            expr, exp.DataType.Type.VARCHAR
116        ),
117        exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
118            expr, exp.DataType.Type.VARCHAR
119        ),
120        exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
121        exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
122        exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
123        exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
124        exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
125        exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
126        exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
127        exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
128        exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
129        exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
130        exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
131        exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
132        exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
133        exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
134        exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
135        exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
136        exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
137            expr, exp.DataType.Type.DOUBLE
138        ),
139        exp.RegexpLike: lambda self, expr: self._annotate_with_type(
140            expr, exp.DataType.Type.BOOLEAN
141        ),
142        exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
143        exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
144        exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
145        exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
146        exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
147        exp.StrToTime: lambda self, expr: self._annotate_with_type(
148            expr, exp.DataType.Type.TIMESTAMP
149        ),
150        exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
151        exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
152        exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
153        exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
154        exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
155        exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
156            expr, exp.DataType.Type.VARCHAR
157        ),
158        exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
159            expr, exp.DataType.Type.DATE
160        ),
161        exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
162            expr, exp.DataType.Type.TIMESTAMP
163        ),
164        exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
165        exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
166            expr, exp.DataType.Type.VARCHAR
167        ),
168        exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
169        exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
170        exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
171        exp.UnixToTime: lambda self, expr: self._annotate_with_type(
172            expr, exp.DataType.Type.TIMESTAMP
173        ),
174        exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
175            expr, exp.DataType.Type.VARCHAR
176        ),
177        exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
178        exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
179        exp.VariancePop: lambda self, expr: self._annotate_with_type(
180            expr, exp.DataType.Type.DOUBLE
181        ),
182        exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
183        exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
184    }
185
186    # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
187    COERCES_TO = {
188        # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
189        exp.DataType.Type.TEXT: set(),
190        exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
191        exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
192        exp.DataType.Type.NCHAR: {
193            exp.DataType.Type.VARCHAR,
194            exp.DataType.Type.NVARCHAR,
195            exp.DataType.Type.TEXT,
196        },
197        exp.DataType.Type.CHAR: {
198            exp.DataType.Type.NCHAR,
199            exp.DataType.Type.VARCHAR,
200            exp.DataType.Type.NVARCHAR,
201            exp.DataType.Type.TEXT,
202        },
203        # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
204        exp.DataType.Type.DOUBLE: set(),
205        exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
206        exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
207        exp.DataType.Type.BIGINT: {
208            exp.DataType.Type.DECIMAL,
209            exp.DataType.Type.FLOAT,
210            exp.DataType.Type.DOUBLE,
211        },
212        exp.DataType.Type.INT: {
213            exp.DataType.Type.BIGINT,
214            exp.DataType.Type.DECIMAL,
215            exp.DataType.Type.FLOAT,
216            exp.DataType.Type.DOUBLE,
217        },
218        exp.DataType.Type.SMALLINT: {
219            exp.DataType.Type.INT,
220            exp.DataType.Type.BIGINT,
221            exp.DataType.Type.DECIMAL,
222            exp.DataType.Type.FLOAT,
223            exp.DataType.Type.DOUBLE,
224        },
225        exp.DataType.Type.TINYINT: {
226            exp.DataType.Type.SMALLINT,
227            exp.DataType.Type.INT,
228            exp.DataType.Type.BIGINT,
229            exp.DataType.Type.DECIMAL,
230            exp.DataType.Type.FLOAT,
231            exp.DataType.Type.DOUBLE,
232        },
233        # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
234        exp.DataType.Type.TIMESTAMPLTZ: set(),
235        exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
236        exp.DataType.Type.TIMESTAMP: {
237            exp.DataType.Type.TIMESTAMPTZ,
238            exp.DataType.Type.TIMESTAMPLTZ,
239        },
240        exp.DataType.Type.DATETIME: {
241            exp.DataType.Type.TIMESTAMP,
242            exp.DataType.Type.TIMESTAMPTZ,
243            exp.DataType.Type.TIMESTAMPLTZ,
244        },
245        exp.DataType.Type.DATE: {
246            exp.DataType.Type.DATETIME,
247            exp.DataType.Type.TIMESTAMP,
248            exp.DataType.Type.TIMESTAMPTZ,
249            exp.DataType.Type.TIMESTAMPLTZ,
250        },
251    }
252
253    TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
254
255    def __init__(self, schema=None, annotators=None, coerces_to=None):
256        self.schema = schema
257        self.annotators = annotators or self.ANNOTATORS
258        self.coerces_to = coerces_to or self.COERCES_TO
259
260    def annotate(self, expression):
261        if isinstance(expression, self.TRAVERSABLES):
262            for scope in traverse_scope(expression):
263                selects = {}
264                for name, source in scope.sources.items():
265                    if not isinstance(source, Scope):
266                        continue
267                    if isinstance(source.expression, exp.UDTF):
268                        values = []
269
270                        if isinstance(source.expression, exp.Lateral):
271                            if isinstance(source.expression.this, exp.Explode):
272                                values = [source.expression.this.this]
273                        else:
274                            values = source.expression.expressions[0].expressions
275
276                        if not values:
277                            continue
278
279                        selects[name] = {
280                            alias: column
281                            for alias, column in zip(
282                                source.expression.alias_column_names,
283                                values,
284                            )
285                        }
286                    else:
287                        selects[name] = {
288                            select.alias_or_name: select for select in source.expression.selects
289                        }
290                # First annotate the current scope's column references
291                for col in scope.columns:
292                    if not col.table:
293                        continue
294
295                    source = scope.sources.get(col.table)
296                    if isinstance(source, exp.Table):
297                        col.type = self.schema.get_column_type(source, col)
298                    elif source and col.table in selects and col.name in selects[col.table]:
299                        col.type = selects[col.table][col.name].type
300                # Then (possibly) annotate the remaining expressions in the scope
301                self._maybe_annotate(scope.expression)
302        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions
303
304    def _maybe_annotate(self, expression):
305        if expression.type:
306            return expression  # We've already inferred the expression's type
307
308        annotator = self.annotators.get(expression.__class__)
309
310        return (
311            annotator(self, expression)
312            if annotator
313            else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
314        )
315
316    def _annotate_args(self, expression):
317        for _, value in expression.iter_expressions():
318            self._maybe_annotate(value)
319
320        return expression
321
322    def _maybe_coerce(self, type1, type2):
323        # We propagate the NULL / UNKNOWN types upwards if found
324        if isinstance(type1, exp.DataType):
325            type1 = type1.this
326        if isinstance(type2, exp.DataType):
327            type2 = type2.this
328
329        if exp.DataType.Type.NULL in (type1, type2):
330            return exp.DataType.Type.NULL
331        if exp.DataType.Type.UNKNOWN in (type1, type2):
332            return exp.DataType.Type.UNKNOWN
333
334        return type2 if type2 in self.coerces_to.get(type1, {}) else type1
335
336    def _annotate_binary(self, expression):
337        self._annotate_args(expression)
338
339        left_type = expression.left.type.this
340        right_type = expression.right.type.this
341
342        if isinstance(expression, (exp.And, exp.Or)):
343            if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
344                expression.type = exp.DataType.Type.NULL
345            elif exp.DataType.Type.NULL in (left_type, right_type):
346                expression.type = exp.DataType.build(
347                    "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
348                )
349            else:
350                expression.type = exp.DataType.Type.BOOLEAN
351        elif isinstance(expression, (exp.Condition, exp.Predicate)):
352            expression.type = exp.DataType.Type.BOOLEAN
353        else:
354            expression.type = self._maybe_coerce(left_type, right_type)
355
356        return expression
357
358    def _annotate_unary(self, expression):
359        self._annotate_args(expression)
360
361        if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
362            expression.type = exp.DataType.Type.BOOLEAN
363        else:
364            expression.type = expression.this.type
365
366        return expression
367
368    def _annotate_literal(self, expression):
369        if expression.is_string:
370            expression.type = exp.DataType.Type.VARCHAR
371        elif expression.is_int:
372            expression.type = exp.DataType.Type.INT
373        else:
374            expression.type = exp.DataType.Type.DOUBLE
375
376        return expression
377
378    def _annotate_with_type(self, expression, target_type):
379        expression.type = target_type
380        return self._annotate_args(expression)
381
382    def _annotate_by_args(self, expression, *args, promote=False):
383        self._annotate_args(expression)
384        expressions = []
385        for arg in args:
386            arg_expr = expression.args.get(arg)
387            expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
388
389        last_datatype = None
390        for expr in expressions:
391            last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
392
393        expression.type = last_datatype or exp.DataType.Type.UNKNOWN
394
395        if promote:
396            if expression.type.this in exp.DataType.INTEGER_TYPES:
397                expression.type = exp.DataType.Type.BIGINT
398            elif expression.type.this in exp.DataType.FLOAT_TYPES:
399                expression.type = exp.DataType.Type.DOUBLE
400
401        return expression
TypeAnnotator(schema=None, annotators=None, coerces_to=None)
255    def __init__(self, schema=None, annotators=None, coerces_to=None):
256        self.schema = schema
257        self.annotators = annotators or self.ANNOTATORS
258        self.coerces_to = coerces_to or self.COERCES_TO
def annotate(self, expression):
260    def annotate(self, expression):
261        if isinstance(expression, self.TRAVERSABLES):
262            for scope in traverse_scope(expression):
263                selects = {}
264                for name, source in scope.sources.items():
265                    if not isinstance(source, Scope):
266                        continue
267                    if isinstance(source.expression, exp.UDTF):
268                        values = []
269
270                        if isinstance(source.expression, exp.Lateral):
271                            if isinstance(source.expression.this, exp.Explode):
272                                values = [source.expression.this.this]
273                        else:
274                            values = source.expression.expressions[0].expressions
275
276                        if not values:
277                            continue
278
279                        selects[name] = {
280                            alias: column
281                            for alias, column in zip(
282                                source.expression.alias_column_names,
283                                values,
284                            )
285                        }
286                    else:
287                        selects[name] = {
288                            select.alias_or_name: select for select in source.expression.selects
289                        }
290                # First annotate the current scope's column references
291                for col in scope.columns:
292                    if not col.table:
293                        continue
294
295                    source = scope.sources.get(col.table)
296                    if isinstance(source, exp.Table):
297                        col.type = self.schema.get_column_type(source, col)
298                    elif source and col.table in selects and col.name in selects[col.table]:
299                        col.type = selects[col.table][col.name].type
300                # Then (possibly) annotate the remaining expressions in the scope
301                self._maybe_annotate(scope.expression)
302        return self._maybe_annotate(expression)  # This takes care of non-traversable expressions