diff options
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 37 |
1 files changed, 28 insertions, 9 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 191ea52..be17f15 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -14,7 +14,7 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None): >>> schema = {"y": {"cola": "SMALLINT"}} >>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x" >>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema) - >>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola" + >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola" <Type.DOUBLE: 'DOUBLE'> Args: @@ -41,9 +41,12 @@ class TypeAnnotator: expr_type: lambda self, expr: self._annotate_binary(expr) for expr_type in subclasses(exp.__name__, exp.Binary) }, - exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this), - exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this), + exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), + exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), + exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr), exp.Alias: lambda self, expr: self._annotate_unary(expr), + exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), + exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Literal: lambda self, expr: self._annotate_literal(expr), exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL), @@ -52,6 +55,9 @@ class TypeAnnotator: expr, exp.DataType.Type.BIGINT ), exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE), + exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"), + exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"), + exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True), exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT), exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT), exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE), @@ -263,10 +269,10 @@ class TypeAnnotator: } # First annotate the current scope's column references for col in scope.columns: - source = scope.sources[col.table] + source = scope.sources.get(col.table) if isinstance(source, exp.Table): col.type = self.schema.get_column_type(source, col) - else: + elif source: col.type = selects[col.table][col.name].type # Then (possibly) annotate the remaining expressions in the scope self._maybe_annotate(scope.expression) @@ -280,6 +286,7 @@ class TypeAnnotator: return expression # We've already inferred the expression's type annotator = self.annotators.get(expression.__class__) + return ( annotator(self, expression) if annotator @@ -295,18 +302,23 @@ class TypeAnnotator: def _maybe_coerce(self, type1, type2): # We propagate the NULL / UNKNOWN types upwards if found + if isinstance(type1, exp.DataType): + type1 = type1.this + if isinstance(type2, exp.DataType): + type2 = type2.this + if exp.DataType.Type.NULL in (type1, type2): return exp.DataType.Type.NULL if exp.DataType.Type.UNKNOWN in (type1, type2): return exp.DataType.Type.UNKNOWN - return type2 if type2 in self.coerces_to[type1] else type1 + return type2 if type2 in self.coerces_to.get(type1, {}) else type1 def _annotate_binary(self, expression): self._annotate_args(expression) - left_type = expression.left.type - right_type = expression.right.type + left_type = expression.left.type.this + right_type = expression.right.type.this if isinstance(expression, (exp.And, exp.Or)): if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL: @@ -348,7 +360,7 @@ class TypeAnnotator: expression.type = target_type return self._annotate_args(expression) - def _annotate_by_args(self, expression, *args): + def _annotate_by_args(self, expression, *args, promote=False): self._annotate_args(expression) expressions = [] for arg in args: @@ -360,4 +372,11 @@ class TypeAnnotator: last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type) expression.type = last_datatype or exp.DataType.Type.UNKNOWN + + if promote: + if expression.type.this in exp.DataType.INTEGER_TYPES: + expression.type = exp.DataType.Type.BIGINT + elif expression.type.this in exp.DataType.FLOAT_TYPES: + expression.type = exp.DataType.Type.DOUBLE + return expression |