diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index 17af6ac..69d4567 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -6,7 +6,7 @@ import typing as t from sqlglot import exp from sqlglot._typing import E -from sqlglot.helper import ensure_list, subclasses +from sqlglot.helper import ensure_list, seq_get, subclasses from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import Schema, ensure_schema @@ -271,6 +271,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), + exp.Bracket: lambda self, e: self._annotate_bracket(e), exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), @@ -287,6 +288,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), + exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP), @@ -524,3 +526,24 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._set_type(expression, datatype) return expression + + def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket: + self._annotate_args(expression) + + bracket_arg = expression.expressions[0] + this = expression.this + + if isinstance(bracket_arg, exp.Slice): + self._set_type(expression, this.type) + elif this.type.is_type(exp.DataType.Type.ARRAY): + contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN + self._set_type(expression, contained_type) + elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys: + index = this.keys.index(bracket_arg) + value = seq_get(this.values, index) + value_type = value.type if value else exp.DataType.Type.UNKNOWN + self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN) + else: + self._set_type(expression, exp.DataType.Type.UNKNOWN) + + return expression |