summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/annotate_types.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r--sqlglot/optimizer/annotate_types.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index a2a86cd..cb9312c 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -263,6 +263,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Div: lambda self, e: self._annotate_div(e),
+ exp.Explode: lambda self, e: self._annotate_explode(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
@@ -333,9 +334,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._visited: t.Set[int] = set()
def _set_type(
- self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
+ self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type]
) -> None:
- expression.type = target_type # type: ignore
+ expression.type = target_type or exp.DataType.Type.UNKNOWN # type: ignore
self._visited.add(id(expression))
def annotate(self, expression: E) -> E:
@@ -564,13 +565,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
if isinstance(bracket_arg, exp.Slice):
self._set_type(expression, this.type)
elif this.type.is_type(exp.DataType.Type.ARRAY):
- contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
- self._set_type(expression, contained_type)
+ self._set_type(expression, seq_get(this.type.expressions, 0))
elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
index = this.keys.index(bracket_arg)
value = seq_get(this.values, index)
- value_type = value.type if value else exp.DataType.Type.UNKNOWN
- self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
+ self._set_type(expression, value.type if value else None)
else:
self._set_type(expression, exp.DataType.Type.UNKNOWN)
@@ -591,3 +590,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
self._set_type(expression, self._maybe_coerce(left_type, right_type))
return expression
+
+ def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
+ self._annotate_args(expression)
+ self._set_type(expression, seq_get(expression.this.type.expressions, 0))
+ return expression