summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/annotate_types.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r--sqlglot/optimizer/annotate_types.py30
1 files changed, 20 insertions, 10 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index d0168d5..a2a86cd 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -4,7 +4,6 @@ import functools
import typing as t
from sqlglot import exp
-from sqlglot._typing import E
from sqlglot.helper import (
ensure_list,
is_date_unit,
@@ -17,7 +16,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import Schema, ensure_schema
if t.TYPE_CHECKING:
- B = t.TypeVar("B", bound=exp.Binary)
+ from sqlglot._typing import B, E
BinaryCoercionFunc = t.Callable[[exp.Expression, exp.Expression], exp.DataType.Type]
BinaryCoercions = t.Dict[
@@ -480,6 +479,20 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
return self._annotate_args(expression)
@t.no_type_check
+ def _annotate_struct_value(
+ self, expression: exp.Expression
+ ) -> t.Optional[exp.DataType] | exp.ColumnDef:
+ alias = expression.args.get("alias")
+ if alias:
+ return exp.ColumnDef(this=alias.copy(), kind=expression.type)
+
+ # Case: key = value or key := value
+ if expression.expression:
+ return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)
+
+ return expression.type
+
+ @t.no_type_check
def _annotate_by_args(
self,
expression: E,
@@ -516,16 +529,13 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
)
if struct:
- expressions = [
- expr.type
- if not expr.args.get("alias")
- else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type)
- for expr in expressions
- ]
-
self._set_type(
expression,
- exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
+ exp.DataType(
+ this=exp.DataType.Type.STRUCT,
+ expressions=[self._annotate_struct_value(expr) for expr in expressions],
+ nested=True,
+ ),
)
return expression