summaryrefslogtreecommitdiffstats
path: root/ipc/ipdl/ipdl
diff options
context:
space:
mode:
Diffstat (limited to 'ipc/ipdl/ipdl')
-rw-r--r--ipc/ipdl/ipdl/__init__.py98
-rw-r--r--ipc/ipdl/ipdl/ast.py468
-rw-r--r--ipc/ipdl/ipdl/builtin.py76
-rw-r--r--ipc/ipdl/ipdl/cgen.py108
-rw-r--r--ipc/ipdl/ipdl/checker.py79
-rw-r--r--ipc/ipdl/ipdl/cxx/__init__.py3
-rw-r--r--ipc/ipdl/ipdl/cxx/ast.py1033
-rw-r--r--ipc/ipdl/ipdl/cxx/cgen.py557
-rw-r--r--ipc/ipdl/ipdl/cxx/code.py187
-rw-r--r--ipc/ipdl/ipdl/lower.py5688
-rw-r--r--ipc/ipdl/ipdl/parser.py680
-rw-r--r--ipc/ipdl/ipdl/type.py1748
-rw-r--r--ipc/ipdl/ipdl/util.py12
13 files changed, 10737 insertions, 0 deletions
diff --git a/ipc/ipdl/ipdl/__init__.py b/ipc/ipdl/ipdl/__init__.py
new file mode 100644
index 0000000000..50ceb4f953
--- /dev/null
+++ b/ipc/ipdl/ipdl/__init__.py
@@ -0,0 +1,98 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+__all__ = [
+ "gencxx",
+ "genipdl",
+ "parse",
+ "typecheck",
+ "writeifmodified",
+ "checkSyncMessage",
+ "checkFixedSyncMessages",
+]
+
+import os
+import sys
+from io import StringIO
+
+from ipdl.cgen import IPDLCodeGen
+from ipdl.lower import LowerToCxx, msgenums
+from ipdl.parser import Parser, ParseError
+from ipdl.type import TypeCheck
+from ipdl.checker import checkSyncMessage, checkFixedSyncMessages
+
+from ipdl.cxx.cgen import CxxCodeGen
+
+
+def parse(specstring, filename="/stdin", includedirs=[], errout=sys.stderr):
+ """Return an IPDL AST if parsing was successful. Print errors to |errout|
+ if it is not."""
+ # The file type and name are later enforced by the type checker.
+ # This is just a hint to the parser.
+ prefix, ext = os.path.splitext(filename)
+ name = os.path.basename(prefix)
+ if ext == ".ipdlh":
+ type = "header"
+ else:
+ type = "protocol"
+
+ try:
+ return Parser(type, name).parse(
+ specstring, os.path.abspath(filename), includedirs
+ )
+ except ParseError as p:
+ print(p, file=errout)
+ return None
+
+
+def typecheck(ast, errout=sys.stderr):
+ """Return True iff |ast| is well typed. Print errors to |errout| if
+ it is not."""
+ return TypeCheck().check(ast, errout)
+
+
+def gencxx(ipdlfilename, ast, outheadersdir, outcppdir, segmentcapacitydict):
+ headers, cpps = LowerToCxx().lower(ast, segmentcapacitydict)
+
+ def resolveHeader(hdr):
+ return [
+ hdr,
+ os.path.join(
+ outheadersdir, *([ns.name for ns in ast.namespaces] + [hdr.name])
+ ),
+ ]
+
+ def resolveCpp(cpp):
+ return [cpp, os.path.join(outcppdir, cpp.name)]
+
+ for ast, filename in [resolveHeader(hdr) for hdr in headers] + [
+ resolveCpp(cpp) for cpp in cpps
+ ]:
+ tempfile = StringIO()
+ CxxCodeGen(tempfile).cgen(ast)
+ writeifmodified(tempfile.getvalue(), filename)
+
+
+def genipdl(ast, outdir):
+ return IPDLCodeGen().cgen(ast)
+
+
+def genmsgenum(ast):
+ return msgenums(ast.protocol, pretty=True)
+
+
+def writeifmodified(contents, file):
+ contents = contents.encode("utf-8")
+ dir = os.path.dirname(file)
+ os.path.exists(dir) or os.makedirs(dir)
+
+ oldcontents = None
+ if os.path.exists(file):
+ fd = open(file, "rb")
+ oldcontents = fd.read()
+ fd.close()
+ if oldcontents != contents:
+ fd = open(file, "wb")
+ fd.write(contents)
+ fd.close()
diff --git a/ipc/ipdl/ipdl/ast.py b/ipc/ipdl/ipdl/ast.py
new file mode 100644
index 0000000000..e50cbb5c65
--- /dev/null
+++ b/ipc/ipdl/ipdl/ast.py
@@ -0,0 +1,468 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+from .util import hash_str
+
+
+NOT_NESTED = 1
+INSIDE_SYNC_NESTED = 2
+INSIDE_CPOW_NESTED = 3
+
+NESTED_ATTR_MAP = {
+ "not": NOT_NESTED,
+ "inside_sync": INSIDE_SYNC_NESTED,
+ "inside_cpow": INSIDE_CPOW_NESTED,
+}
+
+# Each element of this list is the IPDL source representation of a priority.
+priorityList = ["normal", "input", "vsync", "mediumhigh", "control"]
+
+priorityAttrMap = {src: idx for idx, src in enumerate(priorityList)}
+
+NORMAL_PRIORITY = priorityAttrMap["normal"]
+
+
+class Visitor:
+ def defaultVisit(self, node):
+ raise Exception(
+ "INTERNAL ERROR: no visitor for node type `%s'" % (node.__class__.__name__)
+ )
+
+ def visitTranslationUnit(self, tu):
+ for cxxInc in tu.cxxIncludes:
+ cxxInc.accept(self)
+ for inc in tu.includes:
+ inc.accept(self)
+ for su in tu.structsAndUnions:
+ su.accept(self)
+ for using in tu.builtinUsing:
+ using.accept(self)
+ for using in tu.using:
+ using.accept(self)
+ if tu.protocol:
+ tu.protocol.accept(self)
+
+ def visitCxxInclude(self, inc):
+ pass
+
+ def visitInclude(self, inc):
+ # Note: we don't visit the child AST here, because that needs delicate
+ # and pass-specific handling
+ pass
+
+ def visitStructDecl(self, struct):
+ for f in struct.fields:
+ f.accept(self)
+ for a in struct.attributes.values():
+ a.accept(self)
+
+ def visitStructField(self, field):
+ field.typespec.accept(self)
+
+ def visitUnionDecl(self, union):
+ for t in union.components:
+ t.accept(self)
+ for a in union.attributes.values():
+ a.accept(self)
+
+ def visitUsingStmt(self, using):
+ for a in using.attributes.values():
+ a.accept(self)
+
+ def visitProtocol(self, p):
+ for namespace in p.namespaces:
+ namespace.accept(self)
+ for mgr in p.managers:
+ mgr.accept(self)
+ for managed in p.managesStmts:
+ managed.accept(self)
+ for msgDecl in p.messageDecls:
+ msgDecl.accept(self)
+ for a in p.attributes.values():
+ a.accept(self)
+
+ def visitNamespace(self, ns):
+ pass
+
+ def visitManager(self, mgr):
+ pass
+
+ def visitManagesStmt(self, mgs):
+ pass
+
+ def visitMessageDecl(self, md):
+ for inParam in md.inParams:
+ inParam.accept(self)
+ for outParam in md.outParams:
+ outParam.accept(self)
+ for a in md.attributes.values():
+ a.accept(self)
+
+ def visitParam(self, decl):
+ for a in decl.attributes.values():
+ a.accept(self)
+
+ def visitTypeSpec(self, ts):
+ pass
+
+ def visitAttribute(self, a):
+ if isinstance(a.value, Node):
+ a.value.accept(self)
+
+ def visitStringLiteral(self, sl):
+ pass
+
+ def visitDecl(self, d):
+ for a in d.attributes.values():
+ a.accept(self)
+
+
+class Loc:
+ def __init__(self, filename="<??>", lineno=0):
+ assert filename
+ self.filename = filename
+ self.lineno = lineno
+
+ def __repr__(self):
+ return "%r:%r" % (self.filename, self.lineno)
+
+ def __str__(self):
+ return "%s:%s" % (self.filename, self.lineno)
+
+
+Loc.NONE = Loc(filename="<??>", lineno=0)
+
+
+class _struct:
+ pass
+
+
+class Node:
+ def __init__(self, loc=Loc.NONE):
+ self.loc = loc
+
+ def accept(self, visitor):
+ visit = getattr(visitor, "visit" + self.__class__.__name__, None)
+ if visit is None:
+ return getattr(visitor, "defaultVisit")(self)
+ return visit(self)
+
+ def addAttrs(self, attrsName):
+ if not hasattr(self, attrsName):
+ setattr(self, attrsName, _struct())
+
+
+class NamespacedNode(Node):
+ def __init__(self, loc=Loc.NONE, name=None):
+ Node.__init__(self, loc)
+ self.name = name
+ self.namespaces = []
+
+ def addOuterNamespace(self, namespace):
+ self.namespaces.insert(0, namespace)
+
+ def qname(self):
+ return QualifiedId(self.loc, self.name, [ns.name for ns in self.namespaces])
+
+
+class TranslationUnit(NamespacedNode):
+ def __init__(self, type, name):
+ NamespacedNode.__init__(self, name=name)
+ self.filetype = type
+ self.filename = None
+ self.cxxIncludes = []
+ self.includes = []
+ self.builtinUsing = []
+ self.using = []
+ self.structsAndUnions = []
+ self.protocol = None
+
+ def addCxxInclude(self, cxxInclude):
+ self.cxxIncludes.append(cxxInclude)
+
+ def addInclude(self, inc):
+ self.includes.append(inc)
+
+ def addStructDecl(self, struct):
+ self.structsAndUnions.append(struct)
+
+ def addUnionDecl(self, union):
+ self.structsAndUnions.append(union)
+
+ def addUsingStmt(self, using):
+ self.using.append(using)
+
+ def setProtocol(self, protocol):
+ self.protocol = protocol
+
+
+class CxxInclude(Node):
+ def __init__(self, loc, cxxFile):
+ Node.__init__(self, loc)
+ self.file = cxxFile
+
+
+class Include(Node):
+ def __init__(self, loc, type, name):
+ Node.__init__(self, loc)
+ suffix = "ipdl"
+ if type == "header":
+ suffix += "h"
+ self.file = "%s.%s" % (name, suffix)
+
+
+class UsingStmt(Node):
+ def __init__(
+ self,
+ loc,
+ cxxTypeSpec,
+ cxxHeader=None,
+ kind=None,
+ attributes={},
+ ):
+ Node.__init__(self, loc)
+ assert isinstance(cxxTypeSpec, QualifiedId)
+ assert cxxHeader is None or isinstance(cxxHeader, str)
+ assert kind is None or kind == "class" or kind == "struct"
+ self.type = cxxTypeSpec
+ self.header = cxxHeader
+ self.kind = kind
+ self.attributes = attributes
+
+ def canBeForwardDeclared(self):
+ return self.isClass() or self.isStruct()
+
+ def isClass(self):
+ return self.kind == "class"
+
+ def isStruct(self):
+ return self.kind == "struct"
+
+ def isRefcounted(self):
+ return "RefCounted" in self.attributes
+
+ def isSendMoveOnly(self):
+ moveonly = self.attributes.get("MoveOnly")
+ return moveonly and moveonly.value in (None, "send")
+
+ def isDataMoveOnly(self):
+ moveonly = self.attributes.get("MoveOnly")
+ return moveonly and moveonly.value in (None, "data")
+
+
+# "singletons"
+
+
+class PrettyPrinted:
+ @classmethod
+ def __hash__(cls):
+ return hash_str(cls.pretty)
+
+ @classmethod
+ def __str__(cls):
+ return cls.pretty
+
+
+class ASYNC(PrettyPrinted):
+ pretty = "async"
+
+
+class INTR(PrettyPrinted):
+ pretty = "intr"
+
+
+class SYNC(PrettyPrinted):
+ pretty = "sync"
+
+
+class INOUT(PrettyPrinted):
+ pretty = "inout"
+
+
+class IN(PrettyPrinted):
+ pretty = "in"
+
+
+class OUT(PrettyPrinted):
+ pretty = "out"
+
+
+class Namespace(Node):
+ def __init__(self, loc, namespace):
+ Node.__init__(self, loc)
+ self.name = namespace
+
+
+class Protocol(NamespacedNode):
+ def __init__(self, loc):
+ NamespacedNode.__init__(self, loc)
+ self.attributes = {}
+ self.sendSemantics = ASYNC
+ self.managers = []
+ self.managesStmts = []
+ self.messageDecls = []
+
+ def nestedUpTo(self):
+ if "NestedUpTo" not in self.attributes:
+ return NOT_NESTED
+
+ return NESTED_ATTR_MAP.get(self.attributes["NestedUpTo"].value, NOT_NESTED)
+
+ def implAttribute(self, side):
+ assert side in ("parent", "child")
+ attr = self.attributes.get(side.capitalize() + "Impl")
+ if attr is not None:
+ return attr.value
+ return None
+
+
+class StructField(Node):
+ def __init__(self, loc, type, name):
+ Node.__init__(self, loc)
+ self.typespec = type
+ self.name = name
+
+
+class StructDecl(NamespacedNode):
+ def __init__(self, loc, name, fields, attributes):
+ NamespacedNode.__init__(self, loc, name)
+ self.fields = fields
+ self.attributes = attributes
+ # A list of indices into `fields` for determining the order in
+ # which fields are laid out in memory. We don't just reorder
+ # `fields` itself so as to keep the ordering reasonably stable
+ # for e.g. C++ constructors when new fields are added.
+ self.packed_field_ordering = []
+
+
+class UnionDecl(NamespacedNode):
+ def __init__(self, loc, name, components, attributes):
+ NamespacedNode.__init__(self, loc, name)
+ self.components = components
+ self.attributes = attributes
+
+
+class Manager(Node):
+ def __init__(self, loc, managerName):
+ Node.__init__(self, loc)
+ self.name = managerName
+
+
+class ManagesStmt(Node):
+ def __init__(self, loc, managedName):
+ Node.__init__(self, loc)
+ self.name = managedName
+
+
+class MessageDecl(Node):
+ def __init__(self, loc):
+ Node.__init__(self, loc)
+ self.name = None
+ self.attributes = {}
+ self.sendSemantics = ASYNC
+ self.direction = None
+ self.inParams = []
+ self.outParams = []
+
+ def addInParams(self, inParamsList):
+ self.inParams += inParamsList
+
+ def addOutParams(self, outParamsList):
+ self.outParams += outParamsList
+
+ def nested(self):
+ if "Nested" not in self.attributes:
+ return NOT_NESTED
+
+ return NESTED_ATTR_MAP.get(self.attributes["Nested"].value, NOT_NESTED)
+
+ def priority(self):
+ if "Priority" in self.attributes:
+ sourcePriority = self.attributes["Priority"].value
+ else:
+ sourcePriority = "normal"
+ return priorityAttrMap.get(sourcePriority, NORMAL_PRIORITY)
+
+ def replyPriority(self):
+ if "ReplyPriority" in self.attributes:
+ sourcePriority = self.attributes["ReplyPriority"].value
+ if sourcePriority in priorityAttrMap:
+ return priorityAttrMap[sourcePriority]
+ return self.priority()
+
+
+class Param(Node):
+ def __init__(self, loc, typespec, name, attributes={}):
+ Node.__init__(self, loc)
+ self.name = name
+ self.typespec = typespec
+ self.attributes = attributes
+
+
+class TypeSpec(Node):
+ def __init__(self, loc, spec):
+ Node.__init__(self, loc)
+ assert isinstance(spec, str)
+ self.spec = spec # str
+ self.array = False # bool
+ self.maybe = False # bool
+ self.nullable = False # bool
+ self.uniqueptr = False # bool
+
+ def basename(self):
+ return self.spec
+
+ def __str__(self):
+ return self.spec
+
+
+class Attribute(Node):
+ def __init__(self, loc, name, value):
+ Node.__init__(self, loc)
+ self.name = name
+ self.value = value
+
+
+class StringLiteral(Node):
+ def __init__(self, loc, value):
+ Node.__init__(self, loc)
+ self.value = value
+
+ def __str__(self):
+ return '"%s"' % self.value
+
+
+class QualifiedId: # FIXME inherit from node?
+ def __init__(self, loc, baseid, quals=[]):
+ assert isinstance(baseid, str)
+ for qual in quals:
+ assert isinstance(qual, str)
+
+ self.loc = loc
+ self.baseid = baseid
+ self.quals = quals
+
+ def qualify(self, id):
+ self.quals.append(self.baseid)
+ self.baseid = id
+
+ def __str__(self):
+ # NOTE: include a leading "::" in order to force all QualifiedIds to be
+ # fully qualified types in C++
+ return "::" + "::".join(self.quals + [self.baseid])
+
+
+# added by type checking passes
+
+
+class Decl(Node):
+ def __init__(self, loc):
+ Node.__init__(self, loc)
+ self.progname = None # what the programmer typed, if relevant
+ self.shortname = None # shortest way to refer to this decl
+ self.fullname = None # full way to refer to this decl
+ self.loc = loc
+ self.type = None
+ self.scope = None
+ self.attributes = {}
diff --git a/ipc/ipdl/ipdl/builtin.py b/ipc/ipdl/ipdl/builtin.py
new file mode 100644
index 0000000000..b1bab64af8
--- /dev/null
+++ b/ipc/ipdl/ipdl/builtin.py
@@ -0,0 +1,76 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+# WARNING: the syntax of the builtin types is not checked, so please
+# don't add something syntactically invalid. It will not be fun to
+# track down the bug.
+
+# C types
+# These types don't live in any namespace, so can't be imported with `using`
+# statements like normal C++ types.
+CTypes = (
+ "bool",
+ "char",
+ "short",
+ "int",
+ "long",
+ "float",
+ "double",
+)
+
+# C++ types
+# These types must be fully qualified, and will be `typedef`-ed into IPDL
+# structs to make them readily available when used.
+Types = (
+ # stdint types
+ "int8_t",
+ "uint8_t",
+ "int16_t",
+ "uint16_t",
+ "int32_t",
+ "uint32_t",
+ "int64_t",
+ "uint64_t",
+ "intptr_t",
+ "uintptr_t",
+ # You may be tempted to add size_t. Do not! See bug 1525199.
+ # Mozilla types: "less" standard things we know how serialize/deserialize
+ "nsresult",
+ "nsString",
+ "nsCString",
+ "mozilla::ipc::Shmem",
+ "mozilla::ipc::ByteBuf",
+ "mozilla::UniquePtr",
+ "mozilla::ipc::FileDescriptor",
+)
+
+
+# XXX(Bug 1677487) Can we restrict including ByteBuf.h, FileDescriptor.h,
+# MozPromise.h and Shmem.h to those protocols that really use them?
+HeaderIncludes = (
+ "mozilla/Attributes.h",
+ "IPCMessageStart.h",
+ "mozilla/RefPtr.h",
+ "nsString.h",
+ "nsTArray.h",
+ "nsTHashtable.h",
+ "mozilla/MozPromise.h",
+ "mozilla/OperatorNewExtensions.h",
+ "mozilla/UniquePtr.h",
+ "mozilla/ipc/ByteBuf.h",
+ "mozilla/ipc/FileDescriptor.h",
+ "mozilla/ipc/IPCForwards.h",
+ "mozilla/ipc/Shmem.h",
+)
+
+CppIncludes = (
+ "ipc/IPCMessageUtils.h",
+ "ipc/IPCMessageUtilsSpecializations.h",
+ "nsIFile.h",
+ "mozilla/ipc/Endpoint.h",
+ "mozilla/ipc/ProtocolMessageUtils.h",
+ "mozilla/ipc/ProtocolUtils.h",
+ "mozilla/ipc/ShmemMessageUtils.h",
+ "mozilla/ipc/TaintingIPCUtils.h",
+)
diff --git a/ipc/ipdl/ipdl/cgen.py b/ipc/ipdl/ipdl/cgen.py
new file mode 100644
index 0000000000..8ed8da4d81
--- /dev/null
+++ b/ipc/ipdl/ipdl/cgen.py
@@ -0,0 +1,108 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+import sys
+
+from ipdl.ast import Visitor
+
+
+class CodePrinter:
+ def __init__(self, outf=sys.stdout, indentCols=4):
+ self.outf = outf
+ self.col = 0
+ self.indentCols = indentCols
+
+ def write(self, str):
+ self.outf.write(str)
+
+ def printdent(self, str=""):
+ self.write((" " * self.col) + str)
+
+ def println(self, str=""):
+ self.write(str + "\n")
+
+ def printdentln(self, str):
+ self.write((" " * self.col) + str + "\n")
+
+ def indent(self):
+ self.col += self.indentCols
+
+ def dedent(self):
+ self.col -= self.indentCols
+
+
+# -----------------------------------------------------------------------------
+class IPDLCodeGen(CodePrinter, Visitor):
+ """Spits back out equivalent IPDL to the code that generated this.
+ Also known as pretty-printing."""
+
+ def __init__(self, outf=sys.stdout, indentCols=4, printed=set()):
+ CodePrinter.__init__(self, outf, indentCols)
+ self.printed = printed
+
+ def visitTranslationUnit(self, tu):
+ self.printed.add(tu.filename)
+ self.println("//\n// Automatically generated by ipdlc\n//")
+ CodeGen.visitTranslationUnit(self, tu) # NOQA: F821
+
+ def visitCxxInclude(self, inc):
+ self.println('include "' + inc.file + '";')
+
+ def visitProtocolInclude(self, inc):
+ self.println('include protocol "' + inc.file + '";')
+ if inc.tu.filename not in self.printed:
+ self.println("/* Included file:")
+ IPDLCodeGen(
+ outf=self.outf, indentCols=self.indentCols, printed=self.printed
+ ).visitTranslationUnit(inc.tu)
+
+ self.println("*/")
+
+ def visitProtocol(self, p):
+ self.println()
+ for namespace in p.namespaces:
+ namespace.accept(self)
+
+ self.println("%s protocol %s\n{" % (p.sendSemantics[0], p.name))
+ self.indent()
+
+ for mgs in p.managesStmts:
+ mgs.accept(self)
+ if len(p.managesStmts):
+ self.println()
+
+ for msgDecl in p.messageDecls:
+ msgDecl.accept(self)
+ self.println()
+
+ self.dedent()
+ self.println("}")
+ self.write("}\n" * len(p.namespaces))
+
+ def visitManagerStmt(self, mgr):
+ self.printdentln("manager " + mgr.name + ";")
+
+ def visitManagesStmt(self, mgs):
+ self.printdentln("manages " + mgs.name + ";")
+
+ def visitMessageDecl(self, msg):
+ self.printdent("%s %s %s(" % (msg.sendSemantics[0], msg.direction[0], msg.name))
+ for i, inp in enumerate(msg.inParams):
+ inp.accept(self)
+ if i != (len(msg.inParams) - 1):
+ self.write(", ")
+ self.write(")")
+ if 0 == len(msg.outParams):
+ self.println(";")
+ return
+
+ self.println()
+ self.indent()
+ self.printdent("returns (")
+ for i, outp in enumerate(msg.outParams):
+ outp.accept(self)
+ if i != (len(msg.outParams) - 1):
+ self.write(", ")
+ self.println(");")
+ self.dedent()
diff --git a/ipc/ipdl/ipdl/checker.py b/ipc/ipdl/ipdl/checker.py
new file mode 100644
index 0000000000..c93dc3e0e8
--- /dev/null
+++ b/ipc/ipdl/ipdl/checker.py
@@ -0,0 +1,79 @@
+# vim: set ts=4 sw=4 tw=99 et:
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+import sys
+from ipdl.ast import Visitor, ASYNC
+
+
+class SyncMessageChecker(Visitor):
+ syncMsgList = []
+ seenProtocols = []
+ seenSyncMessages = []
+
+ def __init__(self, syncMsgList):
+ SyncMessageChecker.syncMsgList = syncMsgList
+ self.errors = []
+
+ def prettyMsgName(self, msg):
+ return "%s::%s" % (self.currentProtocol, msg)
+
+ def errorUnknownSyncMessage(self, loc, msg):
+ self.errors.append("%s: error: Unknown sync IPC message %s" % (str(loc), msg))
+
+ def errorAsyncMessageCanRemove(self, loc, msg):
+ self.errors.append(
+ "%s: error: IPC message %s is async, can be delisted" % (str(loc), msg)
+ )
+
+ def visitProtocol(self, p):
+ self.errors = []
+ self.currentProtocol = p.name
+ SyncMessageChecker.seenProtocols.append(p.name)
+ Visitor.visitProtocol(self, p)
+
+ def visitMessageDecl(self, md):
+ pn = self.prettyMsgName(md.name)
+ if md.sendSemantics is not ASYNC:
+ if pn not in SyncMessageChecker.syncMsgList:
+ self.errorUnknownSyncMessage(md.loc, pn)
+ SyncMessageChecker.seenSyncMessages.append(pn)
+ elif pn in SyncMessageChecker.syncMsgList:
+ self.errorAsyncMessageCanRemove(md.loc, pn)
+
+ @staticmethod
+ def getFixedSyncMessages():
+ return set(SyncMessageChecker.syncMsgList) - set(
+ SyncMessageChecker.seenSyncMessages
+ )
+
+
+def checkSyncMessage(tu, syncMsgList, errout=sys.stderr):
+ checker = SyncMessageChecker(syncMsgList)
+ tu.accept(checker)
+ if len(checker.errors):
+ for error in checker.errors:
+ print(error, file=errout)
+ return False
+ return True
+
+
+def checkFixedSyncMessages(config, errout=sys.stderr):
+ fixed = SyncMessageChecker.getFixedSyncMessages()
+ error_free = True
+ for item in fixed:
+ protocol = item.split("::")[0]
+ # Ignore things like sync messages in test protocols we didn't compile.
+ # Also, ignore platform-specific IPC messages.
+ if (
+ protocol in SyncMessageChecker.seenProtocols
+ and "platform" not in config.options(item)
+ ):
+ print(
+ "Error: Sync IPC message %s not found, it appears to be fixed.\n"
+ "Please remove it from sync-messages.ini." % item,
+ file=errout,
+ )
+ error_free = False
+ return error_free
diff --git a/ipc/ipdl/ipdl/cxx/__init__.py b/ipc/ipdl/ipdl/cxx/__init__.py
new file mode 100644
index 0000000000..6fbe8159b2
--- /dev/null
+++ b/ipc/ipdl/ipdl/cxx/__init__.py
@@ -0,0 +1,3 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
diff --git a/ipc/ipdl/ipdl/cxx/ast.py b/ipc/ipdl/ipdl/cxx/ast.py
new file mode 100644
index 0000000000..3b9448859b
--- /dev/null
+++ b/ipc/ipdl/ipdl/cxx/ast.py
@@ -0,0 +1,1033 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+import copy
+import functools
+
+from ipdl.util import hash_str
+
+
+class Visitor:
+ def defaultVisit(self, node):
+ raise Exception(
+ "INTERNAL ERROR: no visitor for node type `%s'" % (node.__class__.__name__)
+ )
+
+ def visitWhitespace(self, ws):
+ pass
+
+ def visitVerbatimNode(self, verb):
+ pass
+
+ def visitGroupNode(self, group):
+ for node in group.nodes:
+ node.accept(self)
+
+ def visitFile(self, f):
+ for thing in f.stuff:
+ thing.accept(self)
+
+ def visitCppDirective(self, ppd):
+ pass
+
+ def visitBlock(self, block):
+ for stmt in block.stmts:
+ stmt.accept(self)
+
+ def visitNamespace(self, ns):
+ self.visitBlock(ns)
+
+ def visitType(self, type):
+ pass
+
+ def visitTypeArray(self, ta):
+ ta.basetype.accept(self)
+ ta.nmemb.accept(self)
+
+ def visitTypeEnum(self, enum):
+ pass
+
+ def visitTypeFunction(self, fn):
+ pass
+
+ def visitTypeUnion(self, union):
+ for t, name in union.components:
+ t.accept(self)
+
+ def visitTypedef(self, tdef):
+ tdef.fromtype.accept(self)
+
+ def visitUsing(self, us):
+ us.type.accept(self)
+
+ def visitForwardDecl(self, fd):
+ pass
+
+ def visitDecl(self, decl):
+ decl.type.accept(self)
+
+ def visitParam(self, param):
+ self.visitDecl(param)
+ if param.default is not None:
+ param.default.accept(self)
+
+ def visitClass(self, cls):
+ for inherit in cls.inherits:
+ inherit.accept(self)
+ self.visitBlock(cls)
+
+ def visitInherit(self, inh):
+ pass
+
+ def visitFriendClassDecl(self, fcd):
+ pass
+
+ def visitMethodDecl(self, meth):
+ for param in meth.params:
+ param.accept(self)
+ if meth.ret is not None:
+ meth.ret.accept(self)
+ if meth.typeop is not None:
+ meth.typeop.accept(self)
+ if meth.T is not None:
+ meth.T.accept(self)
+
+ def visitMethodDefn(self, meth):
+ meth.decl.accept(self)
+ self.visitBlock(meth)
+
+ def visitFunctionDecl(self, fun):
+ self.visitMethodDecl(fun)
+
+ def visitFunctionDefn(self, fd):
+ self.visitMethodDefn(fd)
+
+ def visitConstructorDecl(self, ctor):
+ self.visitMethodDecl(ctor)
+
+ def visitConstructorDefn(self, cd):
+ cd.decl.accept(self)
+ for init in cd.memberinits:
+ init.accept(self)
+ self.visitBlock(cd)
+
+ def visitDestructorDecl(self, dtor):
+ self.visitMethodDecl(dtor)
+
+ def visitDestructorDefn(self, dd):
+ dd.decl.accept(self)
+ self.visitBlock(dd)
+
+ def visitExprLiteral(self, l):
+ pass
+
+ def visitExprVar(self, v):
+ pass
+
+ def visitExprPrefixUnop(self, e):
+ e.expr.accept(self)
+
+ def visitExprBinary(self, e):
+ e.left.accept(self)
+ e.right.accept(self)
+
+ def visitExprConditional(self, c):
+ c.cond.accept(self)
+ c.ife.accept(self)
+ c.elsee.accept(self)
+
+ def visitExprAddrOf(self, eao):
+ self.visitExprPrefixUnop(eao)
+
+ def visitExprDeref(self, ed):
+ self.visitExprPrefixUnop(ed)
+
+ def visitExprNot(self, en):
+ self.visitExprPrefixUnop(en)
+
+ def visitExprCast(self, ec):
+ ec.expr.accept(self)
+
+ def visitExprSelect(self, es):
+ es.obj.accept(self)
+
+ def visitExprAssn(self, ea):
+ ea.lhs.accept(self)
+ ea.rhs.accept(self)
+
+ def visitExprCall(self, ec):
+ ec.func.accept(self)
+ for arg in ec.args:
+ arg.accept(self)
+
+ def visitExprMove(self, ec):
+ self.visitExprCall(ec)
+
+ def visitExprNothing(self, ec):
+ self.visitExprCall(ec)
+
+ def visitExprSome(self, ec):
+ self.visitExprCall(ec)
+
+ def visitExprNew(self, en):
+ en.ctype.accept(self)
+ if en.newargs is not None:
+ for arg in en.newargs:
+ arg.accept(self)
+ if en.args is not None:
+ for arg in en.args:
+ arg.accept(self)
+
+ def visitExprDelete(self, ed):
+ ed.obj.accept(self)
+
+ def visitExprMemberInit(self, minit):
+ self.visitExprCall(minit)
+
+ def visitExprLambda(self, l):
+ self.visitBlock(l)
+
+ def visitStmtBlock(self, sb):
+ self.visitBlock(sb)
+
+ def visitStmtDecl(self, sd):
+ sd.decl.accept(self)
+ if sd.init is not None:
+ sd.init.accept(self)
+
+ def visitLabel(self, label):
+ pass
+
+ def visitCaseLabel(self, case):
+ pass
+
+ def visitDefaultLabel(self, dl):
+ pass
+
+ def visitStmtIf(self, si):
+ si.cond.accept(self)
+ si.ifb.accept(self)
+ if si.elseb is not None:
+ si.elseb.accept(self)
+
+ def visitStmtFor(self, sf):
+ if sf.init is not None:
+ sf.init.accept(self)
+ if sf.cond is not None:
+ sf.cond.accept(self)
+ if sf.update is not None:
+ sf.update.accept(self)
+
+ def visitStmtSwitch(self, ss):
+ ss.expr.accept(self)
+ self.visitBlock(ss)
+
+ def visitStmtBreak(self, sb):
+ pass
+
+ def visitStmtExpr(self, se):
+ se.expr.accept(self)
+
+ def visitStmtReturn(self, sr):
+ if sr.expr is not None:
+ sr.expr.accept(self)
+
+
+# ------------------------------
+
+
+class Node:
+ def __init__(self):
+ pass
+
+ def accept(self, visitor):
+ visit = getattr(visitor, "visit" + self.__class__.__name__, None)
+ if visit is None:
+ return getattr(visitor, "defaultVisit")(self)
+ return visit(self)
+
+
+class Whitespace(Node):
+ # yes, this is silly. but we need to stick comments in the
+ # generated code without resorting to more serious hacks
+ def __init__(self, ws, indent=False):
+ Node.__init__(self)
+ self.ws = ws
+ self.indent = indent
+
+
+Whitespace.NL = Whitespace("\n")
+
+
+class VerbatimNode(Node):
+ # A block of text to be written verbatim to the output file.
+ #
+ # NOTE: This node is usually created by `code`. See `code.py` for details.
+ # FIXME: Merge Whitespace and VerbatimNode? They're identical.
+ def __init__(self, text, indent=0):
+ Node.__init__(self)
+ self.text = text
+ self.indent = indent
+
+
+class GroupNode(Node):
+ # A group of nodes to be treated as a single node. These nodes have an
+ # optional indentation level which should be applied when generating them.
+ #
+ # NOTE: This node is usually created by `code`. See `code.py` for details.
+ def __init__(self, nodes, offset=0):
+ Node.__init__(self)
+ self.nodes = nodes
+ self.offset = offset
+
+
+class File(Node):
+ def __init__(self, filename):
+ Node.__init__(self)
+ self.name = filename
+ # array of stuff in the file --- stmts and preprocessor thingies
+ self.stuff = []
+
+ def addthing(self, thing):
+ assert thing is not None
+ assert not isinstance(thing, list)
+ self.stuff.append(thing)
+
+ def addthings(self, things):
+ for t in things:
+ self.addthing(t)
+
+ # "look like" a Block so code doesn't have to care whether they're
+ # in global scope or not
+ def addstmt(self, stmt):
+ assert stmt is not None
+ assert not isinstance(stmt, list)
+ self.stuff.append(stmt)
+
+ def addstmts(self, stmts):
+ for s in stmts:
+ self.addstmt(s)
+
+ def addcode(self, tmpl, **context):
+ from ipdl.cxx.code import StmtCode
+
+ self.addstmt(StmtCode(tmpl, **context))
+
+
+class CppDirective(Node):
+ """represents |#[directive] [rest]|, where |rest| is any string"""
+
+ def __init__(self, directive, rest=None):
+ Node.__init__(self)
+ self.directive = directive
+ self.rest = rest
+
+
+class Block(Node):
+ def __init__(self):
+ Node.__init__(self)
+ self.stmts = []
+
+ def addstmt(self, stmt):
+ assert stmt is not None
+ assert not isinstance(stmt, tuple)
+ self.stmts.append(stmt)
+
+ def addstmts(self, stmts):
+ for s in stmts:
+ self.addstmt(s)
+
+ def addcode(self, tmpl, **context):
+ from ipdl.cxx.code import StmtCode
+
+ self.addstmt(StmtCode(tmpl, **context))
+
+
+# ------------------------------
+# type and decl thingies
+
+
+class Namespace(Block):
+ def __init__(self, name):
+ assert isinstance(name, str)
+
+ Block.__init__(self)
+ self.name = name
+
+
+class Type(Node):
+ def __init__(
+ self,
+ name,
+ const=False,
+ ptr=False,
+ ptrptr=False,
+ ptrconstptr=False,
+ ref=False,
+ rvalref=False,
+ rightconst=False,
+ hasimplicitcopyctor=True,
+ T=None,
+ inner=None,
+ ):
+ """
+ Represents the type |name<T>::inner| with the ptr and const
+ modifiers as specified.
+
+ To avoid getting fancy with recursive types, we limit the kinds
+ of pointer types that can be be constructed.
+
+ ptr => T*
+ ptrptr => T**
+ ptrconstptr => T* const*
+ ref => T&
+ rvalref => T&&
+
+ Any type, naked or pointer, can be const (const T) or ref (T&)."""
+ # XXX(nika): This type is complex enough at this point, perhaps we
+ # should get "fancy with recursive types" to simplify it.
+ assert isinstance(name, str)
+ assert isinstance(const, bool)
+ assert isinstance(ptr, bool)
+ assert isinstance(ptrptr, bool)
+ assert isinstance(ptrconstptr, bool)
+ assert isinstance(ref, bool)
+ assert isinstance(rvalref, bool)
+ assert isinstance(rightconst, bool)
+ assert not isinstance(T, str)
+
+ Node.__init__(self)
+ self.name = name
+ self.const = const
+ self.ptr = ptr
+ self.ptrptr = ptrptr
+ self.ptrconstptr = ptrconstptr
+ self.ref = ref
+ self.rvalref = rvalref
+ self.rightconst = rightconst
+ self.hasimplicitcopyctor = hasimplicitcopyctor
+ self.T = T
+ self.inner = inner
+ # XXX could get serious here with recursive types, but shouldn't
+ # need that for this codegen
+
+ def __deepcopy__(self, memo):
+ return Type(
+ self.name,
+ const=self.const,
+ ptr=self.ptr,
+ ptrptr=self.ptrptr,
+ ptrconstptr=self.ptrconstptr,
+ ref=self.ref,
+ rvalref=self.rvalref,
+ rightconst=self.rightconst,
+ T=copy.deepcopy(self.T, memo),
+ inner=copy.deepcopy(self.inner, memo),
+ )
+
+
+Type.BOOL = Type("bool")
+Type.INT = Type("int")
+Type.INT32 = Type("int32_t")
+Type.INTPTR = Type("intptr_t")
+Type.NSRESULT = Type("nsresult")
+Type.UINT32 = Type("uint32_t")
+Type.UINT32PTR = Type("uint32_t", ptr=True)
+Type.SIZE = Type("size_t")
+Type.VOID = Type("void")
+Type.VOIDPTR = Type("void", ptr=True)
+Type.AUTO = Type("auto")
+Type.AUTORVAL = Type("auto", rvalref=True)
+
+
+class TypeArray(Node):
+ def __init__(self, basetype, nmemb):
+ """the type |basetype DECLNAME[nmemb]|. |nmemb| is an Expr"""
+ self.basetype = basetype
+ self.nmemb = nmemb
+
+
+class TypeEnum(Node):
+ def __init__(self, name=None):
+ """name can be None"""
+ Node.__init__(self)
+ self.name = name
+ self.idnums = [] # pairs of ('Foo', [num]) or ('Foo', None)
+
+ def addId(self, id, num=None):
+ self.idnums.append((id, num))
+
+
+class TypeUnion(Node):
+ def __init__(self, name=None):
+ Node.__init__(self)
+ self.name = name
+ self.components = [] # [ Decl ]
+
+ def addComponent(self, type, name):
+ self.components.append(Decl(type, name))
+
+
+class TypeFunction(Node):
+ def __init__(self, params=[], ret=Type("void")):
+ """Anonymous function type std::function<>"""
+ self.params = params
+ self.ret = ret
+
+
+@functools.total_ordering
+class Typedef(Node):
+ def __init__(self, fromtype, totypename, templateargs=[]):
+ assert isinstance(totypename, str)
+
+ Node.__init__(self)
+ self.fromtype = fromtype
+ self.totypename = totypename
+ self.templateargs = templateargs
+
+ def __lt__(self, other):
+ return self.totypename < other.totypename
+
+ def __eq__(self, other):
+ return self.__class__ == other.__class__ and self.totypename == other.totypename
+
+ def __hash__(self):
+ return hash_str(self.totypename)
+
+
+class Using(Node):
+ def __init__(self, type):
+ Node.__init__(self)
+ self.type = type
+
+
+class ForwardDecl(Node):
+ def __init__(self, pqname, cls=False, struct=False):
+ # Exactly one of cls and struct must be set
+ assert cls ^ struct
+
+ self.pqname = pqname
+ self.cls = cls
+ self.struct = struct
+
+
+class Decl(Node):
+ """represents |Foo bar|, e.g. in a function signature"""
+
+ def __init__(self, type, name):
+ assert type is not None
+ assert not isinstance(type, str)
+ assert isinstance(name, str)
+
+ Node.__init__(self)
+ self.type = type
+ self.name = name
+
+ def __deepcopy__(self, memo):
+ return Decl(copy.deepcopy(self.type, memo), self.name)
+
+
+class Param(Decl):
+ def __init__(self, type, name, default=None):
+ Decl.__init__(self, type, name)
+ self.default = default
+
+ def __deepcopy__(self, memo):
+ return Param(
+ copy.deepcopy(self.type, memo), self.name, copy.deepcopy(self.default, memo)
+ )
+
+
+# ------------------------------
+# class stuff
+
+
+class Class(Block):
+ def __init__(
+ self,
+ name,
+ inherits=[],
+ interface=False,
+ abstract=False,
+ final=False,
+ specializes=None,
+ struct=False,
+ ):
+ assert not (interface and abstract)
+ assert not (abstract and final)
+ assert not (interface and final)
+ assert not (inherits and specializes)
+
+ Block.__init__(self)
+ self.name = name
+ self.inherits = inherits # [ Type ]
+ self.interface = interface # bool
+ self.abstract = abstract # bool
+ self.final = final # bool
+ self.specializes = specializes # Type or None
+ self.struct = struct # bool
+
+
+class Inherit(Node):
+ def __init__(self, type, viz="public"):
+ assert isinstance(viz, str)
+ Node.__init__(self)
+ self.type = type
+ self.viz = viz
+
+
+class FriendClassDecl(Node):
+ def __init__(self, friend):
+ Node.__init__(self)
+ self.friend = friend
+
+
+# Python2 polyfill for Python3's Enum() functional API.
+
+
+def make_enum(name, members_str):
+ members_list = members_str.split()
+ members_dict = {}
+ for member_value, member in enumerate(members_list, start=1):
+ members_dict[member] = member_value
+ return type(name, (), members_dict)
+
+
+MethodSpec = make_enum("MethodSpec", "NONE VIRTUAL PURE OVERRIDE STATIC")
+
+
+class MethodDecl(Node):
+ def __init__(
+ self,
+ name,
+ params=[],
+ ret=Type("void"),
+ methodspec=MethodSpec.NONE,
+ const=False,
+ warn_unused=False,
+ force_inline=False,
+ typeop=None,
+ T=None,
+ cls=None,
+ ):
+ assert not (name and typeop)
+ assert name is None or isinstance(name, str)
+ assert not isinstance(ret, list)
+ for decl in params:
+ assert not isinstance(decl, str)
+ assert not isinstance(T, int)
+ assert isinstance(const, bool)
+ assert isinstance(warn_unused, bool)
+ assert isinstance(force_inline, bool)
+
+ if typeop is not None:
+ assert methodspec == MethodSpec.NONE
+ ret = None
+
+ Node.__init__(self)
+ self.name = name
+ self.params = params # [ Param ]
+ self.ret = ret # Type or None
+ self.methodspec = methodspec # enum
+ self.const = const # bool
+ self.warn_unused = warn_unused # bool
+ self.force_inline = force_inline or bool(T) # bool
+ self.typeop = typeop # Type or None
+ self.T = T # Type or None
+ self.cls = cls # Class or None
+ self.only_for_definition = False
+
+ def __deepcopy__(self, memo):
+ return MethodDecl(
+ self.name,
+ params=copy.deepcopy(self.params, memo),
+ ret=copy.deepcopy(self.ret, memo),
+ methodspec=self.methodspec,
+ const=self.const,
+ warn_unused=self.warn_unused,
+ force_inline=self.force_inline,
+ typeop=copy.deepcopy(self.typeop, memo),
+ T=copy.deepcopy(self.T, memo),
+ )
+
+
+class MethodDefn(Block):
+ def __init__(self, decl):
+ Block.__init__(self)
+ self.decl = decl
+
+
+class FunctionDecl(MethodDecl):
+ def __init__(
+ self,
+ name,
+ params=[],
+ ret=Type("void"),
+ methodspec=MethodSpec.NONE,
+ warn_unused=False,
+ force_inline=False,
+ T=None,
+ ):
+ assert methodspec == MethodSpec.NONE or methodspec == MethodSpec.STATIC
+ MethodDecl.__init__(
+ self,
+ name,
+ params=params,
+ ret=ret,
+ methodspec=methodspec,
+ warn_unused=warn_unused,
+ force_inline=force_inline,
+ T=T,
+ )
+
+
+class FunctionDefn(MethodDefn):
+ def __init__(self, decl):
+ MethodDefn.__init__(self, decl)
+
+
+class ConstructorDecl(MethodDecl):
+ def __init__(self, name, params=[], explicit=False, force_inline=False):
+ MethodDecl.__init__(
+ self, name, params=params, ret=None, force_inline=force_inline
+ )
+ self.explicit = explicit
+
+ def __deepcopy__(self, memo):
+ return ConstructorDecl(
+ self.name, copy.deepcopy(self.params, memo), self.explicit
+ )
+
+
+class ConstructorDefn(MethodDefn):
+ def __init__(self, decl, memberinits=[]):
+ MethodDefn.__init__(self, decl)
+ self.memberinits = memberinits
+
+
+class DestructorDecl(MethodDecl):
+ def __init__(self, name, methodspec=MethodSpec.NONE, force_inline=False):
+ # C++ allows pure or override destructors, but ipdl cgen does not.
+ assert methodspec == MethodSpec.NONE or methodspec == MethodSpec.VIRTUAL
+ MethodDecl.__init__(
+ self,
+ name,
+ params=[],
+ ret=None,
+ methodspec=methodspec,
+ force_inline=force_inline,
+ )
+
+ def __deepcopy__(self, memo):
+ return DestructorDecl(
+ self.name, methodspec=self.methodspec, force_inline=self.force_inline
+ )
+
+
+class DestructorDefn(MethodDefn):
+ def __init__(self, decl):
+ MethodDefn.__init__(self, decl)
+
+
+# ------------------------------
+# expressions
+
+
+class ExprVar(Node):
+ def __init__(self, name):
+ assert isinstance(name, str)
+
+ Node.__init__(self)
+ self.name = name
+
+
+ExprVar.THIS = ExprVar("this")
+
+
+class ExprLiteral(Node):
+ def __init__(self, value, type):
+ """|type| is a Python format specifier; 'd' for example"""
+ Node.__init__(self)
+ self.value = value
+ self.type = type
+
+ @staticmethod
+ def Int(i):
+ return ExprLiteral(i, "d")
+
+ @staticmethod
+ def String(s):
+ return ExprLiteral('"' + s + '"', "s")
+
+ def __str__(self):
+ return ("%" + self.type) % (self.value)
+
+
+ExprLiteral.ZERO = ExprLiteral.Int(0)
+ExprLiteral.ONE = ExprLiteral.Int(1)
+ExprLiteral.NULL = ExprVar("nullptr")
+ExprLiteral.TRUE = ExprVar("true")
+ExprLiteral.FALSE = ExprVar("false")
+
+
+class ExprPrefixUnop(Node):
+ def __init__(self, expr, op):
+ assert not isinstance(expr, tuple)
+ self.expr = expr
+ self.op = op
+
+
+class ExprNot(ExprPrefixUnop):
+ def __init__(self, expr):
+ ExprPrefixUnop.__init__(self, expr, "!")
+
+
+class ExprAddrOf(ExprPrefixUnop):
+ def __init__(self, expr):
+ ExprPrefixUnop.__init__(self, expr, "&")
+
+
+class ExprDeref(ExprPrefixUnop):
+ def __init__(self, expr):
+ ExprPrefixUnop.__init__(self, expr, "*")
+
+
+class ExprCast(Node):
+ def __init__(self, expr, type, static=False, const=False):
+ # Exactly one of these should be set
+ assert static ^ const
+
+ Node.__init__(self)
+ self.expr = expr
+ self.type = type
+ self.static = static
+ self.const = const
+
+
+class ExprBinary(Node):
+ def __init__(self, left, op, right):
+ Node.__init__(self)
+ self.left = left
+ self.op = op
+ self.right = right
+
+
+class ExprConditional(Node):
+ def __init__(self, cond, ife, elsee):
+ Node.__init__(self)
+ self.cond = cond
+ self.ife = ife
+ self.elsee = elsee
+
+
+class ExprSelect(Node):
+ def __init__(self, obj, op, field):
+ assert obj and op and field
+ assert not isinstance(obj, str)
+ assert isinstance(op, str)
+
+ Node.__init__(self)
+ self.obj = obj
+ self.op = op
+ if isinstance(field, str):
+ self.field = ExprVar(field)
+ else:
+ self.field = field
+
+
+class ExprAssn(Node):
+ def __init__(self, lhs, rhs, op="="):
+ Node.__init__(self)
+ self.lhs = lhs
+ self.op = op
+ self.rhs = rhs
+
+
+class ExprCall(Node):
+ def __init__(self, func, args=[]):
+ assert hasattr(func, "accept")
+ assert isinstance(args, list)
+ for arg in args:
+ assert arg and not isinstance(arg, str)
+
+ Node.__init__(self)
+ self.func = func
+ self.args = args
+
+
+class ExprMove(ExprCall):
+ def __init__(self, arg):
+ ExprCall.__init__(self, ExprVar("std::move"), args=[arg])
+
+
+class ExprNew(Node):
+ # XXX taking some poetic license ...
+ def __init__(self, ctype, args=[], newargs=None):
+ assert not (ctype.const or ctype.ref or ctype.rvalref)
+
+ Node.__init__(self)
+ self.ctype = ctype
+ self.args = args
+ self.newargs = newargs
+
+
+class ExprDelete(Node):
+ def __init__(self, obj):
+ Node.__init__(self)
+ self.obj = obj
+
+
+class ExprMemberInit(ExprCall):
+ def __init__(self, member, args=[]):
+ ExprCall.__init__(self, member, args)
+
+
+class ExprLambda(Block):
+ def __init__(self, captures=[], params=[], ret=None):
+ Block.__init__(self)
+ assert isinstance(captures, list)
+ assert isinstance(params, list)
+ self.captures = captures
+ self.params = params
+ self.ret = ret
+
+
+# ------------------------------
+# statements etc.
+
+
+class StmtBlock(Block):
+ def __init__(self, stmts=[]):
+ Block.__init__(self)
+ self.addstmts(stmts)
+
+
+class StmtDecl(Node):
+ def __init__(self, decl, init=None, initargs=None):
+ assert not (init and initargs)
+ assert not isinstance(init, str) # easy to confuse with Decl
+ assert not isinstance(init, list)
+ assert not isinstance(decl, tuple)
+
+ Node.__init__(self)
+ self.decl = decl
+ self.init = init
+ self.initargs = initargs
+
+
+class Label(Node):
+ def __init__(self, name):
+ Node.__init__(self)
+ self.name = name
+
+
+Label.PUBLIC = Label("public")
+Label.PROTECTED = Label("protected")
+Label.PRIVATE = Label("private")
+
+
+class CaseLabel(Node):
+ def __init__(self, name):
+ Node.__init__(self)
+ self.name = name
+
+
+class DefaultLabel(Node):
+ def __init__(self):
+ Node.__init__(self)
+
+
+class StmtIf(Node):
+ def __init__(self, cond):
+ Node.__init__(self)
+ self.cond = cond
+ self.ifb = Block()
+ self.elseb = None
+
+ def addifstmt(self, stmt):
+ self.ifb.addstmt(stmt)
+
+ def addifstmts(self, stmts):
+ self.ifb.addstmts(stmts)
+
+ def addelsestmt(self, stmt):
+ if self.elseb is None:
+ self.elseb = Block()
+ self.elseb.addstmt(stmt)
+
+ def addelsestmts(self, stmts):
+ if self.elseb is None:
+ self.elseb = Block()
+ self.elseb.addstmts(stmts)
+
+
+class StmtFor(Block):
+ def __init__(self, init=None, cond=None, update=None):
+ Block.__init__(self)
+ self.init = init
+ self.cond = cond
+ self.update = update
+
+
+class StmtRangedFor(Block):
+ def __init__(self, var, iteree):
+ assert isinstance(var, ExprVar)
+ assert iteree
+
+ Block.__init__(self)
+ self.var = var
+ self.iteree = iteree
+
+
+class StmtSwitch(Block):
+ def __init__(self, expr):
+ Block.__init__(self)
+ self.expr = expr
+ self.nr_cases = 0
+
+ def addcase(self, case, block):
+ """NOTE: |case| is not checked for uniqueness"""
+ assert not isinstance(case, str)
+ assert (
+ isinstance(block, StmtBreak)
+ or isinstance(block, StmtReturn)
+ or isinstance(block, StmtSwitch)
+ or isinstance(block, GroupNode)
+ or isinstance(block, VerbatimNode)
+ or (
+ hasattr(block, "stmts")
+ and (
+ isinstance(block.stmts[-1], StmtBreak)
+ or isinstance(block.stmts[-1], StmtReturn)
+ or isinstance(block.stmts[-1], GroupNode)
+ or isinstance(block.stmts[-1], VerbatimNode)
+ )
+ )
+ )
+ self.addstmt(case)
+ self.addstmt(block)
+ self.nr_cases += 1
+
+
+class StmtBreak(Node):
+ def __init__(self):
+ Node.__init__(self)
+
+
+class StmtExpr(Node):
+ def __init__(self, expr):
+ assert expr is not None
+
+ Node.__init__(self)
+ self.expr = expr
+
+
+class StmtReturn(Node):
+ def __init__(self, expr=None):
+ Node.__init__(self)
+ self.expr = expr
+
+
+StmtReturn.TRUE = StmtReturn(ExprLiteral.TRUE)
+StmtReturn.FALSE = StmtReturn(ExprLiteral.FALSE)
diff --git a/ipc/ipdl/ipdl/cxx/cgen.py b/ipc/ipdl/ipdl/cxx/cgen.py
new file mode 100644
index 0000000000..fecb90e97b
--- /dev/null
+++ b/ipc/ipdl/ipdl/cxx/cgen.py
@@ -0,0 +1,557 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+import sys
+
+from ipdl.cgen import CodePrinter
+from ipdl.cxx.ast import MethodSpec, TypeArray, Visitor, DestructorDecl
+
+
+class CxxCodeGen(CodePrinter, Visitor):
+ def __init__(self, outf=sys.stdout, indentCols=4):
+ CodePrinter.__init__(self, outf, indentCols)
+
+ def cgen(self, cxxfile):
+ cxxfile.accept(self)
+
+ def visitWhitespace(self, ws):
+ if ws.indent:
+ self.printdent("")
+ self.write(ws.ws)
+
+ def visitVerbatimNode(self, verb):
+ if verb.indent:
+ self.printdent("")
+ self.write(verb.text)
+
+ def visitGroupNode(self, group):
+ offsetCols = self.indentCols * group.offset
+ self.col += offsetCols
+ for node in group.nodes:
+ node.accept(self)
+ self.col -= offsetCols
+
+ def visitCppDirective(self, cd):
+ if cd.rest:
+ self.println("#%s %s" % (cd.directive, cd.rest))
+ else:
+ self.println("#%s" % (cd.directive))
+
+ def visitNamespace(self, ns):
+ self.println("namespace " + ns.name + " {")
+ self.visitBlock(ns)
+ self.println("} // namespace " + ns.name)
+
+ def visitType(self, t):
+ if t.const:
+ self.write("const ")
+
+ self.write(t.name)
+
+ if t.T is not None:
+ self.write("<")
+ if type(t.T) is list:
+ t.T[0].accept(self)
+ for tt in t.T[1:]:
+ self.write(", ")
+ tt.accept(self)
+ else:
+ t.T.accept(self)
+ self.write(">")
+
+ if t.inner is not None:
+ self.write("::")
+ t.inner.accept(self)
+
+ ts = ""
+ if t.ptr:
+ ts += "*"
+ elif t.ptrptr:
+ ts += "**"
+ elif t.ptrconstptr:
+ ts += "* const*"
+
+ if t.ref:
+ ts += "&"
+ elif t.rvalref:
+ ts += "&&"
+ elif t.rightconst:
+ ts += " const"
+
+ self.write(ts)
+
+ def visitTypeEnum(self, te):
+ self.write("enum")
+ if te.name:
+ self.write(" " + te.name)
+ self.println(" {")
+
+ self.indent()
+ nids = len(te.idnums)
+ for i, (id, num) in enumerate(te.idnums):
+ self.printdent(id)
+ if num:
+ self.write(" = " + str(num))
+ if i != (nids - 1):
+ self.write(",")
+ self.println()
+ self.dedent()
+ self.printdent("}")
+
+ def visitTypeUnion(self, u):
+ self.write("union")
+ if u.name:
+ self.write(" " + u.name)
+ self.println(" {")
+
+ self.indent()
+ for decl in u.components:
+ self.printdent()
+ decl.accept(self)
+ self.println(";")
+ self.dedent()
+
+ self.printdent("}")
+
+ def visitTypeFunction(self, fn):
+ self.write("std::function<")
+ fn.ret.accept(self)
+ self.write("(")
+ self.writeDeclList(fn.params)
+ self.write(")>")
+
+ def visitTypedef(self, td):
+ if td.templateargs:
+ formals = ", ".join(["class " + T for T in td.templateargs])
+ args = ", ".join(td.templateargs)
+ self.printdent("template<" + formals + "> using " + td.totypename + " = ")
+ td.fromtype.accept(self)
+ self.println("<" + args + ">;")
+ else:
+ self.printdent("typedef ")
+ td.fromtype.accept(self)
+ self.println(" " + td.totypename + ";")
+
+ def visitUsing(self, us):
+ self.printdent("using ")
+ us.type.accept(self)
+ self.println(";")
+
+ def visitForwardDecl(self, fd):
+ if fd.cls:
+ self.printdent("class ")
+ elif fd.struct:
+ self.printdent("struct ")
+ self.write(str(fd.pqname))
+ self.println(";")
+
+ def visitDecl(self, d):
+ # C-syntax arrays make code generation much more annoying
+ if isinstance(d.type, TypeArray):
+ d.type.basetype.accept(self)
+ else:
+ d.type.accept(self)
+
+ if d.name:
+ self.write(" " + d.name)
+
+ if isinstance(d.type, TypeArray):
+ self.write("[")
+ d.type.nmemb.accept(self)
+ self.write("]")
+
+ def visitParam(self, p):
+ self.visitDecl(p)
+ if p.default is not None:
+ self.write(" = ")
+ p.default.accept(self)
+
+ def visitClass(self, c):
+ if c.specializes is not None:
+ self.printdentln("template<>")
+
+ if c.struct:
+ self.printdent("struct")
+ else:
+ self.printdent("class")
+ self.write(" " + c.name)
+ if c.final:
+ self.write(" final")
+
+ if c.specializes is not None:
+ self.write(" <")
+ c.specializes.accept(self)
+ self.write(">")
+
+ ninh = len(c.inherits)
+ if 0 < ninh:
+ self.println(" :")
+ self.indent()
+ for i, inherit in enumerate(c.inherits):
+ self.printdent()
+ inherit.accept(self)
+ if i != (ninh - 1):
+ self.println(",")
+ self.dedent()
+ self.println()
+
+ self.printdentln("{")
+ self.indent()
+
+ self.visitBlock(c)
+
+ self.dedent()
+ self.printdentln("};")
+
+ def visitInherit(self, inh):
+ self.write(inh.viz + " ")
+ inh.type.accept(self)
+
+ def visitFriendClassDecl(self, fcd):
+ self.printdentln("friend class " + fcd.friend + ";")
+
+ def visitMethodDecl(self, md):
+ if md.T:
+ self.write("template<")
+ self.write("typename ")
+ md.T.accept(self)
+ self.println(">")
+ self.printdent()
+
+ if md.warn_unused:
+ self.write("[[nodiscard]] ")
+
+ if md.methodspec == MethodSpec.STATIC:
+ self.write("static ")
+ elif md.methodspec == MethodSpec.VIRTUAL or md.methodspec == MethodSpec.PURE:
+ self.write("virtual ")
+
+ if md.ret:
+ if md.only_for_definition:
+ self.write("auto ")
+ else:
+ md.ret.accept(self)
+ self.println()
+ self.printdent()
+
+ if md.cls is not None:
+ assert md.only_for_definition
+
+ self.write(md.cls.name)
+ if md.cls.specializes is not None:
+ self.write("<")
+ md.cls.specializes.accept(self)
+ self.write(">")
+ self.write("::")
+
+ if md.typeop is not None:
+ self.write("operator ")
+ md.typeop.accept(self)
+ else:
+ if isinstance(md, DestructorDecl):
+ self.write("~")
+ self.write(md.name)
+
+ self.write("(")
+ self.writeDeclList(md.params)
+ self.write(")")
+
+ if md.const:
+ self.write(" const")
+ if md.ret and md.only_for_definition:
+ self.write(" -> ")
+ md.ret.accept(self)
+
+ if md.methodspec == MethodSpec.OVERRIDE:
+ self.write(" override")
+ elif md.methodspec == MethodSpec.PURE:
+ self.write(" = 0")
+
+ def visitMethodDefn(self, md):
+ # Method specifiers are for decls, not defns.
+ assert md.decl.methodspec == MethodSpec.NONE
+
+ self.printdent()
+ md.decl.accept(self)
+ self.println()
+
+ self.printdentln("{")
+ self.indent()
+ self.visitBlock(md)
+ self.dedent()
+ self.printdentln("}")
+
+ def visitConstructorDecl(self, cd):
+ if cd.explicit:
+ self.write("explicit ")
+ else:
+ self.write("MOZ_IMPLICIT ")
+ self.visitMethodDecl(cd)
+
+ def visitConstructorDefn(self, cd):
+ self.printdent()
+ cd.decl.accept(self)
+ if len(cd.memberinits):
+ self.println(" :")
+ self.indent()
+ ninits = len(cd.memberinits)
+ for i, init in enumerate(cd.memberinits):
+ self.printdent()
+ init.accept(self)
+ if i != (ninits - 1):
+ self.println(",")
+ self.dedent()
+ self.println()
+
+ self.printdentln("{")
+ self.indent()
+
+ self.visitBlock(cd)
+
+ self.dedent()
+ self.printdentln("}")
+
+ def visitDestructorDecl(self, dd):
+ self.visitMethodDecl(dd)
+
+ def visitDestructorDefn(self, dd):
+ self.printdent()
+ dd.decl.accept(self)
+ self.println()
+
+ self.printdentln("{")
+ self.indent()
+
+ self.visitBlock(dd)
+
+ self.dedent()
+ self.printdentln("}")
+
+ def visitExprLiteral(self, el):
+ self.write(str(el))
+
+ def visitExprVar(self, ev):
+ self.write(ev.name)
+
+ def visitExprPrefixUnop(self, e):
+ self.write("(")
+ self.write(e.op)
+ self.write("(")
+ e.expr.accept(self)
+ self.write(")")
+ self.write(")")
+
+ def visitExprCast(self, c):
+ if c.static:
+ pfx, sfx = "static_cast<", ">"
+ else:
+ assert c.const
+ pfx, sfx = "const_cast<", ">"
+ self.write(pfx)
+ c.type.accept(self)
+ self.write(sfx + "(")
+ c.expr.accept(self)
+ self.write(")")
+
+ def visitExprBinary(self, e):
+ self.write("(")
+ e.left.accept(self)
+ self.write(") " + e.op + " (")
+ e.right.accept(self)
+ self.write(")")
+
+ def visitExprConditional(self, c):
+ self.write("(")
+ c.cond.accept(self)
+ self.write(" ? ")
+ c.ife.accept(self)
+ self.write(" : ")
+ c.elsee.accept(self)
+ self.write(")")
+
+ def visitExprSelect(self, es):
+ self.write("(")
+ es.obj.accept(self)
+ self.write(")")
+ self.write(es.op)
+ es.field.accept(self)
+
+ def visitExprAssn(self, ea):
+ ea.lhs.accept(self)
+ self.write(" " + ea.op + " ")
+ ea.rhs.accept(self)
+
+ def visitExprCall(self, ec):
+ ec.func.accept(self)
+ self.write("(")
+ self.writeExprList(ec.args)
+ self.write(")")
+
+ def visitExprNew(self, en):
+ self.write("new ")
+ if en.newargs is not None:
+ self.write("(")
+ self.writeExprList(en.newargs)
+ self.write(") ")
+ en.ctype.accept(self)
+ if en.args is not None:
+ self.write("(")
+ self.writeExprList(en.args)
+ self.write(")")
+
+ def visitExprDelete(self, ed):
+ self.write("delete ")
+ ed.obj.accept(self)
+
+ def visitExprLambda(self, l):
+ self.write("[")
+ ncaptures = len(l.captures)
+ for i, c in enumerate(l.captures):
+ c.accept(self)
+ if i != (ncaptures - 1):
+ self.write(", ")
+ self.write("](")
+ self.writeDeclList(l.params)
+ self.write(")")
+ if l.ret:
+ self.write(" -> ")
+ l.ret.accept(self)
+ self.println(" {")
+ self.indent()
+ self.visitBlock(l)
+ self.dedent()
+ self.printdent("}")
+
+ def visitStmtBlock(self, b):
+ self.printdentln("{")
+ self.indent()
+ self.visitBlock(b)
+ self.dedent()
+ self.printdentln("}")
+
+ def visitLabel(self, label):
+ self.dedent() # better not be at global scope ...
+ self.printdentln(label.name + ":")
+ self.indent()
+
+ def visitCaseLabel(self, cl):
+ self.dedent()
+ self.printdentln("case " + cl.name + ":")
+ self.indent()
+
+ def visitDefaultLabel(self, dl):
+ self.dedent()
+ self.printdentln("default:")
+ self.indent()
+
+ def visitStmtIf(self, si):
+ self.printdent("if (")
+ si.cond.accept(self)
+ self.println(") {")
+ self.indent()
+ si.ifb.accept(self)
+ self.dedent()
+ self.printdentln("}")
+
+ if si.elseb is not None:
+ self.printdentln("else {")
+ self.indent()
+ si.elseb.accept(self)
+ self.dedent()
+ self.printdentln("}")
+
+ def visitStmtFor(self, sf):
+ self.printdent("for (")
+ if sf.init is not None:
+ sf.init.accept(self)
+ self.write("; ")
+ if sf.cond is not None:
+ sf.cond.accept(self)
+ self.write("; ")
+ if sf.update is not None:
+ sf.update.accept(self)
+ self.println(") {")
+
+ self.indent()
+ self.visitBlock(sf)
+ self.dedent()
+ self.printdentln("}")
+
+ def visitStmtRangedFor(self, rf):
+ self.printdent("for (auto& ")
+ rf.var.accept(self)
+ self.write(" : ")
+ rf.iteree.accept(self)
+ self.println(") {")
+
+ self.indent()
+ self.visitBlock(rf)
+ self.dedent()
+ self.printdentln("}")
+
+ def visitStmtSwitch(self, sw):
+ self.printdent("switch (")
+ sw.expr.accept(self)
+ self.println(") {")
+ self.indent()
+ self.visitBlock(sw)
+ self.dedent()
+ self.printdentln("}")
+
+ def visitStmtBreak(self, sb):
+ self.printdentln("break;")
+
+ def visitStmtDecl(self, sd):
+ self.printdent()
+ sd.decl.accept(self)
+ if sd.initargs is not None:
+ self.write("{")
+ self.writeDeclList(sd.initargs)
+ self.write("}")
+ if sd.init is not None:
+ self.write(" = ")
+ sd.init.accept(self)
+ self.println(";")
+
+ def visitStmtExpr(self, se):
+ self.printdent()
+ se.expr.accept(self)
+ self.println(";")
+
+ def visitStmtReturn(self, sr):
+ self.printdent("return")
+ if sr.expr:
+ self.write(" ")
+ sr.expr.accept(self)
+ self.println(";")
+
+ def writeDeclList(self, decls):
+ # FIXME/cjones: try to do nice formatting of these guys
+
+ ndecls = len(decls)
+ if 0 == ndecls:
+ return
+ elif 1 == ndecls:
+ decls[0].accept(self)
+ return
+
+ self.indent()
+ self.indent()
+ for i, decl in enumerate(decls):
+ self.println()
+ self.printdent()
+ decl.accept(self)
+ if i != (ndecls - 1):
+ self.write(",")
+ self.dedent()
+ self.dedent()
+
+ def writeExprList(self, exprs):
+ # FIXME/cjones: try to do nice formatting and share code with
+ # writeDeclList()
+ nexprs = len(exprs)
+ for i, expr in enumerate(exprs):
+ expr.accept(self)
+ if i != (nexprs - 1):
+ self.write(", ")
diff --git a/ipc/ipdl/ipdl/cxx/code.py b/ipc/ipdl/ipdl/cxx/code.py
new file mode 100644
index 0000000000..0b5019b623
--- /dev/null
+++ b/ipc/ipdl/ipdl/cxx/code.py
@@ -0,0 +1,187 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+# This module contains functionality for adding formatted, opaque "code" blocks
+# into the IPDL ast. These code objects follow IPDL C++ ast patterns, and
+# perform lowering in much the same way.
+
+# In general it is recommended to use these for blocks of code which would
+# otherwise be specified by building a hardcoded IPDL-C++ AST, as users of this
+# API are often easier to read than users of the AST APIs in these cases.
+
+import re
+import math
+import textwrap
+
+from ipdl.cxx.ast import Node, Whitespace, GroupNode, VerbatimNode
+
+
+# -----------------------------------------------------------------------------
+# Public API.
+
+
+def StmtCode(tmpl, **kwargs):
+ """Perform template substitution to build opaque C++ AST nodes. See the
+ module documentation for more information on the templating syntax.
+
+ StmtCode nodes should be used where Stmt* nodes are used. They are placed
+ on their own line and indented."""
+ return _code(tmpl, False, kwargs)
+
+
+def ExprCode(tmpl, **kwargs):
+ """Perform template substitution to build opaque C++ AST nodes. See the
+ module documentation for more information on the templating syntax.
+
+ ExprCode nodes should be used where Expr* nodes are used. They are placed
+ inline, and no trailing newline is added."""
+ return _code(tmpl, True, kwargs)
+
+
+def StmtVerbatim(text):
+ """Build an opaque C++ AST node which emits input text verbatim.
+
+ StmtVerbatim nodes should be used where Stmt* nodes are used. They are placed
+ on their own line and indented."""
+ return _verbatim(text, False)
+
+
+def ExprVerbatim(text):
+ """Build an opaque C++ AST node which emits input text verbatim.
+
+ ExprVerbatim nodes should be used where Expr* nodes are used. They are
+ placed inline, and no trailing newline is added."""
+ return _verbatim(text, True)
+
+
+# -----------------------------------------------------------------------------
+# Implementation
+
+
+def _code(tmpl, inline, context):
+ # Remove common indentation, and strip the preceding newline from
+ # '''-quoting, because we usually don't want it.
+ if tmpl.startswith("\n"):
+ tmpl = tmpl[1:]
+ tmpl = textwrap.dedent(tmpl)
+
+ # Process each line in turn, building up a list of nodes.
+ nodes = []
+ for idx, line in enumerate(tmpl.splitlines()):
+ # Place newline tokens between lines in the input.
+ if idx > 0:
+ nodes.append(Whitespace.NL)
+
+ # Don't indent the first line if `inline` is set.
+ skip_indent = inline and idx == 0
+ nodes.append(_line(line.rstrip(), skip_indent, idx + 1, context))
+
+ # If we're inline, don't add the final trailing newline.
+ if not inline:
+ nodes.append(Whitespace.NL)
+ return GroupNode(nodes)
+
+
+def _verbatim(text, inline):
+ # For simplicitly, _verbatim is implemented using the same logic as _code,
+ # but with '$' characters escaped. This ensures we only need to worry about
+ # a single, albeit complex, codepath.
+ return _code(text.replace("$", "$$"), inline, {})
+
+
+# Pattern used to identify substitutions.
+_substPat = re.compile(
+ r"""
+ \$(?:
+ (?P<escaped>\$) | # '$$' is an escaped '$'
+ (?P<list>[*,])?{(?P<expr>[^}]+)} | # ${expr}, $*{expr}, or $,{expr}
+ (?P<invalid>) # For error reporting
+ )
+ """,
+ re.IGNORECASE | re.VERBOSE,
+)
+
+
+def _line(raw, skip_indent, lineno, context):
+ assert "\n" not in raw
+
+ # Determine the level of indentation used for this line
+ line = raw.lstrip()
+ offset = int(math.ceil((len(raw) - len(line)) / 4))
+
+ # If line starts with a directive, don't indent it.
+ if line.startswith("#"):
+ skip_indent = True
+
+ column = 0
+ children = []
+ for match in _substPat.finditer(line):
+ if match.group("invalid") is not None:
+ raise ValueError("Invalid substitution on line %d" % lineno)
+
+ # Any text from before the current entry should be written, and column
+ # advanced.
+ if match.start() > column:
+ before = line[column : match.start()]
+ children.append(VerbatimNode(before))
+ column = match.end()
+
+ # If we have an escaped group, emit a '$' node.
+ if match.group("escaped") is not None:
+ children.append(VerbatimNode("$"))
+ continue
+
+ # At this point we should have an expression.
+ list_chr = match.group("list")
+ expr = match.group("expr")
+ assert expr is not None
+
+ # Evaluate our expression in the context to get the values.
+ try:
+ values = eval(expr, context, {})
+ except Exception as e:
+ msg = "%s in substitution on line %d" % (repr(e), lineno)
+ raise ValueError(msg) from e
+
+ # If we aren't dealing with lists, wrap the result into a
+ # single-element list.
+ if list_chr is None:
+ values = [values]
+
+ # Check if this substitution is inline, or the entire line.
+ inline = match.span() != (0, len(line))
+
+ for idx, value in enumerate(values):
+ # If we're using ',' as list mode, put a comma between each node.
+ if idx > 0 and list_chr == ",":
+ children.append(VerbatimNode(", "))
+
+ # If our value isn't a node, turn it into one. Verbatim should be
+ # inline unless indent isn't being skipped, and the match isn't
+ # inline.
+ if not isinstance(value, Node):
+ value = _verbatim(str(value), skip_indent or inline)
+ children.append(value)
+
+ # If we were the entire line, indentation is handled by the added child
+ # nodes. Do this after the above loop such that created verbatims have
+ # the correct inline-ness.
+ if not inline:
+ skip_indent = True
+
+ # Add any remaining text in the line.
+ if len(line) > column:
+ children.append(VerbatimNode(line[column:]))
+
+ # If we have no children, just emit the empty string. This will become a
+ # blank line.
+ if len(children) == 0:
+ return VerbatimNode("")
+
+ # Add the initial indent if we aren't skipping it.
+ if not skip_indent:
+ children.insert(0, VerbatimNode("", indent=True))
+
+ # Wrap ourselves into a group node with the correct indent offset
+ return GroupNode(children, offset=offset)
diff --git a/ipc/ipdl/ipdl/lower.py b/ipc/ipdl/ipdl/lower.py
new file mode 100644
index 0000000000..b8ef219b9b
--- /dev/null
+++ b/ipc/ipdl/ipdl/lower.py
@@ -0,0 +1,5688 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+import re
+from copy import deepcopy
+from collections import OrderedDict
+import itertools
+
+import ipdl.ast
+import ipdl.builtin
+from ipdl.cxx.ast import *
+from ipdl.cxx.code import *
+from ipdl.type import ActorType, UnionType, TypeVisitor, builtinHeaderIncludes
+from ipdl.util import hash_str
+
+
+# -----------------------------------------------------------------------------
+# "Public" interface to lowering
+##
+
+
+class LowerToCxx:
+ def lower(self, tu, segmentcapacitydict):
+ """returns |[ header: File ], [ cpp : File ]| representing the
+ lowered form of |tu|"""
+ # annotate the AST with IPDL/C++ IR-type stuff used later
+ tu.accept(_DecorateWithCxxStuff())
+
+ # Any modifications to the filename scheme here need corresponding
+ # modifications in the ipdl.py driver script.
+ name = tu.name
+ pheader, pcpp = File(name + ".h"), File(name + ".cpp")
+
+ _GenerateProtocolCode().lower(tu, pheader, pcpp, segmentcapacitydict)
+ headers = [pheader]
+ cpps = [pcpp]
+
+ if tu.protocol:
+ pname = tu.protocol.name
+
+ parentheader, parentcpp = (
+ File(pname + "Parent.h"),
+ File(pname + "Parent.cpp"),
+ )
+ _GenerateProtocolParentCode().lower(
+ tu, pname + "Parent", parentheader, parentcpp
+ )
+
+ childheader, childcpp = File(pname + "Child.h"), File(pname + "Child.cpp")
+ _GenerateProtocolChildCode().lower(
+ tu, pname + "Child", childheader, childcpp
+ )
+
+ headers += [parentheader, childheader]
+ cpps += [parentcpp, childcpp]
+
+ return headers, cpps
+
+
+# -----------------------------------------------------------------------------
+# Helper code
+##
+
+
+def hashfunc(value):
+ h = hash_str(value) % 2 ** 32
+ if h < 0:
+ h += 2 ** 32
+ return h
+
+
+_NULL_ACTOR_ID = ExprLiteral.ZERO
+_FREED_ACTOR_ID = ExprLiteral.ONE
+
+_DISCLAIMER = Whitespace(
+ """//
+// Automatically generated by ipdlc.
+// Edit at your own risk
+//
+
+"""
+)
+
+
+class _struct:
+ pass
+
+
+def _namespacedHeaderName(name, namespaces):
+ pfx = "/".join([ns.name for ns in namespaces])
+ if pfx:
+ return pfx + "/" + name
+ else:
+ return name
+
+
+def _ipdlhHeaderName(tu):
+ assert tu.filetype == "header"
+ return _namespacedHeaderName(tu.name, tu.namespaces)
+
+
+def _protocolHeaderName(p, side=""):
+ if side:
+ side = side.title()
+ base = p.name + side
+ return _namespacedHeaderName(base, p.namespaces)
+
+
+def _includeGuardMacroName(headerfile):
+ return re.sub(r"[./]", "_", headerfile.name)
+
+
+def _includeGuardStart(headerfile):
+ guard = _includeGuardMacroName(headerfile)
+ return [CppDirective("ifndef", guard), CppDirective("define", guard)]
+
+
+def _includeGuardEnd(headerfile):
+ guard = _includeGuardMacroName(headerfile)
+ return [CppDirective("endif", "// ifndef " + guard)]
+
+
+def _messageStartName(ptype):
+ return ptype.name() + "MsgStart"
+
+
+def _protocolId(ptype):
+ return ExprVar(_messageStartName(ptype))
+
+
+def _protocolIdType():
+ return Type.INT32
+
+
+def _actorName(pname, side):
+ """|pname| is the protocol name. |side| is 'Parent' or 'Child'."""
+ tag = side
+ if not tag[0].isupper():
+ tag = side.title()
+ return pname + tag
+
+
+def _actorIdType():
+ return Type.INT32
+
+
+def _actorTypeTagType():
+ return Type.INT32
+
+
+def _actorId(actor=None):
+ if actor is not None:
+ return ExprCall(ExprSelect(actor, "->", "Id"))
+ return ExprCall(ExprVar("Id"))
+
+
+def _actorHId(actorhandle):
+ return ExprSelect(actorhandle, ".", "mId")
+
+
+def _backstagePass():
+ return ExprCall(ExprVar("mozilla::ipc::PrivateIPDLInterface"))
+
+
+def _deleteId():
+ return ExprVar("Msg___delete____ID")
+
+
+def _deleteReplyId():
+ return ExprVar("Reply___delete____ID")
+
+
+def _lookupListener(idexpr):
+ return ExprCall(ExprVar("Lookup"), args=[idexpr])
+
+
+def _makeForwardDeclForQClass(clsname, quals, cls=True, struct=False):
+ fd = ForwardDecl(clsname, cls=cls, struct=struct)
+ if 0 == len(quals):
+ return fd
+
+ outerns = Namespace(quals[0])
+ innerns = outerns
+ for ns in quals[1:]:
+ tmpns = Namespace(ns)
+ innerns.addstmt(tmpns)
+ innerns = tmpns
+
+ innerns.addstmt(fd)
+ return outerns
+
+
+def _makeForwardDeclForActor(ptype, side):
+ return _makeForwardDeclForQClass(
+ _actorName(ptype.qname.baseid, side), ptype.qname.quals
+ )
+
+
+def _makeForwardDecl(type):
+ return _makeForwardDeclForQClass(type.name(), type.qname.quals)
+
+
+def _putInNamespaces(cxxthing, namespaces):
+ """|namespaces| is in order [ outer, ..., inner ]"""
+ if 0 == len(namespaces):
+ return cxxthing
+
+ outerns = Namespace(namespaces[0].name)
+ innerns = outerns
+ for ns in namespaces[1:]:
+ newns = Namespace(ns.name)
+ innerns.addstmt(newns)
+ innerns = newns
+ innerns.addstmt(cxxthing)
+ return outerns
+
+
+def _sendPrefix(msgtype):
+ """Prefix of the name of the C++ method that sends |msgtype|."""
+ if msgtype.isInterrupt():
+ return "Call"
+ return "Send"
+
+
+def _recvPrefix(msgtype):
+ """Prefix of the name of the C++ method that handles |msgtype|."""
+ if msgtype.isInterrupt():
+ return "Answer"
+ return "Recv"
+
+
+def _flatTypeName(ipdltype):
+ """Return a 'flattened' IPDL type name that can be used as an
+ identifier.
+ E.g., |Foo[]| --> |ArrayOfFoo|."""
+ # NB: this logic depends heavily on what IPDL types are allowed to
+ # be constructed; e.g., Foo[][] is disallowed. needs to be kept in
+ # sync with grammar.
+ if ipdltype.isIPDL() and ipdltype.isArray():
+ return "ArrayOf" + _flatTypeName(ipdltype.basetype)
+ if ipdltype.isIPDL() and ipdltype.isMaybe():
+ return "Maybe" + _flatTypeName(ipdltype.basetype)
+ # NotNull types just assume the underlying variant name to avoid unnecessary
+ # noise, as a NotNull<T> and T should never exist in the same union.
+ if ipdltype.isIPDL() and ipdltype.isNotNull():
+ return _flatTypeName(ipdltype.basetype)
+ return ipdltype.name()
+
+
+def _hasVisibleActor(ipdltype):
+ """Return true iff a C++ decl of |ipdltype| would have an Actor* type.
+ For example: |Actor[]| would turn into |Array<ActorParent*>|, so this
+ function would return true for |Actor[]|."""
+ return ipdltype.isIPDL() and (
+ ipdltype.isActor()
+ or (ipdltype.hasBaseType() and _hasVisibleActor(ipdltype.basetype))
+ )
+
+
+def _abortIfFalse(cond, msg):
+ return StmtExpr(
+ ExprCall(ExprVar("MOZ_RELEASE_ASSERT"), [cond, ExprLiteral.String(msg)])
+ )
+
+
+def _refptr(T):
+ return Type("RefPtr", T=T)
+
+
+def _alreadyaddrefed(T):
+ return Type("already_AddRefed", T=T)
+
+
+def _tuple(types, const=False, ref=False):
+ return Type("std::tuple", T=types, const=const, ref=ref)
+
+
+def _promise(resolvetype, rejecttype, tail, resolver=False):
+ inner = Type("Private") if resolver else None
+ return Type("MozPromise", T=[resolvetype, rejecttype, tail], inner=inner)
+
+
+def _makePromise(returns, side, resolver=False):
+ if len(returns) > 1:
+ resolvetype = _tuple([d.bareType(side) for d in returns])
+ else:
+ resolvetype = returns[0].bareType(side)
+
+ # MozPromise is purposefully made to be exclusive only. Really, we mean it.
+ return _promise(
+ resolvetype, _ResponseRejectReason.Type(), ExprLiteral.TRUE, resolver=resolver
+ )
+
+
+def _resolveType(returns, side):
+ if len(returns) > 1:
+ return _tuple([d.inType(side, "send") for d in returns])
+ return returns[0].inType(side, "send")
+
+
+def _makeResolver(returns, side):
+ return TypeFunction([Decl(_resolveType(returns, side), "")])
+
+
+def _cxxArrayType(basetype, const=False, ref=False):
+ return Type("nsTArray", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False)
+
+
+def _cxxSpanType(basetype, const=False, ref=False):
+ basetype = deepcopy(basetype)
+ basetype.rightconst = True
+ return Type(
+ "mozilla::Span", T=basetype, const=const, ref=ref, hasimplicitcopyctor=True
+ )
+
+
+def _cxxMaybeType(basetype, const=False, ref=False):
+ return Type(
+ "mozilla::Maybe",
+ T=basetype,
+ const=const,
+ ref=ref,
+ hasimplicitcopyctor=basetype.hasimplicitcopyctor,
+ )
+
+
+def _cxxReadResultType(basetype, const=False, ref=False):
+ return Type(
+ "IPC::ReadResult",
+ T=basetype,
+ const=const,
+ ref=ref,
+ hasimplicitcopyctor=basetype.hasimplicitcopyctor,
+ )
+
+
+def _cxxNotNullType(basetype, const=False, ref=False):
+ return Type(
+ "mozilla::NotNull",
+ T=basetype,
+ const=const,
+ ref=ref,
+ hasimplicitcopyctor=basetype.hasimplicitcopyctor,
+ )
+
+
+def _cxxManagedContainerType(basetype, const=False, ref=False):
+ return Type(
+ "ManagedContainer", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False
+ )
+
+
+def _cxxLifecycleProxyType(ptr=False):
+ return Type("mozilla::ipc::ActorLifecycleProxy", ptr=ptr)
+
+
+def _otherSide(side):
+ if side == "child":
+ return "parent"
+ if side == "parent":
+ return "child"
+ assert 0
+
+
+def _ifLogging(topLevelProtocol, stmts):
+ return StmtCode(
+ """
+ if (mozilla::ipc::LoggingEnabledFor(${proto})) {
+ $*{stmts}
+ }
+ """,
+ proto=topLevelProtocol,
+ stmts=stmts,
+ )
+
+
+# XXX we need to remove these and install proper error handling
+
+
+def _printErrorMessage(msg):
+ if isinstance(msg, str):
+ msg = ExprLiteral.String(msg)
+ return StmtExpr(ExprCall(ExprVar("NS_ERROR"), args=[msg]))
+
+
+def _protocolErrorBreakpoint(msg):
+ if isinstance(msg, str):
+ msg = ExprLiteral.String(msg)
+ return StmtExpr(
+ ExprCall(ExprVar("mozilla::ipc::ProtocolErrorBreakpoint"), args=[msg])
+ )
+
+
+def _printWarningMessage(msg):
+ if isinstance(msg, str):
+ msg = ExprLiteral.String(msg)
+ return StmtExpr(ExprCall(ExprVar("NS_WARNING"), args=[msg]))
+
+
+def _fatalError(msg):
+ return StmtExpr(ExprCall(ExprVar("FatalError"), args=[ExprLiteral.String(msg)]))
+
+
+def _logicError(msg):
+ return StmtExpr(
+ ExprCall(ExprVar("mozilla::ipc::LogicError"), args=[ExprLiteral.String(msg)])
+ )
+
+
+def _sentinelReadError(classname):
+ return StmtExpr(
+ ExprCall(
+ ExprVar("mozilla::ipc::SentinelReadError"),
+ args=[ExprLiteral.String(classname)],
+ )
+ )
+
+
+# Results that IPDL-generated code returns back to *Channel code.
+# Users never see these
+
+
+class _Result:
+ @staticmethod
+ def Type():
+ return Type("Result")
+
+ Processed = ExprVar("MsgProcessed")
+ NotKnown = ExprVar("MsgNotKnown")
+ NotAllowed = ExprVar("MsgNotAllowed")
+ PayloadError = ExprVar("MsgPayloadError")
+ ProcessingError = ExprVar("MsgProcessingError")
+ RouteError = ExprVar("MsgRouteError")
+ ValuError = ExprVar("MsgValueError") # [sic]
+
+
+# these |errfn*| are functions that generate code to be executed on an
+# error, such as "bad actor ID". each is given a Python string
+# containing a description of the error
+
+# used in user-facing Send*() methods
+
+
+def errfnSend(msg, errcode=ExprLiteral.FALSE):
+ return [_fatalError(msg), StmtReturn(errcode)]
+
+
+def errfnSendCtor(msg):
+ return errfnSend(msg, errcode=ExprLiteral.NULL)
+
+
+# TODO should this error handling be strengthened for dtors?
+
+
+def errfnSendDtor(msg):
+ return [_printErrorMessage(msg), StmtReturn.FALSE]
+
+
+# used in |OnMessage*()| handlers that hand in-messages off to Recv*()
+# interface methods
+
+
+def errfnRecv(msg, errcode=_Result.ValuError):
+ return [_fatalError(msg), StmtReturn(errcode)]
+
+
+def errfnSentinel(rvalue=ExprLiteral.FALSE):
+ def inner(msg):
+ return [_sentinelReadError(msg), StmtReturn(rvalue)]
+
+ return inner
+
+
+def _destroyMethod():
+ return ExprVar("ActorDestroy")
+
+
+def errfnUnreachable(msg):
+ return [_logicError(msg)]
+
+
+def readResultError():
+ return ExprCode("{}")
+
+
+class _DestroyReason:
+ @staticmethod
+ def Type():
+ return Type("ActorDestroyReason")
+
+ Deletion = ExprVar("Deletion")
+ AncestorDeletion = ExprVar("AncestorDeletion")
+ NormalShutdown = ExprVar("NormalShutdown")
+ AbnormalShutdown = ExprVar("AbnormalShutdown")
+ FailedConstructor = ExprVar("FailedConstructor")
+ ManagedEndpointDropped = ExprVar("ManagedEndpointDropped")
+
+
+class _ResponseRejectReason:
+ @staticmethod
+ def Type():
+ return Type("ResponseRejectReason")
+
+ SendError = ExprVar("ResponseRejectReason::SendError")
+ ChannelClosed = ExprVar("ResponseRejectReason::ChannelClosed")
+ HandlerRejected = ExprVar("ResponseRejectReason::HandlerRejected")
+ ActorDestroyed = ExprVar("ResponseRejectReason::ActorDestroyed")
+
+
+# -----------------------------------------------------------------------------
+# Intermediate representation (IR) nodes used during lowering
+
+
+class _ConvertToCxxType(TypeVisitor):
+ def __init__(self, side, fq):
+ self.side = side
+ self.fq = fq
+
+ def typename(self, thing):
+ if self.fq:
+ return thing.fullname()
+ return thing.name()
+
+ def visitImportedCxxType(self, t):
+ cxxtype = Type(self.typename(t))
+ if t.isRefcounted():
+ cxxtype = _refptr(cxxtype)
+ return cxxtype
+
+ def visitBuiltinCType(self, b):
+ return Type(self.typename(b))
+
+ def visitActorType(self, a):
+ if self.side is None:
+ return Type(
+ "::mozilla::ipc::SideVariant",
+ T=[
+ _cxxBareType(a, "parent", self.fq),
+ _cxxBareType(a, "child", self.fq),
+ ],
+ )
+ return Type(_actorName(self.typename(a.protocol), self.side), ptr=True)
+
+ def visitStructType(self, s):
+ return Type(self.typename(s))
+
+ def visitUnionType(self, u):
+ return Type(self.typename(u))
+
+ def visitArrayType(self, a):
+ basecxxtype = a.basetype.accept(self)
+ return _cxxArrayType(basecxxtype)
+
+ def visitMaybeType(self, m):
+ basecxxtype = m.basetype.accept(self)
+ return _cxxMaybeType(basecxxtype)
+
+ def visitShmemType(self, s):
+ return Type(self.typename(s))
+
+ def visitByteBufType(self, s):
+ return Type(self.typename(s))
+
+ def visitFDType(self, s):
+ return Type(self.typename(s))
+
+ def visitEndpointType(self, s):
+ return Type(self.typename(s))
+
+ def visitManagedEndpointType(self, s):
+ return Type(self.typename(s))
+
+ def visitUniquePtrType(self, s):
+ return Type(self.typename(s))
+
+ def visitNotNullType(self, n):
+ basecxxtype = n.basetype.accept(self)
+ return _cxxNotNullType(basecxxtype)
+
+ def visitProtocolType(self, p):
+ assert 0
+
+ def visitMessageType(self, m):
+ assert 0
+
+ def visitVoidType(self, v):
+ assert 0
+
+
+def _cxxBareType(ipdltype, side, fq=False):
+ return ipdltype.accept(_ConvertToCxxType(side, fq))
+
+
+def _cxxRefType(ipdltype, side):
+ t = _cxxBareType(ipdltype, side)
+ t.ref = True
+ return t
+
+
+def _cxxConstRefType(ipdltype, side):
+ t = _cxxBareType(ipdltype, side)
+ if ipdltype.isIPDL() and ipdltype.isActor():
+ return t
+ if ipdltype.isIPDL() and ipdltype.isShmem():
+ t.ref = True
+ return t
+ if ipdltype.isIPDL() and ipdltype.isNotNull():
+ # If the inner type chooses to use a raw pointer, wrap that instead.
+ inner = _cxxConstRefType(ipdltype.basetype, side)
+ if inner.ptr:
+ t = _cxxNotNullType(inner)
+ return t
+ if ipdltype.isIPDL() and ipdltype.hasBaseType():
+ # Keep same constness as inner type.
+ inner = _cxxConstRefType(ipdltype.basetype, side)
+ t.const = inner.const or not inner.ref
+ t.ref = True
+ return t
+ if ipdltype.isCxx() and (ipdltype.isSendMoveOnly() or ipdltype.isDataMoveOnly()):
+ t.const = True
+ t.ref = True
+ return t
+ if ipdltype.isCxx() and ipdltype.isRefcounted():
+ # Use T* instead of const RefPtr<T>&
+ t = t.T
+ t.ptr = True
+ return t
+ t.const = True
+ t.ref = True
+ return t
+
+
+def _cxxTypeNeedsMoveForSend(ipdltype, context="root", visited=None):
+ """Returns `True` if serializing ipdltype requires a mutable reference, e.g.
+ because the underlying resource represented by the value is being
+ transferred to another process. This is occasionally distinct from whether
+ the C++ type exposes a copy constructor, such as for types which are not
+ cheaply copiable, but are not mutated when serialized."""
+
+ if visited is None:
+ visited = set()
+
+ visited.add(ipdltype)
+
+ if ipdltype.isCxx():
+ return ipdltype.isSendMoveOnly()
+
+ if ipdltype.isIPDL():
+ if ipdltype.hasBaseType():
+ return _cxxTypeNeedsMoveForSend(ipdltype.basetype, "wrapper", visited)
+ if ipdltype.isStruct() or ipdltype.isUnion():
+ return any(
+ _cxxTypeNeedsMoveForSend(t, "compound", visited)
+ for t in ipdltype.itercomponents()
+ if t not in visited
+ )
+
+ # For historical reasons, shmem is `const_cast` to a mutable reference
+ # when being stored in a struct or union (see
+ # `_StructField.constRefExpr` and `_UnionMember.getConstValue`), meaning
+ # that they do not cause the containing struct to require move for
+ # sending.
+ if ipdltype.isShmem():
+ return context != "compound"
+
+ return (
+ ipdltype.isByteBuf()
+ or ipdltype.isEndpoint()
+ or ipdltype.isManagedEndpoint()
+ )
+
+ return False
+
+
+def _cxxTypeNeedsMoveForData(ipdltype, context="root", visited=None):
+ """Returns `True` if the bare C++ type corresponding to ipdltype does not
+ satisfy std::is_copy_constructible_v<T>. All C++ types supported by IPDL
+ must support std::is_move_constructible_v<T>, so non-movable types must be
+ passed behind a `UniquePtr`."""
+
+ if visited is None:
+ visited = set()
+
+ visited.add(ipdltype)
+
+ if ipdltype.isCxx():
+ return ipdltype.isDataMoveOnly()
+
+ if ipdltype.isIPDL():
+ if ipdltype.isUniquePtr():
+ return True
+
+ # When nested within a maybe or array, arrays are no longer copyable.
+ if context == "wrapper" and ipdltype.isArray():
+ return True
+ if ipdltype.hasBaseType():
+ return _cxxTypeNeedsMoveForData(ipdltype.basetype, "wrapper", visited)
+ if ipdltype.isStruct() or ipdltype.isUnion():
+ return any(
+ _cxxTypeNeedsMoveForData(t, "compound", visited)
+ for t in ipdltype.itercomponents()
+ if t not in visited
+ )
+ return (
+ ipdltype.isByteBuf()
+ or ipdltype.isEndpoint()
+ or ipdltype.isManagedEndpoint()
+ )
+
+ return False
+
+
+def _cxxTypeCanMove(ipdltype):
+ return not (ipdltype.isIPDL() and ipdltype.isActor())
+
+
+def _cxxForceMoveRefType(ipdltype, side):
+ assert _cxxTypeCanMove(ipdltype)
+ t = _cxxBareType(ipdltype, side)
+ t.rvalref = True
+ return t
+
+
+def _cxxPtrToType(ipdltype, side):
+ t = _cxxBareType(ipdltype, side)
+ if ipdltype.isIPDL() and ipdltype.isActor() and side is not None:
+ t.ptr = False
+ t.ptrptr = True
+ return t
+ t.ptr = True
+ return t
+
+
+def _cxxConstPtrToType(ipdltype, side):
+ t = _cxxBareType(ipdltype, side)
+ if ipdltype.isIPDL() and ipdltype.isActor() and side is not None:
+ t.ptr = False
+ t.ptrconstptr = True
+ return t
+ t.const = True
+ t.ptr = True
+ return t
+
+
+def _cxxInType(ipdltype, side, direction):
+ t = _cxxBareType(ipdltype, side)
+ if ipdltype.isIPDL() and ipdltype.isActor():
+ return t
+ if ipdltype.isIPDL() and ipdltype.isNotNull():
+ # If the inner type chooses to use a raw pointer, wrap that instead.
+ inner = _cxxInType(ipdltype.basetype, side, direction)
+ if inner.ptr:
+ t = _cxxNotNullType(inner)
+ return t
+ if _cxxTypeNeedsMoveForSend(ipdltype):
+ t.rvalref = True
+ return t
+ if ipdltype.isCxx():
+ if ipdltype.isRefcounted():
+ # Use T* instead of const RefPtr<T>&
+ t = t.T
+ t.ptr = True
+ return t
+ if ipdltype.name() == "nsCString":
+ t = Type("nsACString")
+ if ipdltype.name() == "nsString":
+ t = Type("nsAString")
+ # Use Span<T const> rather than nsTArray<T> for array types which aren't
+ # `_cxxTypeNeedsMoveForSend`. This is only done for the "send" side, and not
+ # for recv signatures.
+ if direction == "send" and ipdltype.isIPDL() and ipdltype.isArray():
+ inner = _cxxBareType(ipdltype.basetype, side)
+ return _cxxSpanType(inner)
+
+ t.const = True
+ t.ref = True
+ return t
+
+
+def _allocMethod(ptype, side):
+ return "Alloc" + ptype.name() + side.title()
+
+
+def _deallocMethod(ptype, side):
+ return "Dealloc" + ptype.name() + side.title()
+
+
+##
+# A _HybridDecl straddles IPDL and C++ decls. It knows which C++
+# types correspond to which IPDL types, and it also knows how
+# serialize and deserialize "special" IPDL C++ types.
+##
+
+
+class _HybridDecl:
+ """A hybrid decl stores both an IPDL type and all the C++ type
+ info needed by later passes, along with a basic name for the decl."""
+
+ def __init__(self, ipdltype, name, attributes={}):
+ self.ipdltype = ipdltype
+ self.name = name
+ self.attributes = attributes
+
+ def var(self):
+ return ExprVar(self.name)
+
+ def bareType(self, side, fq=False):
+ """Return this decl's unqualified C++ type."""
+ return _cxxBareType(self.ipdltype, side, fq=fq)
+
+ def refType(self, side):
+ """Return this decl's C++ type as a 'reference' type, which is not
+ necessarily a C++ reference."""
+ return _cxxRefType(self.ipdltype, side)
+
+ def constRefType(self, side):
+ """Return this decl's C++ type as a const, 'reference' type."""
+ return _cxxConstRefType(self.ipdltype, side)
+
+ def ptrToType(self, side):
+ return _cxxPtrToType(self.ipdltype, side)
+
+ def constPtrToType(self, side):
+ return _cxxConstPtrToType(self.ipdltype, side)
+
+ def inType(self, side, direction):
+ """Return this decl's C++ Type with sending inparam semantics."""
+ return _cxxInType(self.ipdltype, side, direction)
+
+ def outType(self, side):
+ """Return this decl's C++ Type with outparam semantics."""
+ t = self.bareType(side)
+ if self.ipdltype.isIPDL() and self.ipdltype.isActor():
+ t.ptr = False
+ t.ptrptr = True
+ return t
+ t.ptr = True
+ return t
+
+ def forceMoveType(self, side):
+ """Return this decl's C++ Type with forced move semantics."""
+ assert _cxxTypeCanMove(self.ipdltype)
+ return _cxxForceMoveRefType(self.ipdltype, side)
+
+
+# --------------------------------------------------
+
+
+class HasFQName:
+ def fqClassName(self):
+ return self.decl.type.fullname()
+
+
+class _CompoundTypeComponent(_HybridDecl):
+ # @override the following methods to make the side argument optional.
+ def bareType(self, side=None, fq=False):
+ return _HybridDecl.bareType(self, side, fq=fq)
+
+ def refType(self, side=None):
+ return _HybridDecl.refType(self, side)
+
+ def constRefType(self, side=None):
+ return _HybridDecl.constRefType(self, side)
+
+ def ptrToType(self, side=None):
+ return _HybridDecl.ptrToType(self, side)
+
+ def constPtrToType(self, side=None):
+ return _HybridDecl.constPtrToType(self, side)
+
+ def forceMoveType(self, side=None):
+ return _HybridDecl.forceMoveType(self, side)
+
+
+class StructDecl(ipdl.ast.StructDecl, HasFQName):
+ def fields_ipdl_order(self):
+ for f in self.fields:
+ yield f
+
+ def fields_member_order(self):
+ assert len(self.packed_field_order) == len(self.fields)
+
+ for i in self.packed_field_order:
+ yield self.fields[i]
+
+ @staticmethod
+ def upgrade(structDecl):
+ assert isinstance(structDecl, ipdl.ast.StructDecl)
+ structDecl.__class__ = StructDecl
+
+
+class _StructField(_CompoundTypeComponent):
+ def __init__(self, ipdltype, name, sd):
+ self.basename = name
+
+ _CompoundTypeComponent.__init__(self, ipdltype, name)
+
+ def getMethod(self, thisexpr=None, sel="."):
+ meth = self.var()
+ if thisexpr is not None:
+ return ExprSelect(thisexpr, sel, meth.name)
+ return meth
+
+ def refExpr(self, thisexpr=None):
+ ref = self.memberVar()
+ if thisexpr is not None:
+ ref = ExprSelect(thisexpr, ".", ref.name)
+ return ref
+
+ def constRefExpr(self, thisexpr=None):
+ # sigh, gross hack
+ refexpr = self.refExpr(thisexpr)
+ if "Shmem" == self.ipdltype.name():
+ refexpr = ExprCast(refexpr, Type("Shmem", ref=True), const=True)
+ return refexpr
+
+ def argVar(self):
+ return ExprVar("_" + self.name)
+
+ def memberVar(self):
+ return ExprVar(self.name + "_")
+
+
+class UnionDecl(ipdl.ast.UnionDecl, HasFQName):
+ def callType(self, var=None):
+ func = ExprVar("type")
+ if var is not None:
+ func = ExprSelect(var, ".", func.name)
+ return ExprCall(func)
+
+ @staticmethod
+ def upgrade(unionDecl):
+ assert isinstance(unionDecl, ipdl.ast.UnionDecl)
+ unionDecl.__class__ = UnionDecl
+
+
+class _UnionMember(_CompoundTypeComponent):
+ """Not in the AFL sense, but rather a member (e.g. |int;|) of an
+ IPDL union type."""
+
+ def __init__(self, ipdltype, ud):
+ flatname = _flatTypeName(ipdltype)
+
+ _CompoundTypeComponent.__init__(self, ipdltype, "V" + flatname)
+ self.flattypename = flatname
+
+ # To create a finite object with a mutually recursive type, a union must
+ # be present somewhere in the recursive loop. Because of that we only
+ # need to care about introducing indirections inside unions.
+ self.recursive = ud.decl.type.mutuallyRecursiveWith(ipdltype)
+
+ def enum(self):
+ return "T" + self.flattypename
+
+ def enumvar(self):
+ return ExprVar(self.enum())
+
+ def internalType(self):
+ if self.recursive:
+ return self.ptrToType()
+ else:
+ return self.bareType()
+
+ def unionType(self):
+ """Type used for storage in generated C union decl."""
+ if self.recursive:
+ return self.ptrToType()
+ else:
+ return Type("mozilla::AlignedStorage2", T=self.internalType())
+
+ def unionValue(self):
+ # NB: knows that Union's storage C union is named |mValue|
+ return ExprSelect(ExprVar("mValue"), ".", self.name)
+
+ def typedef(self):
+ return self.flattypename + "__tdef"
+
+ def callGetConstPtr(self):
+ """Return an expression of type self.constptrToSelfType()"""
+ return ExprCall(ExprVar(self.getConstPtrName()))
+
+ def callGetPtr(self):
+ """Return an expression of type self.ptrToSelfType()"""
+ return ExprCall(ExprVar(self.getPtrName()))
+
+ def callCtor(self, expr=None):
+ assert not isinstance(expr, list)
+
+ if expr is None:
+ args = None
+ elif (
+ self.ipdltype.isIPDL()
+ and self.ipdltype.isArray()
+ and not isinstance(expr, ExprMove)
+ ):
+ args = [ExprCall(ExprSelect(expr, ".", "Clone"), args=[])]
+ else:
+ args = [expr]
+
+ if self.recursive:
+ return ExprAssn(self.callGetPtr(), ExprNew(self.bareType(), args=args))
+ else:
+ return ExprNew(
+ self.bareType(),
+ args=args,
+ newargs=[ExprVar("mozilla::KnownNotNull"), self.callGetPtr()],
+ )
+
+ def callDtor(self):
+ if self.recursive:
+ return ExprDelete(self.callGetPtr())
+ else:
+ return ExprCall(ExprSelect(self.callGetPtr(), "->", "~" + self.typedef()))
+
+ def getTypeName(self):
+ return "get_" + self.flattypename
+
+ def getConstTypeName(self):
+ return "get_" + self.flattypename
+
+ def getOtherTypeName(self):
+ return "get_" + self.otherflattypename
+
+ def getPtrName(self):
+ return "ptr_" + self.flattypename
+
+ def getConstPtrName(self):
+ return "constptr_" + self.flattypename
+
+ def ptrToSelfExpr(self):
+ """|*ptrToSelfExpr()| has type |self.bareType()|"""
+ v = self.unionValue()
+ if self.recursive:
+ return v
+ else:
+ return ExprCall(ExprSelect(v, ".", "addr"))
+
+ def constptrToSelfExpr(self):
+ """|*constptrToSelfExpr()| has type |self.constType()|"""
+ v = self.unionValue()
+ if self.recursive:
+ return v
+ return ExprCall(ExprSelect(v, ".", "addr"))
+
+ def ptrToInternalType(self):
+ t = self.ptrToType()
+ if self.recursive:
+ t.ref = True
+ return t
+
+ def defaultValue(self, fq=False):
+ # Use the default constructor for any class that does not have an
+ # implicit copy constructor.
+ if not self.bareType().hasimplicitcopyctor:
+ return None
+
+ if self.ipdltype.isIPDL() and self.ipdltype.isActor():
+ return ExprLiteral.NULL
+ # XXX sneaky here, maybe need ExprCtor()?
+ return ExprCall(self.bareType(fq=fq))
+
+ def getConstValue(self):
+ v = ExprDeref(self.callGetConstPtr())
+ # sigh
+ if "Shmem" == self.ipdltype.name():
+ v = ExprCast(v, Type("Shmem", ref=True), const=True)
+ return v
+
+
+# --------------------------------------------------
+
+
+class MessageDecl(ipdl.ast.MessageDecl):
+ def baseName(self):
+ return self.name
+
+ def recvMethod(self):
+ name = _recvPrefix(self.decl.type) + self.baseName()
+ if self.decl.type.isCtor():
+ name += "Constructor"
+ return name
+
+ def sendMethod(self):
+ name = _sendPrefix(self.decl.type) + self.baseName()
+ if self.decl.type.isCtor():
+ name += "Constructor"
+ return name
+
+ def hasReply(self):
+ return (
+ self.decl.type.hasReply()
+ or self.decl.type.isCtor()
+ or self.decl.type.isDtor()
+ )
+
+ def hasAsyncReturns(self):
+ return self.decl.type.isAsync() and self.returns
+
+ def msgCtorFunc(self):
+ return "Msg_%s" % (self.decl.progname)
+
+ def prettyMsgName(self, pfx=""):
+ return pfx + self.msgCtorFunc()
+
+ def pqMsgCtorFunc(self):
+ return "%s::%s" % (self.namespace, self.msgCtorFunc())
+
+ def msgId(self):
+ return self.msgCtorFunc() + "__ID"
+
+ def pqMsgId(self):
+ return "%s::%s" % (self.namespace, self.msgId())
+
+ def replyCtorFunc(self):
+ return "Reply_%s" % (self.decl.progname)
+
+ def pqReplyCtorFunc(self):
+ return "%s::%s" % (self.namespace, self.replyCtorFunc())
+
+ def replyId(self):
+ return self.replyCtorFunc() + "__ID"
+
+ def pqReplyId(self):
+ return "%s::%s" % (self.namespace, self.replyId())
+
+ def prettyReplyName(self, pfx=""):
+ return pfx + self.replyCtorFunc()
+
+ def promiseName(self):
+ name = self.baseName()
+ if self.decl.type.isCtor():
+ name += "Constructor"
+ name += "Promise"
+ return name
+
+ def resolverName(self):
+ return self.baseName() + "Resolver"
+
+ def actorDecl(self):
+ return self.params[0]
+
+ def makeCxxParams(
+ self, paramsems="in", returnsems="out", side=None, implicit=True, direction=None
+ ):
+ """Return a list of C++ decls per the spec'd configuration.
+ |params| and |returns| is the C++ semantics of those: 'in', 'out', or None."""
+
+ def makeDecl(d, sems):
+ if (
+ self.decl.type.tainted
+ and "NoTaint" not in d.attributes
+ and direction == "recv"
+ ):
+ # Tainted types are passed by-value, allowing the receiver to move them if desired.
+ assert sems != "out"
+ return Decl(Type("Tainted", T=d.bareType(side)), d.name)
+
+ if sems == "in":
+ t = d.inType(side, direction)
+ # If this is the `recv` side, and we're not using "move"
+ # semantics, that means we're an alloc method, and cannot accept
+ # values by rvalue reference. Downgrade to an lvalue reference.
+ if direction == "recv" and t.rvalref:
+ t.rvalref = False
+ t.ref = True
+ return Decl(t, d.name)
+ elif sems == "move":
+ assert direction == "recv"
+ # For legacy reasons, use an rvalue reference when generating
+ # parameters for recv methods which accept arrays.
+ if d.ipdltype.isIPDL() and d.ipdltype.isArray():
+ t = d.bareType(side)
+ t.rvalref = True
+ return Decl(t, d.name)
+ return Decl(d.inType(side, direction), d.name)
+ elif sems == "out":
+ return Decl(d.outType(side), d.name)
+ else:
+ assert 0
+
+ def makeResolverDecl(returns):
+ return Decl(Type(self.resolverName(), rvalref=True), "aResolve")
+
+ def makeCallbackResolveDecl(returns):
+ if len(returns) > 1:
+ resolvetype = _tuple([d.bareType(side) for d in returns])
+ else:
+ resolvetype = returns[0].bareType(side)
+
+ return Decl(
+ Type("mozilla::ipc::ResolveCallback", T=resolvetype, rvalref=True),
+ "aResolve",
+ )
+
+ def makeCallbackRejectDecl(returns):
+ return Decl(Type("mozilla::ipc::RejectCallback", rvalref=True), "aReject")
+
+ cxxparams = []
+ if paramsems is not None:
+ cxxparams.extend([makeDecl(d, paramsems) for d in self.params])
+
+ if returnsems == "promise" and self.returns:
+ pass
+ elif returnsems == "callback" and self.returns:
+ cxxparams.extend(
+ [
+ makeCallbackResolveDecl(self.returns),
+ makeCallbackRejectDecl(self.returns),
+ ]
+ )
+ elif returnsems == "resolver" and self.returns:
+ cxxparams.extend([makeResolverDecl(self.returns)])
+ elif returnsems is not None:
+ cxxparams.extend([makeDecl(r, returnsems) for r in self.returns])
+
+ if not implicit and self.decl.type.hasImplicitActorParam():
+ cxxparams = cxxparams[1:]
+
+ return cxxparams
+
+ def makeCxxArgs(
+ self, paramsems="in", retsems="out", retcallsems="out", implicit=True
+ ):
+ assert not retcallsems or retsems # retcallsems => returnsems
+ cxxargs = []
+
+ if paramsems == "move":
+ # We don't std::move() RefPtr<T> types because current Recv*()
+ # implementors take these parameters as T*, and
+ # std::move(RefPtr<T>) doesn't coerce to T*.
+ # We also don't move NotNull, as it has no move constructor.
+ cxxargs.extend(
+ [
+ p.var()
+ if p.ipdltype.isRefcounted()
+ or (p.ipdltype.isIPDL() and p.ipdltype.isNotNull())
+ else ExprMove(p.var())
+ for p in self.params
+ ]
+ )
+ elif paramsems == "in":
+ cxxargs.extend([p.var() for p in self.params])
+ else:
+ assert False
+
+ for ret in self.returns:
+ if retsems == "in":
+ if retcallsems == "in":
+ cxxargs.append(ret.var())
+ elif retcallsems == "out":
+ cxxargs.append(ExprAddrOf(ret.var()))
+ else:
+ assert 0
+ elif retsems == "out":
+ if retcallsems == "in":
+ cxxargs.append(ExprDeref(ret.var()))
+ elif retcallsems == "out":
+ cxxargs.append(ret.var())
+ else:
+ assert 0
+ elif retsems == "resolver":
+ pass
+ if retsems == "resolver":
+ cxxargs.append(ExprMove(ExprVar("resolver")))
+
+ if not implicit:
+ assert self.decl.type.hasImplicitActorParam()
+ cxxargs = cxxargs[1:]
+
+ return cxxargs
+
+ @staticmethod
+ def upgrade(messageDecl):
+ assert isinstance(messageDecl, ipdl.ast.MessageDecl)
+ if messageDecl.decl.type.hasImplicitActorParam():
+ messageDecl.params.insert(
+ 0,
+ _HybridDecl(
+ ipdl.type.ActorType(messageDecl.decl.type.constructedType()),
+ "actor",
+ ),
+ )
+ messageDecl.__class__ = MessageDecl
+
+
+# --------------------------------------------------
+def _usesShmem(p):
+ for md in p.messageDecls:
+ for param in md.inParams:
+ if ipdl.type.hasshmem(param.type):
+ return True
+ for ret in md.outParams:
+ if ipdl.type.hasshmem(ret.type):
+ return True
+ return False
+
+
+def _subtreeUsesShmem(p):
+ if _usesShmem(p):
+ return True
+
+ ptype = p.decl.type
+ for mgd in ptype.manages:
+ if ptype is not mgd:
+ if _subtreeUsesShmem(mgd._ast):
+ return True
+ return False
+
+
+class Protocol(ipdl.ast.Protocol):
+ def managerInterfaceType(self, ptr=False):
+ return Type("mozilla::ipc::IProtocol", ptr=ptr)
+
+ def openedProtocolInterfaceType(self, ptr=False):
+ return Type("mozilla::ipc::IToplevelProtocol", ptr=ptr)
+
+ def _ipdlmgrtype(self):
+ assert 1 == len(self.decl.type.managers)
+ for mgr in self.decl.type.managers:
+ return mgr
+
+ def managerActorType(self, side, ptr=False):
+ return Type(_actorName(self._ipdlmgrtype().name(), side), ptr=ptr)
+
+ def unregisterMethod(self, actorThis=None):
+ if actorThis is not None:
+ return ExprSelect(actorThis, "->", "Unregister")
+ return ExprVar("Unregister")
+
+ def removeManageeMethod(self):
+ return ExprVar("RemoveManagee")
+
+ def deallocManageeMethod(self):
+ return ExprVar("DeallocManagee")
+
+ def getChannelMethod(self):
+ return ExprVar("GetIPCChannel")
+
+ def callGetChannel(self, actorThis=None):
+ fn = self.getChannelMethod()
+ if actorThis is not None:
+ fn = ExprSelect(actorThis, "->", fn.name)
+ return ExprCall(fn)
+
+ def processingErrorVar(self):
+ assert self.decl.type.isToplevel()
+ return ExprVar("ProcessingError")
+
+ def shouldContinueFromTimeoutVar(self):
+ assert self.decl.type.isToplevel()
+ return ExprVar("ShouldContinueFromReplyTimeout")
+
+ def routingId(self, actorThis=None):
+ if self.decl.type.isToplevel():
+ return ExprVar("MSG_ROUTING_CONTROL")
+ if actorThis is not None:
+ return ExprCall(ExprSelect(actorThis, "->", "Id"))
+ return ExprCall(ExprVar("Id"))
+
+ def managerVar(self, thisexpr=None):
+ assert thisexpr is not None or not self.decl.type.isToplevel()
+ mvar = ExprCall(ExprVar("Manager"), args=[])
+ if thisexpr is not None:
+ mvar = ExprCall(ExprSelect(thisexpr, "->", "Manager"), args=[])
+ return mvar
+
+ def managedCxxType(self, actortype, side):
+ assert self.decl.type.isManagerOf(actortype)
+ return Type(_actorName(actortype.name(), side), ptr=True)
+
+ def managedMethod(self, actortype, side):
+ assert self.decl.type.isManagerOf(actortype)
+ return ExprVar("Managed" + _actorName(actortype.name(), side))
+
+ def managedVar(self, actortype, side):
+ assert self.decl.type.isManagerOf(actortype)
+ return ExprVar("mManaged" + _actorName(actortype.name(), side))
+
+ def managedVarType(self, actortype, side, const=False, ref=False):
+ assert self.decl.type.isManagerOf(actortype)
+ return _cxxManagedContainerType(
+ Type(_actorName(actortype.name(), side)), const=const, ref=ref
+ )
+
+ def subtreeUsesShmem(self):
+ return _subtreeUsesShmem(self)
+
+ @staticmethod
+ def upgrade(protocol):
+ assert isinstance(protocol, ipdl.ast.Protocol)
+ protocol.__class__ = Protocol
+
+
+class TranslationUnit(ipdl.ast.TranslationUnit):
+ @staticmethod
+ def upgrade(tu):
+ assert isinstance(tu, ipdl.ast.TranslationUnit)
+ tu.__class__ = TranslationUnit
+
+
+# -----------------------------------------------------------------------------
+
+pod_types = {
+ "::int8_t": 1,
+ "::uint8_t": 1,
+ "::int16_t": 2,
+ "::uint16_t": 2,
+ "::int32_t": 4,
+ "::uint32_t": 4,
+ "::int64_t": 8,
+ "::uint64_t": 8,
+ "float": 4,
+ "double": 8,
+}
+max_pod_size = max(pod_types.values())
+# We claim that all types we don't recognize are automatically "bigger"
+# than pod types for ease of sorting.
+pod_size_sentinel = max_pod_size * 2
+
+
+def pod_size(ipdltype):
+ if not ipdltype.isCxx():
+ return pod_size_sentinel
+
+ return pod_types.get(ipdltype.fullname(), pod_size_sentinel)
+
+
+class _DecorateWithCxxStuff(ipdl.ast.Visitor):
+ """Phase 1 of lowering: decorate the IPDL AST with information
+ relevant to C++ code generation.
+
+ This pass results in an AST that is a poor man's "IR"; in reality, a
+ "hybrid" AST mainly consisting of IPDL nodes with new C++ info along
+ with some new IPDL/C++ nodes that are tuned for C++ codegen."""
+
+ def __init__(self):
+ self.visitedTus = set()
+ self.protocolName = None
+
+ def visitTranslationUnit(self, tu):
+ if tu not in self.visitedTus:
+ self.visitedTus.add(tu)
+ ipdl.ast.Visitor.visitTranslationUnit(self, tu)
+ if not isinstance(tu, TranslationUnit):
+ TranslationUnit.upgrade(tu)
+
+ def visitInclude(self, inc):
+ if inc.tu.filetype == "header":
+ inc.tu.accept(self)
+
+ def visitProtocol(self, pro):
+ self.protocolName = pro.name
+ Protocol.upgrade(pro)
+ return ipdl.ast.Visitor.visitProtocol(self, pro)
+
+ def visitStructDecl(self, sd):
+ if not isinstance(sd, StructDecl):
+ newfields = [_StructField(f.decl.type, f.name, sd) for f in sd.fields]
+
+ # Compute a permutation of the fields for in-memory storage such
+ # that the memory layout of the structure will be well-packed.
+ permutation = list(range(len(newfields)))
+
+ # Note that the results of `pod_size` ensure that non-POD fields
+ # sort before POD ones.
+ def size(idx):
+ return pod_size(newfields[idx].ipdltype)
+
+ permutation.sort(key=size, reverse=True)
+
+ sd.fields = newfields
+ sd.packed_field_order = permutation
+ StructDecl.upgrade(sd)
+
+ def visitUnionDecl(self, ud):
+ ud.components = [_UnionMember(ctype, ud) for ctype in ud.decl.type.components]
+ UnionDecl.upgrade(ud)
+
+ def visitDecl(self, decl):
+ return _HybridDecl(decl.type, decl.progname, decl.attributes)
+
+ def visitMessageDecl(self, md):
+ md.namespace = self.protocolName
+ md.params = [param.accept(self) for param in md.inParams]
+ md.returns = [ret.accept(self) for ret in md.outParams]
+ MessageDecl.upgrade(md)
+
+
+# -----------------------------------------------------------------------------
+
+
+def msgenums(protocol, pretty=False):
+ msgenum = TypeEnum("MessageType")
+ msgstart = _messageStartName(protocol.decl.type) + " << 16"
+ msgenum.addId(protocol.name + "Start", msgstart)
+
+ for md in protocol.messageDecls:
+ msgenum.addId(md.prettyMsgName() if pretty else md.msgId())
+ if md.hasReply():
+ msgenum.addId(md.prettyReplyName() if pretty else md.replyId())
+
+ msgenum.addId(protocol.name + "End")
+ return msgenum
+
+
+class _GenerateProtocolCode(ipdl.ast.Visitor):
+ """Creates code common to both the parent and child actors."""
+
+ def __init__(self):
+ self.protocol = None # protocol we're generating a class for
+ self.hdrfile = None # what will become Protocol.h
+ self.cppfile = None # what will become Protocol.cpp
+ self.cppIncludeHeaders = []
+ self.structUnionDefns = []
+ self.funcDefns = []
+
+ def lower(self, tu, cxxHeaderFile, cxxFile, segmentcapacitydict):
+ self.protocol = tu.protocol
+ self.hdrfile = cxxHeaderFile
+ self.cppfile = cxxFile
+ self.segmentcapacitydict = segmentcapacitydict
+ tu.accept(self)
+
+ def visitTranslationUnit(self, tu):
+ hf = self.hdrfile
+
+ hf.addthing(_DISCLAIMER)
+ hf.addthings(_includeGuardStart(hf))
+ hf.addthing(Whitespace.NL)
+
+ for inc in builtinHeaderIncludes:
+ self.visitBuiltinCxxInclude(inc)
+
+ # Compute the set of includes we need for declared structure/union
+ # classes for this protocol.
+ typesToIncludes = {}
+ for using in tu.using:
+ typestr = str(using.type)
+ if typestr not in typesToIncludes:
+ typesToIncludes[typestr] = using.header
+ else:
+ assert typesToIncludes[typestr] == using.header
+
+ aggregateTypeIncludes = set()
+ for su in tu.structsAndUnions:
+ typedeps = _ComputeTypeDeps(su.decl.type, typesToIncludes)
+ if isinstance(su, ipdl.ast.StructDecl):
+ aggregateTypeIncludes.add("mozilla/ipc/IPDLStructMember.h")
+ for f in su.fields:
+ f.ipdltype.accept(typedeps)
+ elif isinstance(su, ipdl.ast.UnionDecl):
+ for c in su.components:
+ c.ipdltype.accept(typedeps)
+
+ aggregateTypeIncludes.update(typedeps.includeHeaders)
+
+ if len(aggregateTypeIncludes) != 0:
+ hf.addthing(Whitespace.NL)
+ hf.addthings([Whitespace("// Headers for typedefs"), Whitespace.NL])
+
+ for headername in sorted(iter(aggregateTypeIncludes)):
+ hf.addthing(CppDirective("include", '"' + headername + '"'))
+
+ # Manually run Visitor.visitTranslationUnit. For dependency resolution
+ # we need to handle structs and unions separately.
+ for cxxInc in tu.cxxIncludes:
+ cxxInc.accept(self)
+ for inc in tu.includes:
+ inc.accept(self)
+ self.generateStructsAndUnions(tu)
+ for using in tu.builtinUsing:
+ using.accept(self)
+ for using in tu.using:
+ using.accept(self)
+ if tu.protocol:
+ tu.protocol.accept(self)
+
+ if tu.filetype == "header":
+ self.cppIncludeHeaders.append(_ipdlhHeaderName(tu) + ".h")
+
+ hf.addthing(Whitespace.NL)
+ hf.addthings(_includeGuardEnd(hf))
+
+ cf = self.cppfile
+ cf.addthings(
+ (
+ [_DISCLAIMER, Whitespace.NL]
+ + [
+ CppDirective("include", '"' + h + '"')
+ for h in self.cppIncludeHeaders
+ ]
+ + [Whitespace.NL]
+ + [
+ CppDirective("include", '"%s"' % filename)
+ for filename in ipdl.builtin.CppIncludes
+ ]
+ + [Whitespace.NL]
+ )
+ )
+
+ if self.protocol:
+ # construct the namespace into which we'll stick all our defns
+ ns = Namespace(self.protocol.name)
+ cf.addthing(_putInNamespaces(ns, self.protocol.namespaces))
+ ns.addstmts(([Whitespace.NL] + self.funcDefns + [Whitespace.NL]))
+
+ cf.addthings(self.structUnionDefns)
+
+ def visitBuiltinCxxInclude(self, inc):
+ self.hdrfile.addthing(CppDirective("include", '"' + inc.file + '"'))
+
+ def visitCxxInclude(self, inc):
+ self.cppIncludeHeaders.append(inc.file)
+
+ def visitInclude(self, inc):
+ if inc.tu.filetype == "header":
+ self.hdrfile.addthing(
+ CppDirective("include", '"' + _ipdlhHeaderName(inc.tu) + '.h"')
+ )
+ # Inherit cpp includes defined by imported header files, as they may
+ # be required to serialize an imported `using` type.
+ for cxxinc in inc.tu.cxxIncludes:
+ cxxinc.accept(self)
+ else:
+ self.cppIncludeHeaders += [
+ _protocolHeaderName(inc.tu.protocol, "parent") + ".h",
+ _protocolHeaderName(inc.tu.protocol, "child") + ".h",
+ ]
+
+ def generateStructsAndUnions(self, tu):
+ """Generate the definitions for all structs and unions. This will
+ re-order the declarations if needed in the C++ code such that
+ dependencies have already been defined."""
+ decls = OrderedDict()
+ for su in tu.structsAndUnions:
+ if isinstance(su, StructDecl):
+ which = "struct"
+ forwarddecls, fulldecltypes, cls = _generateCxxStruct(su)
+ traitsdecl, traitsdefns = _ParamTraits.structPickling(su.decl.type)
+ else:
+ assert isinstance(su, UnionDecl)
+ which = "union"
+ forwarddecls, fulldecltypes, cls = _generateCxxUnion(su)
+ traitsdecl, traitsdefns = _ParamTraits.unionPickling(su.decl.type)
+
+ clsdecl, methoddefns = _splitClassDeclDefn(cls)
+
+ # Store the declarations in the decls map so we can emit in
+ # dependency order.
+ decls[su.decl.type] = (
+ fulldecltypes,
+ [Whitespace.NL]
+ + forwarddecls
+ + [
+ Whitespace(
+ """
+//-----------------------------------------------------------------------------
+// Declaration of the IPDL type |%s %s|
+//
+"""
+ % (which, su.name)
+ ),
+ _putInNamespaces(clsdecl, su.namespaces),
+ ]
+ + [Whitespace.NL, traitsdecl],
+ )
+
+ self.structUnionDefns.extend(
+ [
+ Whitespace(
+ """
+//-----------------------------------------------------------------------------
+// Method definitions for the IPDL type |%s %s|
+//
+"""
+ % (which, su.name)
+ ),
+ _putInNamespaces(methoddefns, su.namespaces),
+ Whitespace.NL,
+ traitsdefns,
+ ]
+ )
+
+ # Generate the declarations structs in dependency order.
+ def gen_struct(deps, defn):
+ for dep in deps:
+ if dep in decls:
+ d, t = decls[dep]
+ del decls[dep]
+ gen_struct(d, t)
+ self.hdrfile.addthings(defn)
+
+ while len(decls) > 0:
+ _, (d, t) = decls.popitem(False)
+ gen_struct(d, t)
+
+ def visitProtocol(self, p):
+ self.cppIncludeHeaders.append(_protocolHeaderName(self.protocol, "") + ".h")
+ self.cppIncludeHeaders.append(
+ _protocolHeaderName(self.protocol, "Parent") + ".h"
+ )
+ self.cppIncludeHeaders.append(
+ _protocolHeaderName(self.protocol, "Child") + ".h"
+ )
+
+ # Forward declare our own actors.
+ self.hdrfile.addthings(
+ [
+ Whitespace.NL,
+ _makeForwardDeclForActor(p.decl.type, "Parent"),
+ _makeForwardDeclForActor(p.decl.type, "Child"),
+ ]
+ )
+
+ self.hdrfile.addthing(
+ Whitespace(
+ """
+//-----------------------------------------------------------------------------
+// Code common to %sChild and %sParent
+//
+"""
+ % (p.name, p.name)
+ )
+ )
+
+ # construct the namespace into which we'll stick all our decls
+ ns = Namespace(self.protocol.name)
+ self.hdrfile.addthing(_putInNamespaces(ns, p.namespaces))
+ ns.addstmt(Whitespace.NL)
+
+ for func in self.genEndpointFuncs():
+ edecl, edefn = _splitFuncDeclDefn(func)
+ ns.addstmts([edecl, Whitespace.NL])
+ self.funcDefns.append(edefn)
+
+ # spit out message type enum and classes
+ msgenum = msgenums(self.protocol)
+ ns.addstmts([StmtDecl(Decl(msgenum, "")), Whitespace.NL])
+
+ for md in p.messageDecls:
+ decls = []
+
+ # Look up the segment capacity used for serializing this
+ # message. If the capacity is not specified, use '0' for
+ # the default capacity (defined in ipc_message.cc)
+ name = "%s::%s" % (md.namespace, md.decl.progname)
+ segmentcapacity = self.segmentcapacitydict.get(name, 0)
+
+ mfDecl, mfDefn = _splitFuncDeclDefn(
+ _generateMessageConstructor(md, segmentcapacity, p, forReply=False)
+ )
+ decls.append(mfDecl)
+ self.funcDefns.append(mfDefn)
+
+ if md.hasReply():
+ rfDecl, rfDefn = _splitFuncDeclDefn(
+ _generateMessageConstructor(md, 0, p, forReply=True)
+ )
+ decls.append(rfDecl)
+ self.funcDefns.append(rfDefn)
+
+ decls.append(Whitespace.NL)
+ ns.addstmts(decls)
+
+ ns.addstmts([Whitespace.NL, Whitespace.NL])
+
+ # Generate code for PFoo::CreateEndpoints.
+ def genEndpointFuncs(self):
+ p = self.protocol.decl.type
+ tparent = _cxxBareType(ActorType(p), "Parent", fq=True)
+ tchild = _cxxBareType(ActorType(p), "Child", fq=True)
+
+ def mkOverload(includepids):
+ params = []
+ if includepids:
+ params = [
+ Decl(Type("base::ProcessId"), "aParentDestPid"),
+ Decl(Type("base::ProcessId"), "aChildDestPid"),
+ ]
+ params += [
+ Decl(
+ Type("mozilla::ipc::Endpoint<" + tparent.name + ">", ptr=True),
+ "aParent",
+ ),
+ Decl(
+ Type("mozilla::ipc::Endpoint<" + tchild.name + ">", ptr=True),
+ "aChild",
+ ),
+ ]
+ openfunc = MethodDefn(
+ MethodDecl("CreateEndpoints", params=params, ret=Type.NSRESULT)
+ )
+ openfunc.addcode(
+ """
+ return mozilla::ipc::CreateEndpoints(
+ mozilla::ipc::PrivateIPDLInterface(),
+ $,{args});
+ """,
+ args=[ExprVar(d.name) for d in params],
+ )
+ return openfunc
+
+ funcs = [mkOverload(True)]
+ if not p.hasOtherPid():
+ funcs.append(mkOverload(False))
+ return funcs
+
+
+# --------------------------------------------------
+
+cppPriorityList = list(
+ map(lambda src: src.upper() + "_PRIORITY", ipdl.ast.priorityList)
+)
+
+
+def _generateMessageConstructor(md, segmentSize, protocol, forReply=False):
+ if forReply:
+ clsname = md.replyCtorFunc()
+ msgid = md.replyId()
+ replyEnum = "REPLY"
+ prioEnum = cppPriorityList[md.decl.type.replyPrio]
+ else:
+ clsname = md.msgCtorFunc()
+ msgid = md.msgId()
+ replyEnum = "NOT_REPLY"
+ prioEnum = cppPriorityList[md.decl.type.prio]
+
+ nested = md.decl.type.nested
+ compress = md.decl.type.compress
+ lazySend = md.decl.type.lazySend
+
+ routingId = ExprVar("routingId")
+
+ func = FunctionDefn(
+ FunctionDecl(
+ clsname,
+ params=[Decl(Type("int32_t"), routingId.name)],
+ ret=Type("mozilla::UniquePtr<IPC::Message>"),
+ )
+ )
+
+ if not compress:
+ compression = "COMPRESSION_NONE"
+ elif compress.value == "all":
+ compression = "COMPRESSION_ALL"
+ else:
+ assert compress.value is None
+ compression = "COMPRESSION_ENABLED"
+
+ if lazySend:
+ lazySendEnum = "LAZY_SEND"
+ else:
+ lazySendEnum = "EAGER_SEND"
+
+ if nested == ipdl.ast.NOT_NESTED:
+ nestedEnum = "NOT_NESTED"
+ elif nested == ipdl.ast.INSIDE_SYNC_NESTED:
+ nestedEnum = "NESTED_INSIDE_SYNC"
+ else:
+ assert nested == ipdl.ast.INSIDE_CPOW_NESTED
+ nestedEnum = "NESTED_INSIDE_CPOW"
+
+ if md.decl.type.isSync():
+ syncEnum = "SYNC"
+ else:
+ syncEnum = "ASYNC"
+
+ # FIXME(bug ???) - remove support for interrupt messages from the IPDL compiler.
+ if md.decl.type.isInterrupt():
+ func.addcode(
+ """
+ static_assert(
+ false,
+ "runtime support for intr messages has been removed from IPDL");
+ """
+ )
+
+ if md.decl.type.isCtor():
+ ctorEnum = "CONSTRUCTOR"
+ else:
+ ctorEnum = "NOT_CONSTRUCTOR"
+
+ def messageEnum(valname):
+ return ExprVar("IPC::Message::" + valname)
+
+ flags = ExprCall(
+ ExprVar("IPC::Message::HeaderFlags"),
+ args=[
+ messageEnum(nestedEnum),
+ messageEnum(prioEnum),
+ messageEnum(compression),
+ messageEnum(lazySendEnum),
+ messageEnum(ctorEnum),
+ messageEnum(syncEnum),
+ messageEnum(replyEnum),
+ ],
+ )
+
+ segmentSize = int(segmentSize)
+ if not segmentSize:
+ segmentSize = 0
+ func.addstmt(
+ StmtReturn(
+ ExprCall(
+ ExprVar("IPC::Message::IPDLMessage"),
+ args=[
+ routingId,
+ ExprVar(msgid),
+ ExprLiteral.Int(int(segmentSize)),
+ flags,
+ ],
+ )
+ )
+ )
+
+ return func
+
+
+# --------------------------------------------------
+
+
+class _ParamTraits:
+ var = ExprVar("aVar")
+ writervar = ExprVar("aWriter")
+ readervar = ExprVar("aReader")
+
+ @classmethod
+ def ifsideis(cls, rdrwtr, side, then, els=None):
+ cxxside = ExprVar("mozilla::ipc::ChildSide")
+ if side == "parent":
+ cxxside = ExprVar("mozilla::ipc::ParentSide")
+
+ ifstmt = StmtIf(
+ ExprBinary(
+ cxxside,
+ "==",
+ ExprCode("${rdrwtr}->GetActor()->GetSide()", rdrwtr=rdrwtr),
+ )
+ )
+ ifstmt.addifstmt(then)
+ if els is not None:
+ ifstmt.addelsestmt(els)
+ return ifstmt
+
+ @classmethod
+ def fatalError(cls, rdrwtr, reason):
+ return StmtCode(
+ "${rdrwtr}->FatalError(${reason});",
+ rdrwtr=rdrwtr,
+ reason=ExprLiteral.String(reason),
+ )
+
+ @classmethod
+ def writeSentinel(cls, writervar, sentinelKey):
+ return [
+ Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True),
+ StmtExpr(
+ ExprCall(
+ ExprSelect(writervar, "->", "WriteSentinel"),
+ args=[ExprLiteral.Int(hashfunc(sentinelKey))],
+ )
+ ),
+ ]
+
+ @classmethod
+ def readSentinel(cls, readervar, sentinelKey, sentinelFail):
+ # Read the sentinel
+ read = ExprCall(
+ ExprSelect(readervar, "->", "ReadSentinel"),
+ args=[ExprLiteral.Int(hashfunc(sentinelKey))],
+ )
+ ifsentinel = StmtIf(ExprNot(read))
+ ifsentinel.addifstmts(sentinelFail)
+
+ return [
+ Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True),
+ ifsentinel,
+ ]
+
+ @classmethod
+ def write(cls, var, writervar, ipdltype=None):
+ if ipdltype and _cxxTypeNeedsMoveForSend(ipdltype):
+ var = ExprMove(var)
+ return ExprCall(ExprVar("IPC::WriteParam"), args=[writervar, var])
+
+ @classmethod
+ def checkedWrite(cls, ipdltype, var, writervar, sentinelKey):
+ assert sentinelKey
+ block = Block()
+
+ block.addstmts(
+ [
+ StmtExpr(cls.write(var, writervar, ipdltype)),
+ ]
+ )
+ block.addstmts(cls.writeSentinel(writervar, sentinelKey))
+ return block
+
+ @classmethod
+ def bulkSentinelKey(cls, fields):
+ return " | ".join(f.basename for f in fields)
+
+ @classmethod
+ def checkedBulkWrite(cls, var, size, fields):
+ block = Block()
+ first = fields[0]
+
+ block.addstmts(
+ [
+ StmtExpr(
+ ExprCall(
+ ExprSelect(cls.writervar, "->", "WriteBytes"),
+ args=[
+ ExprAddrOf(
+ ExprCall(first.getMethod(thisexpr=var, sel="."))
+ ),
+ ExprLiteral.Int(size * len(fields)),
+ ],
+ )
+ )
+ ]
+ )
+ block.addstmts(cls.writeSentinel(cls.writervar, cls.bulkSentinelKey(fields)))
+
+ return block
+
+ @classmethod
+ def checkedBulkRead(cls, var, size, fields):
+ block = Block()
+ first = fields[0]
+
+ readbytes = ExprCall(
+ ExprSelect(cls.readervar, "->", "ReadBytesInto"),
+ args=[
+ ExprAddrOf(ExprCall(first.getMethod(thisexpr=var, sel="->"))),
+ ExprLiteral.Int(size * len(fields)),
+ ],
+ )
+ ifbad = StmtIf(ExprNot(readbytes))
+ errmsg = "Error bulk reading fields from %s" % first.ipdltype.name()
+ ifbad.addifstmts(
+ [cls.fatalError(cls.readervar, errmsg), StmtReturn(readResultError())]
+ )
+ block.addstmt(ifbad)
+ block.addstmts(
+ cls.readSentinel(
+ cls.readervar,
+ cls.bulkSentinelKey(fields),
+ errfnSentinel(readResultError())(errmsg),
+ )
+ )
+
+ return block
+
+ @classmethod
+ def checkedRead(
+ cls,
+ ipdltype,
+ cxxtype,
+ var,
+ readervar,
+ errfn,
+ paramtype,
+ sentinelKey,
+ errfnSentinel,
+ ):
+ assert isinstance(var, ExprVar)
+
+ if not isinstance(paramtype, list):
+ paramtype = ["Error deserializing " + paramtype]
+
+ block = Block()
+
+ # Read the data
+ block.addcode(
+ """
+ auto ${maybevar} = IPC::ReadParam<${ty}>(${reader});
+ if (!${maybevar}) {
+ $*{errfn}
+ }
+ auto& ${var} = *${maybevar};
+ """,
+ maybevar=ExprVar("maybe__" + var.name),
+ ty=cxxtype,
+ reader=readervar,
+ errfn=errfn(*paramtype),
+ var=var,
+ )
+
+ block.addstmts(
+ cls.readSentinel(readervar, sentinelKey, errfnSentinel(*paramtype))
+ )
+
+ return block
+
+ # Helper wrapper for checkedRead for use within _ParamTraits
+ @classmethod
+ def _checkedRead(cls, ipdltype, cxxtype, var, sentinelKey, what):
+ def errfn(msg):
+ return [cls.fatalError(cls.readervar, msg), StmtReturn(readResultError())]
+
+ return cls.checkedRead(
+ ipdltype,
+ cxxtype,
+ var,
+ cls.readervar,
+ errfn=errfn,
+ paramtype=what,
+ sentinelKey=sentinelKey,
+ errfnSentinel=errfnSentinel(readResultError()),
+ )
+
+ @classmethod
+ def generateDecl(cls, fortype, write, read, needsmove=False):
+ # ParamTraits impls are selected ignoring constness, and references.
+ pt = Class(
+ "ParamTraits",
+ specializes=Type(
+ fortype.name, T=fortype.T, inner=fortype.inner, ptr=fortype.ptr
+ ),
+ struct=True,
+ )
+
+ # typedef T paramType;
+ pt.addstmt(Typedef(fortype, "paramType"))
+
+ # static void Write(Message*, const T&);
+ if needsmove:
+ intype = Type("paramType", rvalref=True)
+ else:
+ intype = Type("paramType", ref=True, const=True)
+ writemthd = MethodDefn(
+ MethodDecl(
+ "Write",
+ params=[
+ Decl(Type("IPC::MessageWriter", ptr=True), cls.writervar.name),
+ Decl(intype, cls.var.name),
+ ],
+ methodspec=MethodSpec.STATIC,
+ )
+ )
+ writemthd.addstmts(write)
+ pt.addstmt(writemthd)
+
+ # static ReadResult<T> Read(MessageReader*);
+ readmthd = MethodDefn(
+ MethodDecl(
+ "Read",
+ params=[
+ Decl(Type("IPC::MessageReader", ptr=True), cls.readervar.name),
+ ],
+ ret=Type("IPC::ReadResult<paramType>"),
+ methodspec=MethodSpec.STATIC,
+ )
+ )
+ readmthd.addstmts(read)
+ pt.addstmt(readmthd)
+
+ # Split the class into declaration and definition
+ clsdecl, methoddefns = _splitClassDeclDefn(pt)
+
+ namespaces = [Namespace("IPC")]
+ clsns = _putInNamespaces(clsdecl, namespaces)
+ defns = _putInNamespaces(methoddefns, namespaces)
+ return clsns, defns
+
+ @classmethod
+ def actorPickling(cls, actortype, side):
+ """Generates pickling for IPDL actors. This is a |nullable| deserializer.
+ Write and read callers will perform nullability validation."""
+
+ cxxtype = _cxxBareType(actortype, side, fq=True)
+
+ write = StmtCode(
+ """
+ MOZ_RELEASE_ASSERT(
+ ${writervar}->GetActor(),
+ "Cannot serialize managed actors without an actor");
+
+ int32_t id;
+ if (!${var}) {
+ id = 0; // kNullActorId
+ } else {
+ id = ${var}->Id();
+ if (id == 1) { // kFreedActorId
+ ${var}->FatalError("Actor has been |delete|d");
+ }
+ MOZ_RELEASE_ASSERT(
+ ${writervar}->GetActor()->GetIPCChannel() == ${var}->GetIPCChannel(),
+ "Actor must be from the same channel as the"
+ " actor it's being sent over");
+ MOZ_RELEASE_ASSERT(
+ ${var}->CanSend(),
+ "Actor must still be open when sending");
+ }
+
+ ${write};
+ """,
+ var=cls.var,
+ writervar=cls.writervar,
+ write=cls.write(ExprVar("id"), cls.writervar),
+ )
+
+ # bool Read(..) impl
+ read = StmtCode(
+ """
+ MOZ_RELEASE_ASSERT(
+ ${readervar}->GetActor(),
+ "Cannot deserialize managed actors without an actor");
+ mozilla::Maybe<mozilla::ipc::IProtocol*> actor = ${readervar}->GetActor()
+ ->ReadActor(${readervar}, true, ${actortype}, ${protocolid});
+ if (actor.isSome()) {
+ return static_cast<${cxxtype}>(actor.ref());
+ }
+ return {};
+ """,
+ readervar=cls.readervar,
+ actortype=ExprLiteral.String(actortype.name()),
+ protocolid=_protocolId(actortype),
+ cxxtype=cxxtype,
+ )
+
+ return cls.generateDecl(cxxtype, [write], [read])
+
+ @classmethod
+ def structPickling(cls, structtype):
+ sd = structtype._ast
+ # NOTE: Not using _cxxBareType here as we don't have a side
+ cxxtype = Type(structtype.fullname())
+
+ write = []
+ read = []
+
+ # First serialize/deserialize all non-pod data in IPDL order. These need
+ # to be read/written first because they'll be used to invoke the IPDL
+ # struct's constructor.
+ ctorargs = []
+ for f in sd.fields_ipdl_order():
+ if pod_size(f.ipdltype) == pod_size_sentinel:
+ write.append(
+ cls.checkedWrite(
+ f.ipdltype,
+ ExprCall(f.getMethod(thisexpr=cls.var, sel=".")),
+ cls.writervar,
+ sentinelKey=f.basename,
+ )
+ )
+ read.append(
+ cls._checkedRead(
+ f.ipdltype,
+ f.bareType(fq=True),
+ f.argVar(),
+ f.basename,
+ "'"
+ + f.getMethod().name
+ + "' "
+ + "("
+ + f.ipdltype.name()
+ + ") member of "
+ + "'"
+ + structtype.name()
+ + "'",
+ )
+ )
+ if _cxxTypeCanMove(f.ipdltype):
+ ctorargs.append(ExprMove(f.argVar()))
+ else:
+ ctorargs.append(f.argVar())
+ else:
+ # We're going to bulk-read in this value later, so we'll just
+ # zero-initialize it for now.
+ ctorargs.append(ExprCode("${type}{0}", type=f.bareType(fq=True)))
+
+ resultvar = ExprVar("result__")
+ read.append(
+ StmtDecl(
+ Decl(_cxxReadResultType(Type("paramType")), resultvar.name),
+ initargs=[ExprVar("std::in_place")] + ctorargs,
+ )
+ )
+
+ # After non-pod data, bulk read/write pod data in member order. This has
+ # to be done after the result has been constructed, so that we have
+ # somewhere to read into.
+ for (size, fields) in itertools.groupby(
+ sd.fields_member_order(), lambda f: pod_size(f.ipdltype)
+ ):
+ if size != pod_size_sentinel:
+ fields = list(fields)
+ write.append(cls.checkedBulkWrite(cls.var, size, fields))
+ read.append(cls.checkedBulkRead(resultvar, size, fields))
+
+ read.append(StmtReturn(resultvar))
+
+ return cls.generateDecl(
+ cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(structtype)
+ )
+
+ @classmethod
+ def unionPickling(cls, uniontype):
+ # NOTE: Not using _cxxBareType here as we don't have a side
+ cxxtype = Type(uniontype.fullname())
+ ud = uniontype._ast
+
+ # Use typedef to set up an alias so it's easier to reference the struct type.
+ alias = "union__"
+ typevar = ExprVar("type")
+
+ prelude = [
+ Typedef(cxxtype, alias),
+ ]
+
+ writeswitch = StmtSwitch(typevar)
+ write = prelude + [
+ StmtDecl(Decl(Type.INT, typevar.name), init=ud.callType(cls.var)),
+ cls.checkedWrite(
+ None, typevar, cls.writervar, sentinelKey=uniontype.name()
+ ),
+ Whitespace.NL,
+ writeswitch,
+ ]
+
+ readswitch = StmtSwitch(typevar)
+ read = prelude + [
+ cls._checkedRead(
+ None,
+ Type.INT,
+ typevar,
+ uniontype.name(),
+ "type of union " + uniontype.name(),
+ ),
+ Whitespace.NL,
+ readswitch,
+ ]
+
+ for c in ud.components:
+ caselabel = CaseLabel(alias + "::" + c.enum())
+ origenum = c.enum()
+
+ writecase = StmtBlock()
+ wstmt = cls.checkedWrite(
+ c.ipdltype,
+ ExprCall(ExprSelect(cls.var, ".", c.getTypeName())),
+ cls.writervar,
+ sentinelKey=c.enum(),
+ )
+ writecase.addstmts([wstmt, StmtReturn()])
+ writeswitch.addcase(caselabel, writecase)
+
+ readcase = StmtBlock()
+ tmpvar = ExprVar("tmp")
+ readcase.addstmts(
+ [
+ cls._checkedRead(
+ c.ipdltype,
+ c.bareType(fq=True),
+ tmpvar,
+ origenum,
+ "variant " + origenum + " of union " + uniontype.name(),
+ ),
+ StmtReturn(ExprMove(tmpvar)),
+ ]
+ )
+ readswitch.addcase(caselabel, readcase)
+
+ # Add the error default case
+ writeswitch.addcase(
+ DefaultLabel(),
+ StmtBlock(
+ [
+ cls.fatalError(
+ cls.writervar, "unknown variant of union " + uniontype.name()
+ ),
+ StmtReturn(),
+ ]
+ ),
+ )
+ readswitch.addcase(
+ DefaultLabel(),
+ StmtBlock(
+ [
+ cls.fatalError(
+ cls.readervar, "unknown variant of union " + uniontype.name()
+ ),
+ StmtReturn(readResultError()),
+ ]
+ ),
+ )
+
+ return cls.generateDecl(
+ cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(uniontype)
+ )
+
+
+# --------------------------------------------------
+
+
+class _ComputeTypeDeps(TypeVisitor):
+ """Pass that gathers the C++ types that a particular IPDL type
+ (recursively) depends on. There are three kinds of dependencies: (i)
+ types that need forward declaration; (ii) types that need a |using|
+ stmt; (iii) IPDL structs or unions which must be fully declared
+ before this struct. Some types generate multiple kinds."""
+
+ def __init__(self, fortype, typesToIncludes=None):
+ ipdl.type.TypeVisitor.__init__(self)
+ self.usingTypedefs = []
+ self.forwardDeclStmts = []
+ self.fullDeclTypes = []
+ self.includeHeaders = set()
+ self.fortype = fortype
+ self.typesToIncludes = typesToIncludes
+
+ def maybeTypedef(self, fqname, name, templateargs=[]):
+ assert fqname.startswith("::")
+ if fqname != name:
+ self.usingTypedefs.append(Typedef(Type(fqname), name, templateargs))
+ if self.typesToIncludes is not None and fqname in self.typesToIncludes:
+ self.includeHeaders.add(self.typesToIncludes[fqname])
+
+ def visitImportedCxxType(self, t):
+ if t in self.visited:
+ return
+ self.visited.add(t)
+ self.maybeTypedef(t.fullname(), t.name())
+
+ def visitActorType(self, t):
+ if t in self.visited:
+ return
+ self.visited.add(t)
+
+ fqname, name = t.fullname(), t.name()
+
+ self.includeHeaders.add("mozilla/ipc/SideVariant.h")
+ self.maybeTypedef(_actorName(fqname, "Parent"), _actorName(name, "Parent"))
+ self.maybeTypedef(_actorName(fqname, "Child"), _actorName(name, "Child"))
+
+ self.forwardDeclStmts.extend(
+ [
+ _makeForwardDeclForActor(t.protocol, "parent"),
+ Whitespace.NL,
+ _makeForwardDeclForActor(t.protocol, "child"),
+ Whitespace.NL,
+ ]
+ )
+
+ def visitStructOrUnionType(self, su, defaultVisit):
+ if su in self.visited or su == self.fortype:
+ return
+ self.visited.add(su)
+ self.maybeTypedef(su.fullname(), su.name())
+
+ # Mutually recursive fields in unions are behind indirection, so we only
+ # need a forward decl, and don't need a full type declaration.
+ if isinstance(self.fortype, UnionType) and self.fortype.mutuallyRecursiveWith(
+ su
+ ):
+ self.forwardDeclStmts.append(_makeForwardDecl(su))
+ else:
+ self.fullDeclTypes.append(su)
+
+ return defaultVisit(self, su)
+
+ def visitStructType(self, t):
+ return self.visitStructOrUnionType(t, TypeVisitor.visitStructType)
+
+ def visitUnionType(self, t):
+ return self.visitStructOrUnionType(t, TypeVisitor.visitUnionType)
+
+ def visitArrayType(self, t):
+ return TypeVisitor.visitArrayType(self, t)
+
+ def visitMaybeType(self, m):
+ return TypeVisitor.visitMaybeType(self, m)
+
+ def visitShmemType(self, s):
+ if s in self.visited:
+ return
+ self.visited.add(s)
+ self.maybeTypedef("::mozilla::ipc::Shmem", "Shmem")
+
+ def visitByteBufType(self, s):
+ if s in self.visited:
+ return
+ self.visited.add(s)
+ self.maybeTypedef("::mozilla::ipc::ByteBuf", "ByteBuf")
+
+ def visitFDType(self, s):
+ if s in self.visited:
+ return
+ self.visited.add(s)
+ self.maybeTypedef("::mozilla::ipc::FileDescriptor", "FileDescriptor")
+
+ def visitEndpointType(self, s):
+ if s in self.visited:
+ return
+ self.visited.add(s)
+ self.maybeTypedef("::mozilla::ipc::Endpoint", "Endpoint", ["FooSide"])
+ self.visitActorType(s.actor)
+
+ def visitManagedEndpointType(self, s):
+ if s in self.visited:
+ return
+ self.visited.add(s)
+ self.maybeTypedef(
+ "::mozilla::ipc::ManagedEndpoint", "ManagedEndpoint", ["FooSide"]
+ )
+ self.visitActorType(s.actor)
+
+ def visitUniquePtrType(self, s):
+ if s in self.visited:
+ return
+ self.visited.add(s)
+
+ def visitVoidType(self, v):
+ assert 0
+
+ def visitMessageType(self, v):
+ assert 0
+
+ def visitProtocolType(self, v):
+ assert 0
+
+
+def _fieldStaticAssertions(sd):
+ staticasserts = []
+ for (size, fields) in itertools.groupby(
+ sd.fields_member_order(), lambda f: pod_size(f.ipdltype)
+ ):
+ if size == pod_size_sentinel:
+ continue
+
+ fields = list(fields)
+ if len(fields) == 1:
+ continue
+
+ staticasserts.append(
+ StmtCode(
+ """
+ static_assert(
+ (offsetof(${struct}, ${last}) - offsetof(${struct}, ${first})) == ${expected},
+ "Bad assumptions about field layout!");
+ """,
+ struct=sd.name,
+ first=fields[0].memberVar(),
+ last=fields[-1].memberVar(),
+ expected=ExprLiteral.Int(size * (len(fields) - 1)),
+ )
+ )
+
+ return staticasserts
+
+
+def _generateCxxStruct(sd):
+ """ """
+ # compute all the typedefs and forward decls we need to make
+ gettypedeps = _ComputeTypeDeps(sd.decl.type)
+ for f in sd.fields:
+ f.ipdltype.accept(gettypedeps)
+
+ usingTypedefs = gettypedeps.usingTypedefs
+ forwarddeclstmts = gettypedeps.forwardDeclStmts
+ fulldecltypes = gettypedeps.fullDeclTypes
+
+ struct = Class(sd.name, final=True)
+ struct.addstmts([Label.PRIVATE] + usingTypedefs + [Whitespace.NL, Label.PUBLIC])
+
+ constreftype = Type(sd.name, const=True, ref=True)
+
+ # Struct()
+ # We want the default constructor to be declared if it is available, but
+ # some of our members may not be default-constructible. Silence the
+ # warning which clang generates in that case.
+ #
+ # Members which need value initialization will be handled by wrapping
+ # the member in a template type when declaring them.
+ struct.addcode(
+ """
+ #ifdef __clang__
+ # pragma clang diagnostic push
+ # if __has_warning("-Wdefaulted-function-deleted")
+ # pragma clang diagnostic ignored "-Wdefaulted-function-deleted"
+ # endif
+ #endif
+ ${name}() = default;
+ #ifdef __clang__
+ # pragma clang diagnostic pop
+ #endif
+
+ """,
+ name=sd.name,
+ )
+
+ # If this is an empty struct (no fields), then the default ctor
+ # and "create-with-fields" ctors are equivalent.
+ if len(sd.fields):
+ assert len(sd.fields) == len(sd.packed_field_order)
+
+ # Struct(const field1& _f1, ...)
+ valctor = ConstructorDefn(
+ ConstructorDecl(
+ sd.name,
+ params=[
+ Decl(
+ f.forceMoveType()
+ if _cxxTypeNeedsMoveForData(f.ipdltype)
+ else f.constRefType(),
+ f.argVar().name,
+ )
+ for f in sd.fields_ipdl_order()
+ ],
+ force_inline=True,
+ )
+ )
+ valctor.memberinits = []
+ for f in sd.fields_member_order():
+ arg = f.argVar()
+ if _cxxTypeNeedsMoveForData(f.ipdltype):
+ arg = ExprMove(arg)
+ valctor.memberinits.append(ExprMemberInit(f.memberVar(), args=[arg]))
+
+ struct.addstmts([valctor, Whitespace.NL])
+
+ # If a constructor which moves each argument would be different from the
+ # `const T&` version, also generate that constructor.
+ if not all(
+ _cxxTypeNeedsMoveForData(f.ipdltype) or not _cxxTypeCanMove(f.ipdltype)
+ for f in sd.fields_ipdl_order()
+ ):
+ # Struct(field1&& _f1, ...)
+ valmovector = ConstructorDefn(
+ ConstructorDecl(
+ sd.name,
+ params=[
+ Decl(
+ f.forceMoveType()
+ if _cxxTypeCanMove(f.ipdltype)
+ else f.constRefType(),
+ f.argVar().name,
+ )
+ for f in sd.fields_ipdl_order()
+ ],
+ force_inline=True,
+ )
+ )
+
+ valmovector.memberinits = []
+ for f in sd.fields_member_order():
+ arg = f.argVar()
+ if _cxxTypeCanMove(f.ipdltype):
+ arg = ExprMove(arg)
+ valmovector.memberinits.append(
+ ExprMemberInit(f.memberVar(), args=[arg])
+ )
+
+ struct.addstmts([valmovector, Whitespace.NL])
+
+ # The default copy, move, and assignment constructors, and the default
+ # destructor, will do the right thing.
+
+ if "Comparable" in sd.attributes:
+ # bool operator==(const Struct& _o)
+ ovar = ExprVar("_o")
+ opeqeq = MethodDefn(
+ MethodDecl(
+ "operator==",
+ params=[Decl(constreftype, ovar.name)],
+ ret=Type.BOOL,
+ const=True,
+ )
+ )
+ for f in sd.fields_ipdl_order():
+ ifneq = StmtIf(
+ ExprNot(
+ ExprBinary(
+ ExprCall(f.getMethod()), "==", ExprCall(f.getMethod(ovar))
+ )
+ )
+ )
+ ifneq.addifstmt(StmtReturn.FALSE)
+ opeqeq.addstmt(ifneq)
+ opeqeq.addstmt(StmtReturn.TRUE)
+ struct.addstmts([opeqeq, Whitespace.NL])
+
+ # bool operator!=(const Struct& _o)
+ opneq = MethodDefn(
+ MethodDecl(
+ "operator!=",
+ params=[Decl(constreftype, ovar.name)],
+ ret=Type.BOOL,
+ const=True,
+ )
+ )
+ opneq.addstmt(StmtReturn(ExprNot(ExprCall(ExprVar("operator=="), args=[ovar]))))
+ struct.addstmts([opneq, Whitespace.NL])
+
+ # field1& f1()
+ # const field1& f1() const
+ for f in sd.fields_ipdl_order():
+ get = MethodDefn(
+ MethodDecl(
+ f.getMethod().name, params=[], ret=f.refType(), force_inline=True
+ )
+ )
+ get.addstmt(StmtReturn(f.refExpr()))
+
+ getconstdecl = deepcopy(get.decl)
+ getconstdecl.ret = f.constRefType()
+ getconstdecl.const = True
+ getconst = MethodDefn(getconstdecl)
+ getconst.addstmt(StmtReturn(f.constRefExpr()))
+
+ struct.addstmts([get, getconst, Whitespace.NL])
+
+ # private:
+ struct.addstmt(Label.PRIVATE)
+
+ # Static assertions to ensure our assumptions about field layout match
+ # what the compiler is actually producing. We define this as a member
+ # function, rather than throwing the assertions in the constructor or
+ # similar, because we don't want to evaluate the static assertions every
+ # time the header file containing the structure is included.
+ staticasserts = _fieldStaticAssertions(sd)
+ if staticasserts:
+ method = MethodDefn(
+ MethodDecl("StaticAssertions", params=[], ret=Type.VOID, const=True)
+ )
+ method.addstmts(staticasserts)
+ struct.addstmts([method])
+
+ # members
+ struct.addstmts(
+ [
+ StmtDecl(Decl(_effectiveMemberType(f), f.memberVar().name))
+ for f in sd.fields_member_order()
+ ]
+ )
+
+ return forwarddeclstmts, fulldecltypes, struct
+
+
+def _effectiveMemberType(f):
+ effective_type = f.bareType()
+ # Structs must be copyable for backwards compatibility reasons, so we use
+ # CopyableTArray<T> as their member type for arrays. This is not exposed
+ # in the method signatures, these keep using nsTArray<T>, which is a base
+ # class of CopyableTArray<T>.
+ if effective_type.name == "nsTArray":
+ effective_type.name = "CopyableTArray"
+ return Type("::mozilla::ipc::IPDLStructMember", T=[effective_type])
+
+
+# --------------------------------------------------
+
+
+def _generateCxxUnion(ud):
+ # This Union class basically consists of a type (enum) and a
+ # union for storage. The union can contain POD and non-POD
+ # types. Each type needs a copy/move ctor, assignment operators,
+ # and dtor.
+ #
+ # Rather than templating this class and only providing
+ # specializations for the types we support, which is slightly
+ # "unsafe" in that C++ code can add additional specializations
+ # without the IPDL compiler's knowledge, we instead explicitly
+ # implement non-templated methods for each supported type.
+ #
+ # The one complication that arises is that C++, for arcane
+ # reasons, does not allow the placement destructor of a
+ # builtin type, like int, to be directly invoked. So we need
+ # to hack around this by internally typedef'ing all
+ # constituent types. Sigh.
+ #
+ # So, for each type, this "Union" class needs:
+ # (private)
+ # - entry in the type enum
+ # - entry in the storage union
+ # - [type]ptr() method to get a type* from the underlying union
+ # - same as above to get a const type*
+ # - typedef to hack around placement delete limitations
+ # (public)
+ # - placement delete case for dtor
+ # - copy ctor
+ # - move ctor
+ # - case in generic copy ctor
+ # - copy operator= impl
+ # - move operator= impl
+ # - case in generic operator=
+ # - operator [type&]
+ # - operator [const type&] const
+ # - [type&] get_[type]()
+ # - [const type&] get_[type]() const
+ #
+ cls = Class(ud.name, final=True)
+ # const Union&, i.e., Union type with inparam semantics
+ inClsType = Type(ud.name, const=True, ref=True)
+ refClsType = Type(ud.name, ref=True)
+ rvalueRefClsType = Type(ud.name, rvalref=True)
+ typetype = Type("Type")
+ valuetype = Type("Value")
+ mtypevar = ExprVar("mType")
+ mvaluevar = ExprVar("mValue")
+ maybedtorvar = ExprVar("MaybeDestroy")
+ assertsanityvar = ExprVar("AssertSanity")
+ tnonevar = ExprVar("T__None")
+ tlastvar = ExprVar("T__Last")
+
+ def callAssertSanity(uvar=None, expectTypeVar=None):
+ func = assertsanityvar
+ args = []
+ if uvar is not None:
+ func = ExprSelect(uvar, ".", assertsanityvar.name)
+ if expectTypeVar is not None:
+ args.append(expectTypeVar)
+ return ExprCall(func, args=args)
+
+ def maybeDestroy():
+ return StmtExpr(ExprCall(maybedtorvar))
+
+ # compute all the typedefs and forward decls we need to make
+ gettypedeps = _ComputeTypeDeps(ud.decl.type)
+ for c in ud.components:
+ c.ipdltype.accept(gettypedeps)
+
+ usingTypedefs = gettypedeps.usingTypedefs
+ forwarddeclstmts = gettypedeps.forwardDeclStmts
+ fulldecltypes = gettypedeps.fullDeclTypes
+
+ # the |Type| enum, used to switch on the discunion's real type
+ cls.addstmt(Label.PUBLIC)
+ typeenum = TypeEnum(typetype.name)
+ typeenum.addId(tnonevar.name, 0)
+ firstid = ud.components[0].enum()
+ typeenum.addId(firstid, 1)
+ for c in ud.components[1:]:
+ typeenum.addId(c.enum())
+ typeenum.addId(tlastvar.name, ud.components[-1].enum())
+ cls.addstmts([StmtDecl(Decl(typeenum, "")), Whitespace.NL])
+
+ cls.addstmt(Label.PRIVATE)
+ cls.addstmts(
+ usingTypedefs
+ # hacky typedef's that allow placement dtors of builtins
+ + [Typedef(c.internalType(), c.typedef()) for c in ud.components]
+ )
+ cls.addstmt(Whitespace.NL)
+
+ # the C++ union the discunion use for storage
+ valueunion = TypeUnion(valuetype.name)
+ for c in ud.components:
+ valueunion.addComponent(c.unionType(), c.name)
+ cls.addstmts([StmtDecl(Decl(valueunion, "")), Whitespace.NL])
+
+ # for each constituent type T, add private accessors that
+ # return a pointer to the Value union storage casted to |T*|
+ # and |const T*|
+ for c in ud.components:
+ getptr = MethodDefn(
+ MethodDecl(
+ c.getPtrName(), params=[], ret=c.ptrToInternalType(), force_inline=True
+ )
+ )
+ getptr.addstmt(StmtReturn(c.ptrToSelfExpr()))
+
+ getptrconst = MethodDefn(
+ MethodDecl(
+ c.getConstPtrName(),
+ params=[],
+ ret=c.constPtrToType(),
+ const=True,
+ force_inline=True,
+ )
+ )
+ getptrconst.addstmt(StmtReturn(c.constptrToSelfExpr()))
+
+ cls.addstmts([getptr, getptrconst])
+ cls.addstmt(Whitespace.NL)
+
+ # add a helper method that invokes the placement dtor on the
+ # current underlying value, only if |aNewType| is different
+ # than the current type, and returns true if the underlying
+ # value needs to be re-constructed
+ maybedtor = MethodDefn(MethodDecl(maybedtorvar.name, ret=Type.VOID))
+ # wasn't /actually/ dtor'd, but it needs to be re-constructed
+ ifnone = StmtIf(ExprBinary(mtypevar, "==", tnonevar))
+ ifnone.addifstmt(StmtReturn())
+ # need to destroy. switch on underlying type
+ dtorswitch = StmtSwitch(mtypevar)
+ for c in ud.components:
+ dtorswitch.addcase(
+ CaseLabel(c.enum()), StmtBlock([StmtExpr(c.callDtor()), StmtBreak()])
+ )
+ dtorswitch.addcase(
+ DefaultLabel(), StmtBlock([_logicError("not reached"), StmtBreak()])
+ )
+ maybedtor.addstmts([ifnone, dtorswitch])
+ cls.addstmts([maybedtor, Whitespace.NL])
+
+ # add helper methods that ensure the discunion has a
+ # valid type
+ sanity = MethodDefn(
+ MethodDecl(assertsanityvar.name, ret=Type.VOID, const=True, force_inline=True)
+ )
+ sanity.addstmts(
+ [
+ _abortIfFalse(ExprBinary(tnonevar, "<=", mtypevar), "invalid type tag"),
+ _abortIfFalse(ExprBinary(mtypevar, "<=", tlastvar), "invalid type tag"),
+ ]
+ )
+ cls.addstmt(sanity)
+
+ atypevar = ExprVar("aType")
+ sanity2 = MethodDefn(
+ MethodDecl(
+ assertsanityvar.name,
+ params=[Decl(typetype, atypevar.name)],
+ ret=Type.VOID,
+ const=True,
+ force_inline=True,
+ )
+ )
+ sanity2.addstmts(
+ [
+ StmtExpr(ExprCall(assertsanityvar)),
+ _abortIfFalse(ExprBinary(mtypevar, "==", atypevar), "unexpected type tag"),
+ ]
+ )
+ cls.addstmts([sanity2, Whitespace.NL])
+
+ # ---- begin public methods -----
+
+ # Union() default ctor
+ cls.addstmts(
+ [
+ Label.PUBLIC,
+ ConstructorDefn(
+ ConstructorDecl(ud.name, force_inline=True),
+ memberinits=[ExprMemberInit(mtypevar, [tnonevar])],
+ ),
+ Whitespace.NL,
+ ]
+ )
+
+ # Union(const T&) copy & Union(T&&) move ctors
+ othervar = ExprVar("aOther")
+ for c in ud.components:
+ if not _cxxTypeNeedsMoveForData(c.ipdltype):
+ copyctor = ConstructorDefn(
+ ConstructorDecl(ud.name, params=[Decl(c.constRefType(), othervar.name)])
+ )
+ copyctor.addstmts(
+ [
+ StmtExpr(c.callCtor(othervar)),
+ StmtExpr(ExprAssn(mtypevar, c.enumvar())),
+ ]
+ )
+ cls.addstmts([copyctor, Whitespace.NL])
+
+ if not _cxxTypeCanMove(c.ipdltype):
+ continue
+ movector = ConstructorDefn(
+ ConstructorDecl(ud.name, params=[Decl(c.forceMoveType(), othervar.name)])
+ )
+ movector.addstmts(
+ [
+ StmtExpr(c.callCtor(ExprMove(othervar))),
+ StmtExpr(ExprAssn(mtypevar, c.enumvar())),
+ ]
+ )
+ cls.addstmts([movector, Whitespace.NL])
+
+ unionNeedsMove = any(_cxxTypeNeedsMoveForData(c.ipdltype) for c in ud.components)
+
+ # Union(const Union&) copy ctor
+ if not unionNeedsMove:
+ copyctor = ConstructorDefn(
+ ConstructorDecl(ud.name, params=[Decl(inClsType, othervar.name)])
+ )
+ othertype = ud.callType(othervar)
+ copyswitch = StmtSwitch(othertype)
+ for c in ud.components:
+ copyswitch.addcase(
+ CaseLabel(c.enum()),
+ StmtBlock(
+ [
+ StmtExpr(
+ c.callCtor(
+ ExprCall(
+ ExprSelect(othervar, ".", c.getConstTypeName())
+ )
+ )
+ ),
+ StmtBreak(),
+ ]
+ ),
+ )
+ copyswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()]))
+ copyswitch.addcase(
+ DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()])
+ )
+ copyctor.addstmts(
+ [
+ StmtExpr(callAssertSanity(uvar=othervar)),
+ copyswitch,
+ StmtExpr(ExprAssn(mtypevar, othertype)),
+ ]
+ )
+ cls.addstmts([copyctor, Whitespace.NL])
+
+ # Union(Union&&) move ctor
+ movector = ConstructorDefn(
+ ConstructorDecl(ud.name, params=[Decl(rvalueRefClsType, othervar.name)])
+ )
+ othertypevar = ExprVar("t")
+ moveswitch = StmtSwitch(othertypevar)
+ for c in ud.components:
+ case = StmtBlock()
+ if c.recursive:
+ # This is sound as we set othervar.mTypeVar to T__None after the
+ # switch. The pointer in the union will be left dangling.
+ case.addstmts(
+ [
+ # ptr_C() = other.ptr_C()
+ StmtExpr(
+ ExprAssn(
+ c.callGetPtr(),
+ ExprCall(
+ ExprSelect(othervar, ".", ExprVar(c.getPtrName()))
+ ),
+ )
+ )
+ ]
+ )
+ else:
+ case.addstmts(
+ [
+ # new ... (Move(other.get_C()))
+ StmtExpr(
+ c.callCtor(
+ ExprMove(
+ ExprCall(ExprSelect(othervar, ".", c.getTypeName()))
+ )
+ )
+ ),
+ # other.MaybeDestroy(T__None)
+ StmtExpr(ExprCall(ExprSelect(othervar, ".", maybedtorvar))),
+ ]
+ )
+ case.addstmts([StmtBreak()])
+ moveswitch.addcase(CaseLabel(c.enum()), case)
+ moveswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()]))
+ moveswitch.addcase(
+ DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()])
+ )
+ movector.addstmts(
+ [
+ StmtExpr(callAssertSanity(uvar=othervar)),
+ StmtDecl(Decl(typetype, othertypevar.name), init=ud.callType(othervar)),
+ moveswitch,
+ StmtExpr(ExprAssn(ExprSelect(othervar, ".", mtypevar), tnonevar)),
+ StmtExpr(ExprAssn(mtypevar, othertypevar)),
+ ]
+ )
+ cls.addstmts([movector, Whitespace.NL])
+
+ # ~Union()
+ dtor = DestructorDefn(DestructorDecl(ud.name))
+ dtor.addstmt(maybeDestroy())
+ cls.addstmts([dtor, Whitespace.NL])
+
+ # type()
+ typemeth = MethodDefn(
+ MethodDecl("type", ret=typetype, const=True, force_inline=True)
+ )
+ typemeth.addstmt(StmtReturn(mtypevar))
+ cls.addstmts([typemeth, Whitespace.NL])
+
+ # Union& operator= methods
+ rhsvar = ExprVar("aRhs")
+ for c in ud.components:
+
+ def opeqBody(rhs):
+ return [
+ # might need to placement-delete old value first
+ maybeDestroy(),
+ StmtExpr(c.callCtor(rhs)),
+ StmtExpr(ExprAssn(mtypevar, c.enumvar())),
+ StmtReturn(ExprDeref(ExprVar.THIS)),
+ ]
+
+ if not _cxxTypeNeedsMoveForData(c.ipdltype):
+ # Union& operator=(const T&)
+ opeq = MethodDefn(
+ MethodDecl(
+ "operator=",
+ params=[Decl(c.constRefType(), rhsvar.name)],
+ ret=refClsType,
+ )
+ )
+ opeq.addstmts(opeqBody(rhsvar))
+ cls.addstmts([opeq, Whitespace.NL])
+
+ # Union& operator=(T&&)
+ if not _cxxTypeCanMove(c.ipdltype):
+ continue
+
+ opeq = MethodDefn(
+ MethodDecl(
+ "operator=",
+ params=[Decl(c.forceMoveType(), rhsvar.name)],
+ ret=refClsType,
+ )
+ )
+ opeq.addstmts(opeqBody(ExprMove(rhsvar)))
+ cls.addstmts([opeq, Whitespace.NL])
+
+ # Union& operator=(const Union&)
+ if not unionNeedsMove:
+ opeq = MethodDefn(
+ MethodDecl(
+ "operator=", params=[Decl(inClsType, rhsvar.name)], ret=refClsType
+ )
+ )
+ rhstypevar = ExprVar("t")
+ opeqswitch = StmtSwitch(rhstypevar)
+ for c in ud.components:
+ case = StmtBlock()
+ case.addstmts(
+ [
+ maybeDestroy(),
+ StmtExpr(
+ c.callCtor(
+ ExprCall(ExprSelect(rhsvar, ".", c.getConstTypeName()))
+ )
+ ),
+ StmtBreak(),
+ ]
+ )
+ opeqswitch.addcase(CaseLabel(c.enum()), case)
+ opeqswitch.addcase(
+ CaseLabel(tnonevar.name),
+ StmtBlock([maybeDestroy(), StmtBreak()]),
+ )
+ opeqswitch.addcase(
+ DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()])
+ )
+ opeq.addstmts(
+ [
+ StmtExpr(callAssertSanity(uvar=rhsvar)),
+ StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)),
+ opeqswitch,
+ StmtExpr(ExprAssn(mtypevar, rhstypevar)),
+ StmtReturn(ExprDeref(ExprVar.THIS)),
+ ]
+ )
+ cls.addstmts([opeq, Whitespace.NL])
+
+ # Union& operator=(Union&&)
+ opeq = MethodDefn(
+ MethodDecl(
+ "operator=", params=[Decl(rvalueRefClsType, rhsvar.name)], ret=refClsType
+ )
+ )
+ rhstypevar = ExprVar("t")
+ opeqswitch = StmtSwitch(rhstypevar)
+ for c in ud.components:
+ case = StmtBlock()
+ if c.recursive:
+ case.addstmts(
+ [
+ maybeDestroy(),
+ StmtExpr(
+ ExprAssn(
+ c.callGetPtr(),
+ ExprCall(ExprSelect(rhsvar, ".", ExprVar(c.getPtrName()))),
+ )
+ ),
+ ]
+ )
+ else:
+ case.addstmts(
+ [
+ maybeDestroy(),
+ StmtExpr(
+ c.callCtor(
+ ExprMove(ExprCall(ExprSelect(rhsvar, ".", c.getTypeName())))
+ )
+ ),
+ # other.MaybeDestroy()
+ StmtExpr(ExprCall(ExprSelect(rhsvar, ".", maybedtorvar))),
+ ]
+ )
+ case.addstmts([StmtBreak()])
+ opeqswitch.addcase(CaseLabel(c.enum()), case)
+ opeqswitch.addcase(
+ CaseLabel(tnonevar.name),
+ StmtBlock([maybeDestroy(), StmtBreak()]),
+ )
+ opeqswitch.addcase(
+ DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()])
+ )
+ opeq.addstmts(
+ [
+ StmtExpr(callAssertSanity(uvar=rhsvar)),
+ StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)),
+ opeqswitch,
+ StmtExpr(ExprAssn(ExprSelect(rhsvar, ".", mtypevar), tnonevar)),
+ StmtExpr(ExprAssn(mtypevar, rhstypevar)),
+ StmtReturn(ExprDeref(ExprVar.THIS)),
+ ]
+ )
+ cls.addstmts([opeq, Whitespace.NL])
+
+ if "Comparable" in ud.attributes:
+ # bool operator==(const T&)
+ for c in ud.components:
+ opeqeq = MethodDefn(
+ MethodDecl(
+ "operator==",
+ params=[Decl(c.constRefType(), rhsvar.name)],
+ ret=Type.BOOL,
+ const=True,
+ )
+ )
+ opeqeq.addstmt(
+ StmtReturn(ExprBinary(ExprCall(ExprVar(c.getTypeName())), "==", rhsvar))
+ )
+ cls.addstmts([opeqeq, Whitespace.NL])
+
+ # bool operator==(const Union&)
+ opeqeq = MethodDefn(
+ MethodDecl(
+ "operator==",
+ params=[Decl(inClsType, rhsvar.name)],
+ ret=Type.BOOL,
+ const=True,
+ )
+ )
+ iftypesmismatch = StmtIf(ExprBinary(ud.callType(), "!=", ud.callType(rhsvar)))
+ iftypesmismatch.addifstmt(StmtReturn.FALSE)
+ opeqeq.addstmts([iftypesmismatch, Whitespace.NL])
+
+ opeqeqswitch = StmtSwitch(ud.callType())
+ for c in ud.components:
+ case = StmtBlock()
+ case.addstmt(
+ StmtReturn(
+ ExprBinary(
+ ExprCall(ExprVar(c.getTypeName())),
+ "==",
+ ExprCall(ExprSelect(rhsvar, ".", c.getTypeName())),
+ )
+ )
+ )
+ opeqeqswitch.addcase(CaseLabel(c.enum()), case)
+ opeqeqswitch.addcase(
+ DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn.FALSE])
+ )
+ opeqeq.addstmt(opeqeqswitch)
+
+ cls.addstmts([opeqeq, Whitespace.NL])
+
+ # accessors for each type: operator T&, operator const T&,
+ # T& get(), const T& get()
+ for c in ud.components:
+ getValueVar = ExprVar(c.getTypeName())
+ getConstValueVar = ExprVar(c.getConstTypeName())
+
+ getvalue = MethodDefn(
+ MethodDecl(getValueVar.name, ret=c.refType(), force_inline=True)
+ )
+ getvalue.addstmts(
+ [
+ StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())),
+ StmtReturn(ExprDeref(c.callGetPtr())),
+ ]
+ )
+
+ getconstvalue = MethodDefn(
+ MethodDecl(
+ getConstValueVar.name,
+ ret=c.constRefType(),
+ const=True,
+ force_inline=True,
+ )
+ )
+ getconstvalue.addstmts(
+ [
+ StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())),
+ StmtReturn(c.getConstValue()),
+ ]
+ )
+
+ cls.addstmts([getvalue, getconstvalue])
+
+ optype = MethodDefn(MethodDecl("", typeop=c.refType(), force_inline=True))
+ optype.addstmt(StmtReturn(ExprCall(getValueVar)))
+ opconsttype = MethodDefn(
+ MethodDecl("", const=True, typeop=c.constRefType(), force_inline=True)
+ )
+ opconsttype.addstmt(StmtReturn(ExprCall(getConstValueVar)))
+
+ cls.addstmts([optype, opconsttype, Whitespace.NL])
+ # private vars
+ cls.addstmts(
+ [
+ Label.PRIVATE,
+ StmtDecl(Decl(valuetype, mvaluevar.name)),
+ StmtDecl(Decl(typetype, mtypevar.name)),
+ ]
+ )
+
+ return forwarddeclstmts, fulldecltypes, cls
+
+
+# -----------------------------------------------------------------------------
+
+
+class _FindFriends(ipdl.ast.Visitor):
+ def __init__(self):
+ self.mytype = None # ProtocolType
+ self.vtype = None # ProtocolType
+ self.friends = set() # set<ProtocolType>
+
+ def findFriends(self, ptype):
+ self.mytype = ptype
+ for toplvl in ptype.toplevels():
+ self.walkDownTheProtocolTree(toplvl)
+ return self.friends
+
+ # TODO could make this into a _iterProtocolTreeHelper ...
+ def walkDownTheProtocolTree(self, ptype):
+ if ptype != self.mytype:
+ # don't want to |friend| ourself!
+ self.visit(ptype)
+ for mtype in ptype.manages:
+ if mtype is not ptype:
+ self.walkDownTheProtocolTree(mtype)
+
+ def visit(self, ptype):
+ # |vtype| is the type currently being visited
+ savedptype = self.vtype
+ self.vtype = ptype
+ ptype._ast.accept(self)
+ self.vtype = savedptype
+
+ def visitMessageDecl(self, md):
+ for it in self.iterActorParams(md):
+ if it.protocol == self.mytype:
+ self.friends.add(self.vtype)
+
+ def iterActorParams(self, md):
+ for param in md.inParams:
+ for actor in ipdl.type.iteractortypes(param.type):
+ yield actor
+ for ret in md.outParams:
+ for actor in ipdl.type.iteractortypes(ret.type):
+ yield actor
+
+
+class _GenerateProtocolActorCode(ipdl.ast.Visitor):
+ def __init__(self, myside):
+ self.side = myside # "parent" or "child"
+ self.prettyside = myside.title()
+ self.clsname = None
+ self.protocol = None
+ self.hdrfile = None
+ self.cppfile = None
+ self.ns = None
+ self.cls = None
+ self.protocolCxxIncludes = []
+ self.actorForwardDecls = []
+ self.usingDecls = []
+ self.externalIncludes = set()
+ self.nonForwardDeclaredHeaders = set()
+ self.typedefSet = set(
+ [
+ Typedef(Type("mozilla::ipc::ActorHandle"), "ActorHandle"),
+ Typedef(Type("base::ProcessId"), "ProcessId"),
+ Typedef(Type("mozilla::ipc::ProtocolId"), "ProtocolId"),
+ Typedef(Type("mozilla::ipc::Endpoint"), "Endpoint", ["FooSide"]),
+ Typedef(
+ Type("mozilla::ipc::ManagedEndpoint"),
+ "ManagedEndpoint",
+ ["FooSide"],
+ ),
+ Typedef(Type("mozilla::UniquePtr"), "UniquePtr", ["T"]),
+ Typedef(
+ Type("mozilla::ipc::ResponseRejectReason"), "ResponseRejectReason"
+ ),
+ ]
+ )
+
+ def lower(self, tu, clsname, cxxHeaderFile, cxxFile):
+ self.clsname = clsname
+ self.hdrfile = cxxHeaderFile
+ self.cppfile = cxxFile
+ tu.accept(self)
+
+ def standardTypedefs(self):
+ return [
+ Typedef(Type("mozilla::ipc::IProtocol"), "IProtocol"),
+ Typedef(Type("IPC::Message"), "Message"),
+ Typedef(Type("base::ProcessHandle"), "ProcessHandle"),
+ Typedef(Type("mozilla::ipc::MessageChannel"), "MessageChannel"),
+ Typedef(Type("mozilla::ipc::SharedMemory"), "SharedMemory"),
+ ]
+
+ def visitTranslationUnit(self, tu):
+ self.protocol = tu.protocol
+
+ hf = self.hdrfile
+ cf = self.cppfile
+
+ # make the C++ header
+ hf.addthings(
+ [_DISCLAIMER]
+ + _includeGuardStart(hf)
+ + [
+ Whitespace.NL,
+ CppDirective("include", '"' + _protocolHeaderName(tu.protocol) + '.h"'),
+ ]
+ )
+
+ for inc in tu.includes:
+ inc.accept(self)
+ for inc in tu.cxxIncludes:
+ inc.accept(self)
+
+ for using in tu.builtinUsing:
+ using.accept(self)
+ for using in tu.using:
+ using.accept(self)
+ for su in tu.structsAndUnions:
+ su.accept(self)
+
+ # this generates the actor's full impl in self.cls
+ tu.protocol.accept(self)
+
+ clsdecl, clsdefn = _splitClassDeclDefn(self.cls)
+
+ # XXX damn C++ ... return types in the method defn aren't in
+ # class scope
+ for stmt in clsdefn.stmts:
+ if isinstance(stmt, MethodDefn):
+ if stmt.decl.ret and stmt.decl.ret.name == "Result":
+ stmt.decl.ret.name = clsdecl.name + "::" + stmt.decl.ret.name
+
+ def setToIncludes(s):
+ return [CppDirective("include", '"%s"' % i) for i in sorted(iter(s))]
+
+ def makeNamespace(p, file):
+ if 0 == len(p.namespaces):
+ return file
+ ns = Namespace(p.namespaces[-1].name)
+ outerns = _putInNamespaces(ns, p.namespaces[:-1])
+ file.addthing(outerns)
+ return ns
+
+ if len(self.nonForwardDeclaredHeaders) != 0:
+ self.hdrfile.addthings(
+ [
+ Whitespace("// Headers for things that cannot be forward declared"),
+ Whitespace.NL,
+ ]
+ + setToIncludes(self.nonForwardDeclaredHeaders)
+ + [Whitespace.NL]
+ )
+ self.hdrfile.addthings(self.actorForwardDecls)
+ self.hdrfile.addthings(self.usingDecls)
+
+ hdrns = makeNamespace(self.protocol, self.hdrfile)
+ hdrns.addstmts(
+ [Whitespace.NL, Whitespace.NL, clsdecl, Whitespace.NL, Whitespace.NL]
+ )
+
+ actortype = ActorType(tu.protocol.decl.type)
+ traitsdecl, traitsdefn = _ParamTraits.actorPickling(actortype, self.side)
+
+ self.hdrfile.addthings([traitsdecl, Whitespace.NL] + _includeGuardEnd(hf))
+
+ # If the implementation type is not overridden, add an implicit import
+ # for the default implementation header file. Explicit implementation
+ # types will specify their headers manually with `include`.
+ if self.protocol.implAttribute(self.side) is None:
+ assert self.protocol.name.startswith("P")
+ self.externalIncludes.add(
+ "".join(n.name + "/" for n in self.protocol.namespaces)
+ + self.protocol.name[1:]
+ + self.side.capitalize()
+ + ".h"
+ )
+
+ # make the .cpp file
+ cf.addthings(
+ [
+ _DISCLAIMER,
+ Whitespace.NL,
+ CppDirective(
+ "include",
+ '"' + _protocolHeaderName(self.protocol, self.side) + '.h"',
+ ),
+ ]
+ + setToIncludes(self.externalIncludes)
+ )
+
+ cf.addthings(
+ (
+ [Whitespace.NL]
+ + [
+ CppDirective("include", '"%s.h"' % (inc))
+ for inc in self.protocolCxxIncludes
+ ]
+ + [Whitespace.NL]
+ + [
+ CppDirective("include", '"%s"' % filename)
+ for filename in ipdl.builtin.CppIncludes
+ ]
+ + [Whitespace.NL]
+ )
+ )
+
+ cppns = makeNamespace(self.protocol, cf)
+ cppns.addstmts(
+ [Whitespace.NL, Whitespace.NL, clsdefn, Whitespace.NL, Whitespace.NL]
+ )
+
+ cf.addthing(traitsdefn)
+
+ def visitUsingStmt(self, using):
+ if using.decl.fullname is not None:
+ self.typedefSet.add(
+ Typedef(Type(using.decl.fullname), using.decl.shortname)
+ )
+
+ if using.header is None:
+ return
+
+ if using.canBeForwardDeclared():
+ spec = using.type
+
+ self.usingDecls.extend(
+ [
+ _makeForwardDeclForQClass(
+ spec.baseid,
+ spec.quals,
+ cls=using.isClass(),
+ struct=using.isStruct(),
+ ),
+ Whitespace.NL,
+ ]
+ )
+ self.externalIncludes.add(using.header)
+ else:
+ self.nonForwardDeclaredHeaders.add(using.header)
+
+ def visitCxxInclude(self, inc):
+ self.externalIncludes.add(inc.file)
+
+ def visitInclude(self, inc):
+ if inc.tu.filetype == "header":
+ # Including a header will declare any globals defined by "using"
+ # statements into our scope. To serialize these, we also may need
+ # cxx include statements, so visit them as well.
+ for cxxinc in inc.tu.cxxIncludes:
+ cxxinc.accept(self)
+ for using in inc.tu.using:
+ using.accept(self)
+ for su in inc.tu.structsAndUnions:
+ su.accept(self)
+ else:
+ # Includes for protocols only include types explicitly exported by
+ # those protocols.
+ ip = inc.tu.protocol
+ if ip == self.protocol:
+ return
+
+ self.actorForwardDecls.extend(
+ [
+ _makeForwardDeclForActor(ip.decl.type, self.side),
+ _makeForwardDeclForActor(ip.decl.type, _otherSide(self.side)),
+ Whitespace.NL,
+ ]
+ )
+ self.protocolCxxIncludes.append(_protocolHeaderName(ip, self.side))
+
+ if ip.decl.fullname is not None:
+ self.typedefSet.add(
+ Typedef(
+ Type(_actorName(ip.decl.fullname, self.side.title())),
+ _actorName(ip.decl.shortname, self.side.title()),
+ )
+ )
+
+ self.typedefSet.add(
+ Typedef(
+ Type(
+ _actorName(ip.decl.fullname, _otherSide(self.side).title())
+ ),
+ _actorName(ip.decl.shortname, _otherSide(self.side).title()),
+ )
+ )
+
+ def visitStructDecl(self, sd):
+ if sd.decl.fullname is not None:
+ self.typedefSet.add(Typedef(Type(sd.fqClassName()), sd.name))
+
+ def visitUnionDecl(self, ud):
+ if ud.decl.fullname is not None:
+ self.typedefSet.add(Typedef(Type(ud.fqClassName()), ud.name))
+
+ def visitProtocol(self, p):
+ self.hdrfile.addcode(
+ """
+ #ifdef DEBUG
+ #include "prenv.h"
+ #endif // DEBUG
+
+ #include "mozilla/Tainting.h"
+ #include "mozilla/ipc/MessageChannel.h"
+ #include "mozilla/ipc/ProtocolUtils.h"
+ """
+ )
+
+ self.protocol = p
+ ptype = p.decl.type
+ toplevel = p.decl.type.toplevel()
+
+ hasAsyncReturns = False
+ for md in p.messageDecls:
+ if md.hasAsyncReturns():
+ hasAsyncReturns = True
+ break
+
+ inherits = []
+ if ptype.isToplevel():
+ inherits.append(Inherit(p.openedProtocolInterfaceType(), viz="public"))
+ else:
+ inherits.append(Inherit(p.managerInterfaceType(), viz="public"))
+
+ if ptype.isToplevel() and self.side == "parent":
+ self.hdrfile.addthings(
+ [_makeForwardDeclForQClass("nsIFile", []), Whitespace.NL]
+ )
+
+ self.cls = Class(self.clsname, inherits=inherits, abstract=True)
+
+ self.cls.addstmt(Label.PRIVATE)
+ friends = _FindFriends().findFriends(ptype)
+ if ptype.isManaged():
+ friends.update(ptype.managers)
+
+ # |friend| managed actors so that they can call our Dealloc*()
+ friends.update(ptype.manages)
+
+ # don't friend ourself if we're a self-managed protocol
+ friends.discard(ptype)
+
+ for friend in sorted(friends, key=lambda f: f.fullname()):
+ self.actorForwardDecls.extend(
+ [_makeForwardDeclForActor(friend, self.prettyside), Whitespace.NL]
+ )
+ self.cls.addstmt(
+ FriendClassDecl(_actorName(friend.fullname(), self.prettyside))
+ )
+
+ self.cls.addstmt(Label.PROTECTED)
+ for typedef in sorted(self.typedefSet):
+ self.cls.addstmt(typedef)
+
+ self.cls.addstmt(Whitespace.NL)
+
+ if hasAsyncReturns:
+ self.cls.addstmt(Label.PUBLIC)
+ for md in p.messageDecls:
+ if self.sendsMessage(md) and md.hasAsyncReturns():
+ self.cls.addstmt(
+ Typedef(_makePromise(md.returns, self.side), md.promiseName())
+ )
+ if self.receivesMessage(md) and md.hasAsyncReturns():
+ self.cls.addstmt(
+ Typedef(_makeResolver(md.returns, self.side), md.resolverName())
+ )
+ self.cls.addstmt(Whitespace.NL)
+
+ self.cls.addstmt(Label.PROTECTED)
+ # interface methods that the concrete subclass has to impl
+ for md in p.messageDecls:
+ isctor, isdtor = md.decl.type.isCtor(), md.decl.type.isDtor()
+
+ if self.receivesMessage(md):
+ # generate Recv/Answer* interface
+ implicit = not isdtor
+ returnsems = "resolver" if md.decl.type.isAsync() else "out"
+ recvDecl = MethodDecl(
+ md.recvMethod(),
+ params=md.makeCxxParams(
+ paramsems="move",
+ returnsems=returnsems,
+ side=self.side,
+ implicit=implicit,
+ direction="recv",
+ ),
+ ret=Type("mozilla::ipc::IPCResult"),
+ methodspec=MethodSpec.VIRTUAL,
+ )
+
+ # These method implementations cause problems when trying to
+ # override them with different types in a direct call class.
+ #
+ # For the `isdtor` case there's a simple solution: it doesn't
+ # make much sense to specify arguments and then completely
+ # ignore them, and the no-arg case isn't a problem for
+ # overriding.
+ if isctor or (isdtor and not md.inParams):
+ defaultRecv = MethodDefn(recvDecl)
+ defaultRecv.addcode("return IPC_OK();\n")
+ self.cls.addstmt(defaultRecv)
+ elif self.protocol.implAttribute(self.side) == "virtual":
+ # If we're using virtual calls, we need the methods to be
+ # declared on the base class.
+ recvDecl.methodspec = MethodSpec.PURE
+ self.cls.addstmt(StmtDecl(recvDecl))
+
+ # If we're using virtual calls, we need the methods to be declared on
+ # the base class.
+ if self.protocol.implAttribute(self.side) == "virtual":
+ for md in p.messageDecls:
+ managed = md.decl.type.constructedType()
+ if not ptype.isManagerOf(managed) or md.decl.type.isDtor():
+ continue
+
+ # add the Alloc interface for managed actors
+ actortype = md.actorDecl().bareType(self.side)
+
+ if managed.isRefcounted():
+ if not self.receivesMessage(md):
+ continue
+
+ actortype.ptr = False
+ actortype = _alreadyaddrefed(actortype)
+
+ self.cls.addstmt(
+ StmtDecl(
+ MethodDecl(
+ _allocMethod(managed, self.side),
+ params=md.makeCxxParams(
+ side=self.side, implicit=False, direction="recv"
+ ),
+ ret=actortype,
+ methodspec=MethodSpec.PURE,
+ )
+ )
+ )
+
+ # add the Dealloc interface for all managed non-refcounted actors,
+ # even without ctors. This is useful for protocols which use
+ # ManagedEndpoint for construction.
+ for managed in ptype.manages:
+ if managed.isRefcounted():
+ continue
+
+ self.cls.addstmt(
+ StmtDecl(
+ MethodDecl(
+ _deallocMethod(managed, self.side),
+ params=[
+ Decl(p.managedCxxType(managed, self.side), "aActor")
+ ],
+ ret=Type.BOOL,
+ methodspec=MethodSpec.PURE,
+ )
+ )
+ )
+
+ if ptype.isToplevel():
+ # void ProcessingError(code); default to no-op
+ processingerror = MethodDefn(
+ MethodDecl(
+ p.processingErrorVar().name,
+ params=[
+ Param(_Result.Type(), "aCode"),
+ Param(Type("char", const=True, ptr=True), "aReason"),
+ ],
+ methodspec=MethodSpec.OVERRIDE,
+ )
+ )
+
+ # bool ShouldContinueFromReplyTimeout(); default to |true|
+ shouldcontinue = MethodDefn(
+ MethodDecl(
+ p.shouldContinueFromTimeoutVar().name,
+ ret=Type.BOOL,
+ methodspec=MethodSpec.OVERRIDE,
+ )
+ )
+ shouldcontinue.addcode("return true;\n")
+
+ self.cls.addstmts(
+ [
+ processingerror,
+ shouldcontinue,
+ Whitespace.NL,
+ ]
+ )
+
+ self.cls.addstmts(([Label.PUBLIC] + self.standardTypedefs() + [Whitespace.NL]))
+
+ self.cls.addstmt(Label.PUBLIC)
+ # Actor()
+ ctor = ConstructorDefn(ConstructorDecl(self.clsname))
+ side = ExprVar("mozilla::ipc::" + self.side.title() + "Side")
+ if ptype.isToplevel():
+ name = ExprLiteral.String(_actorName(p.name, self.side))
+ ctor.memberinits = [
+ ExprMemberInit(
+ ExprVar("mozilla::ipc::IToplevelProtocol"),
+ [name, _protocolId(ptype), side],
+ )
+ ]
+ else:
+ ctor.memberinits = [
+ ExprMemberInit(
+ ExprVar("mozilla::ipc::IProtocol"), [_protocolId(ptype), side]
+ )
+ ]
+
+ ctor.addcode("MOZ_COUNT_CTOR(${clsname});\n", clsname=self.clsname)
+ self.cls.addstmts([ctor, Whitespace.NL])
+
+ # ~Actor()
+ dtor = DestructorDefn(
+ DestructorDecl(self.clsname, methodspec=MethodSpec.VIRTUAL)
+ )
+ dtor.addcode("MOZ_COUNT_DTOR(${clsname});\n", clsname=self.clsname)
+
+ self.cls.addstmts([dtor, Whitespace.NL])
+
+ if ptype.isRefcounted():
+ if not ptype.isToplevel():
+ self.cls.addcode(
+ """
+ NS_INLINE_DECL_PURE_VIRTUAL_REFCOUNTING
+ """
+ )
+ self.cls.addstmt(Label.PROTECTED)
+ self.cls.addcode(
+ """
+ void ActorAlloc() final { AddRef(); }
+ void ActorDealloc() final { Release(); }
+ """
+ )
+
+ self.cls.addstmt(Label.PUBLIC)
+ if ptype.hasOtherPid():
+ otherpidmeth = MethodDefn(
+ MethodDecl("OtherPid", ret=Type("::base::ProcessId"), const=True)
+ )
+ otherpidmeth.addcode(
+ """
+ ::base::ProcessId pid =
+ ::mozilla::ipc::IProtocol::ToplevelProtocol()->OtherPidMaybeInvalid();
+ MOZ_RELEASE_ASSERT(pid != ::base::kInvalidProcessId);
+ return pid;
+ """
+ )
+ self.cls.addstmts([otherpidmeth, Whitespace.NL])
+
+ if not ptype.isToplevel():
+ if 1 == len(p.managers):
+ # manager() const
+ managertype = p.managerActorType(self.side, ptr=True)
+ managermeth = MethodDefn(
+ MethodDecl("Manager", ret=managertype, const=True)
+ )
+ managermeth.addcode(
+ """
+ return static_cast<${type}>(IProtocol::Manager());
+ """,
+ type=managertype,
+ )
+
+ self.cls.addstmts([managermeth, Whitespace.NL])
+
+ def actorFromIter(itervar):
+ return ExprCode("${iter}.Get()->GetKey()", iter=itervar)
+
+ def forLoopOverHashtable(hashtable, itervar, const=False):
+ itermeth = "ConstIter" if const else "Iter"
+ return StmtFor(
+ init=ExprCode(
+ "auto ${itervar} = ${hashtable}.${itermeth}()",
+ itervar=itervar,
+ hashtable=hashtable,
+ itermeth=itermeth,
+ ),
+ cond=ExprCode("!${itervar}.Done()", itervar=itervar),
+ update=ExprCode("${itervar}.Next()", itervar=itervar),
+ )
+
+ # Managed[T](Array& inout) const
+ # const Array<T>& Managed() const
+ for managed in ptype.manages:
+ container = p.managedVar(managed, self.side)
+
+ meth = MethodDefn(
+ MethodDecl(
+ p.managedMethod(managed, self.side).name,
+ params=[
+ Decl(
+ _cxxArrayType(
+ p.managedCxxType(managed, self.side), ref=True
+ ),
+ "aArr",
+ )
+ ],
+ const=True,
+ )
+ )
+ meth.addcode("${container}.ToArray(aArr);\n", container=container)
+
+ refmeth = MethodDefn(
+ MethodDecl(
+ p.managedMethod(managed, self.side).name,
+ params=[],
+ ret=p.managedVarType(managed, self.side, const=True, ref=True),
+ const=True,
+ )
+ )
+ refmeth.addcode("return ${container};\n", container=container)
+
+ self.cls.addstmts([meth, refmeth, Whitespace.NL])
+
+ # AllManagedActors(Array& inout) const
+ arrvar = ExprVar("arr__")
+ managedmeth = MethodDefn(
+ MethodDecl(
+ "AllManagedActors",
+ params=[
+ Decl(
+ _cxxArrayType(_refptr(_cxxLifecycleProxyType()), ref=True),
+ arrvar.name,
+ )
+ ],
+ methodspec=MethodSpec.OVERRIDE,
+ const=True,
+ )
+ )
+
+ # Count the number of managed actors, and allocate space in the output array.
+ managedmeth.addcode(
+ """
+ uint32_t total = 0;
+ """
+ )
+ for managed in ptype.manages:
+ managedmeth.addcode(
+ """
+ total += ${container}.Count();
+ """,
+ container=p.managedVar(managed, self.side),
+ )
+ managedmeth.addcode(
+ """
+ arr__.SetCapacity(total);
+
+ """
+ )
+
+ for managed in ptype.manages:
+ managedmeth.addcode(
+ """
+ for (auto* key : ${container}) {
+ arr__.AppendElement(key->GetLifecycleProxy());
+ }
+
+ """,
+ container=p.managedVar(managed, self.side),
+ )
+
+ self.cls.addstmts([managedmeth, Whitespace.NL])
+
+ # OpenPEndpoint(...)/BindPEndpoint(...)
+ for managed in ptype.manages:
+ self.genManagedEndpoint(managed)
+
+ # OnMessageReceived()/OnCallReceived()
+
+ # save these away for use in message handler case stmts
+ msgvar = ExprVar("msg__")
+ self.msgvar = msgvar
+ replyvar = ExprVar("reply__")
+ self.replyvar = replyvar
+ var = ExprVar("v__")
+ self.var = var
+ # for ctor recv cases, we can't read the actor ID into a PFoo*
+ # because it doesn't exist on this side yet. Use a "special"
+ # actor handle instead
+ handlevar = ExprVar("handle__")
+ self.handlevar = handlevar
+
+ msgtype = ExprCode("msg__.type()")
+ self.asyncSwitch = StmtSwitch(msgtype)
+ self.syncSwitch = None
+ self.interruptSwitch = None
+ if toplevel.isSync() or toplevel.isInterrupt():
+ self.syncSwitch = StmtSwitch(msgtype)
+ if toplevel.isInterrupt():
+ self.interruptSwitch = StmtSwitch(msgtype)
+
+ # Add a handler for the MANAGED_ENDPOINT_BOUND and
+ # MANAGED_ENDPOINT_DROPPED message types for managed actors.
+ if not ptype.isToplevel():
+ clearawaitingmanagedendpointbind = """
+ if (!mAwaitingManagedEndpointBind) {
+ NS_WARNING("Unexpected managed endpoint lifecycle message after actor bound!");
+ return MsgNotAllowed;
+ }
+ mAwaitingManagedEndpointBind = false;
+ """
+ self.asyncSwitch.addcase(
+ CaseLabel("MANAGED_ENDPOINT_BOUND_MESSAGE_TYPE"),
+ StmtBlock(
+ [
+ StmtCode(clearawaitingmanagedendpointbind),
+ StmtReturn(_Result.Processed),
+ ]
+ ),
+ )
+ self.asyncSwitch.addcase(
+ CaseLabel("MANAGED_ENDPOINT_DROPPED_MESSAGE_TYPE"),
+ StmtBlock(
+ [
+ StmtCode(clearawaitingmanagedendpointbind),
+ *self.destroyActor(
+ None,
+ ExprVar.THIS,
+ why=_DestroyReason.ManagedEndpointDropped,
+ ),
+ StmtReturn(_Result.Processed),
+ ]
+ ),
+ )
+
+ # implement Send*() methods and add dispatcher cases to
+ # message switch()es
+ for md in p.messageDecls:
+ self.visitMessageDecl(md)
+
+ # add default cases
+ default = StmtCode(
+ """
+ return MsgNotKnown;
+ """
+ )
+ self.asyncSwitch.addcase(DefaultLabel(), default)
+ if toplevel.isSync() or toplevel.isInterrupt():
+ self.syncSwitch.addcase(DefaultLabel(), default)
+ if toplevel.isInterrupt():
+ self.interruptSwitch.addcase(DefaultLabel(), default)
+
+ self.cls.addstmts(self.implementManagerIface())
+
+ def makeHandlerMethod(name, switch, hasReply, dispatches=False):
+ params = [Decl(Type("Message", const=True, ref=True), msgvar.name)]
+ if hasReply:
+ params.append(Decl(Type("UniquePtr<Message>", ref=True), replyvar.name))
+
+ method = MethodDefn(
+ MethodDecl(
+ name,
+ methodspec=MethodSpec.OVERRIDE,
+ params=params,
+ ret=_Result.Type(),
+ )
+ )
+
+ if not switch:
+ method.addcode(
+ """
+ MOZ_ASSERT_UNREACHABLE("message protocol not supported");
+ return MsgNotKnown;
+ """
+ )
+ return method
+
+ if dispatches:
+ if hasReply:
+ ondeadactor = [StmtReturn(_Result.RouteError)]
+ else:
+ ondeadactor = [
+ self.logMessage(
+ None, ExprAddrOf(msgvar), "Ignored message for dead actor"
+ ),
+ StmtReturn(_Result.Processed),
+ ]
+
+ method.addcode(
+ """
+ int32_t route__ = ${msgvar}.routing_id();
+ if (MSG_ROUTING_CONTROL != route__) {
+ IProtocol* routed__ = Lookup(route__);
+ if (!routed__ || !routed__->GetLifecycleProxy()) {
+ $*{ondeadactor}
+ }
+
+ RefPtr<mozilla::ipc::ActorLifecycleProxy> proxy__ =
+ routed__->GetLifecycleProxy();
+ return proxy__->Get()->${name}($,{args});
+ }
+
+ """,
+ msgvar=msgvar,
+ ondeadactor=ondeadactor,
+ name=name,
+ args=[p.name for p in params],
+ )
+
+ # bug 509581: don't generate the switch stmt if there
+ # is only the default case; MSVC doesn't like that
+ if switch.nr_cases > 1:
+ method.addstmt(switch)
+ else:
+ method.addstmt(StmtReturn(_Result.NotKnown))
+
+ return method
+
+ dispatches = ptype.isToplevel() and ptype.isManager()
+ self.cls.addstmts(
+ [
+ makeHandlerMethod(
+ "OnMessageReceived",
+ self.asyncSwitch,
+ hasReply=False,
+ dispatches=dispatches,
+ ),
+ Whitespace.NL,
+ ]
+ )
+ self.cls.addstmts(
+ [
+ makeHandlerMethod(
+ "OnMessageReceived",
+ self.syncSwitch,
+ hasReply=True,
+ dispatches=dispatches,
+ ),
+ Whitespace.NL,
+ ]
+ )
+ self.cls.addstmts(
+ [
+ makeHandlerMethod(
+ "OnCallReceived",
+ self.interruptSwitch,
+ hasReply=True,
+ dispatches=dispatches,
+ ),
+ Whitespace.NL,
+ ]
+ )
+
+ clearsubtreevar = ExprVar("ClearSubtree")
+
+ if ptype.isToplevel():
+ # OnChannelClose()
+ onclose = MethodDefn(
+ MethodDecl("OnChannelClose", methodspec=MethodSpec.OVERRIDE)
+ )
+ onclose.addcode(
+ """
+ DestroySubtree(NormalShutdown);
+ ClearSubtree();
+ DeallocShmems();
+ if (GetLifecycleProxy()) {
+ GetLifecycleProxy()->Release();
+ }
+ """
+ )
+ self.cls.addstmts([onclose, Whitespace.NL])
+
+ # OnChannelError()
+ onerror = MethodDefn(
+ MethodDecl("OnChannelError", methodspec=MethodSpec.OVERRIDE)
+ )
+ onerror.addcode(
+ """
+ DestroySubtree(AbnormalShutdown);
+ ClearSubtree();
+ DeallocShmems();
+ if (GetLifecycleProxy()) {
+ GetLifecycleProxy()->Release();
+ }
+ """
+ )
+ self.cls.addstmts([onerror, Whitespace.NL])
+
+ if ptype.isToplevel() and ptype.isInterrupt():
+ processnative = MethodDefn(
+ MethodDecl("ProcessNativeEventsInInterruptCall", ret=Type.VOID)
+ )
+ processnative.addcode(
+ """
+ #ifdef OS_WIN
+ GetIPCChannel()->ProcessNativeEventsInInterruptCall();
+ #else
+ FatalError("This method is Windows-only");
+ #endif
+ """
+ )
+
+ self.cls.addstmts([processnative, Whitespace.NL])
+
+ # private methods
+ self.cls.addstmt(Label.PRIVATE)
+
+ # ClearSubtree()
+ clearsubtree = MethodDefn(MethodDecl(clearsubtreevar.name))
+ for managed in ptype.manages:
+ clearsubtree.addcode(
+ """
+ for (auto* key : ${container}) {
+ key->ClearSubtree();
+ }
+ for (auto* key : ${container}) {
+ // Recursively releasing ${container} kids.
+ auto* proxy = key->GetLifecycleProxy();
+ NS_IF_RELEASE(proxy);
+ }
+ ${container}.Clear();
+
+ """,
+ container=p.managedVar(managed, self.side),
+ )
+
+ # don't release our own IPC reference: either the manager will do it,
+ # or we're toplevel
+ self.cls.addstmts([clearsubtree, Whitespace.NL])
+
+ if not ptype.isToplevel():
+ self.cls.addstmts(
+ [
+ StmtDecl(
+ Decl(Type.BOOL, "mAwaitingManagedEndpointBind"),
+ init=ExprLiteral.FALSE,
+ ),
+ Whitespace.NL,
+ ]
+ )
+
+ for managed in ptype.manages:
+ self.cls.addstmts(
+ [
+ StmtDecl(
+ Decl(
+ p.managedVarType(managed, self.side),
+ p.managedVar(managed, self.side).name,
+ )
+ )
+ ]
+ )
+
+ def genManagedEndpoint(self, managed):
+ hereEp = "ManagedEndpoint<%s>" % _actorName(managed.name(), self.side)
+ thereEp = "ManagedEndpoint<%s>" % _actorName(
+ managed.name(), _otherSide(self.side)
+ )
+
+ actor = _HybridDecl(ipdl.type.ActorType(managed), "aActor")
+
+ # ManagedEndpoint<PThere> OpenPEndpoint(PHere* aActor)
+ openmeth = MethodDefn(
+ MethodDecl(
+ "Open%sEndpoint" % managed.name(),
+ params=[
+ Decl(self.protocol.managedCxxType(managed, self.side), actor.name)
+ ],
+ ret=Type(thereEp),
+ )
+ )
+ openmeth.addcode(
+ """
+ $*{bind}
+ // Mark our actor as awaiting the other side to be bound. This will
+ // be cleared when a `MANAGED_ENDPOINT_{DROPPED,BOUND}` message is
+ // received.
+ aActor->mAwaitingManagedEndpointBind = true;
+ return ${thereEp}(mozilla::ipc::PrivateIPDLInterface(), aActor);
+ """,
+ bind=self.bindManagedActor(actor, errfn=ExprCall(ExprVar(thereEp))),
+ thereEp=thereEp,
+ )
+
+ # void BindPEndpoint(ManagedEndpoint<PHere>&& aEndpoint, PHere* aActor)
+ bindmeth = MethodDefn(
+ MethodDecl(
+ "Bind%sEndpoint" % managed.name(),
+ params=[
+ Decl(Type(hereEp), "aEndpoint"),
+ Decl(self.protocol.managedCxxType(managed, self.side), actor.name),
+ ],
+ ret=Type.BOOL,
+ )
+ )
+ bindmeth.addcode(
+ """
+ return aEndpoint.Bind(mozilla::ipc::PrivateIPDLInterface(), aActor, this, ${container});
+ """,
+ container=self.protocol.managedVar(managed, self.side),
+ )
+
+ self.cls.addstmts([openmeth, bindmeth, Whitespace.NL])
+
+ def implementManagerIface(self):
+ p = self.protocol
+ protocolbase = Type("IProtocol", ptr=True)
+
+ methods = []
+
+ if p.decl.type.isToplevel():
+ # FIXME: This used to be declared conditionally based on whether
+ # shmem appeared somewhere in the protocol hierarchy, however that
+ # caused issues due to Shmem instances hidden within custom C++
+ # types.
+ self.asyncSwitch.addcase(
+ CaseLabel("SHMEM_CREATED_MESSAGE_TYPE"),
+ self.genShmemCreatedHandler(),
+ )
+ self.asyncSwitch.addcase(
+ CaseLabel("SHMEM_DESTROYED_MESSAGE_TYPE"),
+ self.genShmemDestroyedHandler(),
+ )
+
+ # Keep track of types created with an INOUT ctor. We need to call
+ # Register() or RegisterID() for them depending on the side the managee
+ # is created.
+ inoutCtorTypes = []
+ for msg in p.messageDecls:
+ msgtype = msg.decl.type
+ if msgtype.isCtor() and msgtype.isInout():
+ inoutCtorTypes.append(msgtype.constructedType())
+
+ # all protocols share the "same" RemoveManagee() implementation
+ pvar = ExprVar("aProtocolId")
+ listenervar = ExprVar("aListener")
+ removemanagee = MethodDefn(
+ MethodDecl(
+ p.removeManageeMethod().name,
+ params=[
+ Decl(_protocolIdType(), pvar.name),
+ Decl(protocolbase, listenervar.name),
+ ],
+ methodspec=MethodSpec.OVERRIDE,
+ )
+ )
+
+ if not len(p.managesStmts):
+ removemanagee.addcode(
+ """
+ FatalError("unreached");
+ return;
+ """
+ )
+ else:
+ switchontype = StmtSwitch(pvar)
+ for managee in p.managesStmts:
+ manageeipdltype = managee.decl.type
+ manageecxxtype = _cxxBareType(
+ ipdl.type.ActorType(manageeipdltype), self.side
+ )
+ case = ExprCode(
+ """
+ {
+ ${manageecxxtype} actor = static_cast<${manageecxxtype}>(aListener);
+
+ const bool removed = ${container}.EnsureRemoved(actor);
+ MOZ_RELEASE_ASSERT(removed, "actor not managed by this!");
+
+ auto* proxy = actor->GetLifecycleProxy();
+ NS_IF_RELEASE(proxy);
+ return;
+ }
+ """,
+ manageecxxtype=manageecxxtype,
+ container=p.managedVar(manageeipdltype, self.side),
+ )
+ switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case)
+ switchontype.addcase(
+ DefaultLabel(),
+ ExprCode(
+ """
+ FatalError("unreached");
+ return;
+ """
+ ),
+ )
+ removemanagee.addstmt(switchontype)
+
+ # The `DeallocManagee` method is called for managed actors to trigger
+ # deallocation when ActorLifecycleProxy is freed.
+ deallocmanagee = MethodDefn(
+ MethodDecl(
+ p.deallocManageeMethod().name,
+ params=[
+ Decl(_protocolIdType(), pvar.name),
+ Decl(protocolbase, listenervar.name),
+ ],
+ methodspec=MethodSpec.OVERRIDE,
+ )
+ )
+
+ if not len(p.managesStmts):
+ deallocmanagee.addcode(
+ """
+ FatalError("unreached");
+ return;
+ """
+ )
+ else:
+ switchontype = StmtSwitch(pvar)
+ for managee in p.managesStmts:
+ manageeipdltype = managee.decl.type
+ # Reference counted actor types don't have corresponding
+ # `Dealloc` methods, as they are deallocated by releasing the
+ # IPDL-held reference.
+ if manageeipdltype.isRefcounted():
+ continue
+
+ case = StmtCode(
+ """
+ ${concrete}->${dealloc}(static_cast<${type}>(aListener));
+ return;
+ """,
+ concrete=self.concreteThis(),
+ dealloc=_deallocMethod(manageeipdltype, self.side),
+ type=_cxxBareType(ipdl.type.ActorType(manageeipdltype), self.side),
+ )
+ switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case)
+ switchontype.addcase(
+ DefaultLabel(),
+ StmtCode(
+ """
+ FatalError("unreached");
+ return;
+ """
+ ),
+ )
+ deallocmanagee.addstmt(switchontype)
+
+ return methods + [removemanagee, deallocmanagee, Whitespace.NL]
+
+ def genShmemCreatedHandler(self):
+ assert self.protocol.decl.type.isToplevel()
+
+ return StmtCode(
+ """
+ {
+ if (!ShmemCreated(${msgvar})) {
+ return MsgPayloadError;
+ }
+ return MsgProcessed;
+ }
+ """,
+ msgvar=self.msgvar,
+ )
+
+ def genShmemDestroyedHandler(self):
+ assert self.protocol.decl.type.isToplevel()
+
+ return StmtCode(
+ """
+ {
+ if (!ShmemDestroyed(${msgvar})) {
+ return MsgPayloadError;
+ }
+ return MsgProcessed;
+ }
+ """,
+ msgvar=self.msgvar,
+ )
+
+ # -------------------------------------------------------------------------
+ # The next few functions are the crux of the IPDL code generator.
+ # They generate code for all the nasty work of message
+ # serialization/deserialization and dispatching handlers for
+ # received messages.
+ ##
+
+ def concreteThis(self):
+ implAttr = self.protocol.implAttribute(self.side)
+ if implAttr == "virtual":
+ return ExprVar.THIS
+
+ if implAttr is None:
+ assert self.protocol.name.startswith("P")
+ className = self.protocol.name[1:] + self.side.capitalize()
+ else:
+ assert isinstance(implAttr, ipdl.ast.StringLiteral)
+ className = implAttr.value
+
+ return ExprCode("static_cast<${className}*>(this)", className=className)
+
+ def thisCall(self, function, args):
+ return ExprCall(ExprSelect(self.concreteThis(), "->", function), args=args)
+
+ def visitMessageDecl(self, md):
+ isctor = md.decl.type.isCtor()
+ isdtor = md.decl.type.isDtor()
+ decltype = md.decl.type
+ sendmethod = None
+ movesendmethod = None
+ promisesendmethod = None
+ recvlbl, recvcase = None, None
+
+ def addRecvCase(lbl, case):
+ if decltype.isAsync():
+ self.asyncSwitch.addcase(lbl, case)
+ elif decltype.isSync():
+ self.syncSwitch.addcase(lbl, case)
+ elif decltype.isInterrupt():
+ self.interruptSwitch.addcase(lbl, case)
+ else:
+ assert 0
+
+ if self.sendsMessage(md):
+ isasync = decltype.isAsync()
+
+ # NOTE: Don't generate helper ctors for refcounted types.
+ #
+ # Safety concerns around providing your own actor to a ctor (namely
+ # that the return value won't be checked, and the argument will be
+ # `delete`-ed) are less critical with refcounted actors, due to the
+ # actor being held alive by the callsite.
+ #
+ # This allows refcounted actors to not implement crashing AllocPFoo
+ # methods on the sending side.
+ if isctor and not md.decl.type.constructedType().isRefcounted():
+ self.cls.addstmts([self.genHelperCtor(md), Whitespace.NL])
+
+ if isctor and isasync:
+ sendmethod, (recvlbl, recvcase) = self.genAsyncCtor(md)
+ elif isctor:
+ sendmethod = self.genBlockingCtorMethod(md)
+ elif isdtor and isasync:
+ sendmethod, (recvlbl, recvcase) = self.genAsyncDtor(md)
+ elif isdtor:
+ sendmethod = self.genBlockingDtorMethod(md)
+ elif isasync:
+ (
+ sendmethod,
+ movesendmethod,
+ promisesendmethod,
+ (recvlbl, recvcase),
+ ) = self.genAsyncSendMethod(md)
+ else:
+ sendmethod, movesendmethod = self.genBlockingSendMethod(md)
+
+ # XXX figure out what to do here
+ if isdtor and md.decl.type.constructedType().isToplevel():
+ sendmethod = None
+
+ if sendmethod is not None:
+ self.cls.addstmts([sendmethod, Whitespace.NL])
+ if movesendmethod is not None:
+ self.cls.addstmts([movesendmethod, Whitespace.NL])
+ if promisesendmethod is not None:
+ self.cls.addstmts([promisesendmethod, Whitespace.NL])
+ if recvcase is not None:
+ addRecvCase(recvlbl, recvcase)
+ recvlbl, recvcase = None, None
+
+ if self.receivesMessage(md):
+ if isctor:
+ recvlbl, recvcase = self.genCtorRecvCase(md)
+ elif isdtor:
+ recvlbl, recvcase = self.genDtorRecvCase(md)
+ else:
+ recvlbl, recvcase = self.genRecvCase(md)
+
+ # XXX figure out what to do here
+ if isdtor and md.decl.type.constructedType().isToplevel():
+ return
+
+ addRecvCase(recvlbl, recvcase)
+
+ def genAsyncCtor(self, md):
+ actor = md.actorDecl()
+ method = MethodDefn(self.makeSendMethodDecl(md))
+
+ msgvar, stmts = self.makeMessage(md, errfnSendCtor)
+ sendok, sendstmts = self.sendAsync(md, msgvar)
+
+ method.addcode(
+ """
+ $*{bind}
+
+ // Build our constructor message.
+ $*{stmts}
+
+ // Notify the other side about the newly created actor. This can
+ // fail if our manager has already been destroyed.
+ //
+ // NOTE: If the send call fails due to toplevel channel teardown,
+ // the `IProtocol::ChannelSend` wrapper absorbs the error for us,
+ // so we don't tear down actors unexpectedly.
+ $*{sendstmts}
+
+ // Warn, destroy the actor, and return null if the message failed to
+ // send. Otherwise, return the successfully created actor reference.
+ if (!${sendok}) {
+ NS_WARNING("Error sending ${actorname} constructor");
+ $*{destroy}
+ return nullptr;
+ }
+ return ${actor};
+ """,
+ bind=self.bindManagedActor(actor),
+ stmts=stmts,
+ sendstmts=sendstmts,
+ sendok=sendok,
+ destroy=self.destroyActor(
+ md, actor.var(), why=_DestroyReason.FailedConstructor
+ ),
+ actor=actor.var(),
+ actorname=actor.ipdltype.protocol.name() + self.side.capitalize(),
+ )
+
+ lbl = CaseLabel(md.pqReplyId())
+ case = StmtBlock()
+ case.addstmt(StmtReturn(_Result.Processed))
+ # TODO not really sure what to do with async ctor "replies" yet.
+ # destroy actor if there was an error? tricky ...
+
+ return method, (lbl, case)
+
+ def genBlockingCtorMethod(self, md):
+ actor = md.actorDecl()
+ method = MethodDefn(self.makeSendMethodDecl(md))
+
+ msgvar, stmts = self.makeMessage(md, errfnSendCtor)
+
+ replyvar = self.replyvar
+ sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar)
+ replystmts = self.deserializeReply(
+ md,
+ replyvar,
+ self.side,
+ errfnSendCtor,
+ errfnSentinel(ExprLiteral.NULL),
+ )
+
+ method.addcode(
+ """
+ $*{bind}
+
+ // Build our constructor message.
+ $*{stmts}
+
+ // Synchronously send the constructor message to the other side. If
+ // the send fails, e.g. due to the remote side shutting down, the
+ // actor will be destroyed and potentially freed.
+ UniquePtr<Message> ${replyvar};
+ $*{sendstmts}
+
+ if (!(${sendok})) {
+ // Warn, destroy the actor and return null if the message
+ // failed to send.
+ NS_WARNING("Error sending constructor");
+ $*{destroy}
+ return nullptr;
+ }
+
+ $*{replystmts}
+ return ${actor};
+ """,
+ bind=self.bindManagedActor(actor),
+ stmts=stmts,
+ replyvar=replyvar,
+ sendstmts=sendstmts,
+ sendok=sendok,
+ destroy=self.destroyActor(
+ md, actor.var(), why=_DestroyReason.FailedConstructor
+ ),
+ replystmts=replystmts,
+ actor=actor.var(),
+ actorname=actor.ipdltype.protocol.name() + self.side.capitalize(),
+ )
+
+ return method
+
+ def bindManagedActor(self, actordecl, errfn=ExprLiteral.NULL, idexpr=None):
+ actorproto = actordecl.ipdltype.protocol
+
+ if idexpr is None:
+ setManagerArgs = [ExprVar.THIS]
+ else:
+ setManagerArgs = [ExprVar.THIS, idexpr]
+
+ return [
+ StmtCode(
+ """
+ if (!${actor}) {
+ NS_WARNING("Cannot bind null ${actorname} actor");
+ return ${errfn};
+ }
+
+ ${actor}->SetManagerAndRegister($,{setManagerArgs});
+ ${container}.Insert(${actor});
+ """,
+ actor=actordecl.var(),
+ actorname=actorproto.name() + self.side.capitalize(),
+ errfn=errfn,
+ setManagerArgs=setManagerArgs,
+ container=self.protocol.managedVar(actorproto, self.side),
+ )
+ ]
+
+ def genHelperCtor(self, md):
+ helperdecl = self.makeSendMethodDecl(md)
+ helperdecl.params = helperdecl.params[1:]
+ helper = MethodDefn(helperdecl)
+
+ helper.addstmts(
+ [
+ self.callAllocActor(md, retsems="out", side=self.side),
+ StmtReturn(
+ ExprCall(
+ ExprVar(helperdecl.name), args=md.makeCxxArgs(paramsems="move")
+ )
+ ),
+ ]
+ )
+ return helper
+
+ def genAsyncDtor(self, md):
+ actorvar = ExprVar("actor")
+ method = MethodDefn(self.makeDtorMethodDecl(md, actorvar))
+
+ method.addstmt(self.dtorPrologue(actorvar))
+
+ msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar)
+ sendok, sendstmts = self.sendAsync(md, msgvar, actorvar)
+ method.addstmts(
+ stmts
+ + sendstmts
+ + [Whitespace.NL]
+ + self.dtorEpilogue(md, actorvar)
+ + [StmtReturn(sendok)]
+ )
+
+ lbl = CaseLabel(md.pqReplyId())
+ case = StmtBlock()
+ case.addstmt(StmtReturn(_Result.Processed))
+ # TODO if the dtor is "inherently racy", keep the actor alive
+ # until the other side acks
+
+ return method, (lbl, case)
+
+ def genBlockingDtorMethod(self, md):
+ actorvar = ExprVar("actor")
+ method = MethodDefn(self.makeDtorMethodDecl(md, actorvar))
+
+ method.addstmt(self.dtorPrologue(actorvar))
+
+ msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar)
+
+ replyvar = self.replyvar
+ sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar, actorvar)
+ method.addstmts(
+ stmts
+ + [Whitespace.NL, StmtDecl(Decl(Type("UniquePtr<Message>"), replyvar.name))]
+ + sendstmts
+ )
+
+ destmts = self.deserializeReply(
+ md, replyvar, self.side, errfnSend, errfnSentinel(), actorvar
+ )
+ ifsendok = StmtIf(ExprLiteral.FALSE)
+ ifsendok.addifstmts(destmts)
+ ifsendok.addifstmts(
+ [Whitespace.NL, StmtExpr(ExprAssn(sendok, ExprLiteral.FALSE, "&="))]
+ )
+
+ method.addstmt(ifsendok)
+
+ method.addstmts(
+ self.dtorEpilogue(md, actorvar) + [Whitespace.NL, StmtReturn(sendok)]
+ )
+
+ return method
+
+ def destroyActor(self, md, actorexpr, why=_DestroyReason.Deletion):
+ if md and md.decl.type.isCtor():
+ destroyedType = md.decl.type.constructedType()
+ else:
+ destroyedType = self.protocol.decl.type
+
+ return [
+ StmtCode(
+ """
+ IProtocol* mgr = ${actor}->Manager();
+ ${actor}->DestroySubtree(${why});
+ ${actor}->ClearSubtree();
+ mgr->RemoveManagee(${protoId}, ${actor});
+ """,
+ actor=actorexpr,
+ why=why,
+ protoId=_protocolId(destroyedType),
+ )
+ ]
+
+ def dtorPrologue(self, actorexpr):
+ return StmtCode(
+ """
+ if (!${actor} || !${actor}->CanSend()) {
+ NS_WARNING("Attempt to __delete__ missing or closed actor");
+ return false;
+ }
+ """,
+ actor=actorexpr,
+ )
+
+ def dtorEpilogue(self, md, actorexpr):
+ return self.destroyActor(md, actorexpr)
+
+ def genRecvAsyncReplyCase(self, md):
+ lbl = CaseLabel(md.pqReplyId())
+ case = StmtBlock()
+ resolve, reason, prologue, desrej, desstmts = self.deserializeAsyncReply(
+ md, self.side, errfnRecv, errfnSentinel(_Result.ValuError)
+ )
+
+ if len(md.returns) > 1:
+ resolvetype = _tuple([d.bareType(self.side) for d in md.returns])
+ resolvearg = ExprCall(
+ ExprVar("std::make_tuple"), args=[ExprMove(p.var()) for p in md.returns]
+ )
+ else:
+ resolvetype = md.returns[0].bareType(self.side)
+ resolvearg = ExprMove(md.returns[0].var())
+
+ case.addcode(
+ """
+ $*{prologue}
+
+ UniquePtr<MessageChannel::UntypedCallbackHolder> untypedCallback =
+ GetIPCChannel()->PopCallback(${msgvar}, Id());
+
+ typedef MessageChannel::CallbackHolder<${resolvetype}> CallbackHolder;
+ auto* callback = static_cast<CallbackHolder*>(untypedCallback.get());
+ if (!callback) {
+ FatalError("Error unknown callback");
+ return MsgProcessingError;
+ }
+
+ if (${resolve}) {
+ $*{desstmts}
+ callback->Resolve(${resolvearg});
+ } else {
+ $*{desrej}
+ callback->Reject(std::move(${reason}));
+ }
+ return MsgProcessed;
+ """,
+ prologue=prologue,
+ msgvar=self.msgvar,
+ resolve=resolve,
+ resolvetype=resolvetype,
+ desstmts=desstmts,
+ resolvearg=resolvearg,
+ desrej=desrej,
+ reason=reason,
+ )
+
+ return (lbl, case)
+
+ def genAsyncSendMethod(self, md):
+ decl = self.makeSendMethodDecl(md)
+ if "VirtualSendImpl" in md.attributes:
+ decl.methodspec = MethodSpec.VIRTUAL
+ method = MethodDefn(decl)
+ msgvar, stmts = self.makeMessage(md, errfnSend)
+ retvar, sendstmts = self.sendAsync(md, msgvar)
+
+ method.addstmts(stmts + [Whitespace.NL] + sendstmts + [StmtReturn(retvar)])
+
+ movemethod = None
+
+ # Add the promise overload if we need one.
+ if md.returns:
+ decl = self.makeSendMethodDecl(md, promise=True)
+ if "VirtualSendImpl" in md.attributes:
+ decl.methodspec = MethodSpec.VIRTUAL
+ promisemethod = MethodDefn(decl)
+ stmts = self.sendAsyncWithPromise(md)
+ promisemethod.addstmts(stmts)
+
+ (lbl, case) = self.genRecvAsyncReplyCase(md)
+ else:
+ (promisemethod, lbl, case) = (None, None, None)
+
+ return method, movemethod, promisemethod, (lbl, case)
+
+ def genBlockingSendMethod(self, md):
+ method = MethodDefn(self.makeSendMethodDecl(md))
+
+ msgvar, serstmts = self.makeMessage(md, errfnSend)
+ replyvar = self.replyvar
+
+ sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar)
+ failif = StmtIf(ExprNot(sendok))
+ failif.addifstmt(StmtReturn.FALSE)
+
+ desstmts = self.deserializeReply(
+ md, replyvar, self.side, errfnSend, errfnSentinel()
+ )
+
+ method.addstmts(
+ serstmts
+ + [Whitespace.NL, StmtDecl(Decl(Type("UniquePtr<Message>"), replyvar.name))]
+ + sendstmts
+ + [failif]
+ + desstmts
+ + [Whitespace.NL, StmtReturn.TRUE]
+ )
+
+ movemethod = None
+
+ return method, movemethod
+
+ def genCtorRecvCase(self, md):
+ lbl = CaseLabel(md.pqMsgId())
+ case = StmtBlock()
+ actorhandle = self.handlevar
+
+ stmts = self.deserializeMessage(
+ md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
+ )
+
+ idvar, saveIdStmts = self.saveActorId(md)
+ case.addstmts(
+ stmts
+ + [
+ StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
+ for r in md.returns
+ ]
+ # alloc the actor, register it under the foreign ID
+ + [self.callAllocActor(md, retsems="in", side=self.side)]
+ + self.bindManagedActor(
+ md.actorDecl(), errfn=_Result.ValuError, idexpr=_actorHId(actorhandle)
+ )
+ + [Whitespace.NL]
+ + saveIdStmts
+ + self.invokeRecvHandler(md)
+ + self.makeReply(md, errfnRecv, idvar)
+ + [Whitespace.NL, StmtReturn(_Result.Processed)]
+ )
+
+ return lbl, case
+
+ def genDtorRecvCase(self, md):
+ lbl = CaseLabel(md.pqMsgId())
+ case = StmtBlock()
+
+ stmts = self.deserializeMessage(
+ md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
+ )
+
+ idvar, saveIdStmts = self.saveActorId(md)
+ case.addstmts(
+ stmts
+ + [
+ StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
+ for r in md.returns
+ ]
+ + self.invokeRecvHandler(md)
+ + [Whitespace.NL]
+ + saveIdStmts
+ + self.makeReply(md, errfnRecv, routingId=idvar)
+ + [Whitespace.NL]
+ + self.dtorEpilogue(md, ExprVar.THIS)
+ + [Whitespace.NL, StmtReturn(_Result.Processed)]
+ )
+
+ return lbl, case
+
+ def genRecvCase(self, md):
+ lbl = CaseLabel(md.pqMsgId())
+ case = StmtBlock()
+
+ stmts = self.deserializeMessage(
+ md, self.side, errfn=errfnRecv, errfnSent=errfnSentinel(_Result.ValuError)
+ )
+
+ idvar, saveIdStmts = self.saveActorId(md)
+ declstmts = [
+ StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[])
+ for r in md.returns
+ ]
+ if md.decl.type.isAsync() and md.returns:
+ declstmts = self.makeResolver(md, errfnRecv, routingId=idvar)
+ case.addstmts(
+ stmts
+ + saveIdStmts
+ + declstmts
+ + self.invokeRecvHandler(md)
+ + [Whitespace.NL]
+ + self.makeReply(md, errfnRecv, routingId=idvar)
+ + [StmtReturn(_Result.Processed)]
+ )
+
+ return lbl, case
+
+ # helper methods
+
+ def makeMessage(self, md, errfn, fromActor=None):
+ msgvar = self.msgvar
+ writervar = ExprVar("writer__")
+ routingId = self.protocol.routingId(fromActor)
+ this = fromActor or ExprVar.THIS
+
+ stmts = (
+ [
+ StmtDecl(
+ Decl(Type("UniquePtr<IPC::Message>"), msgvar.name),
+ init=ExprCall(ExprVar(md.pqMsgCtorFunc()), args=[routingId]),
+ ),
+ StmtDecl(
+ Decl(Type("IPC::MessageWriter"), writervar.name),
+ initargs=[ExprDeref(msgvar), this],
+ ),
+ ]
+ + [Whitespace.NL]
+ + [
+ _ParamTraits.checkedWrite(
+ p.ipdltype,
+ p.var(),
+ ExprAddrOf(writervar),
+ sentinelKey=p.name,
+ )
+ for p in md.params
+ ]
+ + [Whitespace.NL]
+ + self.setMessageFlags(md, msgvar)
+ )
+ return msgvar, stmts
+
+ def makeResolver(self, md, errfn, routingId):
+ if routingId is None:
+ routingId = self.protocol.routingId()
+ if not md.decl.type.isAsync() or not md.hasReply():
+ return []
+
+ def paramValue(idx):
+ assert idx < len(md.returns)
+ if len(md.returns) > 1:
+ return ExprCode("std::get<${idx}>(aParam)", idx=idx)
+ return ExprVar("aParam")
+
+ serializeParams = [
+ _ParamTraits.checkedWrite(
+ p.ipdltype,
+ paramValue(idx),
+ ExprAddrOf(ExprVar("writer__")),
+ sentinelKey=p.name,
+ )
+ for idx, p in enumerate(md.returns)
+ ]
+
+ return [
+ StmtCode(
+ """
+ UniquePtr<IPC::Message> ${replyvar}(${replyCtor}(${routingId}));
+ ${replyvar}->set_seqno(${msgvar}.seqno());
+
+ RefPtr<mozilla::ipc::IPDLResolverInner> resolver__ =
+ new mozilla::ipc::IPDLResolverInner(std::move(${replyvar}), this);
+
+ ${resolvertype} resolver = [resolver__ = std::move(resolver__)](${resolveType} aParam) {
+ resolver__->Resolve([&] (IPC::Message* ${replyvar}, IProtocol* self__) {
+ IPC::MessageWriter writer__(*${replyvar}, self__);
+ $*{serializeParams}
+ ${logSendingReply}
+ });
+ };
+ """,
+ msgvar=self.msgvar,
+ resolvertype=Type(md.resolverName()),
+ routingId=routingId,
+ resolveType=_resolveType(md.returns, self.side),
+ replyvar=self.replyvar,
+ replyCtor=ExprVar(md.pqReplyCtorFunc()),
+ serializeParams=serializeParams,
+ logSendingReply=self.logMessage(
+ md,
+ self.replyvar,
+ "Sending reply ",
+ actor=ExprVar("self__"),
+ ),
+ )
+ ]
+
+ def makeReply(self, md, errfn, routingId):
+ if routingId is None:
+ routingId = self.protocol.routingId()
+ # TODO special cases for async ctor/dtor replies
+ if not md.decl.type.hasReply():
+ return []
+ if md.decl.type.isAsync() and md.decl.type.hasReply():
+ return []
+
+ replyvar = self.replyvar
+ return (
+ [
+ StmtExpr(
+ ExprAssn(
+ replyvar,
+ ExprCall(ExprVar(md.pqReplyCtorFunc()), args=[routingId]),
+ )
+ ),
+ StmtDecl(
+ Decl(Type("IPC::MessageWriter"), "writer__"),
+ initargs=[ExprDeref(replyvar), ExprVar.THIS],
+ ),
+ Whitespace.NL,
+ ]
+ + [
+ _ParamTraits.checkedWrite(
+ r.ipdltype,
+ r.var(),
+ ExprAddrOf(ExprVar("writer__")),
+ sentinelKey=r.name,
+ )
+ for r in md.returns
+ ]
+ + self.setMessageFlags(md, replyvar)
+ + [self.logMessage(md, replyvar, "Sending reply ")]
+ )
+
+ def setMessageFlags(self, md, var, seqno=None):
+ stmts = []
+
+ if seqno:
+ stmts.append(
+ StmtExpr(ExprCall(ExprSelect(var, "->", "set_seqno"), args=[seqno]))
+ )
+
+ return stmts + [Whitespace.NL]
+
+ def deserializeMessage(self, md, side, errfn, errfnSent):
+ msgvar = self.msgvar
+ msgexpr = ExprAddrOf(msgvar)
+ readervar = ExprVar("reader__")
+ isctor = md.decl.type.isCtor()
+ stmts = [
+ self.logMessage(md, msgexpr, "Received ", receiving=True),
+ self.profilerLabel(md),
+ Whitespace.NL,
+ ]
+
+ if 0 == len(md.params):
+ return stmts
+
+ start, reads = 0, []
+ if isctor:
+ # return the raw actor handle so that its ID can be used
+ # to construct the "real" actor
+ handlevar = self.handlevar
+ handletype = Type("ActorHandle")
+ reads = [
+ _ParamTraits.checkedRead(
+ None,
+ handletype,
+ handlevar,
+ ExprAddrOf(readervar),
+ errfn,
+ "'%s'" % handletype.name,
+ sentinelKey="actor",
+ errfnSentinel=errfnSent,
+ )
+ ]
+ start = 1
+
+ def maybeTainted(p, side):
+ if md.decl.type.tainted and "NoTaint" not in p.attributes:
+ return Type("Tainted", T=p.bareType(side))
+ return p.bareType(side)
+
+ reads.extend(
+ [
+ _ParamTraits.checkedRead(
+ p.ipdltype,
+ maybeTainted(p, side),
+ p.var(),
+ ExprAddrOf(readervar),
+ errfn,
+ "'%s'" % p.ipdltype.name(),
+ sentinelKey=p.name,
+ errfnSentinel=errfnSent,
+ )
+ for p in md.params[start:]
+ ]
+ )
+
+ stmts.extend(
+ (
+ [
+ StmtDecl(
+ Decl(Type("IPC::MessageReader"), readervar.name),
+ initargs=[msgvar, ExprVar.THIS],
+ )
+ ]
+ + [Whitespace.NL]
+ + reads
+ + [StmtCode("${reader}.EndRead();\n", reader=readervar)]
+ )
+ )
+
+ return stmts
+
+ def deserializeAsyncReply(self, md, side, errfn, errfnSent):
+ msgvar = self.msgvar
+ readervar = ExprVar("reader__")
+ msgexpr = ExprAddrOf(msgvar)
+ isctor = md.decl.type.isCtor()
+ resolve = ExprVar("resolve__")
+ reason = ExprVar("reason__")
+
+ # NOTE: The `resolve__` and `reason__` parameters don't have sentinels,
+ # as they are serialized by the IPDLResolverInner type in
+ # ProtocolUtils.cpp rather than by generated code.
+ desresolve = [
+ StmtCode(
+ """
+ bool resolve__ = false;
+ if (!IPC::ReadParam(&${readervar}, &resolve__)) {
+ FatalError("Error deserializing bool");
+ return MsgValueError;
+ }
+ """,
+ readervar=readervar,
+ ),
+ ]
+ desrej = [
+ StmtCode(
+ """
+ ResponseRejectReason reason__{};
+ if (!IPC::ReadParam(&${readervar}, &reason__)) {
+ FatalError("Error deserializing ResponseRejectReason");
+ return MsgValueError;
+ }
+ ${readervar}.EndRead();
+ """,
+ readervar=readervar,
+ ),
+ ]
+ prologue = [
+ self.logMessage(md, msgexpr, "Received ", receiving=True),
+ self.profilerLabel(md),
+ Whitespace.NL,
+ ]
+
+ if not md.returns:
+ return prologue
+
+ prologue.extend(
+ [
+ StmtDecl(
+ Decl(Type("IPC::MessageReader"), readervar.name),
+ initargs=[msgvar, ExprVar.THIS],
+ )
+ ]
+ + desresolve
+ )
+
+ start, reads = 0, []
+ if isctor:
+ # return the raw actor handle so that its ID can be used
+ # to construct the "real" actor
+ handlevar = self.handlevar
+ handletype = Type("ActorHandle")
+ reads = [
+ _ParamTraits.checkedRead(
+ None,
+ handletype,
+ handlevar,
+ ExprAddrOf(readervar),
+ errfn,
+ "'%s'" % handletype.name,
+ sentinelKey="actor",
+ errfnSentinel=errfnSent,
+ )
+ ]
+ start = 1
+
+ stmts = (
+ reads
+ + [
+ _ParamTraits.checkedRead(
+ p.ipdltype,
+ p.bareType(side),
+ p.var(),
+ ExprAddrOf(readervar),
+ errfn,
+ "'%s'" % p.ipdltype.name(),
+ sentinelKey=p.name,
+ errfnSentinel=errfnSent,
+ )
+ for p in md.returns[start:]
+ ]
+ + [StmtCode("${reader}.EndRead();", reader=readervar)]
+ )
+
+ return resolve, reason, prologue, desrej, stmts
+
+ def deserializeReply(self, md, replyexpr, side, errfn, errfnSentinel, actor=None):
+ stmts = [
+ Whitespace.NL,
+ self.logMessage(md, replyexpr, "Received reply ", actor, receiving=True),
+ ]
+ if 0 == len(md.returns):
+ return stmts
+
+ def tempvar(r):
+ return ExprVar(r.var().name + "__reply")
+
+ readervar = ExprVar("reader__")
+ stmts.extend(
+ [
+ Whitespace.NL,
+ StmtDecl(
+ Decl(Type("IPC::MessageReader"), readervar.name),
+ initargs=[ExprDeref(self.replyvar), ExprVar.THIS],
+ ),
+ ]
+ + [Whitespace.NL]
+ + [
+ _ParamTraits.checkedRead(
+ r.ipdltype,
+ r.bareType(side),
+ tempvar(r),
+ ExprAddrOf(readervar),
+ errfn,
+ "'%s'" % r.ipdltype.name(),
+ sentinelKey=r.name,
+ errfnSentinel=errfnSentinel,
+ )
+ for r in md.returns
+ ]
+ # Move-assign the values out of the variables created with
+ # checkedRead into outparams.
+ + [
+ StmtExpr(ExprAssn(ExprDeref(r.var()), ExprMove(tempvar(r))))
+ for r in md.returns
+ ]
+ + [StmtCode("${reader}.EndRead();", reader=readervar)]
+ )
+
+ return stmts
+
+ def sendAsync(self, md, msgexpr, actor=None):
+ sendok = ExprVar("sendok__")
+ resolvefn = ExprVar("aResolve")
+ rejectfn = ExprVar("aReject")
+
+ stmts = [
+ Whitespace.NL,
+ self.logMessage(md, msgexpr, "Sending ", actor),
+ self.profilerLabel(md),
+ ]
+ stmts.append(Whitespace.NL)
+
+ # Generate the actual call expression.
+ send = ExprVar("ChannelSend")
+ if actor is not None:
+ send = ExprSelect(actor, "->", send.name)
+ if md.returns:
+ stmts.append(
+ StmtExpr(
+ ExprCall(
+ send,
+ args=[
+ ExprMove(msgexpr),
+ ExprVar(md.pqReplyId()),
+ ExprMove(resolvefn),
+ ExprMove(rejectfn),
+ ],
+ )
+ )
+ )
+ retvar = None
+ else:
+ stmts.append(
+ StmtDecl(
+ Decl(Type.BOOL, sendok.name),
+ init=ExprCall(send, args=[ExprMove(msgexpr)]),
+ )
+ )
+ retvar = sendok
+
+ return (retvar, stmts)
+
+ def sendBlocking(self, md, msgexpr, replyexpr, actor=None):
+ send = ExprVar("ChannelSend")
+ if md.decl.type.isInterrupt():
+ send = ExprVar("ChannelCall")
+ if actor is not None:
+ send = ExprSelect(actor, "->", send.name)
+
+ sendok = ExprVar("sendok__")
+ self.externalIncludes.add("mozilla/ProfilerMarkers.h")
+ return (
+ sendok,
+ (
+ [
+ Whitespace.NL,
+ self.logMessage(md, msgexpr, "Sending ", actor),
+ self.profilerLabel(md),
+ ]
+ + [
+ Whitespace.NL,
+ StmtDecl(Decl(Type.BOOL, sendok.name), init=ExprLiteral.FALSE),
+ StmtBlock(
+ [
+ StmtExpr(
+ ExprCall(
+ ExprVar("AUTO_PROFILER_TRACING_MARKER"),
+ [
+ ExprLiteral.String("Sync IPC"),
+ ExprLiteral.String(
+ self.protocol.name
+ + "::"
+ + md.prettyMsgName()
+ ),
+ ExprVar("IPC"),
+ ],
+ )
+ ),
+ StmtExpr(
+ ExprAssn(
+ sendok,
+ ExprCall(
+ send,
+ args=[ExprMove(msgexpr), ExprAddrOf(replyexpr)],
+ ),
+ )
+ ),
+ ]
+ ),
+ ]
+ ),
+ )
+
+ def sendAsyncWithPromise(self, md):
+ # Create a new promise, and forward to the callback send overload.
+ promise = _makePromise(md.returns, self.side, resolver=True)
+
+ if len(md.returns) > 1:
+ resolvetype = _tuple([d.bareType(self.side) for d in md.returns])
+ else:
+ resolvetype = md.returns[0].bareType(self.side)
+
+ resolve = ExprCode(
+ """
+ [promise__](${resolvetype}&& aValue) {
+ promise__->Resolve(std::move(aValue), __func__);
+ }
+ """,
+ resolvetype=resolvetype,
+ )
+ reject = ExprCode(
+ """
+ [promise__](ResponseRejectReason&& aReason) {
+ promise__->Reject(std::move(aReason), __func__);
+ }
+ """,
+ resolvetype=resolvetype,
+ )
+
+ args = [ExprMove(p.var()) for p in md.params] + [resolve, reject]
+ stmt = StmtCode(
+ """
+ RefPtr<${promise}> promise__ = new ${promise}(__func__);
+ promise__->UseDirectTaskDispatch(__func__);
+ ${send}($,{args});
+ return promise__;
+ """,
+ promise=promise,
+ send=md.sendMethod(),
+ args=args,
+ )
+ return [stmt]
+
+ def callAllocActor(self, md, retsems, side):
+ actortype = md.actorDecl().bareType(self.side)
+ if md.decl.type.constructedType().isRefcounted():
+ actortype.ptr = False
+ actortype = _refptr(actortype)
+
+ callalloc = self.thisCall(
+ _allocMethod(md.decl.type.constructedType(), side),
+ args=md.makeCxxArgs(retsems=retsems, retcallsems="out", implicit=False),
+ )
+
+ return StmtDecl(Decl(actortype, md.actorDecl().var().name), init=callalloc)
+
+ def invokeRecvHandler(self, md):
+ retsems = "in"
+ if md.decl.type.isAsync() and md.returns:
+ retsems = "resolver"
+ okdecl = StmtDecl(
+ Decl(Type("mozilla::ipc::IPCResult"), "__ok"),
+ init=self.thisCall(
+ md.recvMethod(),
+ md.makeCxxArgs(
+ paramsems="move",
+ retsems=retsems,
+ retcallsems="out",
+ ),
+ ),
+ )
+ failif = StmtIf(ExprNot(ExprVar("__ok")))
+ failif.addifstmts(
+ [
+ _protocolErrorBreakpoint("Handler returned error code!"),
+ Whitespace(
+ "// Error handled in mozilla::ipc::IPCResult\n", indent=True
+ ),
+ StmtReturn(_Result.ProcessingError),
+ ]
+ )
+ return [okdecl, failif]
+
+ def makeDtorMethodDecl(self, md, actorvar):
+ decl = self.makeSendMethodDecl(md)
+ decl.params.insert(
+ 0,
+ Decl(
+ _cxxInType(
+ ipdl.type.ActorType(md.decl.type.constructedType()),
+ side=self.side,
+ direction="send",
+ ),
+ actorvar.name,
+ ),
+ )
+ decl.methodspec = MethodSpec.STATIC
+ return decl
+
+ def makeSendMethodDecl(self, md, promise=False, paramsems="in"):
+ implicit = md.decl.type.hasImplicitActorParam()
+ if md.decl.type.isAsync() and md.returns:
+ if promise:
+ returnsems = "promise"
+ rettype = _refptr(Type(md.promiseName()))
+ else:
+ returnsems = "callback"
+ rettype = Type.VOID
+ else:
+ assert not promise
+ returnsems = "out"
+ rettype = Type.BOOL
+ decl = MethodDecl(
+ md.sendMethod(),
+ params=md.makeCxxParams(
+ paramsems,
+ returnsems=returnsems,
+ side=self.side,
+ implicit=implicit,
+ direction="send",
+ ),
+ warn_unused=(
+ (self.side == "parent" and returnsems != "callback")
+ or (md.decl.type.isCtor() and not md.decl.type.isAsync())
+ ),
+ ret=rettype,
+ )
+ if md.decl.type.isCtor():
+ decl.ret = md.actorDecl().bareType(self.side)
+ return decl
+
+ def logMessage(self, md, msgptr, pfx, actor=None, receiving=False):
+ actorname = _actorName(self.protocol.name, self.side)
+ return StmtCode(
+ """
+ if (mozilla::ipc::LoggingEnabledFor(${actorname})) {
+ mozilla::ipc::LogMessageForProtocol(
+ ${actorname},
+ ${actor}->ToplevelProtocol()->OtherPidMaybeInvalid(),
+ ${pfx},
+ ${msgptr}->type(),
+ mozilla::ipc::MessageDirection::${direction});
+ }
+ """,
+ actorname=ExprLiteral.String(actorname),
+ actor=actor or ExprVar.THIS,
+ pfx=ExprLiteral.String(pfx),
+ msgptr=msgptr,
+ direction="eReceiving" if receiving else "eSending",
+ )
+
+ def profilerLabel(self, md):
+ self.externalIncludes.add("mozilla/ProfilerLabels.h")
+ return StmtCode(
+ """
+ AUTO_PROFILER_LABEL("${name}::${msgname}", OTHER);
+ """,
+ name=self.protocol.name,
+ msgname=md.prettyMsgName(),
+ )
+
+ def saveActorId(self, md):
+ idvar = ExprVar("id__")
+ if md.decl.type.hasReply():
+ # only save the ID if we're actually going to use it, to
+ # avoid unused-variable warnings
+ saveIdStmts = [
+ StmtDecl(Decl(_actorIdType(), idvar.name), self.protocol.routingId())
+ ]
+ else:
+ saveIdStmts = []
+ return idvar, saveIdStmts
+
+
+class _GenerateProtocolParentCode(_GenerateProtocolActorCode):
+ def __init__(self):
+ _GenerateProtocolActorCode.__init__(self, "parent")
+
+ def sendsMessage(self, md):
+ return not md.decl.type.isIn()
+
+ def receivesMessage(self, md):
+ return md.decl.type.isInout() or md.decl.type.isIn()
+
+
+class _GenerateProtocolChildCode(_GenerateProtocolActorCode):
+ def __init__(self):
+ _GenerateProtocolActorCode.__init__(self, "child")
+
+ def sendsMessage(self, md):
+ return not md.decl.type.isOut()
+
+ def receivesMessage(self, md):
+ return md.decl.type.isInout() or md.decl.type.isOut()
+
+
+# -----------------------------------------------------------------------------
+# Utility passes
+##
+
+
+def _splitClassDeclDefn(cls):
+ """Destructively split |cls| methods into declarations and
+ definitions (if |not methodDecl.force_inline|). Return classDecl,
+ methodDefns."""
+ defns = Block()
+
+ for i, stmt in enumerate(cls.stmts):
+ if isinstance(stmt, MethodDefn) and not stmt.decl.force_inline:
+ decl, defn = _splitMethodDeclDefn(stmt, cls)
+ cls.stmts[i] = StmtDecl(decl)
+ if defn:
+ defns.addstmts([defn, Whitespace.NL])
+
+ return cls, defns
+
+
+def _splitMethodDeclDefn(md, cls):
+ # Pure methods have decls but no defns.
+ if md.decl.methodspec == MethodSpec.PURE:
+ return md.decl, None
+
+ saveddecl = deepcopy(md.decl)
+ md.decl.cls = cls
+ # Don't emit method specifiers on method defns.
+ md.decl.methodspec = MethodSpec.NONE
+ md.decl.warn_unused = False
+ md.decl.only_for_definition = True
+ for param in md.decl.params:
+ if isinstance(param, Param):
+ param.default = None
+ return saveddecl, md
+
+
+def _splitFuncDeclDefn(fun):
+ assert not fun.decl.force_inline
+ return StmtDecl(fun.decl), fun
diff --git a/ipc/ipdl/ipdl/parser.py b/ipc/ipdl/ipdl/parser.py
new file mode 100644
index 0000000000..1857131868
--- /dev/null
+++ b/ipc/ipdl/ipdl/parser.py
@@ -0,0 +1,680 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+import os
+from ply import lex, yacc
+
+from ipdl.ast import *
+
+# -----------------------------------------------------------------------------
+
+
+class ParseError(Exception):
+ def __init__(self, loc, fmt, *args):
+ self.loc = loc
+ self.error = (
+ "%s%s: error: %s" % (Parser.includeStackString(), loc, fmt)
+ ) % args
+
+ def __str__(self):
+ return self.error
+
+
+def _safeLinenoValue(t):
+ lineno, value = 0, "???"
+ if hasattr(t, "lineno"):
+ lineno = t.lineno
+ if hasattr(t, "value"):
+ value = t.value
+ return lineno, value
+
+
+def _error(loc, fmt, *args):
+ raise ParseError(loc, fmt, *args)
+
+
+class Parser:
+ # when we reach an |include [protocol] foo;| statement, we need to
+ # save the current parser state and create a new one. this "stack" is
+ # where that state is saved
+ #
+ # there is one Parser per file
+ current = None
+ parseStack = []
+ parsed = {}
+
+ def __init__(self, type, name, debug=False):
+ assert type and name
+ self.type = type
+ self.debug = debug
+ self.filename = None
+ self.includedirs = None
+ self.loc = None # not always up to date
+ self.lexer = None
+ self.parser = None
+ self.tu = TranslationUnit(type, name)
+ self.direction = None
+
+ def parse(self, input, filename, includedirs):
+ assert os.path.isabs(filename)
+
+ if self.tu.name in Parser.parsed:
+ priorTU = Parser.parsed[self.tu.name].tu
+ if os.path.normcase(priorTU.filename) != os.path.normcase(filename):
+ _error(
+ Loc(filename),
+ "Trying to load `%s' from a file when we'd already seen it in file `%s'"
+ % (self.tu.name, priorTU.filename),
+ )
+
+ return priorTU
+
+ self.lexer = lex.lex(debug=self.debug)
+ self.parser = yacc.yacc(debug=self.debug, write_tables=False)
+ self.filename = filename
+ self.includedirs = includedirs
+ self.tu.filename = filename
+
+ Parser.parsed[self.tu.name] = self
+ Parser.parseStack.append(Parser.current)
+ Parser.current = self
+
+ try:
+ ast = self.parser.parse(input=input, lexer=self.lexer, debug=self.debug)
+ finally:
+ Parser.current = Parser.parseStack.pop()
+
+ return ast
+
+ def resolveIncludePath(self, filepath):
+ """Return the absolute path from which the possibly partial
+ |filepath| should be read, or |None| if |filepath| cannot be located."""
+ for incdir in self.includedirs + [""]:
+ realpath = os.path.join(incdir, filepath)
+ if os.path.isfile(realpath):
+ return os.path.abspath(realpath)
+ return None
+
+ # returns a GCC-style string representation of the include stack.
+ # e.g.,
+ # in file included from 'foo.ipdl', line 120:
+ # in file included from 'bar.ipd', line 12:
+ # which can be printed above a proper error message or warning
+ @staticmethod
+ def includeStackString():
+ s = ""
+ for parse in Parser.parseStack[1:]:
+ s += " in file included from `%s', line %d:\n" % (
+ parse.loc.filename,
+ parse.loc.lineno,
+ )
+ return s
+
+
+def locFromTok(p, num):
+ return Loc(Parser.current.filename, p.lineno(num))
+
+
+# -----------------------------------------------------------------------------
+
+reserved = set(
+ (
+ "async",
+ "both",
+ "child",
+ "class",
+ "from",
+ "include",
+ "intr",
+ "manager",
+ "manages",
+ "namespace",
+ "nullable",
+ "or",
+ "parent",
+ "protocol",
+ "returns",
+ "struct",
+ "sync",
+ "union",
+ "UniquePtr",
+ "using",
+ )
+)
+tokens = [
+ "COLONCOLON",
+ "ID",
+ "STRING",
+] + [r.upper() for r in reserved]
+
+t_COLONCOLON = "::"
+
+literals = "(){}[]<>;:,?="
+t_ignore = " \f\t\v"
+
+
+def t_linecomment(t):
+ r"//[^\n]*"
+
+
+def t_multilinecomment(t):
+ r"/\*(\n|.)*?\*/"
+ t.lexer.lineno += t.value.count("\n")
+
+
+def t_NL(t):
+ r"(?:\r\n|\n|\n)+"
+ t.lexer.lineno += len(t.value)
+
+
+def t_ID(t):
+ r"[a-zA-Z_][a-zA-Z0-9_]*"
+ if t.value in reserved:
+ t.type = t.value.upper()
+ return t
+
+
+def t_STRING(t):
+ r'"[^"\n]*"'
+ t.value = StringLiteral(Loc(Parser.current.filename, t.lineno), t.value[1:-1])
+ return t
+
+
+def t_error(t):
+ _error(
+ Loc(Parser.current.filename, t.lineno),
+ "lexically invalid characters `%s",
+ t.value,
+ )
+
+
+# -----------------------------------------------------------------------------
+
+
+def p_TranslationUnit(p):
+ """TranslationUnit : Preamble NamespacedStuff"""
+ tu = Parser.current.tu
+ tu.loc = Loc(tu.filename)
+ for stmt in p[1]:
+ if isinstance(stmt, CxxInclude):
+ tu.addCxxInclude(stmt)
+ elif isinstance(stmt, Include):
+ tu.addInclude(stmt)
+ elif isinstance(stmt, UsingStmt):
+ tu.addUsingStmt(stmt)
+ else:
+ assert 0
+
+ for thing in p[2]:
+ if isinstance(thing, StructDecl):
+ tu.addStructDecl(thing)
+ elif isinstance(thing, UnionDecl):
+ tu.addUnionDecl(thing)
+ elif isinstance(thing, Protocol):
+ if tu.protocol is not None:
+ _error(thing.loc, "only one protocol definition per file")
+ tu.protocol = thing
+ else:
+ assert 0
+
+ # The "canonical" namespace of the tu, what it's considered to be
+ # in for the purposes of C++: |#include "foo/bar/TU.h"|
+ if tu.protocol:
+ assert tu.filetype == "protocol"
+ tu.namespaces = tu.protocol.namespaces
+ tu.name = tu.protocol.name
+ else:
+ assert tu.filetype == "header"
+ # There's not really a canonical "thing" in headers. So
+ # somewhat arbitrarily use the namespace of the last
+ # interesting thing that was declared.
+ for thing in reversed(tu.structsAndUnions):
+ tu.namespaces = thing.namespaces
+ break
+
+ p[0] = tu
+
+
+# --------------------
+# Preamble
+
+
+def p_Preamble(p):
+ """Preamble : Preamble PreambleStmt ';'
+ |"""
+ if 1 == len(p):
+ p[0] = []
+ else:
+ p[1].append(p[2])
+ p[0] = p[1]
+
+
+def p_PreambleStmt(p):
+ """PreambleStmt : CxxIncludeStmt
+ | IncludeStmt
+ | UsingStmt"""
+ p[0] = p[1]
+
+
+def p_CxxIncludeStmt(p):
+ """CxxIncludeStmt : INCLUDE STRING"""
+ p[0] = CxxInclude(locFromTok(p, 1), p[2].value)
+
+
+def p_IncludeStmt(p):
+ """IncludeStmt : INCLUDE PROTOCOL ID
+ | INCLUDE ID"""
+ loc = locFromTok(p, 1)
+
+ Parser.current.loc = loc
+ if 4 == len(p):
+ id = p[3]
+ type = "protocol"
+ else:
+ id = p[2]
+ type = "header"
+ inc = Include(loc, type, id)
+
+ path = Parser.current.resolveIncludePath(inc.file)
+ if path is None:
+ raise ParseError(loc, "can't locate include file `%s'" % (inc.file))
+
+ inc.tu = Parser(type, id).parse(open(path).read(), path, Parser.current.includedirs)
+ p[0] = inc
+
+
+def p_UsingKind(p):
+ """UsingKind : CLASS
+ | STRUCT
+ |"""
+ p[0] = p[1] if 2 == len(p) else None
+
+
+def p_UsingStmt(p):
+ """UsingStmt : Attributes USING UsingKind CxxType FROM STRING"""
+ p[0] = UsingStmt(
+ locFromTok(p, 2),
+ attributes=p[1],
+ kind=p[3],
+ cxxTypeSpec=p[4],
+ cxxHeader=p[6].value,
+ )
+
+
+# --------------------
+# Namespaced stuff
+
+
+def p_NamespacedStuff(p):
+ """NamespacedStuff : NamespacedStuff NamespaceThing
+ | NamespaceThing"""
+ if 2 == len(p):
+ p[0] = p[1]
+ else:
+ p[1].extend(p[2])
+ p[0] = p[1]
+
+
+def p_NamespaceThing(p):
+ """NamespaceThing : NAMESPACE ID '{' NamespacedStuff '}'
+ | StructDecl
+ | UnionDecl
+ | ProtocolDefn"""
+ if 2 == len(p):
+ p[0] = [p[1]]
+ else:
+ for thing in p[4]:
+ thing.addOuterNamespace(Namespace(locFromTok(p, 1), p[2]))
+ p[0] = p[4]
+
+
+def p_StructDecl(p):
+ """StructDecl : Attributes STRUCT ID '{' StructFields '}' ';'
+ | Attributes STRUCT ID '{' '}' ';'"""
+ if 8 == len(p):
+ p[0] = StructDecl(locFromTok(p, 2), p[3], p[5], p[1])
+ else:
+ p[0] = StructDecl(locFromTok(p, 2), p[3], [], p[1])
+
+
+def p_StructFields(p):
+ """StructFields : StructFields StructField ';'
+ | StructField ';'"""
+ if 3 == len(p):
+ p[0] = [p[1]]
+ else:
+ p[1].append(p[2])
+ p[0] = p[1]
+
+
+def p_StructField(p):
+ """StructField : Type ID"""
+ p[0] = StructField(locFromTok(p, 1), p[1], p[2])
+
+
+def p_UnionDecl(p):
+ """UnionDecl : Attributes UNION ID '{' ComponentTypes '}' ';'"""
+ p[0] = UnionDecl(locFromTok(p, 2), p[3], p[5], p[1])
+
+
+def p_ComponentTypes(p):
+ """ComponentTypes : ComponentTypes Type ';'
+ | Type ';'"""
+ if 3 == len(p):
+ p[0] = [p[1]]
+ else:
+ p[1].append(p[2])
+ p[0] = p[1]
+
+
+def p_ProtocolDefn(p):
+ """ProtocolDefn : Attributes OptionalSendSemantics \
+ PROTOCOL ID '{' ProtocolBody '}' ';'"""
+ protocol = p[6]
+ protocol.loc = locFromTok(p, 3)
+ protocol.name = p[4]
+ protocol.attributes = p[1]
+ protocol.sendSemantics = p[2]
+ p[0] = protocol
+
+ if Parser.current.type == "header":
+ _error(
+ protocol.loc,
+ "can't define a protocol in a header. Do it in a protocol spec instead.",
+ )
+
+
+def p_ProtocolBody(p):
+ """ProtocolBody : ManagersStmtOpt"""
+ p[0] = p[1]
+
+
+# --------------------
+# manager/manages stmts
+
+
+def p_ManagersStmtOpt(p):
+ """ManagersStmtOpt : ManagersStmt ManagesStmtsOpt
+ | ManagesStmtsOpt"""
+ if 2 == len(p):
+ p[0] = p[1]
+ else:
+ p[2].managers = p[1]
+ p[0] = p[2]
+
+
+def p_ManagersStmt(p):
+ """ManagersStmt : MANAGER ManagerList ';'"""
+ if 1 == len(p):
+ p[0] = []
+ else:
+ p[0] = p[2]
+
+
+def p_ManagerList(p):
+ """ManagerList : ID
+ | ManagerList OR ID"""
+ if 2 == len(p):
+ p[0] = [Manager(locFromTok(p, 1), p[1])]
+ else:
+ p[1].append(Manager(locFromTok(p, 3), p[3]))
+ p[0] = p[1]
+
+
+def p_ManagesStmtsOpt(p):
+ """ManagesStmtsOpt : ManagesStmt ManagesStmtsOpt
+ | MessageDeclsOpt"""
+ if 2 == len(p):
+ p[0] = p[1]
+ else:
+ p[2].managesStmts.insert(0, p[1])
+ p[0] = p[2]
+
+
+def p_ManagesStmt(p):
+ """ManagesStmt : MANAGES ID ';'"""
+ p[0] = ManagesStmt(locFromTok(p, 1), p[2])
+
+
+# --------------------
+# Message decls
+
+
+def p_MessageDeclsOpt(p):
+ """MessageDeclsOpt : MessageDeclThing MessageDeclsOpt
+ |"""
+ if 1 == len(p):
+ # we fill in |loc| in the Protocol rule
+ p[0] = Protocol(None)
+ else:
+ p[2].messageDecls.insert(0, p[1])
+ p[0] = p[2]
+
+
+def p_MessageDeclThing(p):
+ """MessageDeclThing : MessageDirectionLabel ':' MessageDecl ';'
+ | MessageDecl ';'"""
+ if 3 == len(p):
+ p[0] = p[1]
+ else:
+ p[0] = p[3]
+
+
+def p_MessageDirectionLabel(p):
+ """MessageDirectionLabel : PARENT
+ | CHILD
+ | BOTH"""
+ if p[1] == "parent":
+ Parser.current.direction = IN
+ elif p[1] == "child":
+ Parser.current.direction = OUT
+ elif p[1] == "both":
+ Parser.current.direction = INOUT
+ else:
+ assert 0
+
+
+def p_MessageDecl(p):
+ """MessageDecl : Attributes SendSemantics MessageBody"""
+ msg = p[3]
+ msg.attributes = p[1]
+ msg.sendSemantics = p[2]
+
+ if Parser.current.direction is None:
+ _error(msg.loc, "missing message direction")
+ msg.direction = Parser.current.direction
+
+ p[0] = msg
+
+
+def p_MessageBody(p):
+ """MessageBody : ID MessageInParams MessageOutParams"""
+ # FIXME/cjones: need better loc info: use one of the quals
+ name = p[1]
+ msg = MessageDecl(locFromTok(p, 1))
+ msg.name = name
+ msg.addInParams(p[2])
+ msg.addOutParams(p[3])
+
+ p[0] = msg
+
+
+def p_MessageInParams(p):
+ """MessageInParams : '(' ParamList ')'"""
+ p[0] = p[2]
+
+
+def p_MessageOutParams(p):
+ """MessageOutParams : RETURNS '(' ParamList ')'
+ |"""
+ if 1 == len(p):
+ p[0] = []
+ else:
+ p[0] = p[3]
+
+
+# --------------------
+# Attributes
+def p_Attributes(p):
+ """Attributes : '[' AttributeList ']'
+ |"""
+ p[0] = {}
+ if 4 == len(p):
+ for attr in p[2]:
+ if attr.name in p[0]:
+ _error(attr.loc, "Repeated extended attribute `%s'", attr.name)
+ p[0][attr.name] = attr
+
+
+def p_AttributeList(p):
+ """AttributeList : Attribute ',' AttributeList
+ | Attribute"""
+ p[0] = [p[1]]
+ if 4 == len(p):
+ p[0] += p[3]
+
+
+def p_Attribute(p):
+ """Attribute : ID AttributeValue"""
+ p[0] = Attribute(locFromTok(p, 1), p[1], p[2])
+
+
+def p_AttributeValue(p):
+ """AttributeValue : '=' ID
+ | '=' STRING
+ |"""
+ if 1 == len(p):
+ p[0] = None
+ else:
+ p[0] = p[2]
+
+
+def p_SendSemantics(p):
+ """SendSemantics : ASYNC
+ | SYNC
+ | INTR"""
+ if p[1] == "async":
+ p[0] = ASYNC
+ elif p[1] == "sync":
+ p[0] = SYNC
+ else:
+ assert p[1] == "intr"
+ p[0] = INTR
+
+
+def p_OptionalSendSemantics(p):
+ """OptionalSendSemantics : SendSemantics
+ |"""
+ if 2 == len(p):
+ p[0] = p[1]
+ else:
+ p[0] = ASYNC
+
+
+# --------------------
+# Minor stuff
+
+
+def p_ParamList(p):
+ """ParamList : ParamList ',' Param
+ | Param
+ |"""
+ if 1 == len(p):
+ p[0] = []
+ elif 2 == len(p):
+ p[0] = [p[1]]
+ else:
+ p[1].append(p[3])
+ p[0] = p[1]
+
+
+def p_Param(p):
+ """Param : Attributes Type ID"""
+ p[0] = Param(locFromTok(p, 2), p[2], p[3], p[1])
+
+
+def p_Type(p):
+ """Type : MaybeNullable BasicType"""
+ # only some types are nullable; we check this in the type checker
+ p[2].nullable = p[1]
+ p[0] = p[2]
+
+
+def p_BasicType(p):
+ """BasicType : CxxID
+ | CxxID '[' ']'
+ | CxxID '?'
+ | CxxUniquePtrInst"""
+ # ID == CxxType; we forbid qnames here,
+ # in favor of the |using| declaration
+ if not isinstance(p[1], TypeSpec):
+ assert (len(p[1]) == 2) or (len(p[1]) == 3)
+ if 2 == len(p[1]):
+ # p[1] is CxxID. isunique = False
+ p[1] = p[1] + (False,)
+ loc, id, isunique = p[1]
+ p[1] = TypeSpec(loc, id)
+ p[1].uniqueptr = isunique
+ if 4 == len(p):
+ p[1].array = True
+ if 3 == len(p):
+ p[1].maybe = True
+ p[0] = p[1]
+
+
+def p_MaybeNullable(p):
+ """MaybeNullable : NULLABLE
+ |"""
+ p[0] = 2 == len(p)
+
+
+# --------------------
+# C++ stuff
+
+
+def p_CxxType(p):
+ """CxxType : QualifiedID
+ | CxxID"""
+ if isinstance(p[1], QualifiedId):
+ p[0] = p[1]
+ else:
+ loc, id = p[1]
+ p[0] = QualifiedId(loc, id)
+
+
+def p_QualifiedID(p):
+ """QualifiedID : QualifiedID COLONCOLON CxxID
+ | CxxID COLONCOLON CxxID"""
+ if isinstance(p[1], QualifiedId):
+ loc, id = p[3]
+ p[1].qualify(id)
+ p[0] = p[1]
+ else:
+ loc1, id1 = p[1]
+ _, id2 = p[3]
+ p[0] = QualifiedId(loc1, id2, [id1])
+
+
+def p_CxxID(p):
+ """CxxID : ID
+ | CxxTemplateInst"""
+ if isinstance(p[1], tuple):
+ p[0] = p[1]
+ else:
+ p[0] = (locFromTok(p, 1), str(p[1]))
+
+
+def p_CxxTemplateInst(p):
+ """CxxTemplateInst : ID '<' ID '>'"""
+ p[0] = (locFromTok(p, 1), str(p[1]) + "<" + str(p[3]) + ">")
+
+
+def p_CxxUniquePtrInst(p):
+ """CxxUniquePtrInst : UNIQUEPTR '<' ID '>'"""
+ p[0] = (locFromTok(p, 1), str(p[3]), True)
+
+
+def p_error(t):
+ lineno, value = _safeLinenoValue(t)
+ _error(Loc(Parser.current.filename, lineno), "bad syntax near `%s'", value)
diff --git a/ipc/ipdl/ipdl/type.py b/ipc/ipdl/ipdl/type.py
new file mode 100644
index 0000000000..76f908dd41
--- /dev/null
+++ b/ipc/ipdl/ipdl/type.py
@@ -0,0 +1,1748 @@
+# vim: set ts=4 sw=4 tw=99 et:
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+import os
+import sys
+
+from ipdl.ast import CxxInclude, Decl, Loc, QualifiedId, StructDecl
+from ipdl.ast import UnionDecl, UsingStmt, Visitor, StringLiteral
+from ipdl.ast import ASYNC, SYNC, INTR
+from ipdl.ast import IN, OUT, INOUT
+from ipdl.ast import NOT_NESTED, INSIDE_SYNC_NESTED, INSIDE_CPOW_NESTED
+from ipdl.ast import priorityList
+import ipdl.builtin as builtin
+from ipdl.util import hash_str
+
+_DELETE_MSG = "__delete__"
+
+
+class TypeVisitor:
+ def __init__(self):
+ self.visited = set()
+
+ def defaultVisit(self, node, *args):
+ raise Exception(
+ "INTERNAL ERROR: no visitor for node type `%s'" % (node.__class__.__name__)
+ )
+
+ def visitVoidType(self, v, *args):
+ pass
+
+ def visitBuiltinCType(self, b, *args):
+ pass
+
+ def visitImportedCxxType(self, t, *args):
+ pass
+
+ def visitMessageType(self, m, *args):
+ for param in m.params:
+ param.accept(self, *args)
+ for ret in m.returns:
+ ret.accept(self, *args)
+ if m.cdtype is not None:
+ m.cdtype.accept(self, *args)
+
+ def visitProtocolType(self, p, *args):
+ # NB: don't visit manager and manages. a naive default impl
+ # could result in an infinite loop
+ pass
+
+ def visitActorType(self, a, *args):
+ a.protocol.accept(self, *args)
+
+ def visitStructType(self, s, *args):
+ if s in self.visited:
+ return
+
+ self.visited.add(s)
+ for field in s.fields:
+ field.accept(self, *args)
+
+ def visitUnionType(self, u, *args):
+ if u in self.visited:
+ return
+
+ self.visited.add(u)
+ for component in u.components:
+ component.accept(self, *args)
+
+ def visitArrayType(self, a, *args):
+ a.basetype.accept(self, *args)
+
+ def visitMaybeType(self, m, *args):
+ m.basetype.accept(self, *args)
+
+ def visitUniquePtrType(self, m, *args):
+ m.basetype.accept(self, *args)
+
+ def visitNotNullType(self, m, *args):
+ m.basetype.accept(self, *args)
+
+ def visitShmemType(self, s, *args):
+ pass
+
+ def visitByteBufType(self, s, *args):
+ pass
+
+ def visitShmemChmodType(self, c, *args):
+ c.shmem.accept(self)
+
+ def visitFDType(self, s, *args):
+ pass
+
+ def visitEndpointType(self, s, *args):
+ pass
+
+ def visitManagedEndpointType(self, s, *args):
+ pass
+
+
+class Type:
+ def __cmp__(self, o):
+ return cmp(self.fullname(), o.fullname())
+
+ def __eq__(self, o):
+ return self.__class__ == o.__class__ and self.fullname() == o.fullname()
+
+ def __hash__(self):
+ return hash_str(self.fullname())
+
+ # Is this a C++ type?
+ def isCxx(self):
+ return False
+
+ # Is this an IPDL type?
+
+ def isIPDL(self):
+ return False
+
+ # Is this type neither compound nor an array?
+
+ def isAtom(self):
+ return False
+
+ def isRefcounted(self):
+ return False
+
+ # Should this type be wrapped in `NotNull<T>` unless marked `nullable`?
+
+ def supportsNullable(self):
+ return False
+
+ def typename(self):
+ return self.__class__.__name__
+
+ def name(self):
+ raise NotImplementedError()
+
+ def fullname(self):
+ raise NotImplementedError()
+
+ def accept(self, visitor, *args):
+ visit = getattr(visitor, "visit" + self.__class__.__name__, None)
+ if visit is None:
+ return getattr(visitor, "defaultVisit")(self, *args)
+ return visit(self, *args)
+
+
+class VoidType(Type):
+ def isCxx(self):
+ return True
+
+ def isIPDL(self):
+ return False
+
+ def isAtom(self):
+ return True
+
+ def name(self):
+ return "void"
+
+ def fullname(self):
+ return "void"
+
+
+VOID = VoidType()
+
+# --------------------
+
+
+class BuiltinCType(Type):
+ def __init__(self, name):
+ self._name = name
+
+ def isCxx(self):
+ return True
+
+ def isAtom(self):
+ return True
+
+ def isSendMoveOnly(self):
+ return False
+
+ def isDataMoveOnly(self):
+ return False
+
+ def name(self):
+ return self._name
+
+ def fullname(self):
+ return self._name
+
+
+class ImportedCxxType(Type):
+ def __init__(self, qname, refcounted, sendmoveonly, datamoveonly):
+ assert isinstance(qname, QualifiedId)
+ self.loc = qname.loc
+ self.qname = qname
+ self.refcounted = refcounted
+ self.sendmoveonly = sendmoveonly
+ self.datamoveonly = datamoveonly
+
+ def isCxx(self):
+ return True
+
+ def isAtom(self):
+ return True
+
+ def isRefcounted(self):
+ return self.refcounted
+
+ def supportsNullable(self):
+ return self.refcounted
+
+ def isSendMoveOnly(self):
+ return self.sendmoveonly
+
+ def isDataMoveOnly(self):
+ return self.datamoveonly
+
+ def name(self):
+ return self.qname.baseid
+
+ def fullname(self):
+ return str(self.qname)
+
+
+# --------------------
+
+
+class IPDLType(Type):
+ def isIPDL(self):
+ return True
+
+ def isMessage(self):
+ return False
+
+ def isProtocol(self):
+ return False
+
+ def isActor(self):
+ return False
+
+ def isStruct(self):
+ return False
+
+ def isUnion(self):
+ return False
+
+ def isArray(self):
+ return False
+
+ def isMaybe(self):
+ return False
+
+ def isUniquePtr(self):
+ return False
+
+ def isNotNull(self):
+ return False
+
+ def isAtom(self):
+ return True
+
+ def isCompound(self):
+ return False
+
+ def isShmem(self):
+ return False
+
+ def isByteBuf(self):
+ return False
+
+ def isFD(self):
+ return False
+
+ def isEndpoint(self):
+ return False
+
+ def isManagedEndpoint(self):
+ return False
+
+ def hasBaseType(self):
+ return False
+
+
+class SendSemanticsType(IPDLType):
+ def __init__(self, nestedRange, sendSemantics):
+ self.nestedRange = nestedRange
+ self.sendSemantics = sendSemantics
+
+ def isAsync(self):
+ return self.sendSemantics == ASYNC
+
+ def isSync(self):
+ return self.sendSemantics == SYNC
+
+ def isInterrupt(self):
+ return self.sendSemantics is INTR
+
+ def sendSemanticsSatisfiedBy(self, greater):
+ def _unwrap(nr):
+ if isinstance(nr, dict):
+ return _unwrap(nr["nested"])
+ elif isinstance(nr, int):
+ return nr
+ else:
+ raise ValueError("Got unexpected nestedRange value: %s" % nr)
+
+ lesser = self
+ lnr0, gnr0, lnr1, gnr1 = (
+ _unwrap(lesser.nestedRange[0]),
+ _unwrap(greater.nestedRange[0]),
+ _unwrap(lesser.nestedRange[1]),
+ _unwrap(greater.nestedRange[1]),
+ )
+ if lnr0 < gnr0 or lnr1 > gnr1:
+ return False
+
+ # Protocols that use intr semantics are not allowed to use
+ # message nesting.
+ if greater.isInterrupt() and lesser.nestedRange != (NOT_NESTED, NOT_NESTED):
+ return False
+
+ if lesser.isAsync():
+ return True
+ elif lesser.isSync() and not greater.isAsync():
+ return True
+ elif greater.isInterrupt():
+ return True
+
+ return False
+
+
+class MessageType(SendSemanticsType):
+ def __init__(
+ self,
+ nested,
+ prio,
+ replyPrio,
+ sendSemantics,
+ direction,
+ ctor=False,
+ dtor=False,
+ cdtype=None,
+ compress=False,
+ tainted=False,
+ lazySend=False,
+ ):
+ assert not (ctor and dtor)
+ assert not (ctor or dtor) or cdtype is not None
+
+ SendSemanticsType.__init__(self, (nested, nested), sendSemantics)
+ self.nested = nested
+ self.prio = prio
+ self.replyPrio = replyPrio
+ self.direction = direction
+ self.params = []
+ self.returns = []
+ self.ctor = ctor
+ self.dtor = dtor
+ self.cdtype = cdtype
+ self.compress = compress
+ self.tainted = tainted
+ self.lazySend = lazySend
+
+ def isMessage(self):
+ return True
+
+ def isCtor(self):
+ return self.ctor
+
+ def isDtor(self):
+ return self.dtor
+
+ def constructedType(self):
+ return self.cdtype
+
+ def isIn(self):
+ return self.direction is IN
+
+ def isOut(self):
+ return self.direction is OUT
+
+ def isInout(self):
+ return self.direction is INOUT
+
+ def hasReply(self):
+ return len(self.returns) or self.isSync() or self.isInterrupt()
+
+ def hasImplicitActorParam(self):
+ return self.isCtor()
+
+
+class ProtocolType(SendSemanticsType):
+ def __init__(self, qname, nested, sendSemantics, refcounted, needsotherpid):
+ SendSemanticsType.__init__(self, (NOT_NESTED, nested), sendSemantics)
+ self.qname = qname
+ self.managers = [] # ProtocolType
+ self.manages = []
+ self.hasDelete = False
+ self.refcounted = refcounted
+ self.needsotherpid = needsotherpid
+
+ def isProtocol(self):
+ return True
+
+ def isRefcounted(self):
+ return self.refcounted
+
+ def hasOtherPid(self):
+ return all(top.needsotherpid for top in self.toplevels())
+
+ def name(self):
+ return self.qname.baseid
+
+ def fullname(self):
+ return str(self.qname)
+
+ def addManager(self, mgrtype):
+ assert mgrtype.isIPDL() and mgrtype.isProtocol()
+ self.managers.append(mgrtype)
+
+ def managedBy(self, mgr):
+ self.managers = list(mgr)
+
+ def toplevel(self):
+ if self.isToplevel():
+ return self
+ for mgr in self.managers:
+ if mgr is not self:
+ return mgr.toplevel()
+
+ def toplevels(self):
+ if self.isToplevel():
+ return [self]
+ toplevels = list()
+ for mgr in self.managers:
+ if mgr is not self:
+ toplevels.extend(mgr.toplevels())
+ return set(toplevels)
+
+ def isManagerOf(self, pt):
+ for managed in self.manages:
+ if pt is managed:
+ return True
+ return False
+
+ def isManagedBy(self, pt):
+ return pt in self.managers
+
+ def isManager(self):
+ return len(self.manages) > 0
+
+ def isManaged(self):
+ return 0 < len(self.managers)
+
+ def isToplevel(self):
+ return not self.isManaged()
+
+ def manager(self):
+ assert 1 == len(self.managers)
+ for mgr in self.managers:
+ return mgr
+
+
+class ActorType(IPDLType):
+ def __init__(self, protocol):
+ self.protocol = protocol
+
+ def isActor(self):
+ return True
+
+ def isRefcounted(self):
+ return self.protocol.isRefcounted()
+
+ def supportsNullable(self):
+ return True
+
+ def name(self):
+ return self.protocol.name()
+
+ def fullname(self):
+ return self.protocol.fullname()
+
+
+class _CompoundType(IPDLType):
+ def __init__(self):
+ self.defined = False # bool
+ self.mutualRec = set() # set(_CompoundType | ArrayType)
+
+ def isAtom(self):
+ return False
+
+ def isCompound(self):
+ return True
+
+ def itercomponents(self):
+ raise Exception('"pure virtual" method')
+
+ def mutuallyRecursiveWith(self, t, exploring=None):
+ """|self| is mutually recursive with |t| iff |self| and |t|
+ are in a cycle in the type graph rooted at |self|. This function
+ looks for such a cycle and returns True if found."""
+ if exploring is None:
+ exploring = set()
+
+ if t.isAtom():
+ return False
+ elif t is self or t in self.mutualRec:
+ return True
+ elif t.hasBaseType():
+ isrec = self.mutuallyRecursiveWith(t.basetype, exploring)
+ if isrec:
+ self.mutualRec.add(t)
+ return isrec
+ elif t in exploring:
+ return False
+
+ exploring.add(t)
+ for c in t.itercomponents():
+ if self.mutuallyRecursiveWith(c, exploring):
+ self.mutualRec.add(c)
+ return True
+ exploring.remove(t)
+
+ return False
+
+
+class StructType(_CompoundType):
+ def __init__(self, qname, fields):
+ _CompoundType.__init__(self)
+ self.qname = qname
+ self.fields = fields # [ Type ]
+
+ def isStruct(self):
+ return True
+
+ def itercomponents(self):
+ for f in self.fields:
+ yield f
+
+ def name(self):
+ return self.qname.baseid
+
+ def fullname(self):
+ return str(self.qname)
+
+
+class UnionType(_CompoundType):
+ def __init__(self, qname, components):
+ _CompoundType.__init__(self)
+ self.qname = qname
+ self.components = components # [ Type ]
+
+ def isUnion(self):
+ return True
+
+ def itercomponents(self):
+ for c in self.components:
+ yield c
+
+ def name(self):
+ return self.qname.baseid
+
+ def fullname(self):
+ return str(self.qname)
+
+
+class ArrayType(IPDLType):
+ def __init__(self, basetype):
+ self.basetype = basetype
+
+ def isAtom(self):
+ return False
+
+ def isArray(self):
+ return True
+
+ def hasBaseType(self):
+ return True
+
+ def name(self):
+ return self.basetype.name() + "[]"
+
+ def fullname(self):
+ return self.basetype.fullname() + "[]"
+
+
+class MaybeType(IPDLType):
+ def __init__(self, basetype):
+ self.basetype = basetype
+
+ def isAtom(self):
+ return False
+
+ def isMaybe(self):
+ return True
+
+ def hasBaseType(self):
+ return True
+
+ def name(self):
+ return self.basetype.name() + "?"
+
+ def fullname(self):
+ return self.basetype.fullname() + "?"
+
+
+class ShmemType(IPDLType):
+ def __init__(self, qname):
+ self.qname = qname
+
+ def isShmem(self):
+ return True
+
+ def name(self):
+ return self.qname.baseid
+
+ def fullname(self):
+ return str(self.qname)
+
+
+class ByteBufType(IPDLType):
+ def __init__(self, qname):
+ self.qname = qname
+
+ def isByteBuf(self):
+ return True
+
+ def name(self):
+ return self.qname.baseid
+
+ def fullname(self):
+ return str(self.qname)
+
+
+class FDType(IPDLType):
+ def __init__(self, qname):
+ self.qname = qname
+
+ def isFD(self):
+ return True
+
+ def name(self):
+ return self.qname.baseid
+
+ def fullname(self):
+ return str(self.qname)
+
+
+class EndpointType(IPDLType):
+ def __init__(self, qname, actor):
+ self.qname = qname
+ self.actor = actor
+
+ def isEndpoint(self):
+ return True
+
+ def name(self):
+ return self.qname.baseid
+
+ def fullname(self):
+ return str(self.qname)
+
+
+class ManagedEndpointType(IPDLType):
+ def __init__(self, qname, actor):
+ self.qname = qname
+ self.actor = actor
+
+ def isManagedEndpoint(self):
+ return True
+
+ def name(self):
+ return self.qname.baseid
+
+ def fullname(self):
+ return str(self.qname)
+
+
+class UniquePtrType(IPDLType):
+ def __init__(self, basetype):
+ self.basetype = basetype
+
+ def isAtom(self):
+ return False
+
+ def isUniquePtr(self):
+ return True
+
+ def hasBaseType(self):
+ return True
+
+ def name(self):
+ return "UniquePtr<" + self.basetype.name() + ">"
+
+ def fullname(self):
+ return "mozilla::UniquePtr<" + self.basetype.fullname() + ">"
+
+
+class NotNullType(IPDLType):
+ def __init__(self, basetype):
+ self.basetype = basetype
+
+ def isAtom(self):
+ return False
+
+ def isNotNull(self):
+ return True
+
+ def hasBaseType(self):
+ return True
+
+ def name(self):
+ return "NotNull<" + self.basetype.name() + ">"
+
+ def fullname(self):
+ return "mozilla::NotNull<" + self.basetype.fullname() + ">"
+
+
+def iteractortypes(t, visited=None):
+ """Iterate over any actor(s) buried in |type|."""
+ if visited is None:
+ visited = set()
+
+ # XXX |yield| semantics makes it hard to use TypeVisitor
+ if not t.isIPDL():
+ return
+ elif t.isActor():
+ yield t
+ elif t.hasBaseType():
+ for actor in iteractortypes(t.basetype, visited):
+ yield actor
+ elif t.isCompound() and t not in visited:
+ visited.add(t)
+ for c in t.itercomponents():
+ for actor in iteractortypes(c, visited):
+ yield actor
+
+
+def hasshmem(type):
+ """Return true iff |type| is shmem or has it buried within."""
+
+ class found(BaseException):
+ pass
+
+ class findShmem(TypeVisitor):
+ def visitShmemType(self, s):
+ raise found()
+
+ try:
+ type.accept(findShmem())
+ except found:
+ return True
+ return False
+
+
+# --------------------
+_builtinloc = Loc("<builtin>", 0)
+
+
+def makeBuiltinUsing(tname):
+ quals = tname.split("::")
+ base = quals.pop()
+ quals = quals[0:]
+ return UsingStmt(_builtinloc, QualifiedId(_builtinloc, base, quals))
+
+
+builtinUsing = [makeBuiltinUsing(t) for t in builtin.Types]
+builtinHeaderIncludes = [CxxInclude(_builtinloc, f) for f in builtin.HeaderIncludes]
+
+
+def errormsg(loc, fmt, *args):
+ while not isinstance(loc, Loc):
+ if loc is None:
+ loc = Loc.NONE
+ else:
+ loc = loc.loc
+ return "%s: error: %s" % (str(loc), fmt % args)
+
+
+# --------------------
+
+
+class SymbolTable:
+ def __init__(self, errors):
+ self.errors = errors
+ self.scopes = [{}] # stack({})
+ self.currentScope = self.scopes[0]
+
+ def enterScope(self):
+ assert isinstance(self.scopes[0], dict)
+ assert isinstance(self.currentScope, dict)
+
+ self.scopes.append({})
+ self.currentScope = self.scopes[-1]
+
+ def exitScope(self):
+ symtab = self.scopes.pop()
+ assert self.currentScope is symtab
+
+ self.currentScope = self.scopes[-1]
+
+ assert isinstance(self.scopes[0], dict)
+ assert isinstance(self.currentScope, dict)
+
+ def lookup(self, sym):
+ # NB: since IPDL doesn't allow any aliased names of different types,
+ # it doesn't matter in which order we walk the scope chain to resolve
+ # |sym|
+ for scope in self.scopes:
+ decl = scope.get(sym, None)
+ if decl is not None:
+ return decl
+ return None
+
+ def declare(self, decl):
+ assert decl.progname or decl.shortname or decl.fullname
+ assert decl.loc
+ assert decl.type
+
+ def tryadd(name):
+ olddecl = self.lookup(name)
+ if olddecl is not None:
+ self.errors.append(
+ errormsg(
+ decl.loc,
+ "redeclaration of symbol `%s', first declared at %s",
+ name,
+ olddecl.loc,
+ )
+ )
+ return
+ self.currentScope[name] = decl
+ decl.scope = self.currentScope
+
+ if decl.progname:
+ tryadd(decl.progname)
+ if decl.shortname:
+ tryadd(decl.shortname)
+ if decl.fullname:
+ tryadd(decl.fullname)
+
+
+class TypeCheck:
+ """This pass sets the .decl attribute of AST nodes for which that is relevant;
+ a decl says where, with what type, and under what name(s) a node was
+ declared.
+
+ With this information, it type checks the AST."""
+
+ def __init__(self):
+ # NB: no IPDL compile will EVER print a warning. A program has
+ # one of two attributes: it is either well typed, or not well typed.
+ self.errors = [] # [ string ]
+
+ def check(self, tu, errout=sys.stderr):
+ def runpass(tcheckpass):
+ tu.accept(tcheckpass)
+ if len(self.errors):
+ self.reportErrors(errout)
+ return False
+ return True
+
+ # tag each relevant node with "decl" information, giving type, name,
+ # and location of declaration
+ if not runpass(GatherDecls(builtinUsing, self.errors)):
+ return False
+
+ # now that the nodes have decls, type checking is much easier.
+ if not runpass(CheckTypes(self.errors)):
+ return False
+
+ return True
+
+ def reportErrors(self, errout):
+ for error in self.errors:
+ print(error, file=errout)
+
+
+class TcheckVisitor(Visitor):
+ def __init__(self, errors):
+ self.errors = errors
+
+ def error(self, loc, fmt, *args):
+ self.errors.append(errormsg(loc, fmt, *args))
+
+
+class GatherDecls(TcheckVisitor):
+ def __init__(self, builtinUsing, errors):
+ TcheckVisitor.__init__(self, errors)
+
+ # |self.symtab| is the symbol table for the translation unit
+ # currently being visited
+ self.symtab = None
+ self.builtinUsing = builtinUsing
+
+ def declare(
+ self, loc, type, shortname=None, fullname=None, progname=None, attributes={}
+ ):
+ d = Decl(loc)
+ d.type = type
+ d.progname = progname
+ d.shortname = shortname
+ d.fullname = fullname
+ d.attributes = attributes
+ self.symtab.declare(d)
+ return d
+
+ # Check that only attributes allowed by an attribute spec are present
+ # within the given attribute dictionary. The spec value may be either
+ # `None`, for a valueless attribute, a list of valid attribute values, or a
+ # callable which returns a truthy value if the attribute is valid.
+ def checkAttributes(self, attributes, spec):
+ for attr in attributes.values():
+ if attr.name not in spec:
+ self.error(attr.loc, "unknown attribute `%s'", attr.name)
+ continue
+
+ aspec = spec[attr.name]
+ if aspec is None:
+ if attr.value is not None:
+ self.error(
+ attr.loc,
+ "unexpected value for valueless attribute `%s'",
+ attr.name,
+ )
+ elif isinstance(aspec, (list, tuple)):
+ if not any(
+ isinstance(attr.value, s)
+ if isinstance(s, type)
+ else attr.value == s
+ for s in aspec
+ ):
+ self.error(
+ attr.loc,
+ "invalid value for attribute `%s', expected one of: %s",
+ attr.name,
+ ", ".join(
+ s.__name__ if isinstance(s, type) else str(s) for s in aspec
+ ),
+ )
+ elif callable(aspec):
+ if not aspec(attr.value):
+ self.error(attr.loc, "invalid value for attribute `%s'", attr.name)
+ else:
+ raise Exception("INTERNAL ERROR: Invalid attribute spec")
+
+ def visitTranslationUnit(self, tu):
+ # all TranslationUnits declare symbols in global scope
+ if hasattr(tu, "visited"):
+ return
+ tu.visited = True
+ savedSymtab = self.symtab
+ self.symtab = SymbolTable(self.errors)
+
+ # pretend like the translation unit "using"-ed these for the
+ # sake of type checking and C++ code generation
+ tu.builtinUsing = self.builtinUsing
+
+ # for everyone's sanity, enforce that the filename and tu name
+ # match
+ basefilename = os.path.basename(tu.filename)
+ expectedfilename = "%s.ipdl" % (tu.name)
+ if not tu.protocol:
+ # header
+ expectedfilename += "h"
+ if basefilename != expectedfilename:
+ self.error(
+ tu.loc,
+ "expected file for translation unit `%s' to be named `%s'; instead it's named `%s'", # NOQA: E501
+ tu.name,
+ expectedfilename,
+ basefilename,
+ )
+
+ if tu.protocol:
+ assert tu.name == tu.protocol.name
+
+ p = tu.protocol
+
+ self.checkAttributes(
+ p.attributes,
+ {
+ "ManualDealloc": None,
+ "NestedUpTo": ("not", "inside_sync", "inside_cpow"),
+ "NeedsOtherPid": None,
+ "ChildImpl": ("virtual", StringLiteral),
+ "ParentImpl": ("virtual", StringLiteral),
+ },
+ )
+
+ # FIXME/cjones: it's a little weird and counterintuitive
+ # to put both the namespace and non-namespaced name in the
+ # global scope. try to figure out something better; maybe
+ # a type-neutral |using| that works for C++ and protocol
+ # types?
+ qname = p.qname()
+ fullname = str(qname)
+ p.decl = self.declare(
+ loc=p.loc,
+ type=ProtocolType(
+ qname,
+ p.nestedUpTo(),
+ p.sendSemantics,
+ "ManualDealloc" not in p.attributes,
+ "NeedsOtherPid" in p.attributes,
+ ),
+ shortname=p.name,
+ fullname=fullname,
+ )
+
+ p.parentEndpointDecl = self.declare(
+ loc=p.loc,
+ type=EndpointType(
+ QualifiedId(
+ p.loc, "Endpoint<" + fullname + "Parent>", ["mozilla", "ipc"]
+ ),
+ ActorType(p.decl.type),
+ ),
+ shortname="Endpoint<" + p.name + "Parent>",
+ )
+ p.childEndpointDecl = self.declare(
+ loc=p.loc,
+ type=EndpointType(
+ QualifiedId(
+ p.loc, "Endpoint<" + fullname + "Child>", ["mozilla", "ipc"]
+ ),
+ ActorType(p.decl.type),
+ ),
+ shortname="Endpoint<" + p.name + "Child>",
+ )
+
+ p.parentManagedEndpointDecl = self.declare(
+ loc=p.loc,
+ type=ManagedEndpointType(
+ QualifiedId(
+ p.loc,
+ "ManagedEndpoint<" + fullname + "Parent>",
+ ["mozilla", "ipc"],
+ ),
+ ActorType(p.decl.type),
+ ),
+ shortname="ManagedEndpoint<" + p.name + "Parent>",
+ )
+ p.childManagedEndpointDecl = self.declare(
+ loc=p.loc,
+ type=ManagedEndpointType(
+ QualifiedId(
+ p.loc,
+ "ManagedEndpoint<" + fullname + "Child>",
+ ["mozilla", "ipc"],
+ ),
+ ActorType(p.decl.type),
+ ),
+ shortname="ManagedEndpoint<" + p.name + "Child>",
+ )
+
+ # XXX ugh, this sucks. but we need this information to compute
+ # what friend decls we need in generated C++
+ p.decl.type._ast = p
+
+ # make sure we have decls for all dependent protocols
+ for pinc in tu.includes:
+ pinc.accept(self)
+
+ # declare imported (and builtin) C and C++ types
+ for ctype in builtin.CTypes:
+ self.declare(
+ loc=_builtinloc,
+ type=BuiltinCType(ctype),
+ shortname=ctype,
+ )
+ for using in tu.builtinUsing:
+ using.accept(self)
+ for using in tu.using:
+ using.accept(self)
+
+ # first pass to "forward-declare" all structs and unions in
+ # order to support recursive definitions
+ for su in tu.structsAndUnions:
+ self.declareStructOrUnion(su)
+
+ # second pass to check each definition
+ for su in tu.structsAndUnions:
+ su.accept(self)
+
+ if tu.protocol:
+ # grab symbols in the protocol itself
+ p.accept(self)
+
+ self.symtab = savedSymtab
+
+ def declareStructOrUnion(self, su):
+ if hasattr(su, "decl"):
+ self.symtab.declare(su.decl)
+ return
+
+ qname = su.qname()
+ fullname = str(qname)
+
+ if isinstance(su, StructDecl):
+ sutype = StructType(qname, [])
+ elif isinstance(su, UnionDecl):
+ sutype = UnionType(qname, [])
+ else:
+ assert 0 and "unknown type"
+
+ # XXX more suckage. this time for pickling structs/unions
+ # declared in headers.
+ sutype._ast = su
+
+ su.decl = self.declare(
+ loc=su.loc, type=sutype, shortname=su.name, fullname=fullname
+ )
+
+ def visitInclude(self, inc):
+ if inc.tu is None:
+ self.error(
+ inc.loc,
+ "(type checking here will be unreliable because of an earlier error)",
+ )
+ return
+ inc.tu.accept(self)
+ if inc.tu.protocol:
+ self.symtab.declare(inc.tu.protocol.decl)
+ self.symtab.declare(inc.tu.protocol.parentEndpointDecl)
+ self.symtab.declare(inc.tu.protocol.childEndpointDecl)
+ self.symtab.declare(inc.tu.protocol.parentManagedEndpointDecl)
+ self.symtab.declare(inc.tu.protocol.childManagedEndpointDecl)
+ else:
+ # This is a header. Import its "exported" globals into
+ # our scope.
+ for using in inc.tu.using:
+ using.accept(self)
+ for su in inc.tu.structsAndUnions:
+ self.declareStructOrUnion(su)
+
+ def visitStructDecl(self, sd):
+ # If we've already processed this struct, don't do it again.
+ if hasattr(sd, "visited"):
+ return
+
+ stype = sd.decl.type
+
+ self.symtab.enterScope()
+ sd.visited = True
+
+ self.checkAttributes(sd.attributes, {"Comparable": None})
+
+ for f in sd.fields:
+ ftypedecl = self.symtab.lookup(str(f.typespec))
+ if ftypedecl is None:
+ self.error(
+ f.loc,
+ "field `%s' of struct `%s' has unknown type `%s'",
+ f.name,
+ sd.name,
+ str(f.typespec),
+ )
+ continue
+
+ f.decl = self.declare(
+ loc=f.loc,
+ type=self._canonicalType(ftypedecl.type, f.typespec),
+ shortname=f.name,
+ fullname=None,
+ )
+ stype.fields.append(f.decl.type)
+
+ self.symtab.exitScope()
+
+ def visitUnionDecl(self, ud):
+ utype = ud.decl.type
+
+ # If we've already processed this union, don't do it again.
+ if len(utype.components):
+ return
+
+ self.checkAttributes(ud.attributes, {"Comparable": None})
+
+ for c in ud.components:
+ cdecl = self.symtab.lookup(str(c))
+ if cdecl is None:
+ self.error(
+ c.loc, "unknown component type `%s' of union `%s'", str(c), ud.name
+ )
+ continue
+ utype.components.append(self._canonicalType(cdecl.type, c))
+
+ def visitUsingStmt(self, using):
+ fullname = str(using.type)
+
+ self.checkAttributes(
+ using.attributes,
+ {
+ "MoveOnly": (None, "data", "send"),
+ "RefCounted": None,
+ },
+ )
+
+ if fullname == "::mozilla::ipc::Shmem":
+ ipdltype = ShmemType(using.type)
+ elif fullname == "::mozilla::ipc::ByteBuf":
+ ipdltype = ByteBufType(using.type)
+ elif fullname == "::mozilla::ipc::FileDescriptor":
+ ipdltype = FDType(using.type)
+ else:
+ ipdltype = ImportedCxxType(
+ using.type,
+ using.isRefcounted(),
+ using.isSendMoveOnly(),
+ using.isDataMoveOnly(),
+ )
+ existingType = self.symtab.lookup(ipdltype.fullname())
+ if existingType and existingType.fullname == ipdltype.fullname():
+ if ipdltype.isRefcounted() != existingType.type.isRefcounted():
+ self.error(
+ using.loc,
+ "inconsistent refcounted status of type `%s`",
+ str(using.type),
+ )
+ if (
+ ipdltype.isSendMoveOnly() != existingType.type.isSendMoveOnly()
+ or ipdltype.isDataMoveOnly() != existingType.type.isDataMoveOnly()
+ ):
+ self.error(
+ using.loc,
+ "inconsistent moveonly status of type `%s`",
+ str(using.type),
+ )
+ using.decl = existingType
+ return
+ using.decl = self.declare(
+ loc=using.loc,
+ type=ipdltype,
+ shortname=using.type.baseid,
+ fullname=fullname,
+ )
+
+ def visitProtocol(self, p):
+ # protocol scope
+ self.symtab.enterScope()
+
+ seenmgrs = set()
+ for mgr in p.managers:
+ if mgr.name in seenmgrs:
+ self.error(mgr.loc, "manager `%s' appears multiple times", mgr.name)
+ continue
+
+ seenmgrs.add(mgr.name)
+ mgr.of = p
+ mgr.accept(self)
+
+ for managed in p.managesStmts:
+ managed.manager = p
+ managed.accept(self)
+
+ if not (p.managers or p.messageDecls or p.managesStmts):
+ self.error(p.loc, "top-level protocol `%s' cannot be empty", p.name)
+
+ setattr(self, "currentProtocolDecl", p.decl)
+ for msg in p.messageDecls:
+ msg.accept(self)
+ del self.currentProtocolDecl
+
+ p.decl.type.hasDelete = not not self.symtab.lookup(_DELETE_MSG)
+ if not (p.decl.type.hasDelete or p.decl.type.isToplevel()):
+ self.error(
+ p.loc,
+ "destructor declaration `%s(...)' required for managed protocol `%s'",
+ _DELETE_MSG,
+ p.name,
+ )
+
+ if not p.decl.type.isToplevel() and p.decl.type.needsotherpid:
+ self.error(p.loc, "[NeedsOtherPid] only applies to toplevel protocols")
+
+ if p.decl.type.isToplevel() and not p.decl.type.isRefcounted():
+ self.error(p.loc, "Toplevel protocols cannot be [ManualDealloc]")
+
+ # FIXME/cjones declare all the little C++ thingies that will
+ # be generated. they're not relevant to IPDL itself, but
+ # those ("invisible") symbols can clash with others in the
+ # IPDL spec, and we'd like to catch those before C++ compilers
+ # are allowed to obfuscate the error
+
+ self.symtab.exitScope()
+
+ def visitManager(self, mgr):
+ mgrdecl = self.symtab.lookup(mgr.name)
+ pdecl = mgr.of.decl
+ assert pdecl
+
+ pname, mgrname = pdecl.shortname, mgr.name
+ loc = mgr.loc
+
+ if mgrdecl is None:
+ self.error(
+ loc,
+ "protocol `%s' referenced as |manager| of `%s' has not been declared",
+ mgrname,
+ pname,
+ )
+ elif not isinstance(mgrdecl.type, ProtocolType):
+ self.error(
+ loc,
+ "entity `%s' referenced as |manager| of `%s' is not of `protocol' type; instead it is of type `%s'", # NOQA: E501
+ mgrname,
+ pname,
+ mgrdecl.type.typename(),
+ )
+ else:
+ mgr.decl = mgrdecl
+ pdecl.type.addManager(mgrdecl.type)
+
+ def visitManagesStmt(self, mgs):
+ mgsdecl = self.symtab.lookup(mgs.name)
+ pdecl = mgs.manager.decl
+ assert pdecl
+
+ pname, mgsname = pdecl.shortname, mgs.name
+ loc = mgs.loc
+
+ if mgsdecl is None:
+ self.error(
+ loc,
+ "protocol `%s', managed by `%s', has not been declared",
+ mgsname,
+ pname,
+ )
+ elif not isinstance(mgsdecl.type, ProtocolType):
+ self.error(
+ loc,
+ "%s declares itself managing a non-`protocol' entity `%s' of type `%s'",
+ pname,
+ mgsname,
+ mgsdecl.type.typename(),
+ )
+ else:
+ mgs.decl = mgsdecl
+ pdecl.type.manages.append(mgsdecl.type)
+
+ def visitMessageDecl(self, md):
+ msgname = md.name
+ loc = md.loc
+
+ self.checkAttributes(
+ md.attributes,
+ {
+ "Tainted": None,
+ "Compress": (None, "all"),
+ "Priority": priorityList,
+ "ReplyPriority": priorityList,
+ "Nested": ("not", "inside_sync", "inside_cpow"),
+ "LegacyIntr": None,
+ "VirtualSendImpl": None,
+ "LazySend": None,
+ },
+ )
+
+ if md.sendSemantics is INTR and "LegacyIntr" not in md.attributes:
+ self.error(
+ loc,
+ "intr message `%s' allowed only with [LegacyIntr]; DO NOT USE IN SHIPPING CODE",
+ msgname,
+ )
+
+ if md.sendSemantics is INTR and "Priority" in md.attributes:
+ self.error(loc, "intr message `%s' cannot specify [Priority]", msgname)
+
+ if md.sendSemantics is INTR and "Nested" in md.attributes:
+ self.error(loc, "intr message `%s' cannot specify [Nested]", msgname)
+
+ if md.sendSemantics is not ASYNC and "LazySend" in md.attributes:
+ self.error(loc, "non-async message `%s' cannot specify [LazySend]", msgname)
+
+ if md.sendSemantics is not ASYNC and "ReplyPriority" in md.attributes:
+ self.error(
+ loc, "non-async message `%s' cannot specify [ReplyPriority]", msgname
+ )
+
+ if not md.outParams and "ReplyPriority" in md.attributes:
+ self.error(
+ loc, "non-returns message `%s' cannot specify [ReplyPriority]", msgname
+ )
+
+ isctor = False
+ isdtor = False
+ cdtype = None
+
+ decl = self.symtab.lookup(msgname)
+ if decl is not None and decl.type.isProtocol():
+ # probably a ctor. we'll check validity later.
+ msgname += "Constructor"
+ isctor = True
+ cdtype = decl.type
+ elif decl is not None:
+ self.error(
+ loc,
+ "message name `%s' already declared as `%s'",
+ msgname,
+ decl.type.typename(),
+ )
+ # if we error here, no big deal; move on to find more
+
+ if _DELETE_MSG == msgname:
+ isdtor = True
+ cdtype = self.currentProtocolDecl.type
+
+ # enter message scope
+ self.symtab.enterScope()
+
+ msgtype = MessageType(
+ nested=md.nested(),
+ prio=md.priority(),
+ replyPrio=md.replyPriority(),
+ sendSemantics=md.sendSemantics,
+ direction=md.direction,
+ ctor=isctor,
+ dtor=isdtor,
+ cdtype=cdtype,
+ compress=md.attributes.get("Compress"),
+ tainted="Tainted" in md.attributes,
+ lazySend="LazySend" in md.attributes,
+ )
+
+ # replace inparam Param nodes with proper Decls
+ def paramToDecl(param):
+ self.checkAttributes(
+ param.attributes,
+ {
+ # Passback indicates that the argument is unused by the Parent and is
+ # merely returned to the Child later.
+ # AllValid indicates that the entire span of values representable by
+ # the type are acceptable. e.g. 0-255 in a uint8
+ "NoTaint": ("passback", "allvalid")
+ },
+ )
+
+ ptname = param.typespec.basename()
+ ploc = param.typespec.loc
+
+ if "NoTaint" in param.attributes and "Tainted" not in md.attributes:
+ self.error(
+ ploc,
+ "argument typename `%s' of message `%s' has a NoTaint attribute, but the message lacks the Tainted attribute",
+ ptname,
+ msgname,
+ )
+
+ ptdecl = self.symtab.lookup(ptname)
+ if ptdecl is None:
+ self.error(
+ ploc,
+ "argument typename `%s' of message `%s' has not been declared",
+ ptname,
+ msgname,
+ )
+ ptype = VOID
+ else:
+ ptype = self._canonicalType(ptdecl.type, param.typespec)
+ return self.declare(
+ loc=ploc, type=ptype, progname=param.name, attributes=param.attributes
+ )
+
+ for i, inparam in enumerate(md.inParams):
+ pdecl = paramToDecl(inparam)
+ msgtype.params.append(pdecl.type)
+ md.inParams[i] = pdecl
+ for i, outparam in enumerate(md.outParams):
+ pdecl = paramToDecl(outparam)
+ msgtype.returns.append(pdecl.type)
+ md.outParams[i] = pdecl
+
+ self.symtab.exitScope()
+
+ md.decl = self.declare(loc=loc, type=msgtype, progname=msgname)
+ md.protocolDecl = self.currentProtocolDecl
+ md.decl._md = md
+
+ def _canonicalType(self, itype, typespec):
+ loc = typespec.loc
+ if typespec.uniqueptr:
+ itype = UniquePtrType(itype)
+
+ if itype.isIPDL() and itype.isProtocol():
+ itype = ActorType(itype)
+
+ if itype.supportsNullable():
+ if not typespec.nullable:
+ itype = NotNullType(itype)
+ elif typespec.nullable:
+ self.error(
+ loc, "`nullable' qualifier for type `%s' is unsupported", itype.name()
+ )
+
+ if typespec.array:
+ itype = ArrayType(itype)
+
+ if typespec.maybe:
+ itype = MaybeType(itype)
+
+ return itype
+
+
+# -----------------------------------------------------------------------------
+
+
+def checkcycles(p, stack=None):
+ cycles = []
+
+ if stack is None:
+ stack = []
+
+ for cp in p.manages:
+ # special case for self-managed protocols
+ if cp is p:
+ continue
+
+ if cp in stack:
+ return [stack + [p, cp]]
+ cycles += checkcycles(cp, stack + [p])
+
+ return cycles
+
+
+def formatcycles(cycles):
+ r = []
+ for cycle in cycles:
+ s = " -> ".join([ptype.name() for ptype in cycle])
+ r.append("`%s'" % s)
+ return ", ".join(r)
+
+
+def fullyDefined(t, exploring=None):
+ """The rules for "full definition" of a type are
+ defined(atom) := true
+ defined(array basetype) := defined(basetype)
+ defined(struct f1 f2...) := defined(f1) and defined(f2) and ...
+ defined(union c1 c2 ...) := defined(c1) or defined(c2) or ..."""
+ if exploring is None:
+ exploring = set()
+
+ if t.isAtom():
+ return True
+ elif t.hasBaseType():
+ return fullyDefined(t.basetype, exploring)
+ elif t.defined:
+ return True
+ assert t.isCompound()
+
+ if t in exploring:
+ return False
+
+ exploring.add(t)
+ for c in t.itercomponents():
+ cdefined = fullyDefined(c, exploring)
+ if t.isStruct() and not cdefined:
+ t.defined = False
+ break
+ elif t.isUnion() and cdefined:
+ t.defined = True
+ break
+ else:
+ if t.isStruct():
+ t.defined = True
+ elif t.isUnion():
+ t.defined = False
+ exploring.remove(t)
+
+ return t.defined
+
+
+class CheckTypes(TcheckVisitor):
+ def __init__(self, errors):
+ TcheckVisitor.__init__(self, errors)
+ self.visited = set()
+ self.ptype = None
+
+ def visitInclude(self, inc):
+ if inc.tu.filename in self.visited:
+ return
+ self.visited.add(inc.tu.filename)
+ if inc.tu.protocol:
+ inc.tu.protocol.accept(self)
+
+ def visitStructDecl(self, sd):
+ if not fullyDefined(sd.decl.type):
+ self.error(sd.decl.loc, "struct `%s' is only partially defined", sd.name)
+
+ def visitUnionDecl(self, ud):
+ if not fullyDefined(ud.decl.type):
+ self.error(ud.decl.loc, "union `%s' is only partially defined", ud.name)
+
+ def visitProtocol(self, p):
+ self.ptype = p.decl.type
+
+ # check that we require no more "power" than our manager protocols
+ ptype, pname = p.decl.type, p.decl.shortname
+
+ for mgrtype in ptype.managers:
+ if mgrtype is not None and not ptype.sendSemanticsSatisfiedBy(mgrtype):
+ self.error(
+ p.decl.loc,
+ "protocol `%s' requires more powerful send semantics than its manager `%s' provides", # NOQA: E501
+ pname,
+ mgrtype.name(),
+ )
+
+ if ptype.isInterrupt() and ptype.nestedRange != (NOT_NESTED, NOT_NESTED):
+ self.error(
+ p.decl.loc, "intr protocol `%s' cannot specify [NestedUpTo]", p.name
+ )
+
+ if ptype.isToplevel():
+ cycles = checkcycles(p.decl.type)
+ if cycles:
+ self.error(
+ p.decl.loc,
+ "cycle(s) detected in manager/manages hierarchy: %s",
+ formatcycles(cycles),
+ )
+
+ if 1 == len(ptype.managers) and ptype is ptype.manager():
+ self.error(
+ p.decl.loc, "top-level protocol `%s' cannot manage itself", p.name
+ )
+
+ return Visitor.visitProtocol(self, p)
+
+ def visitManagesStmt(self, mgs):
+ pdecl = mgs.manager.decl
+ ptype, pname = pdecl.type, pdecl.shortname
+
+ mgsdecl = mgs.decl
+ mgstype, mgsname = mgsdecl.type, mgsdecl.shortname
+
+ loc = mgs.loc
+
+ # we added this information; sanity check it
+ assert ptype.isManagerOf(mgstype)
+
+ # check that the "managed" protocol agrees
+ if not mgstype.isManagedBy(ptype):
+ self.error(
+ loc,
+ "|manages| declaration in protocol `%s' does not match any |manager| declaration in protocol `%s'", # NOQA: E501
+ pname,
+ mgsname,
+ )
+
+ def visitManager(self, mgr):
+ pdecl = mgr.of.decl
+ ptype, pname = pdecl.type, pdecl.shortname
+
+ mgrdecl = mgr.decl
+ mgrtype, mgrname = mgrdecl.type, mgrdecl.shortname
+
+ # we added this information; sanity check it
+ assert ptype.isManagedBy(mgrtype)
+
+ loc = mgr.loc
+
+ # check that the "manager" protocol agrees
+ if not mgrtype.isManagerOf(ptype):
+ self.error(
+ loc,
+ "|manager| declaration in protocol `%s' does not match any |manages| declaration in protocol `%s'", # NOQA: E501
+ pname,
+ mgrname,
+ )
+
+ def visitMessageDecl(self, md):
+ mtype, mname = md.decl.type, md.decl.progname
+ ptype, pname = md.protocolDecl.type, md.protocolDecl.shortname
+
+ loc = md.decl.loc
+
+ if mtype.nested == INSIDE_SYNC_NESTED and not mtype.isSync():
+ self.error(
+ loc,
+ "inside_sync nested messages must be sync (here, message `%s' in protocol `%s')",
+ mname,
+ pname,
+ )
+
+ if mtype.nested == INSIDE_CPOW_NESTED and (mtype.isOut() or mtype.isInout()):
+ self.error(
+ loc,
+ "inside_cpow nested parent-to-child messages are verboten (here, message `%s' in protocol `%s')", # NOQA: E501
+ mname,
+ pname,
+ )
+
+ # We allow inside_sync messages that are themselves sync to be sent from the
+ # parent. Normal and inside_cpow nested messages that are sync can only come from
+ # the child.
+ if (
+ mtype.isSync()
+ and mtype.nested == NOT_NESTED
+ and (mtype.isOut() or mtype.isInout())
+ ):
+ self.error(
+ loc,
+ "sync parent-to-child messages are verboten (here, message `%s' in protocol `%s')",
+ mname,
+ pname,
+ )
+
+ if not mtype.sendSemanticsSatisfiedBy(ptype):
+ self.error(
+ loc,
+ "message `%s' requires more powerful send semantics than its protocol `%s' provides", # NOQA: E501
+ mname,
+ pname,
+ )
+
+ if (mtype.isCtor() or mtype.isDtor()) and mtype.isAsync() and mtype.returns:
+ self.error(
+ loc, "asynchronous ctor/dtor message `%s' declares return values", mname
+ )
+
+ if mtype.compress and (not mtype.isAsync() or mtype.isCtor() or mtype.isDtor()):
+
+ if mtype.isCtor() or mtype.isDtor():
+ message_type = "constructor" if mtype.isCtor() else "destructor"
+ error_message = (
+ "%s messages can't use compression (here, in protocol `%s')"
+ % (message_type, pname)
+ )
+ else:
+ error_message = (
+ "message `%s' in protocol `%s' requests compression but is not async"
+ % (mname, pname) # NOQA: E501
+ )
+
+ self.error(loc, error_message)
+
+ if mtype.isCtor() and not ptype.isManagerOf(mtype.constructedType()):
+ self.error(
+ loc,
+ "ctor for protocol `%s', which is not managed by protocol `%s'",
+ mname[: -len("constructor")],
+ pname,
+ )
diff --git a/ipc/ipdl/ipdl/util.py b/ipc/ipdl/ipdl/util.py
new file mode 100644
index 0000000000..60d9c904e2
--- /dev/null
+++ b/ipc/ipdl/ipdl/util.py
@@ -0,0 +1,12 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+import zlib
+
+
+# The built-in hash over the str type is non-deterministic, so we need to do
+# this instead.
+def hash_str(s):
+ assert isinstance(s, str)
+ return zlib.adler32(s.encode("utf-8"))