summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/annotate_types.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r--sqlglot/optimizer/annotate_types.py516
1 files changed, 253 insertions, 263 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 6238759..39e2c53 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,13 +1,25 @@
+from __future__ import annotations
+
+import typing as t
+
from sqlglot import exp
+from sqlglot._typing import E
from sqlglot.helper import ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
-from sqlglot.schema import ensure_schema
+from sqlglot.schema import Schema, ensure_schema
+
+if t.TYPE_CHECKING:
+ B = t.TypeVar("B", bound=exp.Binary)
-def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
+def annotate_types(
+ expression: E,
+ schema: t.Optional[t.Dict | Schema] = None,
+ annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
+ coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
+) -> E:
"""
- Recursively infer & annotate types in an expression syntax tree against a schema.
- Assumes that we've already executed the optimizer's qualify_columns step.
+ Infers the types of an expression, annotating its AST accordingly.
Example:
>>> import sqlglot
@@ -18,12 +30,13 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
<Type.DOUBLE: 'DOUBLE'>
Args:
- expression (sqlglot.Expression): Expression to annotate.
- schema (dict|sqlglot.optimizer.Schema): Database schema.
- annotators (dict): Maps expression type to corresponding annotation function.
- coerces_to (dict): Maps expression type to set of types that it can be coerced into.
+ expression: Expression to annotate.
+ schema: Database schema.
+ annotators: Maps expression type to corresponding annotation function.
+ coerces_to: Maps expression type to set of types that it can be coerced into.
+
Returns:
- sqlglot.Expression: expression annotated with types
+ The expression annotated with types.
"""
schema = ensure_schema(schema)
@@ -31,276 +44,241 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
return TypeAnnotator(schema, annotators, coerces_to).annotate(expression)
-class TypeAnnotator:
- ANNOTATORS = {
- **{
- expr_type: lambda self, expr: self._annotate_unary(expr)
- for expr_type in subclasses(exp.__name__, exp.Unary)
- },
- **{
- expr_type: lambda self, expr: self._annotate_binary(expr)
- for expr_type in subclasses(exp.__name__, exp.Binary)
- },
- exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
- exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
- exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
- exp.Alias: lambda self, expr: self._annotate_unary(expr),
- exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
- exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
- exp.Literal: lambda self, expr: self._annotate_literal(expr),
- exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
- exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
- exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
- exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.BIGINT
- ),
- exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
- exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
- exp.Sum: lambda self, expr: self._annotate_by_args(
- expr, "this", "expressions", promote=True
- ),
- exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
- exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATETIME
- ),
- exp.CurrentTime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATETIME
- ),
- exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATETIME
- ),
- exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.TimestampSub: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATE
- ),
- exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
- exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
- exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
- exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
- exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
- exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.GroupConcat: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
- exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
- exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP),
- exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL),
- exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"),
- exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
- exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
- exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DOUBLE
- ),
- exp.RegexpLike: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.BOOLEAN
- ),
- exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.StrToTime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DATE
- ),
- exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.UnixToTime: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.TIMESTAMP
- ),
- exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.VARCHAR
- ),
- exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.VariancePop: lambda self, expr: self._annotate_with_type(
- expr, exp.DataType.Type.DOUBLE
- ),
- exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
- exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
- }
+def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]:
+ return lambda self, e: self._annotate_with_type(e, data_type)
- # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
- COERCES_TO = {
- # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT
- exp.DataType.Type.TEXT: set(),
- exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
- exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
- exp.DataType.Type.NCHAR: {
- exp.DataType.Type.VARCHAR,
- exp.DataType.Type.NVARCHAR,
+
+class _TypeAnnotator(type):
+ def __new__(cls, clsname, bases, attrs):
+ klass = super().__new__(cls, clsname, bases, attrs)
+
+ # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI):
+ # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html
+ text_precedence = (
exp.DataType.Type.TEXT,
- },
- exp.DataType.Type.CHAR: {
- exp.DataType.Type.NCHAR,
- exp.DataType.Type.VARCHAR,
exp.DataType.Type.NVARCHAR,
- exp.DataType.Type.TEXT,
- },
- # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE
- exp.DataType.Type.DOUBLE: set(),
- exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
- exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
- exp.DataType.Type.BIGINT: {
- exp.DataType.Type.DECIMAL,
- exp.DataType.Type.FLOAT,
+ exp.DataType.Type.VARCHAR,
+ exp.DataType.Type.NCHAR,
+ exp.DataType.Type.CHAR,
+ )
+ numeric_precedence = (
exp.DataType.Type.DOUBLE,
+ exp.DataType.Type.FLOAT,
+ exp.DataType.Type.DECIMAL,
+ exp.DataType.Type.BIGINT,
+ exp.DataType.Type.INT,
+ exp.DataType.Type.SMALLINT,
+ exp.DataType.Type.TINYINT,
+ )
+ timelike_precedence = (
+ exp.DataType.Type.TIMESTAMPLTZ,
+ exp.DataType.Type.TIMESTAMPTZ,
+ exp.DataType.Type.TIMESTAMP,
+ exp.DataType.Type.DATETIME,
+ exp.DataType.Type.DATE,
+ )
+
+ for type_precedence in (text_precedence, numeric_precedence, timelike_precedence):
+ coerces_to = set()
+ for data_type in type_precedence:
+ klass.COERCES_TO[data_type] = coerces_to.copy()
+ coerces_to |= {data_type}
+
+ return klass
+
+
+class TypeAnnotator(metaclass=_TypeAnnotator):
+ TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = {
+ exp.DataType.Type.BIGINT: {
+ exp.ApproxDistinct,
+ exp.ArraySize,
+ exp.Count,
+ exp.Length,
+ },
+ exp.DataType.Type.BOOLEAN: {
+ exp.Between,
+ exp.Boolean,
+ exp.In,
+ exp.RegexpLike,
+ },
+ exp.DataType.Type.DATE: {
+ exp.CurrentDate,
+ exp.Date,
+ exp.DateAdd,
+ exp.DateStrToDate,
+ exp.DateSub,
+ exp.DateTrunc,
+ exp.DiToDate,
+ exp.StrToDate,
+ exp.TimeStrToDate,
+ exp.TsOrDsToDate,
+ },
+ exp.DataType.Type.DATETIME: {
+ exp.CurrentDatetime,
+ exp.DatetimeAdd,
+ exp.DatetimeSub,
+ },
+ exp.DataType.Type.DOUBLE: {
+ exp.ApproxQuantile,
+ exp.Avg,
+ exp.Exp,
+ exp.Ln,
+ exp.Log,
+ exp.Log2,
+ exp.Log10,
+ exp.Pow,
+ exp.Quantile,
+ exp.Round,
+ exp.SafeDivide,
+ exp.Sqrt,
+ exp.Stddev,
+ exp.StddevPop,
+ exp.StddevSamp,
+ exp.Variance,
+ exp.VariancePop,
},
exp.DataType.Type.INT: {
- exp.DataType.Type.BIGINT,
- exp.DataType.Type.DECIMAL,
- exp.DataType.Type.FLOAT,
- exp.DataType.Type.DOUBLE,
+ exp.Ceil,
+ exp.DateDiff,
+ exp.DatetimeDiff,
+ exp.Extract,
+ exp.TimestampDiff,
+ exp.TimeDiff,
+ exp.DateToDi,
+ exp.Floor,
+ exp.Levenshtein,
+ exp.StrPosition,
+ exp.TsOrDiToDi,
},
- exp.DataType.Type.SMALLINT: {
- exp.DataType.Type.INT,
- exp.DataType.Type.BIGINT,
- exp.DataType.Type.DECIMAL,
- exp.DataType.Type.FLOAT,
- exp.DataType.Type.DOUBLE,
+ exp.DataType.Type.TIMESTAMP: {
+ exp.CurrentTime,
+ exp.CurrentTimestamp,
+ exp.StrToTime,
+ exp.TimeAdd,
+ exp.TimeStrToTime,
+ exp.TimeSub,
+ exp.TimestampAdd,
+ exp.TimestampSub,
+ exp.UnixToTime,
},
exp.DataType.Type.TINYINT: {
- exp.DataType.Type.SMALLINT,
- exp.DataType.Type.INT,
- exp.DataType.Type.BIGINT,
- exp.DataType.Type.DECIMAL,
- exp.DataType.Type.FLOAT,
- exp.DataType.Type.DOUBLE,
+ exp.Day,
+ exp.Month,
+ exp.Week,
+ exp.Year,
},
- # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
- exp.DataType.Type.TIMESTAMPLTZ: set(),
- exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
- exp.DataType.Type.TIMESTAMP: {
- exp.DataType.Type.TIMESTAMPTZ,
- exp.DataType.Type.TIMESTAMPLTZ,
+ exp.DataType.Type.VARCHAR: {
+ exp.ArrayConcat,
+ exp.Concat,
+ exp.ConcatWs,
+ exp.DateToDateStr,
+ exp.GroupConcat,
+ exp.Initcap,
+ exp.Lower,
+ exp.SafeConcat,
+ exp.Substring,
+ exp.TimeToStr,
+ exp.TimeToTimeStr,
+ exp.Trim,
+ exp.TsOrDsToDateStr,
+ exp.UnixToStr,
+ exp.UnixToTimeStr,
+ exp.Upper,
},
- exp.DataType.Type.DATETIME: {
- exp.DataType.Type.TIMESTAMP,
- exp.DataType.Type.TIMESTAMPTZ,
- exp.DataType.Type.TIMESTAMPLTZ,
+ }
+
+ ANNOTATORS = {
+ **{
+ expr_type: lambda self, e: self._annotate_unary(e)
+ for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias))
},
- exp.DataType.Type.DATE: {
- exp.DataType.Type.DATETIME,
- exp.DataType.Type.TIMESTAMP,
- exp.DataType.Type.TIMESTAMPTZ,
- exp.DataType.Type.TIMESTAMPLTZ,
+ **{
+ expr_type: lambda self, e: self._annotate_binary(e)
+ for expr_type in subclasses(exp.__name__, exp.Binary)
+ },
+ **{
+ expr_type: _annotate_with_type_lambda(data_type)
+ for data_type, expressions in TYPE_TO_EXPRESSIONS.items()
+ for expr_type in expressions
},
+ exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
+ exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
+ exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
+ exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
+ exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()),
+ exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
+ exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
+ exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
+ exp.Literal: lambda self, e: self._annotate_literal(e),
+ exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
+ exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
+ exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
+ exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
+ exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
+ exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
+ exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
}
- TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery)
+ # Specifies what types a given type can be coerced into (autofilled)
+ COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {}
- def __init__(self, schema=None, annotators=None, coerces_to=None):
+ def __init__(
+ self,
+ schema: Schema,
+ annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None,
+ coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
+ ) -> None:
self.schema = schema
self.annotators = annotators or self.ANNOTATORS
self.coerces_to = coerces_to or self.COERCES_TO
- def annotate(self, expression):
- if isinstance(expression, self.TRAVERSABLES):
- for scope in traverse_scope(expression):
- selects = {}
- for name, source in scope.sources.items():
- if not isinstance(source, Scope):
- continue
- if isinstance(source.expression, exp.UDTF):
- values = []
-
- if isinstance(source.expression, exp.Lateral):
- if isinstance(source.expression.this, exp.Explode):
- values = [source.expression.this.this]
- else:
- values = source.expression.expressions[0].expressions
-
- if not values:
- continue
-
- selects[name] = {
- alias: column
- for alias, column in zip(
- source.expression.alias_column_names,
- values,
- )
- }
+ def annotate(self, expression: E) -> E:
+ for scope in traverse_scope(expression):
+ selects = {}
+ for name, source in scope.sources.items():
+ if not isinstance(source, Scope):
+ continue
+ if isinstance(source.expression, exp.UDTF):
+ values = []
+
+ if isinstance(source.expression, exp.Lateral):
+ if isinstance(source.expression.this, exp.Explode):
+ values = [source.expression.this.this]
else:
- selects[name] = {
- select.alias_or_name: select for select in source.expression.selects
- }
- # First annotate the current scope's column references
- for col in scope.columns:
- if not col.table:
+ values = source.expression.expressions[0].expressions
+
+ if not values:
continue
- source = scope.sources.get(col.table)
- if isinstance(source, exp.Table):
- col.type = self.schema.get_column_type(source, col)
- elif source and col.table in selects and col.name in selects[col.table]:
- col.type = selects[col.table][col.name].type
- # Then (possibly) annotate the remaining expressions in the scope
- self._maybe_annotate(scope.expression)
+ selects[name] = {
+ alias: column
+ for alias, column in zip(
+ source.expression.alias_column_names,
+ values,
+ )
+ }
+ else:
+ selects[name] = {
+ select.alias_or_name: select for select in source.expression.selects
+ }
+
+ # First annotate the current scope's column references
+ for col in scope.columns:
+ if not col.table:
+ continue
+
+ source = scope.sources.get(col.table)
+ if isinstance(source, exp.Table):
+ col.type = self.schema.get_column_type(source, col)
+ elif source and col.table in selects and col.name in selects[col.table]:
+ col.type = selects[col.table][col.name].type
+
+ # Then (possibly) annotate the remaining expressions in the scope
+ self._maybe_annotate(scope.expression)
+
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
- def _maybe_annotate(self, expression):
+ def _maybe_annotate(self, expression: E) -> E:
if expression.type:
return expression # We've already inferred the expression's type
@@ -312,13 +290,15 @@ class TypeAnnotator:
else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN)
)
- def _annotate_args(self, expression):
+ def _annotate_args(self, expression: E) -> E:
for _, value in expression.iter_expressions():
self._maybe_annotate(value)
return expression
- def _maybe_coerce(self, type1, type2):
+ def _maybe_coerce(
+ self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type
+ ) -> exp.DataType.Type:
# We propagate the NULL / UNKNOWN types upwards if found
if isinstance(type1, exp.DataType):
type1 = type1.this
@@ -330,9 +310,14 @@ class TypeAnnotator:
if exp.DataType.Type.UNKNOWN in (type1, type2):
return exp.DataType.Type.UNKNOWN
- return type2 if type2 in self.coerces_to.get(type1, {}) else type1
+ return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore
- def _annotate_binary(self, expression):
+ # Note: the following "no_type_check" decorators were added because mypy was yelling due
+ # to assigning Type values to expression.type (since its getter returns Optional[DataType]).
+ # This is a known mypy issue: https://github.com/python/mypy/issues/3004
+
+ @t.no_type_check
+ def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)
left_type = expression.left.type.this
@@ -354,7 +339,8 @@ class TypeAnnotator:
return expression
- def _annotate_unary(self, expression):
+ @t.no_type_check
+ def _annotate_unary(self, expression: E) -> E:
self._annotate_args(expression)
if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren):
@@ -364,7 +350,8 @@ class TypeAnnotator:
return expression
- def _annotate_literal(self, expression):
+ @t.no_type_check
+ def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
if expression.is_string:
expression.type = exp.DataType.Type.VARCHAR
elif expression.is_int:
@@ -374,13 +361,16 @@ class TypeAnnotator:
return expression
- def _annotate_with_type(self, expression, target_type):
+ @t.no_type_check
+ def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
expression.type = target_type
return self._annotate_args(expression)
- def _annotate_by_args(self, expression, *args, promote=False):
+ @t.no_type_check
+ def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E:
self._annotate_args(expression)
- expressions = []
+
+ expressions: t.List[exp.Expression] = []
for arg in args:
arg_expr = expression.args.get(arg)
expressions.extend(expr for expr in ensure_list(arg_expr) if expr)