summaryrefslogtreecommitdiffstats
path: root/sqlglot/serde.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/serde.py')
-rw-r--r--sqlglot/serde.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/sqlglot/serde.py b/sqlglot/serde.py
index c5203a7..b019035 100644
--- a/sqlglot/serde.py
+++ b/sqlglot/serde.py
@@ -5,7 +5,7 @@ import typing as t
from sqlglot import expressions as exp
if t.TYPE_CHECKING:
- JSON = t.Union[dict, list, str, float, int, bool]
+ JSON = t.Union[dict, list, str, float, int, bool, None]
Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON]
@@ -24,12 +24,12 @@ def dump(node: Node) -> JSON:
klass = node.__class__.__qualname__
if node.__class__.__module__ != exp.__name__:
klass = f"{node.__module__}.{klass}"
- obj = {
+ obj: t.Dict = {
"class": klass,
"args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []},
}
if node.type:
- obj["type"] = node.type.sql()
+ obj["type"] = dump(node.type)
if node.comments:
obj["comments"] = node.comments
if node._meta is not None:
@@ -60,7 +60,7 @@ def load(obj: JSON) -> Node:
klass = getattr(module, class_name)
expression = klass(**{k: load(v) for k, v in obj["args"].items()})
- expression.type = obj.get("type")
+ expression.type = t.cast(exp.DataType, load(obj.get("type")))
expression.comments = obj.get("comments")
expression._meta = obj.get("meta")