diff options
Diffstat (limited to 'sqlglot/serde.py')
-rw-r--r-- | sqlglot/serde.py | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/sqlglot/serde.py b/sqlglot/serde.py new file mode 100644 index 0000000..a47ffdb --- /dev/null +++ b/sqlglot/serde.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import expressions as exp + +if t.TYPE_CHECKING: + JSON = t.Union[dict, list, str, float, int, bool] + Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON] + + +def dump(node: Node) -> JSON: + """ + Recursively dump an AST into a JSON-serializable dict. + """ + if isinstance(node, list): + return [dump(i) for i in node] + if isinstance(node, exp.DataType.Type): + return { + "class": "DataType.Type", + "value": node.value, + } + if isinstance(node, exp.Expression): + klass = node.__class__.__qualname__ + if node.__class__.__module__ != exp.__name__: + klass = f"{node.__module__}.{klass}" + obj = { + "class": klass, + "args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []}, + } + if node.type: + obj["type"] = node.type.sql() + if node.comments: + obj["comments"] = node.comments + return obj + return node + + +def load(obj: JSON) -> Node: + """ + Recursively load a dict (as returned by `dump`) into an AST. + """ + if isinstance(obj, list): + return [load(i) for i in obj] + if isinstance(obj, dict): + class_name = obj["class"] + + if class_name == "DataType.Type": + return exp.DataType.Type(obj["value"]) + + if "." in class_name: + module_path, class_name = class_name.rsplit(".", maxsplit=1) + module = __import__(module_path, fromlist=[class_name]) + else: + module = exp + + klass = getattr(module, class_name) + + expression = klass(**{k: load(v) for k, v in obj["args"].items()}) + type_ = obj.get("type") + if type_: + expression.type = exp.DataType.build(type_) + comments = obj.get("comments") + if comments: + expression.comments = load(comments) + return expression + return obj |