summaryrefslogtreecommitdiffstats
path: root/sqlglot/optimizer/annotate_types.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/optimizer/annotate_types.py')
-rw-r--r--sqlglot/optimizer/annotate_types.py24
1 files changed, 23 insertions, 1 deletions
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 7b990f1..d0168d5 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -195,6 +195,9 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.StrPosition,
exp.TsOrDiToDi,
},
+ exp.DataType.Type.JSON: {
+ exp.ParseJSON,
+ },
exp.DataType.Type.TIMESTAMP: {
exp.CurrentTime,
exp.CurrentTimestamp,
@@ -275,6 +278,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
+ exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
}
NESTED_TYPES = {
@@ -477,7 +481,12 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
@t.no_type_check
def _annotate_by_args(
- self, expression: E, *args: str, promote: bool = False, array: bool = False
+ self,
+ expression: E,
+ *args: str,
+ promote: bool = False,
+ array: bool = False,
+ struct: bool = False,
) -> E:
self._annotate_args(expression)
@@ -506,6 +515,19 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
),
)
+ if struct:
+ expressions = [
+ expr.type
+ if not expr.args.get("alias")
+ else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type)
+ for expr in expressions
+ ]
+
+ self._set_type(
+ expression,
+ exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
+ )
+
return expression
def _annotate_timeunit(