diff options
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 516 |
1 files changed, 253 insertions, 263 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 6238759..39e2c53 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,13 +1,25 @@ +from __future__ import annotations + +import typing as t + from sqlglot import exp +from sqlglot._typing import E from sqlglot.helper import ensure_list, subclasses from sqlglot.optimizer.scope import Scope, traverse_scope -from sqlglot.schema import ensure_schema +from sqlglot.schema import Schema, ensure_schema + +if t.TYPE_CHECKING: + B = t.TypeVar("B", bound=exp.Binary) -def annotate_types(expression, schema=None, annotators=None, coerces_to=None): +def annotate_types( + expression: E, + schema: t.Optional[t.Dict | Schema] = None, + annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, + coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, +) -> E: """ - Recursively infer & annotate types in an expression syntax tree against a schema. - Assumes that we've already executed the optimizer's qualify_columns step. + Infers the types of an expression, annotating its AST accordingly. Example: >>> import sqlglot @@ -18,12 +30,13 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): <Type.DOUBLE: 'DOUBLE'> Args: - expression (sqlglot.Expression): Expression to annotate. - schema (dict|sqlglot.optimizer.Schema): Database schema. - annotators (dict): Maps expression type to corresponding annotation function. - coerces_to (dict): Maps expression type to set of types that it can be coerced into. + expression: Expression to annotate. + schema: Database schema. + annotators: Maps expression type to corresponding annotation function. + coerces_to: Maps expression type to set of types that it can be coerced into. + Returns: - sqlglot.Expression: expression annotated with types + The expression annotated with types. """ schema = ensure_schema(schema) @@ -31,276 +44,241 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) -class TypeAnnotator: - ANNOTATORS = { - **{ - expr_type: lambda self, expr: self._annotate_unary(expr) - for expr_type in subclasses(exp.__name__, exp.Unary) - }, - **{ - expr_type: lambda self, expr: self._annotate_binary(expr) - for expr_type in subclasses(exp.__name__, exp.Binary) - }, - exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), - exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), - exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), - exp.Alias: lambda self, expr: self._annotate_unary(expr), - exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), - exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), - exp.Literal: lambda self, expr: self._annotate_literal(expr), - exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), - exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), - exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), - exp.ApproxDistinct: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.BIGINT - ), - exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Min: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), - exp.Max: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), - exp.Sum: lambda self, expr: self._annotate_by_args( - expr, "this", "expressions", promote=True - ), - exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), - exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.CurrentDatetime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATETIME - ), - exp.CurrentTime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DatetimeAdd: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATETIME - ), - exp.DatetimeSub: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATETIME - ), - exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.TimestampAdd: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.TimestampSub: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), - exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.DateStrToDate: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATE - ), - exp.DateToDateStr: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), - exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"), - exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"), - exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"), - exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"), - exp.Concat: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.GroupConcat: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.ArrayConcat: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.ArraySize: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), - exp.Map: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), - exp.VarMap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.MAP), - exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.Interval: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INTERVAL), - exp.Least: lambda self, expr: self._annotate_by_args(expr, "expressions"), - exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), - exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), - exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.ApproxQuantile: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DOUBLE - ), - exp.RegexpLike: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.BOOLEAN - ), - exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.StrToTime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.TimeStrToDate: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DATE - ), - exp.TimeStrToTime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), - exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), - exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.UnixToTime: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.TIMESTAMP - ), - exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.VARCHAR - ), - exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), - exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), - exp.VariancePop: lambda self, expr: self._annotate_with_type( - expr, exp.DataType.Type.DOUBLE - ), - exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), - exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), - } +def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: + return lambda self, e: self._annotate_with_type(e, data_type) - # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html - COERCES_TO = { - # CHAR < NCHAR < VARCHAR < NVARCHAR < TEXT - exp.DataType.Type.TEXT: set(), - exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT}, - exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT}, - exp.DataType.Type.NCHAR: { - exp.DataType.Type.VARCHAR, - exp.DataType.Type.NVARCHAR, + +class _TypeAnnotator(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + + # Highest-to-lowest type precedence, as specified in Spark's docs (ANSI): + # https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html + text_precedence = ( exp.DataType.Type.TEXT, - }, - exp.DataType.Type.CHAR: { - exp.DataType.Type.NCHAR, - exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, - exp.DataType.Type.TEXT, - }, - # TINYINT < SMALLINT < INT < BIGINT < DECIMAL < FLOAT < DOUBLE - exp.DataType.Type.DOUBLE: set(), - exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE}, - exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}, - exp.DataType.Type.BIGINT: { - exp.DataType.Type.DECIMAL, - exp.DataType.Type.FLOAT, + exp.DataType.Type.VARCHAR, + exp.DataType.Type.NCHAR, + exp.DataType.Type.CHAR, + ) + numeric_precedence = ( exp.DataType.Type.DOUBLE, + exp.DataType.Type.FLOAT, + exp.DataType.Type.DECIMAL, + exp.DataType.Type.BIGINT, + exp.DataType.Type.INT, + exp.DataType.Type.SMALLINT, + exp.DataType.Type.TINYINT, + ) + timelike_precedence = ( + exp.DataType.Type.TIMESTAMPLTZ, + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMP, + exp.DataType.Type.DATETIME, + exp.DataType.Type.DATE, + ) + + for type_precedence in (text_precedence, numeric_precedence, timelike_precedence): + coerces_to = set() + for data_type in type_precedence: + klass.COERCES_TO[data_type] = coerces_to.copy() + coerces_to |= {data_type} + + return klass + + +class TypeAnnotator(metaclass=_TypeAnnotator): + TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { + exp.DataType.Type.BIGINT: { + exp.ApproxDistinct, + exp.ArraySize, + exp.Count, + exp.Length, + }, + exp.DataType.Type.BOOLEAN: { + exp.Between, + exp.Boolean, + exp.In, + exp.RegexpLike, + }, + exp.DataType.Type.DATE: { + exp.CurrentDate, + exp.Date, + exp.DateAdd, + exp.DateStrToDate, + exp.DateSub, + exp.DateTrunc, + exp.DiToDate, + exp.StrToDate, + exp.TimeStrToDate, + exp.TsOrDsToDate, + }, + exp.DataType.Type.DATETIME: { + exp.CurrentDatetime, + exp.DatetimeAdd, + exp.DatetimeSub, + }, + exp.DataType.Type.DOUBLE: { + exp.ApproxQuantile, + exp.Avg, + exp.Exp, + exp.Ln, + exp.Log, + exp.Log2, + exp.Log10, + exp.Pow, + exp.Quantile, + exp.Round, + exp.SafeDivide, + exp.Sqrt, + exp.Stddev, + exp.StddevPop, + exp.StddevSamp, + exp.Variance, + exp.VariancePop, }, exp.DataType.Type.INT: { - exp.DataType.Type.BIGINT, - exp.DataType.Type.DECIMAL, - exp.DataType.Type.FLOAT, - exp.DataType.Type.DOUBLE, + exp.Ceil, + exp.DateDiff, + exp.DatetimeDiff, + exp.Extract, + exp.TimestampDiff, + exp.TimeDiff, + exp.DateToDi, + exp.Floor, + exp.Levenshtein, + exp.StrPosition, + exp.TsOrDiToDi, }, - exp.DataType.Type.SMALLINT: { - exp.DataType.Type.INT, - exp.DataType.Type.BIGINT, - exp.DataType.Type.DECIMAL, - exp.DataType.Type.FLOAT, - exp.DataType.Type.DOUBLE, + exp.DataType.Type.TIMESTAMP: { + exp.CurrentTime, + exp.CurrentTimestamp, + exp.StrToTime, + exp.TimeAdd, + exp.TimeStrToTime, + exp.TimeSub, + exp.TimestampAdd, + exp.TimestampSub, + exp.UnixToTime, }, exp.DataType.Type.TINYINT: { - exp.DataType.Type.SMALLINT, - exp.DataType.Type.INT, - exp.DataType.Type.BIGINT, - exp.DataType.Type.DECIMAL, - exp.DataType.Type.FLOAT, - exp.DataType.Type.DOUBLE, + exp.Day, + exp.Month, + exp.Week, + exp.Year, }, - # DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ - exp.DataType.Type.TIMESTAMPLTZ: set(), - exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ}, - exp.DataType.Type.TIMESTAMP: { - exp.DataType.Type.TIMESTAMPTZ, - exp.DataType.Type.TIMESTAMPLTZ, + exp.DataType.Type.VARCHAR: { + exp.ArrayConcat, + exp.Concat, + exp.ConcatWs, + exp.DateToDateStr, + exp.GroupConcat, + exp.Initcap, + exp.Lower, + exp.SafeConcat, + exp.Substring, + exp.TimeToStr, + exp.TimeToTimeStr, + exp.Trim, + exp.TsOrDsToDateStr, + exp.UnixToStr, + exp.UnixToTimeStr, + exp.Upper, }, - exp.DataType.Type.DATETIME: { - exp.DataType.Type.TIMESTAMP, - exp.DataType.Type.TIMESTAMPTZ, - exp.DataType.Type.TIMESTAMPLTZ, + } + + ANNOTATORS = { + **{ + expr_type: lambda self, e: self._annotate_unary(e) + for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) }, - exp.DataType.Type.DATE: { - exp.DataType.Type.DATETIME, - exp.DataType.Type.TIMESTAMP, - exp.DataType.Type.TIMESTAMPTZ, - exp.DataType.Type.TIMESTAMPLTZ, + **{ + expr_type: lambda self, e: self._annotate_binary(e) + for expr_type in subclasses(exp.__name__, exp.Binary) + }, + **{ + expr_type: _annotate_with_type_lambda(data_type) + for data_type, expressions in TYPE_TO_EXPRESSIONS.items() + for expr_type in expressions }, + exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), + exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), + exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), + exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), + exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), + exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), + exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), + exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), + exp.Literal: lambda self, e: self._annotate_literal(e), + exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), + exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), + exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), + exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), + exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), + exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), + exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), } - TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) + # Specifies what types a given type can be coerced into (autofilled) + COERCES_TO: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} - def __init__(self, schema=None, annotators=None, coerces_to=None): + def __init__( + self, + schema: Schema, + annotators: t.Optional[t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]]] = None, + coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None, + ) -> None: self.schema = schema self.annotators = annotators or self.ANNOTATORS self.coerces_to = coerces_to or self.COERCES_TO - def annotate(self, expression): - if isinstance(expression, self.TRAVERSABLES): - for scope in traverse_scope(expression): - selects = {} - for name, source in scope.sources.items(): - if not isinstance(source, Scope): - continue - if isinstance(source.expression, exp.UDTF): - values = [] - - if isinstance(source.expression, exp.Lateral): - if isinstance(source.expression.this, exp.Explode): - values = [source.expression.this.this] - else: - values = source.expression.expressions[0].expressions - - if not values: - continue - - selects[name] = { - alias: column - for alias, column in zip( - source.expression.alias_column_names, - values, - ) - } + def annotate(self, expression: E) -> E: + for scope in traverse_scope(expression): + selects = {} + for name, source in scope.sources.items(): + if not isinstance(source, Scope): + continue + if isinstance(source.expression, exp.UDTF): + values = [] + + if isinstance(source.expression, exp.Lateral): + if isinstance(source.expression.this, exp.Explode): + values = [source.expression.this.this] else: - selects[name] = { - select.alias_or_name: select for select in source.expression.selects - } - # First annotate the current scope's column references - for col in scope.columns: - if not col.table: + values = source.expression.expressions[0].expressions + + if not values: continue - source = scope.sources.get(col.table) - if isinstance(source, exp.Table): - col.type = self.schema.get_column_type(source, col) - elif source and col.table in selects and col.name in selects[col.table]: - col.type = selects[col.table][col.name].type - # Then (possibly) annotate the remaining expressions in the scope - self._maybe_annotate(scope.expression) + selects[name] = { + alias: column + for alias, column in zip( + source.expression.alias_column_names, + values, + ) + } + else: + selects[name] = { + select.alias_or_name: select for select in source.expression.selects + } + + # First annotate the current scope's column references + for col in scope.columns: + if not col.table: + continue + + source = scope.sources.get(col.table) + if isinstance(source, exp.Table): + col.type = self.schema.get_column_type(source, col) + elif source and col.table in selects and col.name in selects[col.table]: + col.type = selects[col.table][col.name].type + + # Then (possibly) annotate the remaining expressions in the scope + self._maybe_annotate(scope.expression) + return self._maybe_annotate(expression) # This takes care of non-traversable expressions - def _maybe_annotate(self, expression): + def _maybe_annotate(self, expression: E) -> E: if expression.type: return expression # We've already inferred the expression's type @@ -312,13 +290,15 @@ class TypeAnnotator: else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) ) - def _annotate_args(self, expression): + def _annotate_args(self, expression: E) -> E: for _, value in expression.iter_expressions(): self._maybe_annotate(value) return expression - def _maybe_coerce(self, type1, type2): + def _maybe_coerce( + self, type1: exp.DataType | exp.DataType.Type, type2: exp.DataType | exp.DataType.Type + ) -> exp.DataType.Type: # We propagate the NULL / UNKNOWN types upwards if found if isinstance(type1, exp.DataType): type1 = type1.this @@ -330,9 +310,14 @@ class TypeAnnotator: if exp.DataType.Type.UNKNOWN in (type1, type2): return exp.DataType.Type.UNKNOWN - return type2 if type2 in self.coerces_to.get(type1, {}) else type1 + return type2 if type2 in self.coerces_to.get(type1, {}) else type1 # type: ignore - def _annotate_binary(self, expression): + # Note: the following "no_type_check" decorators were added because mypy was yelling due + # to assigning Type values to expression.type (since its getter returns Optional[DataType]). + # This is a known mypy issue: https://github.com/python/mypy/issues/3004 + + @t.no_type_check + def _annotate_binary(self, expression: B) -> B: self._annotate_args(expression) left_type = expression.left.type.this @@ -354,7 +339,8 @@ class TypeAnnotator: return expression - def _annotate_unary(self, expression): + @t.no_type_check + def _annotate_unary(self, expression: E) -> E: self._annotate_args(expression) if isinstance(expression, exp.Condition) and not isinstance(expression, exp.Paren): @@ -364,7 +350,8 @@ class TypeAnnotator: return expression - def _annotate_literal(self, expression): + @t.no_type_check + def _annotate_literal(self, expression: exp.Literal) -> exp.Literal: if expression.is_string: expression.type = exp.DataType.Type.VARCHAR elif expression.is_int: @@ -374,13 +361,16 @@ class TypeAnnotator: return expression - def _annotate_with_type(self, expression, target_type): + @t.no_type_check + def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E: expression.type = target_type return self._annotate_args(expression) - def _annotate_by_args(self, expression, *args, promote=False): + @t.no_type_check + def _annotate_by_args(self, expression: E, *args: str, promote: bool = False) -> E: self._annotate_args(expression) - expressions = [] + + expressions: t.List[exp.Expression] = [] for arg in args: arg_expr = expression.args.get(arg) expressions.extend(expr for expr in ensure_list(arg_expr) if expr) |