summaryrefslogtreecommitdiffstats
path: root/ipc/ipdl/ipdl/cgen.py
blob: 8ed8da4d8105dde34a7f89fd8802669f6d7b6aba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import sys

from ipdl.ast import Visitor


class CodePrinter:
    def __init__(self, outf=sys.stdout, indentCols=4):
        self.outf = outf
        self.col = 0
        self.indentCols = indentCols

    def write(self, str):
        self.outf.write(str)

    def printdent(self, str=""):
        self.write((" " * self.col) + str)

    def println(self, str=""):
        self.write(str + "\n")

    def printdentln(self, str):
        self.write((" " * self.col) + str + "\n")

    def indent(self):
        self.col += self.indentCols

    def dedent(self):
        self.col -= self.indentCols


# -----------------------------------------------------------------------------
class IPDLCodeGen(CodePrinter, Visitor):
    """Spits back out equivalent IPDL to the code that generated this.
    Also known as pretty-printing."""

    def __init__(self, outf=sys.stdout, indentCols=4, printed=set()):
        CodePrinter.__init__(self, outf, indentCols)
        self.printed = printed

    def visitTranslationUnit(self, tu):
        self.printed.add(tu.filename)
        self.println("//\n// Automatically generated by ipdlc\n//")
        CodeGen.visitTranslationUnit(self, tu)  # NOQA: F821

    def visitCxxInclude(self, inc):
        self.println('include "' + inc.file + '";')

    def visitProtocolInclude(self, inc):
        self.println('include protocol "' + inc.file + '";')
        if inc.tu.filename not in self.printed:
            self.println("/* Included file:")
            IPDLCodeGen(
                outf=self.outf, indentCols=self.indentCols, printed=self.printed
            ).visitTranslationUnit(inc.tu)

            self.println("*/")

    def visitProtocol(self, p):
        self.println()
        for namespace in p.namespaces:
            namespace.accept(self)

        self.println("%s protocol %s\n{" % (p.sendSemantics[0], p.name))
        self.indent()

        for mgs in p.managesStmts:
            mgs.accept(self)
        if len(p.managesStmts):
            self.println()

        for msgDecl in p.messageDecls:
            msgDecl.accept(self)
        self.println()

        self.dedent()
        self.println("}")
        self.write("}\n" * len(p.namespaces))

    def visitManagerStmt(self, mgr):
        self.printdentln("manager " + mgr.name + ";")

    def visitManagesStmt(self, mgs):
        self.printdentln("manages " + mgs.name + ";")

    def visitMessageDecl(self, msg):
        self.printdent("%s %s %s(" % (msg.sendSemantics[0], msg.direction[0], msg.name))
        for i, inp in enumerate(msg.inParams):
            inp.accept(self)
            if i != (len(msg.inParams) - 1):
                self.write(", ")
        self.write(")")
        if 0 == len(msg.outParams):
            self.println(";")
            return

        self.println()
        self.indent()
        self.printdent("returns (")
        for i, outp in enumerate(msg.outParams):
            outp.accept(self)
            if i != (len(msg.outParams) - 1):
                self.write(", ")
        self.println(");")
        self.dedent()