summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/annotate_types.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--sqlglot/optimizer/annotate_types.py25
1 files changed, 24 insertions, 1 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 17af6ac..69d4567 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -6,7 +6,7 @@ import typing as t
from sqlglot import exp
from sqlglot._typing import E
-from sqlglot.helper import ensure_list, subclasses
+from sqlglot.helper import ensure_list, seq_get, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema, ensure_schema
@@ -271,6 +271,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True),
exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True),
exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
+ exp.Bracket: lambda self, e: self._annotate_bracket(e),
exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"),
exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
@@ -287,6 +288,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
+ exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
@@ -524,3 +526,24 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, datatype)
return expression
+
+ def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
+ self._annotate_args(expression)
+
+ bracket_arg = expression.expressions[0]
+ this = expression.this
+
+ if isinstance(bracket_arg, exp.Slice):
+ self._set_type(expression, this.type)
+ elif this.type.is_type(exp.DataType.Type.ARRAY):
+ contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
+ self._set_type(expression, contained_type)
+ elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
+ index = this.keys.index(bracket_arg)
+ value = seq_get(this.values, index)
+ value_type = value.type if value else exp.DataType.Type.UNKNOWN
+ self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
+ else:
+ self._set_type(expression, exp.DataType.Type.UNKNOWN)
+
+ return expression