summaryrefslogtreecommitdiffstats
path: root/sqlglot/serde.py
blob: a47ffdbd5bc47c8aecb08de3ac9e87723d63cce1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from __future__ import annotations

import typing as t

from sqlglot import expressions as exp

if t.TYPE_CHECKING:
    JSON = t.Union[dict, list, str, float, int, bool]
    Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON]


def dump(node: Node) -> JSON:
    """
    Recursively dump an AST into a JSON-serializable dict.
    """
    if isinstance(node, list):
        return [dump(i) for i in node]
    if isinstance(node, exp.DataType.Type):
        return {
            "class": "DataType.Type",
            "value": node.value,
        }
    if isinstance(node, exp.Expression):
        klass = node.__class__.__qualname__
        if node.__class__.__module__ != exp.__name__:
            klass = f"{node.__module__}.{klass}"
        obj = {
            "class": klass,
            "args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []},
        }
        if node.type:
            obj["type"] = node.type.sql()
        if node.comments:
            obj["comments"] = node.comments
        return obj
    return node


def load(obj: JSON) -> Node:
    """
    Recursively load a dict (as returned by `dump`) into an AST.
    """
    if isinstance(obj, list):
        return [load(i) for i in obj]
    if isinstance(obj, dict):
        class_name = obj["class"]

        if class_name == "DataType.Type":
            return exp.DataType.Type(obj["value"])

        if "." in class_name:
            module_path, class_name = class_name.rsplit(".", maxsplit=1)
            module = __import__(module_path, fromlist=[class_name])
        else:
            module = exp

        klass = getattr(module, class_name)

        expression = klass(**{k: load(v) for k, v in obj["args"].items()})
        type_ = obj.get("type")
        if type_:
            expression.type = exp.DataType.build(type_)
        comments = obj.get("comments")
        if comments:
            expression.comments = load(comments)
        return expression
    return obj