diff options
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r-- | sqlglot/optimizer/annotate_types.py | 30 |
1 files changed, 20 insertions, 10 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index d0168d5..a2a86cd 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -4,7 +4,6 @@ import functools import typing as t from sqlglot import exp -from sqlglot._typing import E from sqlglot.helper import ( ensure_list, is_date_unit, @@ -17,7 +16,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope from sqlglot.schema import Schema, ensure_schema if t.TYPE_CHECKING: - B = t.TypeVar("B", bound=exp.Binary) + from sqlglot._typing import B, E BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type] BinaryCoercions = t.Dict[ @@ -480,6 +479,20 @@ class TypeAnnotator(metaclass=_TypeAnnotator): return self._annotate_args(expression) @t.no_type_check + def _annotate_struct_value( + self, expression: exp.Expression + ) -> t.Optional[exp.DataType] | exp.ColumnDef: + alias = expression.args.get("alias") + if alias: + return exp.ColumnDef(this=alias.copy(), kind=expression.type) + + # Case: key = value or key := value + if expression.expression: + return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type) + + return expression.type + + @t.no_type_check def _annotate_by_args( self, expression: E, @@ -516,16 +529,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator): ) if struct: - expressions = [ - expr.type - if not expr.args.get("alias") - else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type) - for expr in expressions - ] - self._set_type( expression, - exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True), + exp.DataType( + this=exp.DataType.Type.STRUCT, + expressions=[self._annotate_struct_value(expr) for expr in expressions], + nested=True, + ), ) return expression |