diff options
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 158 |
1 files changed, 137 insertions, 21 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 3f5f089..a2cef37 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -1,16 +1,20 @@ from sqlglot import exp from sqlglot.helper import ensure_list, subclasses +from sqlglot.optimizer.schema import ensure_schema +from sqlglot.optimizer.scope import Scope, traverse_scope def annotate_types(expression, schema=None, annotators=None, coerces_to=None): """ Recursively infer & annotate types in an expression syntax tree against a schema. + Assumes that we've already executed the optimizer's qualify_columns step. - (TODO -- replace this with a better example after adding some functionality) Example: >>> import sqlglot - >>> annotated_expression = annotate_types(sqlglot.parse_one('5 + 5.3')) - >>> annotated_expression.type + >>> schema = {"y": {"cola": "SMALLINT"}} + >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" + >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) + >>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'> Args: @@ -22,6 +26,8 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): sqlglot.Expression: expression annotated with types """ + schema = ensure_schema(schema) + return TypeAnnotator(schema, annotators, coerces_to).annotate(expression) @@ -35,10 +41,81 @@ class TypeAnnotator: expr_type: lambda self, expr: self._annotate_binary(expr) for expr_type in subclasses(exp.__name__, exp.Binary) }, - exp.Cast: lambda self, expr: self._annotate_cast(expr), - exp.DataType: lambda self, expr: self._annotate_data_type(expr), + exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this), + exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this), + exp.Alias: lambda self, expr: self._annotate_unary(expr), exp.Literal: lambda self, expr: self._annotate_literal(expr), - exp.Boolean: lambda self, expr: self._annotate_boolean(expr), + exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), + exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN), + exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME), + exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), + exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), + exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.Ln: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Log: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Log2: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Log10: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Lower: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), + exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), + exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), + exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP), + exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR), + exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), + exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT), } # Reference: https://spark.apache.org/docs/3.2.0/sql-ref-ansi-compliance.html @@ -97,43 +174,82 @@ class TypeAnnotator: }, } + TRAVERSABLES = (exp.Select, exp.Union, exp.UDTF, exp.Subquery) + def __init__(self, schema=None, annotators=None, coerces_to=None): self.schema = schema self.annotators = annotators or self.ANNOTATORS self.coerces_to = coerces_to or self.COERCES_TO def annotate(self, expression): + if isinstance(expression, self.TRAVERSABLES): + for scope in traverse_scope(expression): + subscope_selects = { + name: {select.alias_or_name: select for select in source.selects} + for name, source in scope.sources.items() + if isinstance(source, Scope) + } + + # First annotate the current scope's column references + for col in scope.columns: + source = scope.sources[col.table] + if isinstance(source, exp.Table): + col.type = self.schema.get_column_type(source, col) + else: + col.type = subscope_selects[col.table][col.name].type + + # Then (possibly) annotate the remaining expressions in the scope + self._maybe_annotate(scope.expression) + + return self._maybe_annotate(expression) # This takes care of non-traversable expressions + + def _maybe_annotate(self, expression): if not isinstance(expression, exp.Expression): return None + if expression.type: + return expression # We've already inferred the expression's type + annotator = self.annotators.get(expression.__class__) - return annotator(self, expression) if annotator else self._annotate_args(expression) + return ( + annotator(self, expression) + if annotator + else self._annotate_with_type(expression, exp.DataType.Type.UNKNOWN) + ) def _annotate_args(self, expression): for value in expression.args.values(): for v in ensure_list(value): - self.annotate(v) + self._maybe_annotate(v) return expression - def _annotate_cast(self, expression): - expression.type = expression.args["to"].this - return self._annotate_args(expression) - - def _annotate_data_type(self, expression): - expression.type = expression.this - return self._annotate_args(expression) - def _maybe_coerce(self, type1, type2): + # We propagate the NULL / UNKNOWN types upwards if found + if exp.DataType.Type.NULL in (type1, type2): + return exp.DataType.Type.NULL + if exp.DataType.Type.UNKNOWN in (type1, type2): + return exp.DataType.Type.UNKNOWN + return type2 if type2 in self.coerces_to[type1] else type1 def _annotate_binary(self, expression): self._annotate_args(expression) - if isinstance(expression, (exp.Condition, exp.Predicate)): + left_type = expression.left.type + right_type = expression.right.type + + if isinstance(expression, (exp.And, exp.Or)): + if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: + expression.type = exp.DataType.Type.NULL + elif exp.DataType.Type.NULL in (left_type, right_type): + expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN")) + else: + expression.type = exp.DataType.Type.BOOLEAN + elif isinstance(expression, (exp.Condition, exp.Predicate)): expression.type = exp.DataType.Type.BOOLEAN else: - expression.type = self._maybe_coerce(expression.left.type, expression.right.type) + expression.type = self._maybe_coerce(left_type, right_type) return expression @@ -157,6 +273,6 @@ class TypeAnnotator: return expression - def _annotate_boolean(self, expression): - expression.type = exp.DataType.Type.BOOLEAN - return expression + def _annotate_with_type(self, expression, target_type): + expression.type = target_type + return self._annotate_args(expression) |