diff options
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index a2a86cd..cb9312c 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -263,6 +263,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator): exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), exp.Div: lambda self, e: self._annotate_div(e), + exp.Explode: lambda self, e: self._annotate_explode(e), exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), @@ -333,9 +334,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._visited: t.Set[int] = set() def _set_type( - self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type + self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type] ) -> None: - expression.type = target_type # type: ignore + expression.type = target_type or exp.DataType.Type.UNKNOWN # type: ignore self._visited.add(id(expression)) def annotate(self, expression: E) -> E: @@ -564,13 +565,11 @@ class TypeAnnotator(metaclass=_TypeAnnotator): if isinstance(bracket_arg, exp.Slice): self._set_type(expression, this.type) elif this.type.is_type(exp.DataType.Type.ARRAY): - contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN - self._set_type(expression, contained_type) + self._set_type(expression, seq_get(this.type.expressions, 0)) elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys: index = this.keys.index(bracket_arg) value = seq_get(this.values, index) - value_type = value.type if value else exp.DataType.Type.UNKNOWN - self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN) + self._set_type(expression, value.type if value else None) else: self._set_type(expression, exp.DataType.Type.UNKNOWN) @@ -591,3 +590,8 @@ class TypeAnnotator(metaclass=_TypeAnnotator): self._set_type(expression, self._maybe_coerce(left_type, right_type)) return expression + + def _annotate_explode(self, expression: exp.Explode) -> exp.Explode: + self._annotate_args(expression) + self._set_type(expression, seq_get(expression.this.type.expressions, 0)) + return expression |