diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-17 10:32:12 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-01-17 10:32:12 +0000 |
commit | 244a05de60c9417daab9528b51788c3d2a00dc5f (patch) | |
tree | 89a9c82aa41d397e1b81c320ad7a287b6c80f313 /sqlglot/serde.py | |
parent | Adding upstream version 10.4.2. (diff) | |
download | sqlglot-upstream/10.5.2.tar.xz sqlglot-upstream/10.5.2.zip |
Adding upstream version 10.5.2.upstream/10.5.2
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/serde.py')
-rw-r--r-- | sqlglot/serde.py | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/sqlglot/serde.py b/sqlglot/serde.py new file mode 100644 index 0000000..a47ffdb --- /dev/null +++ b/sqlglot/serde.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import expressions as exp + +if t.TYPE_CHECKING: + JSON = t.Union[dict, list, str, float, int, bool] + Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON] + + +def dump(node: Node) -> JSON: + """ + Recursively dump an AST into a JSON-serializable dict. + """ + if isinstance(node, list): + return [dump(i) for i in node] + if isinstance(node, exp.DataType.Type): + return { + "class": "DataType.Type", + "value": node.value, + } + if isinstance(node, exp.Expression): + klass = node.__class__.__qualname__ + if node.__class__.__module__ != exp.__name__: + klass = f"{node.__module__}.{klass}" + obj = { + "class": klass, + "args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []}, + } + if node.type: + obj["type"] = node.type.sql() + if node.comments: + obj["comments"] = node.comments + return obj + return node + + +def load(obj: JSON) -> Node: + """ + Recursively load a dict (as returned by `dump`) into an AST. + """ + if isinstance(obj, list): + return [load(i) for i in obj] + if isinstance(obj, dict): + class_name = obj["class"] + + if class_name == "DataType.Type": + return exp.DataType.Type(obj["value"]) + + if "." in class_name: + module_path, class_name = class_name.rsplit(".", maxsplit=1) + module = __import__(module_path, fromlist=[class_name]) + else: + module = exp + + klass = getattr(module, class_name) + + expression = klass(**{k: load(v) for k, v in obj["args"].items()}) + type_ = obj.get("type") + if type_: + expression.type = exp.DataType.build(type_) + comments = obj.get("comments") + if comments: + expression.comments = load(comments) + return expression + return obj |