diff options
Diffstat (limited to 'ipc/ipdl/ipdl')
-rw-r--r-- | ipc/ipdl/ipdl/__init__.py | 98 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/ast.py | 468 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/builtin.py | 76 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/cgen.py | 108 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/checker.py | 79 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/cxx/__init__.py | 3 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/cxx/ast.py | 1033 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/cxx/cgen.py | 557 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/cxx/code.py | 187 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/lower.py | 5688 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/parser.py | 680 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/type.py | 1748 | ||||
-rw-r--r-- | ipc/ipdl/ipdl/util.py | 12 |
13 files changed, 10737 insertions, 0 deletions
diff --git a/ipc/ipdl/ipdl/__init__.py b/ipc/ipdl/ipdl/__init__.py new file mode 100644 index 0000000000..50ceb4f953 --- /dev/null +++ b/ipc/ipdl/ipdl/__init__.py @@ -0,0 +1,98 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +__all__ = [ + "gencxx", + "genipdl", + "parse", + "typecheck", + "writeifmodified", + "checkSyncMessage", + "checkFixedSyncMessages", +] + +import os +import sys +from io import StringIO + +from ipdl.cgen import IPDLCodeGen +from ipdl.lower import LowerToCxx, msgenums +from ipdl.parser import Parser, ParseError +from ipdl.type import TypeCheck +from ipdl.checker import checkSyncMessage, checkFixedSyncMessages + +from ipdl.cxx.cgen import CxxCodeGen + + +def parse(specstring, filename="/stdin", includedirs=[], errout=sys.stderr): + """Return an IPDL AST if parsing was successful. Print errors to |errout| + if it is not.""" + # The file type and name are later enforced by the type checker. + # This is just a hint to the parser. + prefix, ext = os.path.splitext(filename) + name = os.path.basename(prefix) + if ext == ".ipdlh": + type = "header" + else: + type = "protocol" + + try: + return Parser(type, name).parse( + specstring, os.path.abspath(filename), includedirs + ) + except ParseError as p: + print(p, file=errout) + return None + + +def typecheck(ast, errout=sys.stderr): + """Return True iff |ast| is well typed. Print errors to |errout| if + it is not.""" + return TypeCheck().check(ast, errout) + + +def gencxx(ipdlfilename, ast, outheadersdir, outcppdir, segmentcapacitydict): + headers, cpps = LowerToCxx().lower(ast, segmentcapacitydict) + + def resolveHeader(hdr): + return [ + hdr, + os.path.join( + outheadersdir, *([ns.name for ns in ast.namespaces] + [hdr.name]) + ), + ] + + def resolveCpp(cpp): + return [cpp, os.path.join(outcppdir, cpp.name)] + + for ast, filename in [resolveHeader(hdr) for hdr in headers] + [ + resolveCpp(cpp) for cpp in cpps + ]: + tempfile = StringIO() + CxxCodeGen(tempfile).cgen(ast) + writeifmodified(tempfile.getvalue(), filename) + + +def genipdl(ast, outdir): + return IPDLCodeGen().cgen(ast) + + +def genmsgenum(ast): + return msgenums(ast.protocol, pretty=True) + + +def writeifmodified(contents, file): + contents = contents.encode("utf-8") + dir = os.path.dirname(file) + os.path.exists(dir) or os.makedirs(dir) + + oldcontents = None + if os.path.exists(file): + fd = open(file, "rb") + oldcontents = fd.read() + fd.close() + if oldcontents != contents: + fd = open(file, "wb") + fd.write(contents) + fd.close() diff --git a/ipc/ipdl/ipdl/ast.py b/ipc/ipdl/ipdl/ast.py new file mode 100644 index 0000000000..e50cbb5c65 --- /dev/null +++ b/ipc/ipdl/ipdl/ast.py @@ -0,0 +1,468 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +from .util import hash_str + + +NOT_NESTED = 1 +INSIDE_SYNC_NESTED = 2 +INSIDE_CPOW_NESTED = 3 + +NESTED_ATTR_MAP = { + "not": NOT_NESTED, + "inside_sync": INSIDE_SYNC_NESTED, + "inside_cpow": INSIDE_CPOW_NESTED, +} + +# Each element of this list is the IPDL source representation of a priority. +priorityList = ["normal", "input", "vsync", "mediumhigh", "control"] + +priorityAttrMap = {src: idx for idx, src in enumerate(priorityList)} + +NORMAL_PRIORITY = priorityAttrMap["normal"] + + +class Visitor: + def defaultVisit(self, node): + raise Exception( + "INTERNAL ERROR: no visitor for node type `%s'" % (node.__class__.__name__) + ) + + def visitTranslationUnit(self, tu): + for cxxInc in tu.cxxIncludes: + cxxInc.accept(self) + for inc in tu.includes: + inc.accept(self) + for su in tu.structsAndUnions: + su.accept(self) + for using in tu.builtinUsing: + using.accept(self) + for using in tu.using: + using.accept(self) + if tu.protocol: + tu.protocol.accept(self) + + def visitCxxInclude(self, inc): + pass + + def visitInclude(self, inc): + # Note: we don't visit the child AST here, because that needs delicate + # and pass-specific handling + pass + + def visitStructDecl(self, struct): + for f in struct.fields: + f.accept(self) + for a in struct.attributes.values(): + a.accept(self) + + def visitStructField(self, field): + field.typespec.accept(self) + + def visitUnionDecl(self, union): + for t in union.components: + t.accept(self) + for a in union.attributes.values(): + a.accept(self) + + def visitUsingStmt(self, using): + for a in using.attributes.values(): + a.accept(self) + + def visitProtocol(self, p): + for namespace in p.namespaces: + namespace.accept(self) + for mgr in p.managers: + mgr.accept(self) + for managed in p.managesStmts: + managed.accept(self) + for msgDecl in p.messageDecls: + msgDecl.accept(self) + for a in p.attributes.values(): + a.accept(self) + + def visitNamespace(self, ns): + pass + + def visitManager(self, mgr): + pass + + def visitManagesStmt(self, mgs): + pass + + def visitMessageDecl(self, md): + for inParam in md.inParams: + inParam.accept(self) + for outParam in md.outParams: + outParam.accept(self) + for a in md.attributes.values(): + a.accept(self) + + def visitParam(self, decl): + for a in decl.attributes.values(): + a.accept(self) + + def visitTypeSpec(self, ts): + pass + + def visitAttribute(self, a): + if isinstance(a.value, Node): + a.value.accept(self) + + def visitStringLiteral(self, sl): + pass + + def visitDecl(self, d): + for a in d.attributes.values(): + a.accept(self) + + +class Loc: + def __init__(self, filename="<??>", lineno=0): + assert filename + self.filename = filename + self.lineno = lineno + + def __repr__(self): + return "%r:%r" % (self.filename, self.lineno) + + def __str__(self): + return "%s:%s" % (self.filename, self.lineno) + + +Loc.NONE = Loc(filename="<??>", lineno=0) + + +class _struct: + pass + + +class Node: + def __init__(self, loc=Loc.NONE): + self.loc = loc + + def accept(self, visitor): + visit = getattr(visitor, "visit" + self.__class__.__name__, None) + if visit is None: + return getattr(visitor, "defaultVisit")(self) + return visit(self) + + def addAttrs(self, attrsName): + if not hasattr(self, attrsName): + setattr(self, attrsName, _struct()) + + +class NamespacedNode(Node): + def __init__(self, loc=Loc.NONE, name=None): + Node.__init__(self, loc) + self.name = name + self.namespaces = [] + + def addOuterNamespace(self, namespace): + self.namespaces.insert(0, namespace) + + def qname(self): + return QualifiedId(self.loc, self.name, [ns.name for ns in self.namespaces]) + + +class TranslationUnit(NamespacedNode): + def __init__(self, type, name): + NamespacedNode.__init__(self, name=name) + self.filetype = type + self.filename = None + self.cxxIncludes = [] + self.includes = [] + self.builtinUsing = [] + self.using = [] + self.structsAndUnions = [] + self.protocol = None + + def addCxxInclude(self, cxxInclude): + self.cxxIncludes.append(cxxInclude) + + def addInclude(self, inc): + self.includes.append(inc) + + def addStructDecl(self, struct): + self.structsAndUnions.append(struct) + + def addUnionDecl(self, union): + self.structsAndUnions.append(union) + + def addUsingStmt(self, using): + self.using.append(using) + + def setProtocol(self, protocol): + self.protocol = protocol + + +class CxxInclude(Node): + def __init__(self, loc, cxxFile): + Node.__init__(self, loc) + self.file = cxxFile + + +class Include(Node): + def __init__(self, loc, type, name): + Node.__init__(self, loc) + suffix = "ipdl" + if type == "header": + suffix += "h" + self.file = "%s.%s" % (name, suffix) + + +class UsingStmt(Node): + def __init__( + self, + loc, + cxxTypeSpec, + cxxHeader=None, + kind=None, + attributes={}, + ): + Node.__init__(self, loc) + assert isinstance(cxxTypeSpec, QualifiedId) + assert cxxHeader is None or isinstance(cxxHeader, str) + assert kind is None or kind == "class" or kind == "struct" + self.type = cxxTypeSpec + self.header = cxxHeader + self.kind = kind + self.attributes = attributes + + def canBeForwardDeclared(self): + return self.isClass() or self.isStruct() + + def isClass(self): + return self.kind == "class" + + def isStruct(self): + return self.kind == "struct" + + def isRefcounted(self): + return "RefCounted" in self.attributes + + def isSendMoveOnly(self): + moveonly = self.attributes.get("MoveOnly") + return moveonly and moveonly.value in (None, "send") + + def isDataMoveOnly(self): + moveonly = self.attributes.get("MoveOnly") + return moveonly and moveonly.value in (None, "data") + + +# "singletons" + + +class PrettyPrinted: + @classmethod + def __hash__(cls): + return hash_str(cls.pretty) + + @classmethod + def __str__(cls): + return cls.pretty + + +class ASYNC(PrettyPrinted): + pretty = "async" + + +class INTR(PrettyPrinted): + pretty = "intr" + + +class SYNC(PrettyPrinted): + pretty = "sync" + + +class INOUT(PrettyPrinted): + pretty = "inout" + + +class IN(PrettyPrinted): + pretty = "in" + + +class OUT(PrettyPrinted): + pretty = "out" + + +class Namespace(Node): + def __init__(self, loc, namespace): + Node.__init__(self, loc) + self.name = namespace + + +class Protocol(NamespacedNode): + def __init__(self, loc): + NamespacedNode.__init__(self, loc) + self.attributes = {} + self.sendSemantics = ASYNC + self.managers = [] + self.managesStmts = [] + self.messageDecls = [] + + def nestedUpTo(self): + if "NestedUpTo" not in self.attributes: + return NOT_NESTED + + return NESTED_ATTR_MAP.get(self.attributes["NestedUpTo"].value, NOT_NESTED) + + def implAttribute(self, side): + assert side in ("parent", "child") + attr = self.attributes.get(side.capitalize() + "Impl") + if attr is not None: + return attr.value + return None + + +class StructField(Node): + def __init__(self, loc, type, name): + Node.__init__(self, loc) + self.typespec = type + self.name = name + + +class StructDecl(NamespacedNode): + def __init__(self, loc, name, fields, attributes): + NamespacedNode.__init__(self, loc, name) + self.fields = fields + self.attributes = attributes + # A list of indices into `fields` for determining the order in + # which fields are laid out in memory. We don't just reorder + # `fields` itself so as to keep the ordering reasonably stable + # for e.g. C++ constructors when new fields are added. + self.packed_field_ordering = [] + + +class UnionDecl(NamespacedNode): + def __init__(self, loc, name, components, attributes): + NamespacedNode.__init__(self, loc, name) + self.components = components + self.attributes = attributes + + +class Manager(Node): + def __init__(self, loc, managerName): + Node.__init__(self, loc) + self.name = managerName + + +class ManagesStmt(Node): + def __init__(self, loc, managedName): + Node.__init__(self, loc) + self.name = managedName + + +class MessageDecl(Node): + def __init__(self, loc): + Node.__init__(self, loc) + self.name = None + self.attributes = {} + self.sendSemantics = ASYNC + self.direction = None + self.inParams = [] + self.outParams = [] + + def addInParams(self, inParamsList): + self.inParams += inParamsList + + def addOutParams(self, outParamsList): + self.outParams += outParamsList + + def nested(self): + if "Nested" not in self.attributes: + return NOT_NESTED + + return NESTED_ATTR_MAP.get(self.attributes["Nested"].value, NOT_NESTED) + + def priority(self): + if "Priority" in self.attributes: + sourcePriority = self.attributes["Priority"].value + else: + sourcePriority = "normal" + return priorityAttrMap.get(sourcePriority, NORMAL_PRIORITY) + + def replyPriority(self): + if "ReplyPriority" in self.attributes: + sourcePriority = self.attributes["ReplyPriority"].value + if sourcePriority in priorityAttrMap: + return priorityAttrMap[sourcePriority] + return self.priority() + + +class Param(Node): + def __init__(self, loc, typespec, name, attributes={}): + Node.__init__(self, loc) + self.name = name + self.typespec = typespec + self.attributes = attributes + + +class TypeSpec(Node): + def __init__(self, loc, spec): + Node.__init__(self, loc) + assert isinstance(spec, str) + self.spec = spec # str + self.array = False # bool + self.maybe = False # bool + self.nullable = False # bool + self.uniqueptr = False # bool + + def basename(self): + return self.spec + + def __str__(self): + return self.spec + + +class Attribute(Node): + def __init__(self, loc, name, value): + Node.__init__(self, loc) + self.name = name + self.value = value + + +class StringLiteral(Node): + def __init__(self, loc, value): + Node.__init__(self, loc) + self.value = value + + def __str__(self): + return '"%s"' % self.value + + +class QualifiedId: # FIXME inherit from node? + def __init__(self, loc, baseid, quals=[]): + assert isinstance(baseid, str) + for qual in quals: + assert isinstance(qual, str) + + self.loc = loc + self.baseid = baseid + self.quals = quals + + def qualify(self, id): + self.quals.append(self.baseid) + self.baseid = id + + def __str__(self): + # NOTE: include a leading "::" in order to force all QualifiedIds to be + # fully qualified types in C++ + return "::" + "::".join(self.quals + [self.baseid]) + + +# added by type checking passes + + +class Decl(Node): + def __init__(self, loc): + Node.__init__(self, loc) + self.progname = None # what the programmer typed, if relevant + self.shortname = None # shortest way to refer to this decl + self.fullname = None # full way to refer to this decl + self.loc = loc + self.type = None + self.scope = None + self.attributes = {} diff --git a/ipc/ipdl/ipdl/builtin.py b/ipc/ipdl/ipdl/builtin.py new file mode 100644 index 0000000000..b1bab64af8 --- /dev/null +++ b/ipc/ipdl/ipdl/builtin.py @@ -0,0 +1,76 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +# WARNING: the syntax of the builtin types is not checked, so please +# don't add something syntactically invalid. It will not be fun to +# track down the bug. + +# C types +# These types don't live in any namespace, so can't be imported with `using` +# statements like normal C++ types. +CTypes = ( + "bool", + "char", + "short", + "int", + "long", + "float", + "double", +) + +# C++ types +# These types must be fully qualified, and will be `typedef`-ed into IPDL +# structs to make them readily available when used. +Types = ( + # stdint types + "int8_t", + "uint8_t", + "int16_t", + "uint16_t", + "int32_t", + "uint32_t", + "int64_t", + "uint64_t", + "intptr_t", + "uintptr_t", + # You may be tempted to add size_t. Do not! See bug 1525199. + # Mozilla types: "less" standard things we know how serialize/deserialize + "nsresult", + "nsString", + "nsCString", + "mozilla::ipc::Shmem", + "mozilla::ipc::ByteBuf", + "mozilla::UniquePtr", + "mozilla::ipc::FileDescriptor", +) + + +# XXX(Bug 1677487) Can we restrict including ByteBuf.h, FileDescriptor.h, +# MozPromise.h and Shmem.h to those protocols that really use them? +HeaderIncludes = ( + "mozilla/Attributes.h", + "IPCMessageStart.h", + "mozilla/RefPtr.h", + "nsString.h", + "nsTArray.h", + "nsTHashtable.h", + "mozilla/MozPromise.h", + "mozilla/OperatorNewExtensions.h", + "mozilla/UniquePtr.h", + "mozilla/ipc/ByteBuf.h", + "mozilla/ipc/FileDescriptor.h", + "mozilla/ipc/IPCForwards.h", + "mozilla/ipc/Shmem.h", +) + +CppIncludes = ( + "ipc/IPCMessageUtils.h", + "ipc/IPCMessageUtilsSpecializations.h", + "nsIFile.h", + "mozilla/ipc/Endpoint.h", + "mozilla/ipc/ProtocolMessageUtils.h", + "mozilla/ipc/ProtocolUtils.h", + "mozilla/ipc/ShmemMessageUtils.h", + "mozilla/ipc/TaintingIPCUtils.h", +) diff --git a/ipc/ipdl/ipdl/cgen.py b/ipc/ipdl/ipdl/cgen.py new file mode 100644 index 0000000000..8ed8da4d81 --- /dev/null +++ b/ipc/ipdl/ipdl/cgen.py @@ -0,0 +1,108 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import sys + +from ipdl.ast import Visitor + + +class CodePrinter: + def __init__(self, outf=sys.stdout, indentCols=4): + self.outf = outf + self.col = 0 + self.indentCols = indentCols + + def write(self, str): + self.outf.write(str) + + def printdent(self, str=""): + self.write((" " * self.col) + str) + + def println(self, str=""): + self.write(str + "\n") + + def printdentln(self, str): + self.write((" " * self.col) + str + "\n") + + def indent(self): + self.col += self.indentCols + + def dedent(self): + self.col -= self.indentCols + + +# ----------------------------------------------------------------------------- +class IPDLCodeGen(CodePrinter, Visitor): + """Spits back out equivalent IPDL to the code that generated this. + Also known as pretty-printing.""" + + def __init__(self, outf=sys.stdout, indentCols=4, printed=set()): + CodePrinter.__init__(self, outf, indentCols) + self.printed = printed + + def visitTranslationUnit(self, tu): + self.printed.add(tu.filename) + self.println("//\n// Automatically generated by ipdlc\n//") + CodeGen.visitTranslationUnit(self, tu) # NOQA: F821 + + def visitCxxInclude(self, inc): + self.println('include "' + inc.file + '";') + + def visitProtocolInclude(self, inc): + self.println('include protocol "' + inc.file + '";') + if inc.tu.filename not in self.printed: + self.println("/* Included file:") + IPDLCodeGen( + outf=self.outf, indentCols=self.indentCols, printed=self.printed + ).visitTranslationUnit(inc.tu) + + self.println("*/") + + def visitProtocol(self, p): + self.println() + for namespace in p.namespaces: + namespace.accept(self) + + self.println("%s protocol %s\n{" % (p.sendSemantics[0], p.name)) + self.indent() + + for mgs in p.managesStmts: + mgs.accept(self) + if len(p.managesStmts): + self.println() + + for msgDecl in p.messageDecls: + msgDecl.accept(self) + self.println() + + self.dedent() + self.println("}") + self.write("}\n" * len(p.namespaces)) + + def visitManagerStmt(self, mgr): + self.printdentln("manager " + mgr.name + ";") + + def visitManagesStmt(self, mgs): + self.printdentln("manages " + mgs.name + ";") + + def visitMessageDecl(self, msg): + self.printdent("%s %s %s(" % (msg.sendSemantics[0], msg.direction[0], msg.name)) + for i, inp in enumerate(msg.inParams): + inp.accept(self) + if i != (len(msg.inParams) - 1): + self.write(", ") + self.write(")") + if 0 == len(msg.outParams): + self.println(";") + return + + self.println() + self.indent() + self.printdent("returns (") + for i, outp in enumerate(msg.outParams): + outp.accept(self) + if i != (len(msg.outParams) - 1): + self.write(", ") + self.println(");") + self.dedent() diff --git a/ipc/ipdl/ipdl/checker.py b/ipc/ipdl/ipdl/checker.py new file mode 100644 index 0000000000..c93dc3e0e8 --- /dev/null +++ b/ipc/ipdl/ipdl/checker.py @@ -0,0 +1,79 @@ +# vim: set ts=4 sw=4 tw=99 et: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import sys +from ipdl.ast import Visitor, ASYNC + + +class SyncMessageChecker(Visitor): + syncMsgList = [] + seenProtocols = [] + seenSyncMessages = [] + + def __init__(self, syncMsgList): + SyncMessageChecker.syncMsgList = syncMsgList + self.errors = [] + + def prettyMsgName(self, msg): + return "%s::%s" % (self.currentProtocol, msg) + + def errorUnknownSyncMessage(self, loc, msg): + self.errors.append("%s: error: Unknown sync IPC message %s" % (str(loc), msg)) + + def errorAsyncMessageCanRemove(self, loc, msg): + self.errors.append( + "%s: error: IPC message %s is async, can be delisted" % (str(loc), msg) + ) + + def visitProtocol(self, p): + self.errors = [] + self.currentProtocol = p.name + SyncMessageChecker.seenProtocols.append(p.name) + Visitor.visitProtocol(self, p) + + def visitMessageDecl(self, md): + pn = self.prettyMsgName(md.name) + if md.sendSemantics is not ASYNC: + if pn not in SyncMessageChecker.syncMsgList: + self.errorUnknownSyncMessage(md.loc, pn) + SyncMessageChecker.seenSyncMessages.append(pn) + elif pn in SyncMessageChecker.syncMsgList: + self.errorAsyncMessageCanRemove(md.loc, pn) + + @staticmethod + def getFixedSyncMessages(): + return set(SyncMessageChecker.syncMsgList) - set( + SyncMessageChecker.seenSyncMessages + ) + + +def checkSyncMessage(tu, syncMsgList, errout=sys.stderr): + checker = SyncMessageChecker(syncMsgList) + tu.accept(checker) + if len(checker.errors): + for error in checker.errors: + print(error, file=errout) + return False + return True + + +def checkFixedSyncMessages(config, errout=sys.stderr): + fixed = SyncMessageChecker.getFixedSyncMessages() + error_free = True + for item in fixed: + protocol = item.split("::")[0] + # Ignore things like sync messages in test protocols we didn't compile. + # Also, ignore platform-specific IPC messages. + if ( + protocol in SyncMessageChecker.seenProtocols + and "platform" not in config.options(item) + ): + print( + "Error: Sync IPC message %s not found, it appears to be fixed.\n" + "Please remove it from sync-messages.ini." % item, + file=errout, + ) + error_free = False + return error_free diff --git a/ipc/ipdl/ipdl/cxx/__init__.py b/ipc/ipdl/ipdl/cxx/__init__.py new file mode 100644 index 0000000000..6fbe8159b2 --- /dev/null +++ b/ipc/ipdl/ipdl/cxx/__init__.py @@ -0,0 +1,3 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. diff --git a/ipc/ipdl/ipdl/cxx/ast.py b/ipc/ipdl/ipdl/cxx/ast.py new file mode 100644 index 0000000000..3b9448859b --- /dev/null +++ b/ipc/ipdl/ipdl/cxx/ast.py @@ -0,0 +1,1033 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import copy +import functools + +from ipdl.util import hash_str + + +class Visitor: + def defaultVisit(self, node): + raise Exception( + "INTERNAL ERROR: no visitor for node type `%s'" % (node.__class__.__name__) + ) + + def visitWhitespace(self, ws): + pass + + def visitVerbatimNode(self, verb): + pass + + def visitGroupNode(self, group): + for node in group.nodes: + node.accept(self) + + def visitFile(self, f): + for thing in f.stuff: + thing.accept(self) + + def visitCppDirective(self, ppd): + pass + + def visitBlock(self, block): + for stmt in block.stmts: + stmt.accept(self) + + def visitNamespace(self, ns): + self.visitBlock(ns) + + def visitType(self, type): + pass + + def visitTypeArray(self, ta): + ta.basetype.accept(self) + ta.nmemb.accept(self) + + def visitTypeEnum(self, enum): + pass + + def visitTypeFunction(self, fn): + pass + + def visitTypeUnion(self, union): + for t, name in union.components: + t.accept(self) + + def visitTypedef(self, tdef): + tdef.fromtype.accept(self) + + def visitUsing(self, us): + us.type.accept(self) + + def visitForwardDecl(self, fd): + pass + + def visitDecl(self, decl): + decl.type.accept(self) + + def visitParam(self, param): + self.visitDecl(param) + if param.default is not None: + param.default.accept(self) + + def visitClass(self, cls): + for inherit in cls.inherits: + inherit.accept(self) + self.visitBlock(cls) + + def visitInherit(self, inh): + pass + + def visitFriendClassDecl(self, fcd): + pass + + def visitMethodDecl(self, meth): + for param in meth.params: + param.accept(self) + if meth.ret is not None: + meth.ret.accept(self) + if meth.typeop is not None: + meth.typeop.accept(self) + if meth.T is not None: + meth.T.accept(self) + + def visitMethodDefn(self, meth): + meth.decl.accept(self) + self.visitBlock(meth) + + def visitFunctionDecl(self, fun): + self.visitMethodDecl(fun) + + def visitFunctionDefn(self, fd): + self.visitMethodDefn(fd) + + def visitConstructorDecl(self, ctor): + self.visitMethodDecl(ctor) + + def visitConstructorDefn(self, cd): + cd.decl.accept(self) + for init in cd.memberinits: + init.accept(self) + self.visitBlock(cd) + + def visitDestructorDecl(self, dtor): + self.visitMethodDecl(dtor) + + def visitDestructorDefn(self, dd): + dd.decl.accept(self) + self.visitBlock(dd) + + def visitExprLiteral(self, l): + pass + + def visitExprVar(self, v): + pass + + def visitExprPrefixUnop(self, e): + e.expr.accept(self) + + def visitExprBinary(self, e): + e.left.accept(self) + e.right.accept(self) + + def visitExprConditional(self, c): + c.cond.accept(self) + c.ife.accept(self) + c.elsee.accept(self) + + def visitExprAddrOf(self, eao): + self.visitExprPrefixUnop(eao) + + def visitExprDeref(self, ed): + self.visitExprPrefixUnop(ed) + + def visitExprNot(self, en): + self.visitExprPrefixUnop(en) + + def visitExprCast(self, ec): + ec.expr.accept(self) + + def visitExprSelect(self, es): + es.obj.accept(self) + + def visitExprAssn(self, ea): + ea.lhs.accept(self) + ea.rhs.accept(self) + + def visitExprCall(self, ec): + ec.func.accept(self) + for arg in ec.args: + arg.accept(self) + + def visitExprMove(self, ec): + self.visitExprCall(ec) + + def visitExprNothing(self, ec): + self.visitExprCall(ec) + + def visitExprSome(self, ec): + self.visitExprCall(ec) + + def visitExprNew(self, en): + en.ctype.accept(self) + if en.newargs is not None: + for arg in en.newargs: + arg.accept(self) + if en.args is not None: + for arg in en.args: + arg.accept(self) + + def visitExprDelete(self, ed): + ed.obj.accept(self) + + def visitExprMemberInit(self, minit): + self.visitExprCall(minit) + + def visitExprLambda(self, l): + self.visitBlock(l) + + def visitStmtBlock(self, sb): + self.visitBlock(sb) + + def visitStmtDecl(self, sd): + sd.decl.accept(self) + if sd.init is not None: + sd.init.accept(self) + + def visitLabel(self, label): + pass + + def visitCaseLabel(self, case): + pass + + def visitDefaultLabel(self, dl): + pass + + def visitStmtIf(self, si): + si.cond.accept(self) + si.ifb.accept(self) + if si.elseb is not None: + si.elseb.accept(self) + + def visitStmtFor(self, sf): + if sf.init is not None: + sf.init.accept(self) + if sf.cond is not None: + sf.cond.accept(self) + if sf.update is not None: + sf.update.accept(self) + + def visitStmtSwitch(self, ss): + ss.expr.accept(self) + self.visitBlock(ss) + + def visitStmtBreak(self, sb): + pass + + def visitStmtExpr(self, se): + se.expr.accept(self) + + def visitStmtReturn(self, sr): + if sr.expr is not None: + sr.expr.accept(self) + + +# ------------------------------ + + +class Node: + def __init__(self): + pass + + def accept(self, visitor): + visit = getattr(visitor, "visit" + self.__class__.__name__, None) + if visit is None: + return getattr(visitor, "defaultVisit")(self) + return visit(self) + + +class Whitespace(Node): + # yes, this is silly. but we need to stick comments in the + # generated code without resorting to more serious hacks + def __init__(self, ws, indent=False): + Node.__init__(self) + self.ws = ws + self.indent = indent + + +Whitespace.NL = Whitespace("\n") + + +class VerbatimNode(Node): + # A block of text to be written verbatim to the output file. + # + # NOTE: This node is usually created by `code`. See `code.py` for details. + # FIXME: Merge Whitespace and VerbatimNode? They're identical. + def __init__(self, text, indent=0): + Node.__init__(self) + self.text = text + self.indent = indent + + +class GroupNode(Node): + # A group of nodes to be treated as a single node. These nodes have an + # optional indentation level which should be applied when generating them. + # + # NOTE: This node is usually created by `code`. See `code.py` for details. + def __init__(self, nodes, offset=0): + Node.__init__(self) + self.nodes = nodes + self.offset = offset + + +class File(Node): + def __init__(self, filename): + Node.__init__(self) + self.name = filename + # array of stuff in the file --- stmts and preprocessor thingies + self.stuff = [] + + def addthing(self, thing): + assert thing is not None + assert not isinstance(thing, list) + self.stuff.append(thing) + + def addthings(self, things): + for t in things: + self.addthing(t) + + # "look like" a Block so code doesn't have to care whether they're + # in global scope or not + def addstmt(self, stmt): + assert stmt is not None + assert not isinstance(stmt, list) + self.stuff.append(stmt) + + def addstmts(self, stmts): + for s in stmts: + self.addstmt(s) + + def addcode(self, tmpl, **context): + from ipdl.cxx.code import StmtCode + + self.addstmt(StmtCode(tmpl, **context)) + + +class CppDirective(Node): + """represents |#[directive] [rest]|, where |rest| is any string""" + + def __init__(self, directive, rest=None): + Node.__init__(self) + self.directive = directive + self.rest = rest + + +class Block(Node): + def __init__(self): + Node.__init__(self) + self.stmts = [] + + def addstmt(self, stmt): + assert stmt is not None + assert not isinstance(stmt, tuple) + self.stmts.append(stmt) + + def addstmts(self, stmts): + for s in stmts: + self.addstmt(s) + + def addcode(self, tmpl, **context): + from ipdl.cxx.code import StmtCode + + self.addstmt(StmtCode(tmpl, **context)) + + +# ------------------------------ +# type and decl thingies + + +class Namespace(Block): + def __init__(self, name): + assert isinstance(name, str) + + Block.__init__(self) + self.name = name + + +class Type(Node): + def __init__( + self, + name, + const=False, + ptr=False, + ptrptr=False, + ptrconstptr=False, + ref=False, + rvalref=False, + rightconst=False, + hasimplicitcopyctor=True, + T=None, + inner=None, + ): + """ + Represents the type |name<T>::inner| with the ptr and const + modifiers as specified. + + To avoid getting fancy with recursive types, we limit the kinds + of pointer types that can be be constructed. + + ptr => T* + ptrptr => T** + ptrconstptr => T* const* + ref => T& + rvalref => T&& + + Any type, naked or pointer, can be const (const T) or ref (T&).""" + # XXX(nika): This type is complex enough at this point, perhaps we + # should get "fancy with recursive types" to simplify it. + assert isinstance(name, str) + assert isinstance(const, bool) + assert isinstance(ptr, bool) + assert isinstance(ptrptr, bool) + assert isinstance(ptrconstptr, bool) + assert isinstance(ref, bool) + assert isinstance(rvalref, bool) + assert isinstance(rightconst, bool) + assert not isinstance(T, str) + + Node.__init__(self) + self.name = name + self.const = const + self.ptr = ptr + self.ptrptr = ptrptr + self.ptrconstptr = ptrconstptr + self.ref = ref + self.rvalref = rvalref + self.rightconst = rightconst + self.hasimplicitcopyctor = hasimplicitcopyctor + self.T = T + self.inner = inner + # XXX could get serious here with recursive types, but shouldn't + # need that for this codegen + + def __deepcopy__(self, memo): + return Type( + self.name, + const=self.const, + ptr=self.ptr, + ptrptr=self.ptrptr, + ptrconstptr=self.ptrconstptr, + ref=self.ref, + rvalref=self.rvalref, + rightconst=self.rightconst, + T=copy.deepcopy(self.T, memo), + inner=copy.deepcopy(self.inner, memo), + ) + + +Type.BOOL = Type("bool") +Type.INT = Type("int") +Type.INT32 = Type("int32_t") +Type.INTPTR = Type("intptr_t") +Type.NSRESULT = Type("nsresult") +Type.UINT32 = Type("uint32_t") +Type.UINT32PTR = Type("uint32_t", ptr=True) +Type.SIZE = Type("size_t") +Type.VOID = Type("void") +Type.VOIDPTR = Type("void", ptr=True) +Type.AUTO = Type("auto") +Type.AUTORVAL = Type("auto", rvalref=True) + + +class TypeArray(Node): + def __init__(self, basetype, nmemb): + """the type |basetype DECLNAME[nmemb]|. |nmemb| is an Expr""" + self.basetype = basetype + self.nmemb = nmemb + + +class TypeEnum(Node): + def __init__(self, name=None): + """name can be None""" + Node.__init__(self) + self.name = name + self.idnums = [] # pairs of ('Foo', [num]) or ('Foo', None) + + def addId(self, id, num=None): + self.idnums.append((id, num)) + + +class TypeUnion(Node): + def __init__(self, name=None): + Node.__init__(self) + self.name = name + self.components = [] # [ Decl ] + + def addComponent(self, type, name): + self.components.append(Decl(type, name)) + + +class TypeFunction(Node): + def __init__(self, params=[], ret=Type("void")): + """Anonymous function type std::function<>""" + self.params = params + self.ret = ret + + +@functools.total_ordering +class Typedef(Node): + def __init__(self, fromtype, totypename, templateargs=[]): + assert isinstance(totypename, str) + + Node.__init__(self) + self.fromtype = fromtype + self.totypename = totypename + self.templateargs = templateargs + + def __lt__(self, other): + return self.totypename < other.totypename + + def __eq__(self, other): + return self.__class__ == other.__class__ and self.totypename == other.totypename + + def __hash__(self): + return hash_str(self.totypename) + + +class Using(Node): + def __init__(self, type): + Node.__init__(self) + self.type = type + + +class ForwardDecl(Node): + def __init__(self, pqname, cls=False, struct=False): + # Exactly one of cls and struct must be set + assert cls ^ struct + + self.pqname = pqname + self.cls = cls + self.struct = struct + + +class Decl(Node): + """represents |Foo bar|, e.g. in a function signature""" + + def __init__(self, type, name): + assert type is not None + assert not isinstance(type, str) + assert isinstance(name, str) + + Node.__init__(self) + self.type = type + self.name = name + + def __deepcopy__(self, memo): + return Decl(copy.deepcopy(self.type, memo), self.name) + + +class Param(Decl): + def __init__(self, type, name, default=None): + Decl.__init__(self, type, name) + self.default = default + + def __deepcopy__(self, memo): + return Param( + copy.deepcopy(self.type, memo), self.name, copy.deepcopy(self.default, memo) + ) + + +# ------------------------------ +# class stuff + + +class Class(Block): + def __init__( + self, + name, + inherits=[], + interface=False, + abstract=False, + final=False, + specializes=None, + struct=False, + ): + assert not (interface and abstract) + assert not (abstract and final) + assert not (interface and final) + assert not (inherits and specializes) + + Block.__init__(self) + self.name = name + self.inherits = inherits # [ Type ] + self.interface = interface # bool + self.abstract = abstract # bool + self.final = final # bool + self.specializes = specializes # Type or None + self.struct = struct # bool + + +class Inherit(Node): + def __init__(self, type, viz="public"): + assert isinstance(viz, str) + Node.__init__(self) + self.type = type + self.viz = viz + + +class FriendClassDecl(Node): + def __init__(self, friend): + Node.__init__(self) + self.friend = friend + + +# Python2 polyfill for Python3's Enum() functional API. + + +def make_enum(name, members_str): + members_list = members_str.split() + members_dict = {} + for member_value, member in enumerate(members_list, start=1): + members_dict[member] = member_value + return type(name, (), members_dict) + + +MethodSpec = make_enum("MethodSpec", "NONE VIRTUAL PURE OVERRIDE STATIC") + + +class MethodDecl(Node): + def __init__( + self, + name, + params=[], + ret=Type("void"), + methodspec=MethodSpec.NONE, + const=False, + warn_unused=False, + force_inline=False, + typeop=None, + T=None, + cls=None, + ): + assert not (name and typeop) + assert name is None or isinstance(name, str) + assert not isinstance(ret, list) + for decl in params: + assert not isinstance(decl, str) + assert not isinstance(T, int) + assert isinstance(const, bool) + assert isinstance(warn_unused, bool) + assert isinstance(force_inline, bool) + + if typeop is not None: + assert methodspec == MethodSpec.NONE + ret = None + + Node.__init__(self) + self.name = name + self.params = params # [ Param ] + self.ret = ret # Type or None + self.methodspec = methodspec # enum + self.const = const # bool + self.warn_unused = warn_unused # bool + self.force_inline = force_inline or bool(T) # bool + self.typeop = typeop # Type or None + self.T = T # Type or None + self.cls = cls # Class or None + self.only_for_definition = False + + def __deepcopy__(self, memo): + return MethodDecl( + self.name, + params=copy.deepcopy(self.params, memo), + ret=copy.deepcopy(self.ret, memo), + methodspec=self.methodspec, + const=self.const, + warn_unused=self.warn_unused, + force_inline=self.force_inline, + typeop=copy.deepcopy(self.typeop, memo), + T=copy.deepcopy(self.T, memo), + ) + + +class MethodDefn(Block): + def __init__(self, decl): + Block.__init__(self) + self.decl = decl + + +class FunctionDecl(MethodDecl): + def __init__( + self, + name, + params=[], + ret=Type("void"), + methodspec=MethodSpec.NONE, + warn_unused=False, + force_inline=False, + T=None, + ): + assert methodspec == MethodSpec.NONE or methodspec == MethodSpec.STATIC + MethodDecl.__init__( + self, + name, + params=params, + ret=ret, + methodspec=methodspec, + warn_unused=warn_unused, + force_inline=force_inline, + T=T, + ) + + +class FunctionDefn(MethodDefn): + def __init__(self, decl): + MethodDefn.__init__(self, decl) + + +class ConstructorDecl(MethodDecl): + def __init__(self, name, params=[], explicit=False, force_inline=False): + MethodDecl.__init__( + self, name, params=params, ret=None, force_inline=force_inline + ) + self.explicit = explicit + + def __deepcopy__(self, memo): + return ConstructorDecl( + self.name, copy.deepcopy(self.params, memo), self.explicit + ) + + +class ConstructorDefn(MethodDefn): + def __init__(self, decl, memberinits=[]): + MethodDefn.__init__(self, decl) + self.memberinits = memberinits + + +class DestructorDecl(MethodDecl): + def __init__(self, name, methodspec=MethodSpec.NONE, force_inline=False): + # C++ allows pure or override destructors, but ipdl cgen does not. + assert methodspec == MethodSpec.NONE or methodspec == MethodSpec.VIRTUAL + MethodDecl.__init__( + self, + name, + params=[], + ret=None, + methodspec=methodspec, + force_inline=force_inline, + ) + + def __deepcopy__(self, memo): + return DestructorDecl( + self.name, methodspec=self.methodspec, force_inline=self.force_inline + ) + + +class DestructorDefn(MethodDefn): + def __init__(self, decl): + MethodDefn.__init__(self, decl) + + +# ------------------------------ +# expressions + + +class ExprVar(Node): + def __init__(self, name): + assert isinstance(name, str) + + Node.__init__(self) + self.name = name + + +ExprVar.THIS = ExprVar("this") + + +class ExprLiteral(Node): + def __init__(self, value, type): + """|type| is a Python format specifier; 'd' for example""" + Node.__init__(self) + self.value = value + self.type = type + + @staticmethod + def Int(i): + return ExprLiteral(i, "d") + + @staticmethod + def String(s): + return ExprLiteral('"' + s + '"', "s") + + def __str__(self): + return ("%" + self.type) % (self.value) + + +ExprLiteral.ZERO = ExprLiteral.Int(0) +ExprLiteral.ONE = ExprLiteral.Int(1) +ExprLiteral.NULL = ExprVar("nullptr") +ExprLiteral.TRUE = ExprVar("true") +ExprLiteral.FALSE = ExprVar("false") + + +class ExprPrefixUnop(Node): + def __init__(self, expr, op): + assert not isinstance(expr, tuple) + self.expr = expr + self.op = op + + +class ExprNot(ExprPrefixUnop): + def __init__(self, expr): + ExprPrefixUnop.__init__(self, expr, "!") + + +class ExprAddrOf(ExprPrefixUnop): + def __init__(self, expr): + ExprPrefixUnop.__init__(self, expr, "&") + + +class ExprDeref(ExprPrefixUnop): + def __init__(self, expr): + ExprPrefixUnop.__init__(self, expr, "*") + + +class ExprCast(Node): + def __init__(self, expr, type, static=False, const=False): + # Exactly one of these should be set + assert static ^ const + + Node.__init__(self) + self.expr = expr + self.type = type + self.static = static + self.const = const + + +class ExprBinary(Node): + def __init__(self, left, op, right): + Node.__init__(self) + self.left = left + self.op = op + self.right = right + + +class ExprConditional(Node): + def __init__(self, cond, ife, elsee): + Node.__init__(self) + self.cond = cond + self.ife = ife + self.elsee = elsee + + +class ExprSelect(Node): + def __init__(self, obj, op, field): + assert obj and op and field + assert not isinstance(obj, str) + assert isinstance(op, str) + + Node.__init__(self) + self.obj = obj + self.op = op + if isinstance(field, str): + self.field = ExprVar(field) + else: + self.field = field + + +class ExprAssn(Node): + def __init__(self, lhs, rhs, op="="): + Node.__init__(self) + self.lhs = lhs + self.op = op + self.rhs = rhs + + +class ExprCall(Node): + def __init__(self, func, args=[]): + assert hasattr(func, "accept") + assert isinstance(args, list) + for arg in args: + assert arg and not isinstance(arg, str) + + Node.__init__(self) + self.func = func + self.args = args + + +class ExprMove(ExprCall): + def __init__(self, arg): + ExprCall.__init__(self, ExprVar("std::move"), args=[arg]) + + +class ExprNew(Node): + # XXX taking some poetic license ... + def __init__(self, ctype, args=[], newargs=None): + assert not (ctype.const or ctype.ref or ctype.rvalref) + + Node.__init__(self) + self.ctype = ctype + self.args = args + self.newargs = newargs + + +class ExprDelete(Node): + def __init__(self, obj): + Node.__init__(self) + self.obj = obj + + +class ExprMemberInit(ExprCall): + def __init__(self, member, args=[]): + ExprCall.__init__(self, member, args) + + +class ExprLambda(Block): + def __init__(self, captures=[], params=[], ret=None): + Block.__init__(self) + assert isinstance(captures, list) + assert isinstance(params, list) + self.captures = captures + self.params = params + self.ret = ret + + +# ------------------------------ +# statements etc. + + +class StmtBlock(Block): + def __init__(self, stmts=[]): + Block.__init__(self) + self.addstmts(stmts) + + +class StmtDecl(Node): + def __init__(self, decl, init=None, initargs=None): + assert not (init and initargs) + assert not isinstance(init, str) # easy to confuse with Decl + assert not isinstance(init, list) + assert not isinstance(decl, tuple) + + Node.__init__(self) + self.decl = decl + self.init = init + self.initargs = initargs + + +class Label(Node): + def __init__(self, name): + Node.__init__(self) + self.name = name + + +Label.PUBLIC = Label("public") +Label.PROTECTED = Label("protected") +Label.PRIVATE = Label("private") + + +class CaseLabel(Node): + def __init__(self, name): + Node.__init__(self) + self.name = name + + +class DefaultLabel(Node): + def __init__(self): + Node.__init__(self) + + +class StmtIf(Node): + def __init__(self, cond): + Node.__init__(self) + self.cond = cond + self.ifb = Block() + self.elseb = None + + def addifstmt(self, stmt): + self.ifb.addstmt(stmt) + + def addifstmts(self, stmts): + self.ifb.addstmts(stmts) + + def addelsestmt(self, stmt): + if self.elseb is None: + self.elseb = Block() + self.elseb.addstmt(stmt) + + def addelsestmts(self, stmts): + if self.elseb is None: + self.elseb = Block() + self.elseb.addstmts(stmts) + + +class StmtFor(Block): + def __init__(self, init=None, cond=None, update=None): + Block.__init__(self) + self.init = init + self.cond = cond + self.update = update + + +class StmtRangedFor(Block): + def __init__(self, var, iteree): + assert isinstance(var, ExprVar) + assert iteree + + Block.__init__(self) + self.var = var + self.iteree = iteree + + +class StmtSwitch(Block): + def __init__(self, expr): + Block.__init__(self) + self.expr = expr + self.nr_cases = 0 + + def addcase(self, case, block): + """NOTE: |case| is not checked for uniqueness""" + assert not isinstance(case, str) + assert ( + isinstance(block, StmtBreak) + or isinstance(block, StmtReturn) + or isinstance(block, StmtSwitch) + or isinstance(block, GroupNode) + or isinstance(block, VerbatimNode) + or ( + hasattr(block, "stmts") + and ( + isinstance(block.stmts[-1], StmtBreak) + or isinstance(block.stmts[-1], StmtReturn) + or isinstance(block.stmts[-1], GroupNode) + or isinstance(block.stmts[-1], VerbatimNode) + ) + ) + ) + self.addstmt(case) + self.addstmt(block) + self.nr_cases += 1 + + +class StmtBreak(Node): + def __init__(self): + Node.__init__(self) + + +class StmtExpr(Node): + def __init__(self, expr): + assert expr is not None + + Node.__init__(self) + self.expr = expr + + +class StmtReturn(Node): + def __init__(self, expr=None): + Node.__init__(self) + self.expr = expr + + +StmtReturn.TRUE = StmtReturn(ExprLiteral.TRUE) +StmtReturn.FALSE = StmtReturn(ExprLiteral.FALSE) diff --git a/ipc/ipdl/ipdl/cxx/cgen.py b/ipc/ipdl/ipdl/cxx/cgen.py new file mode 100644 index 0000000000..fecb90e97b --- /dev/null +++ b/ipc/ipdl/ipdl/cxx/cgen.py @@ -0,0 +1,557 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import sys + +from ipdl.cgen import CodePrinter +from ipdl.cxx.ast import MethodSpec, TypeArray, Visitor, DestructorDecl + + +class CxxCodeGen(CodePrinter, Visitor): + def __init__(self, outf=sys.stdout, indentCols=4): + CodePrinter.__init__(self, outf, indentCols) + + def cgen(self, cxxfile): + cxxfile.accept(self) + + def visitWhitespace(self, ws): + if ws.indent: + self.printdent("") + self.write(ws.ws) + + def visitVerbatimNode(self, verb): + if verb.indent: + self.printdent("") + self.write(verb.text) + + def visitGroupNode(self, group): + offsetCols = self.indentCols * group.offset + self.col += offsetCols + for node in group.nodes: + node.accept(self) + self.col -= offsetCols + + def visitCppDirective(self, cd): + if cd.rest: + self.println("#%s %s" % (cd.directive, cd.rest)) + else: + self.println("#%s" % (cd.directive)) + + def visitNamespace(self, ns): + self.println("namespace " + ns.name + " {") + self.visitBlock(ns) + self.println("} // namespace " + ns.name) + + def visitType(self, t): + if t.const: + self.write("const ") + + self.write(t.name) + + if t.T is not None: + self.write("<") + if type(t.T) is list: + t.T[0].accept(self) + for tt in t.T[1:]: + self.write(", ") + tt.accept(self) + else: + t.T.accept(self) + self.write(">") + + if t.inner is not None: + self.write("::") + t.inner.accept(self) + + ts = "" + if t.ptr: + ts += "*" + elif t.ptrptr: + ts += "**" + elif t.ptrconstptr: + ts += "* const*" + + if t.ref: + ts += "&" + elif t.rvalref: + ts += "&&" + elif t.rightconst: + ts += " const" + + self.write(ts) + + def visitTypeEnum(self, te): + self.write("enum") + if te.name: + self.write(" " + te.name) + self.println(" {") + + self.indent() + nids = len(te.idnums) + for i, (id, num) in enumerate(te.idnums): + self.printdent(id) + if num: + self.write(" = " + str(num)) + if i != (nids - 1): + self.write(",") + self.println() + self.dedent() + self.printdent("}") + + def visitTypeUnion(self, u): + self.write("union") + if u.name: + self.write(" " + u.name) + self.println(" {") + + self.indent() + for decl in u.components: + self.printdent() + decl.accept(self) + self.println(";") + self.dedent() + + self.printdent("}") + + def visitTypeFunction(self, fn): + self.write("std::function<") + fn.ret.accept(self) + self.write("(") + self.writeDeclList(fn.params) + self.write(")>") + + def visitTypedef(self, td): + if td.templateargs: + formals = ", ".join(["class " + T for T in td.templateargs]) + args = ", ".join(td.templateargs) + self.printdent("template<" + formals + "> using " + td.totypename + " = ") + td.fromtype.accept(self) + self.println("<" + args + ">;") + else: + self.printdent("typedef ") + td.fromtype.accept(self) + self.println(" " + td.totypename + ";") + + def visitUsing(self, us): + self.printdent("using ") + us.type.accept(self) + self.println(";") + + def visitForwardDecl(self, fd): + if fd.cls: + self.printdent("class ") + elif fd.struct: + self.printdent("struct ") + self.write(str(fd.pqname)) + self.println(";") + + def visitDecl(self, d): + # C-syntax arrays make code generation much more annoying + if isinstance(d.type, TypeArray): + d.type.basetype.accept(self) + else: + d.type.accept(self) + + if d.name: + self.write(" " + d.name) + + if isinstance(d.type, TypeArray): + self.write("[") + d.type.nmemb.accept(self) + self.write("]") + + def visitParam(self, p): + self.visitDecl(p) + if p.default is not None: + self.write(" = ") + p.default.accept(self) + + def visitClass(self, c): + if c.specializes is not None: + self.printdentln("template<>") + + if c.struct: + self.printdent("struct") + else: + self.printdent("class") + self.write(" " + c.name) + if c.final: + self.write(" final") + + if c.specializes is not None: + self.write(" <") + c.specializes.accept(self) + self.write(">") + + ninh = len(c.inherits) + if 0 < ninh: + self.println(" :") + self.indent() + for i, inherit in enumerate(c.inherits): + self.printdent() + inherit.accept(self) + if i != (ninh - 1): + self.println(",") + self.dedent() + self.println() + + self.printdentln("{") + self.indent() + + self.visitBlock(c) + + self.dedent() + self.printdentln("};") + + def visitInherit(self, inh): + self.write(inh.viz + " ") + inh.type.accept(self) + + def visitFriendClassDecl(self, fcd): + self.printdentln("friend class " + fcd.friend + ";") + + def visitMethodDecl(self, md): + if md.T: + self.write("template<") + self.write("typename ") + md.T.accept(self) + self.println(">") + self.printdent() + + if md.warn_unused: + self.write("[[nodiscard]] ") + + if md.methodspec == MethodSpec.STATIC: + self.write("static ") + elif md.methodspec == MethodSpec.VIRTUAL or md.methodspec == MethodSpec.PURE: + self.write("virtual ") + + if md.ret: + if md.only_for_definition: + self.write("auto ") + else: + md.ret.accept(self) + self.println() + self.printdent() + + if md.cls is not None: + assert md.only_for_definition + + self.write(md.cls.name) + if md.cls.specializes is not None: + self.write("<") + md.cls.specializes.accept(self) + self.write(">") + self.write("::") + + if md.typeop is not None: + self.write("operator ") + md.typeop.accept(self) + else: + if isinstance(md, DestructorDecl): + self.write("~") + self.write(md.name) + + self.write("(") + self.writeDeclList(md.params) + self.write(")") + + if md.const: + self.write(" const") + if md.ret and md.only_for_definition: + self.write(" -> ") + md.ret.accept(self) + + if md.methodspec == MethodSpec.OVERRIDE: + self.write(" override") + elif md.methodspec == MethodSpec.PURE: + self.write(" = 0") + + def visitMethodDefn(self, md): + # Method specifiers are for decls, not defns. + assert md.decl.methodspec == MethodSpec.NONE + + self.printdent() + md.decl.accept(self) + self.println() + + self.printdentln("{") + self.indent() + self.visitBlock(md) + self.dedent() + self.printdentln("}") + + def visitConstructorDecl(self, cd): + if cd.explicit: + self.write("explicit ") + else: + self.write("MOZ_IMPLICIT ") + self.visitMethodDecl(cd) + + def visitConstructorDefn(self, cd): + self.printdent() + cd.decl.accept(self) + if len(cd.memberinits): + self.println(" :") + self.indent() + ninits = len(cd.memberinits) + for i, init in enumerate(cd.memberinits): + self.printdent() + init.accept(self) + if i != (ninits - 1): + self.println(",") + self.dedent() + self.println() + + self.printdentln("{") + self.indent() + + self.visitBlock(cd) + + self.dedent() + self.printdentln("}") + + def visitDestructorDecl(self, dd): + self.visitMethodDecl(dd) + + def visitDestructorDefn(self, dd): + self.printdent() + dd.decl.accept(self) + self.println() + + self.printdentln("{") + self.indent() + + self.visitBlock(dd) + + self.dedent() + self.printdentln("}") + + def visitExprLiteral(self, el): + self.write(str(el)) + + def visitExprVar(self, ev): + self.write(ev.name) + + def visitExprPrefixUnop(self, e): + self.write("(") + self.write(e.op) + self.write("(") + e.expr.accept(self) + self.write(")") + self.write(")") + + def visitExprCast(self, c): + if c.static: + pfx, sfx = "static_cast<", ">" + else: + assert c.const + pfx, sfx = "const_cast<", ">" + self.write(pfx) + c.type.accept(self) + self.write(sfx + "(") + c.expr.accept(self) + self.write(")") + + def visitExprBinary(self, e): + self.write("(") + e.left.accept(self) + self.write(") " + e.op + " (") + e.right.accept(self) + self.write(")") + + def visitExprConditional(self, c): + self.write("(") + c.cond.accept(self) + self.write(" ? ") + c.ife.accept(self) + self.write(" : ") + c.elsee.accept(self) + self.write(")") + + def visitExprSelect(self, es): + self.write("(") + es.obj.accept(self) + self.write(")") + self.write(es.op) + es.field.accept(self) + + def visitExprAssn(self, ea): + ea.lhs.accept(self) + self.write(" " + ea.op + " ") + ea.rhs.accept(self) + + def visitExprCall(self, ec): + ec.func.accept(self) + self.write("(") + self.writeExprList(ec.args) + self.write(")") + + def visitExprNew(self, en): + self.write("new ") + if en.newargs is not None: + self.write("(") + self.writeExprList(en.newargs) + self.write(") ") + en.ctype.accept(self) + if en.args is not None: + self.write("(") + self.writeExprList(en.args) + self.write(")") + + def visitExprDelete(self, ed): + self.write("delete ") + ed.obj.accept(self) + + def visitExprLambda(self, l): + self.write("[") + ncaptures = len(l.captures) + for i, c in enumerate(l.captures): + c.accept(self) + if i != (ncaptures - 1): + self.write(", ") + self.write("](") + self.writeDeclList(l.params) + self.write(")") + if l.ret: + self.write(" -> ") + l.ret.accept(self) + self.println(" {") + self.indent() + self.visitBlock(l) + self.dedent() + self.printdent("}") + + def visitStmtBlock(self, b): + self.printdentln("{") + self.indent() + self.visitBlock(b) + self.dedent() + self.printdentln("}") + + def visitLabel(self, label): + self.dedent() # better not be at global scope ... + self.printdentln(label.name + ":") + self.indent() + + def visitCaseLabel(self, cl): + self.dedent() + self.printdentln("case " + cl.name + ":") + self.indent() + + def visitDefaultLabel(self, dl): + self.dedent() + self.printdentln("default:") + self.indent() + + def visitStmtIf(self, si): + self.printdent("if (") + si.cond.accept(self) + self.println(") {") + self.indent() + si.ifb.accept(self) + self.dedent() + self.printdentln("}") + + if si.elseb is not None: + self.printdentln("else {") + self.indent() + si.elseb.accept(self) + self.dedent() + self.printdentln("}") + + def visitStmtFor(self, sf): + self.printdent("for (") + if sf.init is not None: + sf.init.accept(self) + self.write("; ") + if sf.cond is not None: + sf.cond.accept(self) + self.write("; ") + if sf.update is not None: + sf.update.accept(self) + self.println(") {") + + self.indent() + self.visitBlock(sf) + self.dedent() + self.printdentln("}") + + def visitStmtRangedFor(self, rf): + self.printdent("for (auto& ") + rf.var.accept(self) + self.write(" : ") + rf.iteree.accept(self) + self.println(") {") + + self.indent() + self.visitBlock(rf) + self.dedent() + self.printdentln("}") + + def visitStmtSwitch(self, sw): + self.printdent("switch (") + sw.expr.accept(self) + self.println(") {") + self.indent() + self.visitBlock(sw) + self.dedent() + self.printdentln("}") + + def visitStmtBreak(self, sb): + self.printdentln("break;") + + def visitStmtDecl(self, sd): + self.printdent() + sd.decl.accept(self) + if sd.initargs is not None: + self.write("{") + self.writeDeclList(sd.initargs) + self.write("}") + if sd.init is not None: + self.write(" = ") + sd.init.accept(self) + self.println(";") + + def visitStmtExpr(self, se): + self.printdent() + se.expr.accept(self) + self.println(";") + + def visitStmtReturn(self, sr): + self.printdent("return") + if sr.expr: + self.write(" ") + sr.expr.accept(self) + self.println(";") + + def writeDeclList(self, decls): + # FIXME/cjones: try to do nice formatting of these guys + + ndecls = len(decls) + if 0 == ndecls: + return + elif 1 == ndecls: + decls[0].accept(self) + return + + self.indent() + self.indent() + for i, decl in enumerate(decls): + self.println() + self.printdent() + decl.accept(self) + if i != (ndecls - 1): + self.write(",") + self.dedent() + self.dedent() + + def writeExprList(self, exprs): + # FIXME/cjones: try to do nice formatting and share code with + # writeDeclList() + nexprs = len(exprs) + for i, expr in enumerate(exprs): + expr.accept(self) + if i != (nexprs - 1): + self.write(", ") diff --git a/ipc/ipdl/ipdl/cxx/code.py b/ipc/ipdl/ipdl/cxx/code.py new file mode 100644 index 0000000000..0b5019b623 --- /dev/null +++ b/ipc/ipdl/ipdl/cxx/code.py @@ -0,0 +1,187 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +# This module contains functionality for adding formatted, opaque "code" blocks +# into the IPDL ast. These code objects follow IPDL C++ ast patterns, and +# perform lowering in much the same way. + +# In general it is recommended to use these for blocks of code which would +# otherwise be specified by building a hardcoded IPDL-C++ AST, as users of this +# API are often easier to read than users of the AST APIs in these cases. + +import re +import math +import textwrap + +from ipdl.cxx.ast import Node, Whitespace, GroupNode, VerbatimNode + + +# ----------------------------------------------------------------------------- +# Public API. + + +def StmtCode(tmpl, **kwargs): + """Perform template substitution to build opaque C++ AST nodes. See the + module documentation for more information on the templating syntax. + + StmtCode nodes should be used where Stmt* nodes are used. They are placed + on their own line and indented.""" + return _code(tmpl, False, kwargs) + + +def ExprCode(tmpl, **kwargs): + """Perform template substitution to build opaque C++ AST nodes. See the + module documentation for more information on the templating syntax. + + ExprCode nodes should be used where Expr* nodes are used. They are placed + inline, and no trailing newline is added.""" + return _code(tmpl, True, kwargs) + + +def StmtVerbatim(text): + """Build an opaque C++ AST node which emits input text verbatim. + + StmtVerbatim nodes should be used where Stmt* nodes are used. They are placed + on their own line and indented.""" + return _verbatim(text, False) + + +def ExprVerbatim(text): + """Build an opaque C++ AST node which emits input text verbatim. + + ExprVerbatim nodes should be used where Expr* nodes are used. They are + placed inline, and no trailing newline is added.""" + return _verbatim(text, True) + + +# ----------------------------------------------------------------------------- +# Implementation + + +def _code(tmpl, inline, context): + # Remove common indentation, and strip the preceding newline from + # '''-quoting, because we usually don't want it. + if tmpl.startswith("\n"): + tmpl = tmpl[1:] + tmpl = textwrap.dedent(tmpl) + + # Process each line in turn, building up a list of nodes. + nodes = [] + for idx, line in enumerate(tmpl.splitlines()): + # Place newline tokens between lines in the input. + if idx > 0: + nodes.append(Whitespace.NL) + + # Don't indent the first line if `inline` is set. + skip_indent = inline and idx == 0 + nodes.append(_line(line.rstrip(), skip_indent, idx + 1, context)) + + # If we're inline, don't add the final trailing newline. + if not inline: + nodes.append(Whitespace.NL) + return GroupNode(nodes) + + +def _verbatim(text, inline): + # For simplicitly, _verbatim is implemented using the same logic as _code, + # but with '$' characters escaped. This ensures we only need to worry about + # a single, albeit complex, codepath. + return _code(text.replace("$", "$$"), inline, {}) + + +# Pattern used to identify substitutions. +_substPat = re.compile( + r""" + \$(?: + (?P<escaped>\$) | # '$$' is an escaped '$' + (?P<list>[*,])?{(?P<expr>[^}]+)} | # ${expr}, $*{expr}, or $,{expr} + (?P<invalid>) # For error reporting + ) + """, + re.IGNORECASE | re.VERBOSE, +) + + +def _line(raw, skip_indent, lineno, context): + assert "\n" not in raw + + # Determine the level of indentation used for this line + line = raw.lstrip() + offset = int(math.ceil((len(raw) - len(line)) / 4)) + + # If line starts with a directive, don't indent it. + if line.startswith("#"): + skip_indent = True + + column = 0 + children = [] + for match in _substPat.finditer(line): + if match.group("invalid") is not None: + raise ValueError("Invalid substitution on line %d" % lineno) + + # Any text from before the current entry should be written, and column + # advanced. + if match.start() > column: + before = line[column : match.start()] + children.append(VerbatimNode(before)) + column = match.end() + + # If we have an escaped group, emit a '$' node. + if match.group("escaped") is not None: + children.append(VerbatimNode("$")) + continue + + # At this point we should have an expression. + list_chr = match.group("list") + expr = match.group("expr") + assert expr is not None + + # Evaluate our expression in the context to get the values. + try: + values = eval(expr, context, {}) + except Exception as e: + msg = "%s in substitution on line %d" % (repr(e), lineno) + raise ValueError(msg) from e + + # If we aren't dealing with lists, wrap the result into a + # single-element list. + if list_chr is None: + values = [values] + + # Check if this substitution is inline, or the entire line. + inline = match.span() != (0, len(line)) + + for idx, value in enumerate(values): + # If we're using ',' as list mode, put a comma between each node. + if idx > 0 and list_chr == ",": + children.append(VerbatimNode(", ")) + + # If our value isn't a node, turn it into one. Verbatim should be + # inline unless indent isn't being skipped, and the match isn't + # inline. + if not isinstance(value, Node): + value = _verbatim(str(value), skip_indent or inline) + children.append(value) + + # If we were the entire line, indentation is handled by the added child + # nodes. Do this after the above loop such that created verbatims have + # the correct inline-ness. + if not inline: + skip_indent = True + + # Add any remaining text in the line. + if len(line) > column: + children.append(VerbatimNode(line[column:])) + + # If we have no children, just emit the empty string. This will become a + # blank line. + if len(children) == 0: + return VerbatimNode("") + + # Add the initial indent if we aren't skipping it. + if not skip_indent: + children.insert(0, VerbatimNode("", indent=True)) + + # Wrap ourselves into a group node with the correct indent offset + return GroupNode(children, offset=offset) diff --git a/ipc/ipdl/ipdl/lower.py b/ipc/ipdl/ipdl/lower.py new file mode 100644 index 0000000000..b8ef219b9b --- /dev/null +++ b/ipc/ipdl/ipdl/lower.py @@ -0,0 +1,5688 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import re +from copy import deepcopy +from collections import OrderedDict +import itertools + +import ipdl.ast +import ipdl.builtin +from ipdl.cxx.ast import * +from ipdl.cxx.code import * +from ipdl.type import ActorType, UnionType, TypeVisitor, builtinHeaderIncludes +from ipdl.util import hash_str + + +# ----------------------------------------------------------------------------- +# "Public" interface to lowering +## + + +class LowerToCxx: + def lower(self, tu, segmentcapacitydict): + """returns |[ header: File ], [ cpp : File ]| representing the + lowered form of |tu|""" + # annotate the AST with IPDL/C++ IR-type stuff used later + tu.accept(_DecorateWithCxxStuff()) + + # Any modifications to the filename scheme here need corresponding + # modifications in the ipdl.py driver script. + name = tu.name + pheader, pcpp = File(name + ".h"), File(name + ".cpp") + + _GenerateProtocolCode().lower(tu, pheader, pcpp, segmentcapacitydict) + headers = [pheader] + cpps = [pcpp] + + if tu.protocol: + pname = tu.protocol.name + + parentheader, parentcpp = ( + File(pname + "Parent.h"), + File(pname + "Parent.cpp"), + ) + _GenerateProtocolParentCode().lower( + tu, pname + "Parent", parentheader, parentcpp + ) + + childheader, childcpp = File(pname + "Child.h"), File(pname + "Child.cpp") + _GenerateProtocolChildCode().lower( + tu, pname + "Child", childheader, childcpp + ) + + headers += [parentheader, childheader] + cpps += [parentcpp, childcpp] + + return headers, cpps + + +# ----------------------------------------------------------------------------- +# Helper code +## + + +def hashfunc(value): + h = hash_str(value) % 2 ** 32 + if h < 0: + h += 2 ** 32 + return h + + +_NULL_ACTOR_ID = ExprLiteral.ZERO +_FREED_ACTOR_ID = ExprLiteral.ONE + +_DISCLAIMER = Whitespace( + """// +// Automatically generated by ipdlc. +// Edit at your own risk +// + +""" +) + + +class _struct: + pass + + +def _namespacedHeaderName(name, namespaces): + pfx = "/".join([ns.name for ns in namespaces]) + if pfx: + return pfx + "/" + name + else: + return name + + +def _ipdlhHeaderName(tu): + assert tu.filetype == "header" + return _namespacedHeaderName(tu.name, tu.namespaces) + + +def _protocolHeaderName(p, side=""): + if side: + side = side.title() + base = p.name + side + return _namespacedHeaderName(base, p.namespaces) + + +def _includeGuardMacroName(headerfile): + return re.sub(r"[./]", "_", headerfile.name) + + +def _includeGuardStart(headerfile): + guard = _includeGuardMacroName(headerfile) + return [CppDirective("ifndef", guard), CppDirective("define", guard)] + + +def _includeGuardEnd(headerfile): + guard = _includeGuardMacroName(headerfile) + return [CppDirective("endif", "// ifndef " + guard)] + + +def _messageStartName(ptype): + return ptype.name() + "MsgStart" + + +def _protocolId(ptype): + return ExprVar(_messageStartName(ptype)) + + +def _protocolIdType(): + return Type.INT32 + + +def _actorName(pname, side): + """|pname| is the protocol name. |side| is 'Parent' or 'Child'.""" + tag = side + if not tag[0].isupper(): + tag = side.title() + return pname + tag + + +def _actorIdType(): + return Type.INT32 + + +def _actorTypeTagType(): + return Type.INT32 + + +def _actorId(actor=None): + if actor is not None: + return ExprCall(ExprSelect(actor, "->", "Id")) + return ExprCall(ExprVar("Id")) + + +def _actorHId(actorhandle): + return ExprSelect(actorhandle, ".", "mId") + + +def _backstagePass(): + return ExprCall(ExprVar("mozilla::ipc::PrivateIPDLInterface")) + + +def _deleteId(): + return ExprVar("Msg___delete____ID") + + +def _deleteReplyId(): + return ExprVar("Reply___delete____ID") + + +def _lookupListener(idexpr): + return ExprCall(ExprVar("Lookup"), args=[idexpr]) + + +def _makeForwardDeclForQClass(clsname, quals, cls=True, struct=False): + fd = ForwardDecl(clsname, cls=cls, struct=struct) + if 0 == len(quals): + return fd + + outerns = Namespace(quals[0]) + innerns = outerns + for ns in quals[1:]: + tmpns = Namespace(ns) + innerns.addstmt(tmpns) + innerns = tmpns + + innerns.addstmt(fd) + return outerns + + +def _makeForwardDeclForActor(ptype, side): + return _makeForwardDeclForQClass( + _actorName(ptype.qname.baseid, side), ptype.qname.quals + ) + + +def _makeForwardDecl(type): + return _makeForwardDeclForQClass(type.name(), type.qname.quals) + + +def _putInNamespaces(cxxthing, namespaces): + """|namespaces| is in order [ outer, ..., inner ]""" + if 0 == len(namespaces): + return cxxthing + + outerns = Namespace(namespaces[0].name) + innerns = outerns + for ns in namespaces[1:]: + newns = Namespace(ns.name) + innerns.addstmt(newns) + innerns = newns + innerns.addstmt(cxxthing) + return outerns + + +def _sendPrefix(msgtype): + """Prefix of the name of the C++ method that sends |msgtype|.""" + if msgtype.isInterrupt(): + return "Call" + return "Send" + + +def _recvPrefix(msgtype): + """Prefix of the name of the C++ method that handles |msgtype|.""" + if msgtype.isInterrupt(): + return "Answer" + return "Recv" + + +def _flatTypeName(ipdltype): + """Return a 'flattened' IPDL type name that can be used as an + identifier. + E.g., |Foo[]| --> |ArrayOfFoo|.""" + # NB: this logic depends heavily on what IPDL types are allowed to + # be constructed; e.g., Foo[][] is disallowed. needs to be kept in + # sync with grammar. + if ipdltype.isIPDL() and ipdltype.isArray(): + return "ArrayOf" + _flatTypeName(ipdltype.basetype) + if ipdltype.isIPDL() and ipdltype.isMaybe(): + return "Maybe" + _flatTypeName(ipdltype.basetype) + # NotNull types just assume the underlying variant name to avoid unnecessary + # noise, as a NotNull<T> and T should never exist in the same union. + if ipdltype.isIPDL() and ipdltype.isNotNull(): + return _flatTypeName(ipdltype.basetype) + return ipdltype.name() + + +def _hasVisibleActor(ipdltype): + """Return true iff a C++ decl of |ipdltype| would have an Actor* type. + For example: |Actor[]| would turn into |Array<ActorParent*>|, so this + function would return true for |Actor[]|.""" + return ipdltype.isIPDL() and ( + ipdltype.isActor() + or (ipdltype.hasBaseType() and _hasVisibleActor(ipdltype.basetype)) + ) + + +def _abortIfFalse(cond, msg): + return StmtExpr( + ExprCall(ExprVar("MOZ_RELEASE_ASSERT"), [cond, ExprLiteral.String(msg)]) + ) + + +def _refptr(T): + return Type("RefPtr", T=T) + + +def _alreadyaddrefed(T): + return Type("already_AddRefed", T=T) + + +def _tuple(types, const=False, ref=False): + return Type("std::tuple", T=types, const=const, ref=ref) + + +def _promise(resolvetype, rejecttype, tail, resolver=False): + inner = Type("Private") if resolver else None + return Type("MozPromise", T=[resolvetype, rejecttype, tail], inner=inner) + + +def _makePromise(returns, side, resolver=False): + if len(returns) > 1: + resolvetype = _tuple([d.bareType(side) for d in returns]) + else: + resolvetype = returns[0].bareType(side) + + # MozPromise is purposefully made to be exclusive only. Really, we mean it. + return _promise( + resolvetype, _ResponseRejectReason.Type(), ExprLiteral.TRUE, resolver=resolver + ) + + +def _resolveType(returns, side): + if len(returns) > 1: + return _tuple([d.inType(side, "send") for d in returns]) + return returns[0].inType(side, "send") + + +def _makeResolver(returns, side): + return TypeFunction([Decl(_resolveType(returns, side), "")]) + + +def _cxxArrayType(basetype, const=False, ref=False): + return Type("nsTArray", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False) + + +def _cxxSpanType(basetype, const=False, ref=False): + basetype = deepcopy(basetype) + basetype.rightconst = True + return Type( + "mozilla::Span", T=basetype, const=const, ref=ref, hasimplicitcopyctor=True + ) + + +def _cxxMaybeType(basetype, const=False, ref=False): + return Type( + "mozilla::Maybe", + T=basetype, + const=const, + ref=ref, + hasimplicitcopyctor=basetype.hasimplicitcopyctor, + ) + + +def _cxxReadResultType(basetype, const=False, ref=False): + return Type( + "IPC::ReadResult", + T=basetype, + const=const, + ref=ref, + hasimplicitcopyctor=basetype.hasimplicitcopyctor, + ) + + +def _cxxNotNullType(basetype, const=False, ref=False): + return Type( + "mozilla::NotNull", + T=basetype, + const=const, + ref=ref, + hasimplicitcopyctor=basetype.hasimplicitcopyctor, + ) + + +def _cxxManagedContainerType(basetype, const=False, ref=False): + return Type( + "ManagedContainer", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False + ) + + +def _cxxLifecycleProxyType(ptr=False): + return Type("mozilla::ipc::ActorLifecycleProxy", ptr=ptr) + + +def _otherSide(side): + if side == "child": + return "parent" + if side == "parent": + return "child" + assert 0 + + +def _ifLogging(topLevelProtocol, stmts): + return StmtCode( + """ + if (mozilla::ipc::LoggingEnabledFor(${proto})) { + $*{stmts} + } + """, + proto=topLevelProtocol, + stmts=stmts, + ) + + +# XXX we need to remove these and install proper error handling + + +def _printErrorMessage(msg): + if isinstance(msg, str): + msg = ExprLiteral.String(msg) + return StmtExpr(ExprCall(ExprVar("NS_ERROR"), args=[msg])) + + +def _protocolErrorBreakpoint(msg): + if isinstance(msg, str): + msg = ExprLiteral.String(msg) + return StmtExpr( + ExprCall(ExprVar("mozilla::ipc::ProtocolErrorBreakpoint"), args=[msg]) + ) + + +def _printWarningMessage(msg): + if isinstance(msg, str): + msg = ExprLiteral.String(msg) + return StmtExpr(ExprCall(ExprVar("NS_WARNING"), args=[msg])) + + +def _fatalError(msg): + return StmtExpr(ExprCall(ExprVar("FatalError"), args=[ExprLiteral.String(msg)])) + + +def _logicError(msg): + return StmtExpr( + ExprCall(ExprVar("mozilla::ipc::LogicError"), args=[ExprLiteral.String(msg)]) + ) + + +def _sentinelReadError(classname): + return StmtExpr( + ExprCall( + ExprVar("mozilla::ipc::SentinelReadError"), + args=[ExprLiteral.String(classname)], + ) + ) + + +# Results that IPDL-generated code returns back to *Channel code. +# Users never see these + + +class _Result: + @staticmethod + def Type(): + return Type("Result") + + Processed = ExprVar("MsgProcessed") + NotKnown = ExprVar("MsgNotKnown") + NotAllowed = ExprVar("MsgNotAllowed") + PayloadError = ExprVar("MsgPayloadError") + ProcessingError = ExprVar("MsgProcessingError") + RouteError = ExprVar("MsgRouteError") + ValuError = ExprVar("MsgValueError") # [sic] + + +# these |errfn*| are functions that generate code to be executed on an +# error, such as "bad actor ID". each is given a Python string +# containing a description of the error + +# used in user-facing Send*() methods + + +def errfnSend(msg, errcode=ExprLiteral.FALSE): + return [_fatalError(msg), StmtReturn(errcode)] + + +def errfnSendCtor(msg): + return errfnSend(msg, errcode=ExprLiteral.NULL) + + +# TODO should this error handling be strengthened for dtors? + + +def errfnSendDtor(msg): + return [_printErrorMessage(msg), StmtReturn.FALSE] + + +# used in |OnMessage*()| handlers that hand in-messages off to Recv*() +# interface methods + + +def errfnRecv(msg, errcode=_Result.ValuError): + return [_fatalError(msg), StmtReturn(errcode)] + + +def errfnSentinel(rvalue=ExprLiteral.FALSE): + def inner(msg): + return [_sentinelReadError(msg), StmtReturn(rvalue)] + + return inner + + +def _destroyMethod(): + return ExprVar("ActorDestroy") + + +def errfnUnreachable(msg): + return [_logicError(msg)] + + +def readResultError(): + return ExprCode("{}") + + +class _DestroyReason: + @staticmethod + def Type(): + return Type("ActorDestroyReason") + + Deletion = ExprVar("Deletion") + AncestorDeletion = ExprVar("AncestorDeletion") + NormalShutdown = ExprVar("NormalShutdown") + AbnormalShutdown = ExprVar("AbnormalShutdown") + FailedConstructor = ExprVar("FailedConstructor") + ManagedEndpointDropped = ExprVar("ManagedEndpointDropped") + + +class _ResponseRejectReason: + @staticmethod + def Type(): + return Type("ResponseRejectReason") + + SendError = ExprVar("ResponseRejectReason::SendError") + ChannelClosed = ExprVar("ResponseRejectReason::ChannelClosed") + HandlerRejected = ExprVar("ResponseRejectReason::HandlerRejected") + ActorDestroyed = ExprVar("ResponseRejectReason::ActorDestroyed") + + +# ----------------------------------------------------------------------------- +# Intermediate representation (IR) nodes used during lowering + + +class _ConvertToCxxType(TypeVisitor): + def __init__(self, side, fq): + self.side = side + self.fq = fq + + def typename(self, thing): + if self.fq: + return thing.fullname() + return thing.name() + + def visitImportedCxxType(self, t): + cxxtype = Type(self.typename(t)) + if t.isRefcounted(): + cxxtype = _refptr(cxxtype) + return cxxtype + + def visitBuiltinCType(self, b): + return Type(self.typename(b)) + + def visitActorType(self, a): + if self.side is None: + return Type( + "::mozilla::ipc::SideVariant", + T=[ + _cxxBareType(a, "parent", self.fq), + _cxxBareType(a, "child", self.fq), + ], + ) + return Type(_actorName(self.typename(a.protocol), self.side), ptr=True) + + def visitStructType(self, s): + return Type(self.typename(s)) + + def visitUnionType(self, u): + return Type(self.typename(u)) + + def visitArrayType(self, a): + basecxxtype = a.basetype.accept(self) + return _cxxArrayType(basecxxtype) + + def visitMaybeType(self, m): + basecxxtype = m.basetype.accept(self) + return _cxxMaybeType(basecxxtype) + + def visitShmemType(self, s): + return Type(self.typename(s)) + + def visitByteBufType(self, s): + return Type(self.typename(s)) + + def visitFDType(self, s): + return Type(self.typename(s)) + + def visitEndpointType(self, s): + return Type(self.typename(s)) + + def visitManagedEndpointType(self, s): + return Type(self.typename(s)) + + def visitUniquePtrType(self, s): + return Type(self.typename(s)) + + def visitNotNullType(self, n): + basecxxtype = n.basetype.accept(self) + return _cxxNotNullType(basecxxtype) + + def visitProtocolType(self, p): + assert 0 + + def visitMessageType(self, m): + assert 0 + + def visitVoidType(self, v): + assert 0 + + +def _cxxBareType(ipdltype, side, fq=False): + return ipdltype.accept(_ConvertToCxxType(side, fq)) + + +def _cxxRefType(ipdltype, side): + t = _cxxBareType(ipdltype, side) + t.ref = True + return t + + +def _cxxConstRefType(ipdltype, side): + t = _cxxBareType(ipdltype, side) + if ipdltype.isIPDL() and ipdltype.isActor(): + return t + if ipdltype.isIPDL() and ipdltype.isShmem(): + t.ref = True + return t + if ipdltype.isIPDL() and ipdltype.isNotNull(): + # If the inner type chooses to use a raw pointer, wrap that instead. + inner = _cxxConstRefType(ipdltype.basetype, side) + if inner.ptr: + t = _cxxNotNullType(inner) + return t + if ipdltype.isIPDL() and ipdltype.hasBaseType(): + # Keep same constness as inner type. + inner = _cxxConstRefType(ipdltype.basetype, side) + t.const = inner.const or not inner.ref + t.ref = True + return t + if ipdltype.isCxx() and (ipdltype.isSendMoveOnly() or ipdltype.isDataMoveOnly()): + t.const = True + t.ref = True + return t + if ipdltype.isCxx() and ipdltype.isRefcounted(): + # Use T* instead of const RefPtr<T>& + t = t.T + t.ptr = True + return t + t.const = True + t.ref = True + return t + + +def _cxxTypeNeedsMoveForSend(ipdltype, context="root", visited=None): + """Returns `True` if serializing ipdltype requires a mutable reference, e.g. + because the underlying resource represented by the value is being + transferred to another process. This is occasionally distinct from whether + the C++ type exposes a copy constructor, such as for types which are not + cheaply copiable, but are not mutated when serialized.""" + + if visited is None: + visited = set() + + visited.add(ipdltype) + + if ipdltype.isCxx(): + return ipdltype.isSendMoveOnly() + + if ipdltype.isIPDL(): + if ipdltype.hasBaseType(): + return _cxxTypeNeedsMoveForSend(ipdltype.basetype, "wrapper", visited) + if ipdltype.isStruct() or ipdltype.isUnion(): + return any( + _cxxTypeNeedsMoveForSend(t, "compound", visited) + for t in ipdltype.itercomponents() + if t not in visited + ) + + # For historical reasons, shmem is `const_cast` to a mutable reference + # when being stored in a struct or union (see + # `_StructField.constRefExpr` and `_UnionMember.getConstValue`), meaning + # that they do not cause the containing struct to require move for + # sending. + if ipdltype.isShmem(): + return context != "compound" + + return ( + ipdltype.isByteBuf() + or ipdltype.isEndpoint() + or ipdltype.isManagedEndpoint() + ) + + return False + + +def _cxxTypeNeedsMoveForData(ipdltype, context="root", visited=None): + """Returns `True` if the bare C++ type corresponding to ipdltype does not + satisfy std::is_copy_constructible_v<T>. All C++ types supported by IPDL + must support std::is_move_constructible_v<T>, so non-movable types must be + passed behind a `UniquePtr`.""" + + if visited is None: + visited = set() + + visited.add(ipdltype) + + if ipdltype.isCxx(): + return ipdltype.isDataMoveOnly() + + if ipdltype.isIPDL(): + if ipdltype.isUniquePtr(): + return True + + # When nested within a maybe or array, arrays are no longer copyable. + if context == "wrapper" and ipdltype.isArray(): + return True + if ipdltype.hasBaseType(): + return _cxxTypeNeedsMoveForData(ipdltype.basetype, "wrapper", visited) + if ipdltype.isStruct() or ipdltype.isUnion(): + return any( + _cxxTypeNeedsMoveForData(t, "compound", visited) + for t in ipdltype.itercomponents() + if t not in visited + ) + return ( + ipdltype.isByteBuf() + or ipdltype.isEndpoint() + or ipdltype.isManagedEndpoint() + ) + + return False + + +def _cxxTypeCanMove(ipdltype): + return not (ipdltype.isIPDL() and ipdltype.isActor()) + + +def _cxxForceMoveRefType(ipdltype, side): + assert _cxxTypeCanMove(ipdltype) + t = _cxxBareType(ipdltype, side) + t.rvalref = True + return t + + +def _cxxPtrToType(ipdltype, side): + t = _cxxBareType(ipdltype, side) + if ipdltype.isIPDL() and ipdltype.isActor() and side is not None: + t.ptr = False + t.ptrptr = True + return t + t.ptr = True + return t + + +def _cxxConstPtrToType(ipdltype, side): + t = _cxxBareType(ipdltype, side) + if ipdltype.isIPDL() and ipdltype.isActor() and side is not None: + t.ptr = False + t.ptrconstptr = True + return t + t.const = True + t.ptr = True + return t + + +def _cxxInType(ipdltype, side, direction): + t = _cxxBareType(ipdltype, side) + if ipdltype.isIPDL() and ipdltype.isActor(): + return t + if ipdltype.isIPDL() and ipdltype.isNotNull(): + # If the inner type chooses to use a raw pointer, wrap that instead. + inner = _cxxInType(ipdltype.basetype, side, direction) + if inner.ptr: + t = _cxxNotNullType(inner) + return t + if _cxxTypeNeedsMoveForSend(ipdltype): + t.rvalref = True + return t + if ipdltype.isCxx(): + if ipdltype.isRefcounted(): + # Use T* instead of const RefPtr<T>& + t = t.T + t.ptr = True + return t + if ipdltype.name() == "nsCString": + t = Type("nsACString") + if ipdltype.name() == "nsString": + t = Type("nsAString") + # Use Span<T const> rather than nsTArray<T> for array types which aren't + # `_cxxTypeNeedsMoveForSend`. This is only done for the "send" side, and not + # for recv signatures. + if direction == "send" and ipdltype.isIPDL() and ipdltype.isArray(): + inner = _cxxBareType(ipdltype.basetype, side) + return _cxxSpanType(inner) + + t.const = True + t.ref = True + return t + + +def _allocMethod(ptype, side): + return "Alloc" + ptype.name() + side.title() + + +def _deallocMethod(ptype, side): + return "Dealloc" + ptype.name() + side.title() + + +## +# A _HybridDecl straddles IPDL and C++ decls. It knows which C++ +# types correspond to which IPDL types, and it also knows how +# serialize and deserialize "special" IPDL C++ types. +## + + +class _HybridDecl: + """A hybrid decl stores both an IPDL type and all the C++ type + info needed by later passes, along with a basic name for the decl.""" + + def __init__(self, ipdltype, name, attributes={}): + self.ipdltype = ipdltype + self.name = name + self.attributes = attributes + + def var(self): + return ExprVar(self.name) + + def bareType(self, side, fq=False): + """Return this decl's unqualified C++ type.""" + return _cxxBareType(self.ipdltype, side, fq=fq) + + def refType(self, side): + """Return this decl's C++ type as a 'reference' type, which is not + necessarily a C++ reference.""" + return _cxxRefType(self.ipdltype, side) + + def constRefType(self, side): + """Return this decl's C++ type as a const, 'reference' type.""" + return _cxxConstRefType(self.ipdltype, side) + + def ptrToType(self, side): + return _cxxPtrToType(self.ipdltype, side) + + def constPtrToType(self, side): + return _cxxConstPtrToType(self.ipdltype, side) + + def inType(self, side, direction): + """Return this decl's C++ Type with sending inparam semantics.""" + return _cxxInType(self.ipdltype, side, direction) + + def outType(self, side): + """Return this decl's C++ Type with outparam semantics.""" + t = self.bareType(side) + if self.ipdltype.isIPDL() and self.ipdltype.isActor(): + t.ptr = False + t.ptrptr = True + return t + t.ptr = True + return t + + def forceMoveType(self, side): + """Return this decl's C++ Type with forced move semantics.""" + assert _cxxTypeCanMove(self.ipdltype) + return _cxxForceMoveRefType(self.ipdltype, side) + + +# -------------------------------------------------- + + +class HasFQName: + def fqClassName(self): + return self.decl.type.fullname() + + +class _CompoundTypeComponent(_HybridDecl): + # @override the following methods to make the side argument optional. + def bareType(self, side=None, fq=False): + return _HybridDecl.bareType(self, side, fq=fq) + + def refType(self, side=None): + return _HybridDecl.refType(self, side) + + def constRefType(self, side=None): + return _HybridDecl.constRefType(self, side) + + def ptrToType(self, side=None): + return _HybridDecl.ptrToType(self, side) + + def constPtrToType(self, side=None): + return _HybridDecl.constPtrToType(self, side) + + def forceMoveType(self, side=None): + return _HybridDecl.forceMoveType(self, side) + + +class StructDecl(ipdl.ast.StructDecl, HasFQName): + def fields_ipdl_order(self): + for f in self.fields: + yield f + + def fields_member_order(self): + assert len(self.packed_field_order) == len(self.fields) + + for i in self.packed_field_order: + yield self.fields[i] + + @staticmethod + def upgrade(structDecl): + assert isinstance(structDecl, ipdl.ast.StructDecl) + structDecl.__class__ = StructDecl + + +class _StructField(_CompoundTypeComponent): + def __init__(self, ipdltype, name, sd): + self.basename = name + + _CompoundTypeComponent.__init__(self, ipdltype, name) + + def getMethod(self, thisexpr=None, sel="."): + meth = self.var() + if thisexpr is not None: + return ExprSelect(thisexpr, sel, meth.name) + return meth + + def refExpr(self, thisexpr=None): + ref = self.memberVar() + if thisexpr is not None: + ref = ExprSelect(thisexpr, ".", ref.name) + return ref + + def constRefExpr(self, thisexpr=None): + # sigh, gross hack + refexpr = self.refExpr(thisexpr) + if "Shmem" == self.ipdltype.name(): + refexpr = ExprCast(refexpr, Type("Shmem", ref=True), const=True) + return refexpr + + def argVar(self): + return ExprVar("_" + self.name) + + def memberVar(self): + return ExprVar(self.name + "_") + + +class UnionDecl(ipdl.ast.UnionDecl, HasFQName): + def callType(self, var=None): + func = ExprVar("type") + if var is not None: + func = ExprSelect(var, ".", func.name) + return ExprCall(func) + + @staticmethod + def upgrade(unionDecl): + assert isinstance(unionDecl, ipdl.ast.UnionDecl) + unionDecl.__class__ = UnionDecl + + +class _UnionMember(_CompoundTypeComponent): + """Not in the AFL sense, but rather a member (e.g. |int;|) of an + IPDL union type.""" + + def __init__(self, ipdltype, ud): + flatname = _flatTypeName(ipdltype) + + _CompoundTypeComponent.__init__(self, ipdltype, "V" + flatname) + self.flattypename = flatname + + # To create a finite object with a mutually recursive type, a union must + # be present somewhere in the recursive loop. Because of that we only + # need to care about introducing indirections inside unions. + self.recursive = ud.decl.type.mutuallyRecursiveWith(ipdltype) + + def enum(self): + return "T" + self.flattypename + + def enumvar(self): + return ExprVar(self.enum()) + + def internalType(self): + if self.recursive: + return self.ptrToType() + else: + return self.bareType() + + def unionType(self): + """Type used for storage in generated C union decl.""" + if self.recursive: + return self.ptrToType() + else: + return Type("mozilla::AlignedStorage2", T=self.internalType()) + + def unionValue(self): + # NB: knows that Union's storage C union is named |mValue| + return ExprSelect(ExprVar("mValue"), ".", self.name) + + def typedef(self): + return self.flattypename + "__tdef" + + def callGetConstPtr(self): + """Return an expression of type self.constptrToSelfType()""" + return ExprCall(ExprVar(self.getConstPtrName())) + + def callGetPtr(self): + """Return an expression of type self.ptrToSelfType()""" + return ExprCall(ExprVar(self.getPtrName())) + + def callCtor(self, expr=None): + assert not isinstance(expr, list) + + if expr is None: + args = None + elif ( + self.ipdltype.isIPDL() + and self.ipdltype.isArray() + and not isinstance(expr, ExprMove) + ): + args = [ExprCall(ExprSelect(expr, ".", "Clone"), args=[])] + else: + args = [expr] + + if self.recursive: + return ExprAssn(self.callGetPtr(), ExprNew(self.bareType(), args=args)) + else: + return ExprNew( + self.bareType(), + args=args, + newargs=[ExprVar("mozilla::KnownNotNull"), self.callGetPtr()], + ) + + def callDtor(self): + if self.recursive: + return ExprDelete(self.callGetPtr()) + else: + return ExprCall(ExprSelect(self.callGetPtr(), "->", "~" + self.typedef())) + + def getTypeName(self): + return "get_" + self.flattypename + + def getConstTypeName(self): + return "get_" + self.flattypename + + def getOtherTypeName(self): + return "get_" + self.otherflattypename + + def getPtrName(self): + return "ptr_" + self.flattypename + + def getConstPtrName(self): + return "constptr_" + self.flattypename + + def ptrToSelfExpr(self): + """|*ptrToSelfExpr()| has type |self.bareType()|""" + v = self.unionValue() + if self.recursive: + return v + else: + return ExprCall(ExprSelect(v, ".", "addr")) + + def constptrToSelfExpr(self): + """|*constptrToSelfExpr()| has type |self.constType()|""" + v = self.unionValue() + if self.recursive: + return v + return ExprCall(ExprSelect(v, ".", "addr")) + + def ptrToInternalType(self): + t = self.ptrToType() + if self.recursive: + t.ref = True + return t + + def defaultValue(self, fq=False): + # Use the default constructor for any class that does not have an + # implicit copy constructor. + if not self.bareType().hasimplicitcopyctor: + return None + + if self.ipdltype.isIPDL() and self.ipdltype.isActor(): + return ExprLiteral.NULL + # XXX sneaky here, maybe need ExprCtor()? + return ExprCall(self.bareType(fq=fq)) + + def getConstValue(self): + v = ExprDeref(self.callGetConstPtr()) + # sigh + if "Shmem" == self.ipdltype.name(): + v = ExprCast(v, Type("Shmem", ref=True), const=True) + return v + + +# -------------------------------------------------- + + +class MessageDecl(ipdl.ast.MessageDecl): + def baseName(self): + return self.name + + def recvMethod(self): + name = _recvPrefix(self.decl.type) + self.baseName() + if self.decl.type.isCtor(): + name += "Constructor" + return name + + def sendMethod(self): + name = _sendPrefix(self.decl.type) + self.baseName() + if self.decl.type.isCtor(): + name += "Constructor" + return name + + def hasReply(self): + return ( + self.decl.type.hasReply() + or self.decl.type.isCtor() + or self.decl.type.isDtor() + ) + + def hasAsyncReturns(self): + return self.decl.type.isAsync() and self.returns + + def msgCtorFunc(self): + return "Msg_%s" % (self.decl.progname) + + def prettyMsgName(self, pfx=""): + return pfx + self.msgCtorFunc() + + def pqMsgCtorFunc(self): + return "%s::%s" % (self.namespace, self.msgCtorFunc()) + + def msgId(self): + return self.msgCtorFunc() + "__ID" + + def pqMsgId(self): + return "%s::%s" % (self.namespace, self.msgId()) + + def replyCtorFunc(self): + return "Reply_%s" % (self.decl.progname) + + def pqReplyCtorFunc(self): + return "%s::%s" % (self.namespace, self.replyCtorFunc()) + + def replyId(self): + return self.replyCtorFunc() + "__ID" + + def pqReplyId(self): + return "%s::%s" % (self.namespace, self.replyId()) + + def prettyReplyName(self, pfx=""): + return pfx + self.replyCtorFunc() + + def promiseName(self): + name = self.baseName() + if self.decl.type.isCtor(): + name += "Constructor" + name += "Promise" + return name + + def resolverName(self): + return self.baseName() + "Resolver" + + def actorDecl(self): + return self.params[0] + + def makeCxxParams( + self, paramsems="in", returnsems="out", side=None, implicit=True, direction=None + ): + """Return a list of C++ decls per the spec'd configuration. + |params| and |returns| is the C++ semantics of those: 'in', 'out', or None.""" + + def makeDecl(d, sems): + if ( + self.decl.type.tainted + and "NoTaint" not in d.attributes + and direction == "recv" + ): + # Tainted types are passed by-value, allowing the receiver to move them if desired. + assert sems != "out" + return Decl(Type("Tainted", T=d.bareType(side)), d.name) + + if sems == "in": + t = d.inType(side, direction) + # If this is the `recv` side, and we're not using "move" + # semantics, that means we're an alloc method, and cannot accept + # values by rvalue reference. Downgrade to an lvalue reference. + if direction == "recv" and t.rvalref: + t.rvalref = False + t.ref = True + return Decl(t, d.name) + elif sems == "move": + assert direction == "recv" + # For legacy reasons, use an rvalue reference when generating + # parameters for recv methods which accept arrays. + if d.ipdltype.isIPDL() and d.ipdltype.isArray(): + t = d.bareType(side) + t.rvalref = True + return Decl(t, d.name) + return Decl(d.inType(side, direction), d.name) + elif sems == "out": + return Decl(d.outType(side), d.name) + else: + assert 0 + + def makeResolverDecl(returns): + return Decl(Type(self.resolverName(), rvalref=True), "aResolve") + + def makeCallbackResolveDecl(returns): + if len(returns) > 1: + resolvetype = _tuple([d.bareType(side) for d in returns]) + else: + resolvetype = returns[0].bareType(side) + + return Decl( + Type("mozilla::ipc::ResolveCallback", T=resolvetype, rvalref=True), + "aResolve", + ) + + def makeCallbackRejectDecl(returns): + return Decl(Type("mozilla::ipc::RejectCallback", rvalref=True), "aReject") + + cxxparams = [] + if paramsems is not None: + cxxparams.extend([makeDecl(d, paramsems) for d in self.params]) + + if returnsems == "promise" and self.returns: + pass + elif returnsems == "callback" and self.returns: + cxxparams.extend( + [ + makeCallbackResolveDecl(self.returns), + makeCallbackRejectDecl(self.returns), + ] + ) + elif returnsems == "resolver" and self.returns: + cxxparams.extend([makeResolverDecl(self.returns)]) + elif returnsems is not None: + cxxparams.extend([makeDecl(r, returnsems) for r in self.returns]) + + if not implicit and self.decl.type.hasImplicitActorParam(): + cxxparams = cxxparams[1:] + + return cxxparams + + def makeCxxArgs( + self, paramsems="in", retsems="out", retcallsems="out", implicit=True + ): + assert not retcallsems or retsems # retcallsems => returnsems + cxxargs = [] + + if paramsems == "move": + # We don't std::move() RefPtr<T> types because current Recv*() + # implementors take these parameters as T*, and + # std::move(RefPtr<T>) doesn't coerce to T*. + # We also don't move NotNull, as it has no move constructor. + cxxargs.extend( + [ + p.var() + if p.ipdltype.isRefcounted() + or (p.ipdltype.isIPDL() and p.ipdltype.isNotNull()) + else ExprMove(p.var()) + for p in self.params + ] + ) + elif paramsems == "in": + cxxargs.extend([p.var() for p in self.params]) + else: + assert False + + for ret in self.returns: + if retsems == "in": + if retcallsems == "in": + cxxargs.append(ret.var()) + elif retcallsems == "out": + cxxargs.append(ExprAddrOf(ret.var())) + else: + assert 0 + elif retsems == "out": + if retcallsems == "in": + cxxargs.append(ExprDeref(ret.var())) + elif retcallsems == "out": + cxxargs.append(ret.var()) + else: + assert 0 + elif retsems == "resolver": + pass + if retsems == "resolver": + cxxargs.append(ExprMove(ExprVar("resolver"))) + + if not implicit: + assert self.decl.type.hasImplicitActorParam() + cxxargs = cxxargs[1:] + + return cxxargs + + @staticmethod + def upgrade(messageDecl): + assert isinstance(messageDecl, ipdl.ast.MessageDecl) + if messageDecl.decl.type.hasImplicitActorParam(): + messageDecl.params.insert( + 0, + _HybridDecl( + ipdl.type.ActorType(messageDecl.decl.type.constructedType()), + "actor", + ), + ) + messageDecl.__class__ = MessageDecl + + +# -------------------------------------------------- +def _usesShmem(p): + for md in p.messageDecls: + for param in md.inParams: + if ipdl.type.hasshmem(param.type): + return True + for ret in md.outParams: + if ipdl.type.hasshmem(ret.type): + return True + return False + + +def _subtreeUsesShmem(p): + if _usesShmem(p): + return True + + ptype = p.decl.type + for mgd in ptype.manages: + if ptype is not mgd: + if _subtreeUsesShmem(mgd._ast): + return True + return False + + +class Protocol(ipdl.ast.Protocol): + def managerInterfaceType(self, ptr=False): + return Type("mozilla::ipc::IProtocol", ptr=ptr) + + def openedProtocolInterfaceType(self, ptr=False): + return Type("mozilla::ipc::IToplevelProtocol", ptr=ptr) + + def _ipdlmgrtype(self): + assert 1 == len(self.decl.type.managers) + for mgr in self.decl.type.managers: + return mgr + + def managerActorType(self, side, ptr=False): + return Type(_actorName(self._ipdlmgrtype().name(), side), ptr=ptr) + + def unregisterMethod(self, actorThis=None): + if actorThis is not None: + return ExprSelect(actorThis, "->", "Unregister") + return ExprVar("Unregister") + + def removeManageeMethod(self): + return ExprVar("RemoveManagee") + + def deallocManageeMethod(self): + return ExprVar("DeallocManagee") + + def getChannelMethod(self): + return ExprVar("GetIPCChannel") + + def callGetChannel(self, actorThis=None): + fn = self.getChannelMethod() + if actorThis is not None: + fn = ExprSelect(actorThis, "->", fn.name) + return ExprCall(fn) + + def processingErrorVar(self): + assert self.decl.type.isToplevel() + return ExprVar("ProcessingError") + + def shouldContinueFromTimeoutVar(self): + assert self.decl.type.isToplevel() + return ExprVar("ShouldContinueFromReplyTimeout") + + def routingId(self, actorThis=None): + if self.decl.type.isToplevel(): + return ExprVar("MSG_ROUTING_CONTROL") + if actorThis is not None: + return ExprCall(ExprSelect(actorThis, "->", "Id")) + return ExprCall(ExprVar("Id")) + + def managerVar(self, thisexpr=None): + assert thisexpr is not None or not self.decl.type.isToplevel() + mvar = ExprCall(ExprVar("Manager"), args=[]) + if thisexpr is not None: + mvar = ExprCall(ExprSelect(thisexpr, "->", "Manager"), args=[]) + return mvar + + def managedCxxType(self, actortype, side): + assert self.decl.type.isManagerOf(actortype) + return Type(_actorName(actortype.name(), side), ptr=True) + + def managedMethod(self, actortype, side): + assert self.decl.type.isManagerOf(actortype) + return ExprVar("Managed" + _actorName(actortype.name(), side)) + + def managedVar(self, actortype, side): + assert self.decl.type.isManagerOf(actortype) + return ExprVar("mManaged" + _actorName(actortype.name(), side)) + + def managedVarType(self, actortype, side, const=False, ref=False): + assert self.decl.type.isManagerOf(actortype) + return _cxxManagedContainerType( + Type(_actorName(actortype.name(), side)), const=const, ref=ref + ) + + def subtreeUsesShmem(self): + return _subtreeUsesShmem(self) + + @staticmethod + def upgrade(protocol): + assert isinstance(protocol, ipdl.ast.Protocol) + protocol.__class__ = Protocol + + +class TranslationUnit(ipdl.ast.TranslationUnit): + @staticmethod + def upgrade(tu): + assert isinstance(tu, ipdl.ast.TranslationUnit) + tu.__class__ = TranslationUnit + + +# ----------------------------------------------------------------------------- + +pod_types = { + "::int8_t": 1, + "::uint8_t": 1, + "::int16_t": 2, + "::uint16_t": 2, + "::int32_t": 4, + "::uint32_t": 4, + "::int64_t": 8, + "::uint64_t": 8, + "float": 4, + "double": 8, +} +max_pod_size = max(pod_types.values()) +# We claim that all types we don't recognize are automatically "bigger" +# than pod types for ease of sorting. +pod_size_sentinel = max_pod_size * 2 + + +def pod_size(ipdltype): + if not ipdltype.isCxx(): + return pod_size_sentinel + + return pod_types.get(ipdltype.fullname(), pod_size_sentinel) + + +class _DecorateWithCxxStuff(ipdl.ast.Visitor): + """Phase 1 of lowering: decorate the IPDL AST with information + relevant to C++ code generation. + + This pass results in an AST that is a poor man's "IR"; in reality, a + "hybrid" AST mainly consisting of IPDL nodes with new C++ info along + with some new IPDL/C++ nodes that are tuned for C++ codegen.""" + + def __init__(self): + self.visitedTus = set() + self.protocolName = None + + def visitTranslationUnit(self, tu): + if tu not in self.visitedTus: + self.visitedTus.add(tu) + ipdl.ast.Visitor.visitTranslationUnit(self, tu) + if not isinstance(tu, TranslationUnit): + TranslationUnit.upgrade(tu) + + def visitInclude(self, inc): + if inc.tu.filetype == "header": + inc.tu.accept(self) + + def visitProtocol(self, pro): + self.protocolName = pro.name + Protocol.upgrade(pro) + return ipdl.ast.Visitor.visitProtocol(self, pro) + + def visitStructDecl(self, sd): + if not isinstance(sd, StructDecl): + newfields = [_StructField(f.decl.type, f.name, sd) for f in sd.fields] + + # Compute a permutation of the fields for in-memory storage such + # that the memory layout of the structure will be well-packed. + permutation = list(range(len(newfields))) + + # Note that the results of `pod_size` ensure that non-POD fields + # sort before POD ones. + def size(idx): + return pod_size(newfields[idx].ipdltype) + + permutation.sort(key=size, reverse=True) + + sd.fields = newfields + sd.packed_field_order = permutation + StructDecl.upgrade(sd) + + def visitUnionDecl(self, ud): + ud.components = [_UnionMember(ctype, ud) for ctype in ud.decl.type.components] + UnionDecl.upgrade(ud) + + def visitDecl(self, decl): + return _HybridDecl(decl.type, decl.progname, decl.attributes) + + def visitMessageDecl(self, md): + md.namespace = self.protocolName + md.params = [param.accept(self) for param in md.inParams] + md.returns = [ret.accept(self) for ret in md.outParams] + MessageDecl.upgrade(md) + + +# ----------------------------------------------------------------------------- + + +def msgenums(protocol, pretty=False): + msgenum = TypeEnum("MessageType") + msgstart = _messageStartName(protocol.decl.type) + " << 16" + msgenum.addId(protocol.name + "Start", msgstart) + + for md in protocol.messageDecls: + msgenum.addId(md.prettyMsgName() if pretty else md.msgId()) + if md.hasReply(): + msgenum.addId(md.prettyReplyName() if pretty else md.replyId()) + + msgenum.addId(protocol.name + "End") + return msgenum + + +class _GenerateProtocolCode(ipdl.ast.Visitor): + """Creates code common to both the parent and child actors.""" + + def __init__(self): + self.protocol = None # protocol we're generating a class for + self.hdrfile = None # what will become Protocol.h + self.cppfile = None # what will become Protocol.cpp + self.cppIncludeHeaders = [] + self.structUnionDefns = [] + self.funcDefns = [] + + def lower(self, tu, cxxHeaderFile, cxxFile, segmentcapacitydict): + self.protocol = tu.protocol + self.hdrfile = cxxHeaderFile + self.cppfile = cxxFile + self.segmentcapacitydict = segmentcapacitydict + tu.accept(self) + + def visitTranslationUnit(self, tu): + hf = self.hdrfile + + hf.addthing(_DISCLAIMER) + hf.addthings(_includeGuardStart(hf)) + hf.addthing(Whitespace.NL) + + for inc in builtinHeaderIncludes: + self.visitBuiltinCxxInclude(inc) + + # Compute the set of includes we need for declared structure/union + # classes for this protocol. + typesToIncludes = {} + for using in tu.using: + typestr = str(using.type) + if typestr not in typesToIncludes: + typesToIncludes[typestr] = using.header + else: + assert typesToIncludes[typestr] == using.header + + aggregateTypeIncludes = set() + for su in tu.structsAndUnions: + typedeps = _ComputeTypeDeps(su.decl.type, typesToIncludes) + if isinstance(su, ipdl.ast.StructDecl): + aggregateTypeIncludes.add("mozilla/ipc/IPDLStructMember.h") + for f in su.fields: + f.ipdltype.accept(typedeps) + elif isinstance(su, ipdl.ast.UnionDecl): + for c in su.components: + c.ipdltype.accept(typedeps) + + aggregateTypeIncludes.update(typedeps.includeHeaders) + + if len(aggregateTypeIncludes) != 0: + hf.addthing(Whitespace.NL) + hf.addthings([Whitespace("// Headers for typedefs"), Whitespace.NL]) + + for headername in sorted(iter(aggregateTypeIncludes)): + hf.addthing(CppDirective("include", '"' + headername + '"')) + + # Manually run Visitor.visitTranslationUnit. For dependency resolution + # we need to handle structs and unions separately. + for cxxInc in tu.cxxIncludes: + cxxInc.accept(self) + for inc in tu.includes: + inc.accept(self) + self.generateStructsAndUnions(tu) + for using in tu.builtinUsing: + using.accept(self) + for using in tu.using: + using.accept(self) + if tu.protocol: + tu.protocol.accept(self) + + if tu.filetype == "header": + self.cppIncludeHeaders.append(_ipdlhHeaderName(tu) + ".h") + + hf.addthing(Whitespace.NL) + hf.addthings(_includeGuardEnd(hf)) + + cf = self.cppfile + cf.addthings( + ( + [_DISCLAIMER, Whitespace.NL] + + [ + CppDirective("include", '"' + h + '"') + for h in self.cppIncludeHeaders + ] + + [Whitespace.NL] + + [ + CppDirective("include", '"%s"' % filename) + for filename in ipdl.builtin.CppIncludes + ] + + [Whitespace.NL] + ) + ) + + if self.protocol: + # construct the namespace into which we'll stick all our defns + ns = Namespace(self.protocol.name) + cf.addthing(_putInNamespaces(ns, self.protocol.namespaces)) + ns.addstmts(([Whitespace.NL] + self.funcDefns + [Whitespace.NL])) + + cf.addthings(self.structUnionDefns) + + def visitBuiltinCxxInclude(self, inc): + self.hdrfile.addthing(CppDirective("include", '"' + inc.file + '"')) + + def visitCxxInclude(self, inc): + self.cppIncludeHeaders.append(inc.file) + + def visitInclude(self, inc): + if inc.tu.filetype == "header": + self.hdrfile.addthing( + CppDirective("include", '"' + _ipdlhHeaderName(inc.tu) + '.h"') + ) + # Inherit cpp includes defined by imported header files, as they may + # be required to serialize an imported `using` type. + for cxxinc in inc.tu.cxxIncludes: + cxxinc.accept(self) + else: + self.cppIncludeHeaders += [ + _protocolHeaderName(inc.tu.protocol, "parent") + ".h", + _protocolHeaderName(inc.tu.protocol, "child") + ".h", + ] + + def generateStructsAndUnions(self, tu): + """Generate the definitions for all structs and unions. This will + re-order the declarations if needed in the C++ code such that + dependencies have already been defined.""" + decls = OrderedDict() + for su in tu.structsAndUnions: + if isinstance(su, StructDecl): + which = "struct" + forwarddecls, fulldecltypes, cls = _generateCxxStruct(su) + traitsdecl, traitsdefns = _ParamTraits.structPickling(su.decl.type) + else: + assert isinstance(su, UnionDecl) + which = "union" + forwarddecls, fulldecltypes, cls = _generateCxxUnion(su) + traitsdecl, traitsdefns = _ParamTraits.unionPickling(su.decl.type) + + clsdecl, methoddefns = _splitClassDeclDefn(cls) + + # Store the declarations in the decls map so we can emit in + # dependency order. + decls[su.decl.type] = ( + fulldecltypes, + [Whitespace.NL] + + forwarddecls + + [ + Whitespace( + """ +//----------------------------------------------------------------------------- +// Declaration of the IPDL type |%s %s| +// +""" + % (which, su.name) + ), + _putInNamespaces(clsdecl, su.namespaces), + ] + + [Whitespace.NL, traitsdecl], + ) + + self.structUnionDefns.extend( + [ + Whitespace( + """ +//----------------------------------------------------------------------------- +// Method definitions for the IPDL type |%s %s| +// +""" + % (which, su.name) + ), + _putInNamespaces(methoddefns, su.namespaces), + Whitespace.NL, + traitsdefns, + ] + ) + + # Generate the declarations structs in dependency order. + def gen_struct(deps, defn): + for dep in deps: + if dep in decls: + d, t = decls[dep] + del decls[dep] + gen_struct(d, t) + self.hdrfile.addthings(defn) + + while len(decls) > 0: + _, (d, t) = decls.popitem(False) + gen_struct(d, t) + + def visitProtocol(self, p): + self.cppIncludeHeaders.append(_protocolHeaderName(self.protocol, "") + ".h") + self.cppIncludeHeaders.append( + _protocolHeaderName(self.protocol, "Parent") + ".h" + ) + self.cppIncludeHeaders.append( + _protocolHeaderName(self.protocol, "Child") + ".h" + ) + + # Forward declare our own actors. + self.hdrfile.addthings( + [ + Whitespace.NL, + _makeForwardDeclForActor(p.decl.type, "Parent"), + _makeForwardDeclForActor(p.decl.type, "Child"), + ] + ) + + self.hdrfile.addthing( + Whitespace( + """ +//----------------------------------------------------------------------------- +// Code common to %sChild and %sParent +// +""" + % (p.name, p.name) + ) + ) + + # construct the namespace into which we'll stick all our decls + ns = Namespace(self.protocol.name) + self.hdrfile.addthing(_putInNamespaces(ns, p.namespaces)) + ns.addstmt(Whitespace.NL) + + for func in self.genEndpointFuncs(): + edecl, edefn = _splitFuncDeclDefn(func) + ns.addstmts([edecl, Whitespace.NL]) + self.funcDefns.append(edefn) + + # spit out message type enum and classes + msgenum = msgenums(self.protocol) + ns.addstmts([StmtDecl(Decl(msgenum, "")), Whitespace.NL]) + + for md in p.messageDecls: + decls = [] + + # Look up the segment capacity used for serializing this + # message. If the capacity is not specified, use '0' for + # the default capacity (defined in ipc_message.cc) + name = "%s::%s" % (md.namespace, md.decl.progname) + segmentcapacity = self.segmentcapacitydict.get(name, 0) + + mfDecl, mfDefn = _splitFuncDeclDefn( + _generateMessageConstructor(md, segmentcapacity, p, forReply=False) + ) + decls.append(mfDecl) + self.funcDefns.append(mfDefn) + + if md.hasReply(): + rfDecl, rfDefn = _splitFuncDeclDefn( + _generateMessageConstructor(md, 0, p, forReply=True) + ) + decls.append(rfDecl) + self.funcDefns.append(rfDefn) + + decls.append(Whitespace.NL) + ns.addstmts(decls) + + ns.addstmts([Whitespace.NL, Whitespace.NL]) + + # Generate code for PFoo::CreateEndpoints. + def genEndpointFuncs(self): + p = self.protocol.decl.type + tparent = _cxxBareType(ActorType(p), "Parent", fq=True) + tchild = _cxxBareType(ActorType(p), "Child", fq=True) + + def mkOverload(includepids): + params = [] + if includepids: + params = [ + Decl(Type("base::ProcessId"), "aParentDestPid"), + Decl(Type("base::ProcessId"), "aChildDestPid"), + ] + params += [ + Decl( + Type("mozilla::ipc::Endpoint<" + tparent.name + ">", ptr=True), + "aParent", + ), + Decl( + Type("mozilla::ipc::Endpoint<" + tchild.name + ">", ptr=True), + "aChild", + ), + ] + openfunc = MethodDefn( + MethodDecl("CreateEndpoints", params=params, ret=Type.NSRESULT) + ) + openfunc.addcode( + """ + return mozilla::ipc::CreateEndpoints( + mozilla::ipc::PrivateIPDLInterface(), + $,{args}); + """, + args=[ExprVar(d.name) for d in params], + ) + return openfunc + + funcs = [mkOverload(True)] + if not p.hasOtherPid(): + funcs.append(mkOverload(False)) + return funcs + + +# -------------------------------------------------- + +cppPriorityList = list( + map(lambda src: src.upper() + "_PRIORITY", ipdl.ast.priorityList) +) + + +def _generateMessageConstructor(md, segmentSize, protocol, forReply=False): + if forReply: + clsname = md.replyCtorFunc() + msgid = md.replyId() + replyEnum = "REPLY" + prioEnum = cppPriorityList[md.decl.type.replyPrio] + else: + clsname = md.msgCtorFunc() + msgid = md.msgId() + replyEnum = "NOT_REPLY" + prioEnum = cppPriorityList[md.decl.type.prio] + + nested = md.decl.type.nested + compress = md.decl.type.compress + lazySend = md.decl.type.lazySend + + routingId = ExprVar("routingId") + + func = FunctionDefn( + FunctionDecl( + clsname, + params=[Decl(Type("int32_t"), routingId.name)], + ret=Type("mozilla::UniquePtr<IPC::Message>"), + ) + ) + + if not compress: + compression = "COMPRESSION_NONE" + elif compress.value == "all": + compression = "COMPRESSION_ALL" + else: + assert compress.value is None + compression = "COMPRESSION_ENABLED" + + if lazySend: + lazySendEnum = "LAZY_SEND" + else: + lazySendEnum = "EAGER_SEND" + + if nested == ipdl.ast.NOT_NESTED: + nestedEnum = "NOT_NESTED" + elif nested == ipdl.ast.INSIDE_SYNC_NESTED: + nestedEnum = "NESTED_INSIDE_SYNC" + else: + assert nested == ipdl.ast.INSIDE_CPOW_NESTED + nestedEnum = "NESTED_INSIDE_CPOW" + + if md.decl.type.isSync(): + syncEnum = "SYNC" + else: + syncEnum = "ASYNC" + + # FIXME(bug ???) - remove support for interrupt messages from the IPDL compiler. + if md.decl.type.isInterrupt(): + func.addcode( + """ + static_assert( + false, + "runtime support for intr messages has been removed from IPDL"); + """ + ) + + if md.decl.type.isCtor(): + ctorEnum = "CONSTRUCTOR" + else: + ctorEnum = "NOT_CONSTRUCTOR" + + def messageEnum(valname): + return ExprVar("IPC::Message::" + valname) + + flags = ExprCall( + ExprVar("IPC::Message::HeaderFlags"), + args=[ + messageEnum(nestedEnum), + messageEnum(prioEnum), + messageEnum(compression), + messageEnum(lazySendEnum), + messageEnum(ctorEnum), + messageEnum(syncEnum), + messageEnum(replyEnum), + ], + ) + + segmentSize = int(segmentSize) + if not segmentSize: + segmentSize = 0 + func.addstmt( + StmtReturn( + ExprCall( + ExprVar("IPC::Message::IPDLMessage"), + args=[ + routingId, + ExprVar(msgid), + ExprLiteral.Int(int(segmentSize)), + flags, + ], + ) + ) + ) + + return func + + +# -------------------------------------------------- + + +class _ParamTraits: + var = ExprVar("aVar") + writervar = ExprVar("aWriter") + readervar = ExprVar("aReader") + + @classmethod + def ifsideis(cls, rdrwtr, side, then, els=None): + cxxside = ExprVar("mozilla::ipc::ChildSide") + if side == "parent": + cxxside = ExprVar("mozilla::ipc::ParentSide") + + ifstmt = StmtIf( + ExprBinary( + cxxside, + "==", + ExprCode("${rdrwtr}->GetActor()->GetSide()", rdrwtr=rdrwtr), + ) + ) + ifstmt.addifstmt(then) + if els is not None: + ifstmt.addelsestmt(els) + return ifstmt + + @classmethod + def fatalError(cls, rdrwtr, reason): + return StmtCode( + "${rdrwtr}->FatalError(${reason});", + rdrwtr=rdrwtr, + reason=ExprLiteral.String(reason), + ) + + @classmethod + def writeSentinel(cls, writervar, sentinelKey): + return [ + Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True), + StmtExpr( + ExprCall( + ExprSelect(writervar, "->", "WriteSentinel"), + args=[ExprLiteral.Int(hashfunc(sentinelKey))], + ) + ), + ] + + @classmethod + def readSentinel(cls, readervar, sentinelKey, sentinelFail): + # Read the sentinel + read = ExprCall( + ExprSelect(readervar, "->", "ReadSentinel"), + args=[ExprLiteral.Int(hashfunc(sentinelKey))], + ) + ifsentinel = StmtIf(ExprNot(read)) + ifsentinel.addifstmts(sentinelFail) + + return [ + Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True), + ifsentinel, + ] + + @classmethod + def write(cls, var, writervar, ipdltype=None): + if ipdltype and _cxxTypeNeedsMoveForSend(ipdltype): + var = ExprMove(var) + return ExprCall(ExprVar("IPC::WriteParam"), args=[writervar, var]) + + @classmethod + def checkedWrite(cls, ipdltype, var, writervar, sentinelKey): + assert sentinelKey + block = Block() + + block.addstmts( + [ + StmtExpr(cls.write(var, writervar, ipdltype)), + ] + ) + block.addstmts(cls.writeSentinel(writervar, sentinelKey)) + return block + + @classmethod + def bulkSentinelKey(cls, fields): + return " | ".join(f.basename for f in fields) + + @classmethod + def checkedBulkWrite(cls, var, size, fields): + block = Block() + first = fields[0] + + block.addstmts( + [ + StmtExpr( + ExprCall( + ExprSelect(cls.writervar, "->", "WriteBytes"), + args=[ + ExprAddrOf( + ExprCall(first.getMethod(thisexpr=var, sel=".")) + ), + ExprLiteral.Int(size * len(fields)), + ], + ) + ) + ] + ) + block.addstmts(cls.writeSentinel(cls.writervar, cls.bulkSentinelKey(fields))) + + return block + + @classmethod + def checkedBulkRead(cls, var, size, fields): + block = Block() + first = fields[0] + + readbytes = ExprCall( + ExprSelect(cls.readervar, "->", "ReadBytesInto"), + args=[ + ExprAddrOf(ExprCall(first.getMethod(thisexpr=var, sel="->"))), + ExprLiteral.Int(size * len(fields)), + ], + ) + ifbad = StmtIf(ExprNot(readbytes)) + errmsg = "Error bulk reading fields from %s" % first.ipdltype.name() + ifbad.addifstmts( + [cls.fatalError(cls.readervar, errmsg), StmtReturn(readResultError())] + ) + block.addstmt(ifbad) + block.addstmts( + cls.readSentinel( + cls.readervar, + cls.bulkSentinelKey(fields), + errfnSentinel(readResultError())(errmsg), + ) + ) + + return block + + @classmethod + def checkedRead( + cls, + ipdltype, + cxxtype, + var, + readervar, + errfn, + paramtype, + sentinelKey, + errfnSentinel, + ): + assert isinstance(var, ExprVar) + + if not isinstance(paramtype, list): + paramtype = ["Error deserializing " + paramtype] + + block = Block() + + # Read the data + block.addcode( + """ + auto ${maybevar} = IPC::ReadParam<${ty}>(${reader}); + if (!${maybevar}) { + $*{errfn} + } + auto& ${var} = *${maybevar}; + """, + maybevar=ExprVar("maybe__" + var.name), + ty=cxxtype, + reader=readervar, + errfn=errfn(*paramtype), + var=var, + ) + + block.addstmts( + cls.readSentinel(readervar, sentinelKey, errfnSentinel(*paramtype)) + ) + + return block + + # Helper wrapper for checkedRead for use within _ParamTraits + @classmethod + def _checkedRead(cls, ipdltype, cxxtype, var, sentinelKey, what): + def errfn(msg): + return [cls.fatalError(cls.readervar, msg), StmtReturn(readResultError())] + + return cls.checkedRead( + ipdltype, + cxxtype, + var, + cls.readervar, + errfn=errfn, + paramtype=what, + sentinelKey=sentinelKey, + errfnSentinel=errfnSentinel(readResultError()), + ) + + @classmethod + def generateDecl(cls, fortype, write, read, needsmove=False): + # ParamTraits impls are selected ignoring constness, and references. + pt = Class( + "ParamTraits", + specializes=Type( + fortype.name, T=fortype.T, inner=fortype.inner, ptr=fortype.ptr + ), + struct=True, + ) + + # typedef T paramType; + pt.addstmt(Typedef(fortype, "paramType")) + + # static void Write(Message*, const T&); + if needsmove: + intype = Type("paramType", rvalref=True) + else: + intype = Type("paramType", ref=True, const=True) + writemthd = MethodDefn( + MethodDecl( + "Write", + params=[ + Decl(Type("IPC::MessageWriter", ptr=True), cls.writervar.name), + Decl(intype, cls.var.name), + ], + methodspec=MethodSpec.STATIC, + ) + ) + writemthd.addstmts(write) + pt.addstmt(writemthd) + + # static ReadResult<T> Read(MessageReader*); + readmthd = MethodDefn( + MethodDecl( + "Read", + params=[ + Decl(Type("IPC::MessageReader", ptr=True), cls.readervar.name), + ], + ret=Type("IPC::ReadResult<paramType>"), + methodspec=MethodSpec.STATIC, + ) + ) + readmthd.addstmts(read) + pt.addstmt(readmthd) + + # Split the class into declaration and definition + clsdecl, methoddefns = _splitClassDeclDefn(pt) + + namespaces = [Namespace("IPC")] + clsns = _putInNamespaces(clsdecl, namespaces) + defns = _putInNamespaces(methoddefns, namespaces) + return clsns, defns + + @classmethod + def actorPickling(cls, actortype, side): + """Generates pickling for IPDL actors. This is a |nullable| deserializer. + Write and read callers will perform nullability validation.""" + + cxxtype = _cxxBareType(actortype, side, fq=True) + + write = StmtCode( + """ + MOZ_RELEASE_ASSERT( + ${writervar}->GetActor(), + "Cannot serialize managed actors without an actor"); + + int32_t id; + if (!${var}) { + id = 0; // kNullActorId + } else { + id = ${var}->Id(); + if (id == 1) { // kFreedActorId + ${var}->FatalError("Actor has been |delete|d"); + } + MOZ_RELEASE_ASSERT( + ${writervar}->GetActor()->GetIPCChannel() == ${var}->GetIPCChannel(), + "Actor must be from the same channel as the" + " actor it's being sent over"); + MOZ_RELEASE_ASSERT( + ${var}->CanSend(), + "Actor must still be open when sending"); + } + + ${write}; + """, + var=cls.var, + writervar=cls.writervar, + write=cls.write(ExprVar("id"), cls.writervar), + ) + + # bool Read(..) impl + read = StmtCode( + """ + MOZ_RELEASE_ASSERT( + ${readervar}->GetActor(), + "Cannot deserialize managed actors without an actor"); + mozilla::Maybe<mozilla::ipc::IProtocol*> actor = ${readervar}->GetActor() + ->ReadActor(${readervar}, true, ${actortype}, ${protocolid}); + if (actor.isSome()) { + return static_cast<${cxxtype}>(actor.ref()); + } + return {}; + """, + readervar=cls.readervar, + actortype=ExprLiteral.String(actortype.name()), + protocolid=_protocolId(actortype), + cxxtype=cxxtype, + ) + + return cls.generateDecl(cxxtype, [write], [read]) + + @classmethod + def structPickling(cls, structtype): + sd = structtype._ast + # NOTE: Not using _cxxBareType here as we don't have a side + cxxtype = Type(structtype.fullname()) + + write = [] + read = [] + + # First serialize/deserialize all non-pod data in IPDL order. These need + # to be read/written first because they'll be used to invoke the IPDL + # struct's constructor. + ctorargs = [] + for f in sd.fields_ipdl_order(): + if pod_size(f.ipdltype) == pod_size_sentinel: + write.append( + cls.checkedWrite( + f.ipdltype, + ExprCall(f.getMethod(thisexpr=cls.var, sel=".")), + cls.writervar, + sentinelKey=f.basename, + ) + ) + read.append( + cls._checkedRead( + f.ipdltype, + f.bareType(fq=True), + f.argVar(), + f.basename, + "'" + + f.getMethod().name + + "' " + + "(" + + f.ipdltype.name() + + ") member of " + + "'" + + structtype.name() + + "'", + ) + ) + if _cxxTypeCanMove(f.ipdltype): + ctorargs.append(ExprMove(f.argVar())) + else: + ctorargs.append(f.argVar()) + else: + # We're going to bulk-read in this value later, so we'll just + # zero-initialize it for now. + ctorargs.append(ExprCode("${type}{0}", type=f.bareType(fq=True))) + + resultvar = ExprVar("result__") + read.append( + StmtDecl( + Decl(_cxxReadResultType(Type("paramType")), resultvar.name), + initargs=[ExprVar("std::in_place")] + ctorargs, + ) + ) + + # After non-pod data, bulk read/write pod data in member order. This has + # to be done after the result has been constructed, so that we have + # somewhere to read into. + for (size, fields) in itertools.groupby( + sd.fields_member_order(), lambda f: pod_size(f.ipdltype) + ): + if size != pod_size_sentinel: + fields = list(fields) + write.append(cls.checkedBulkWrite(cls.var, size, fields)) + read.append(cls.checkedBulkRead(resultvar, size, fields)) + + read.append(StmtReturn(resultvar)) + + return cls.generateDecl( + cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(structtype) + ) + + @classmethod + def unionPickling(cls, uniontype): + # NOTE: Not using _cxxBareType here as we don't have a side + cxxtype = Type(uniontype.fullname()) + ud = uniontype._ast + + # Use typedef to set up an alias so it's easier to reference the struct type. + alias = "union__" + typevar = ExprVar("type") + + prelude = [ + Typedef(cxxtype, alias), + ] + + writeswitch = StmtSwitch(typevar) + write = prelude + [ + StmtDecl(Decl(Type.INT, typevar.name), init=ud.callType(cls.var)), + cls.checkedWrite( + None, typevar, cls.writervar, sentinelKey=uniontype.name() + ), + Whitespace.NL, + writeswitch, + ] + + readswitch = StmtSwitch(typevar) + read = prelude + [ + cls._checkedRead( + None, + Type.INT, + typevar, + uniontype.name(), + "type of union " + uniontype.name(), + ), + Whitespace.NL, + readswitch, + ] + + for c in ud.components: + caselabel = CaseLabel(alias + "::" + c.enum()) + origenum = c.enum() + + writecase = StmtBlock() + wstmt = cls.checkedWrite( + c.ipdltype, + ExprCall(ExprSelect(cls.var, ".", c.getTypeName())), + cls.writervar, + sentinelKey=c.enum(), + ) + writecase.addstmts([wstmt, StmtReturn()]) + writeswitch.addcase(caselabel, writecase) + + readcase = StmtBlock() + tmpvar = ExprVar("tmp") + readcase.addstmts( + [ + cls._checkedRead( + c.ipdltype, + c.bareType(fq=True), + tmpvar, + origenum, + "variant " + origenum + " of union " + uniontype.name(), + ), + StmtReturn(ExprMove(tmpvar)), + ] + ) + readswitch.addcase(caselabel, readcase) + + # Add the error default case + writeswitch.addcase( + DefaultLabel(), + StmtBlock( + [ + cls.fatalError( + cls.writervar, "unknown variant of union " + uniontype.name() + ), + StmtReturn(), + ] + ), + ) + readswitch.addcase( + DefaultLabel(), + StmtBlock( + [ + cls.fatalError( + cls.readervar, "unknown variant of union " + uniontype.name() + ), + StmtReturn(readResultError()), + ] + ), + ) + + return cls.generateDecl( + cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(uniontype) + ) + + +# -------------------------------------------------- + + +class _ComputeTypeDeps(TypeVisitor): + """Pass that gathers the C++ types that a particular IPDL type + (recursively) depends on. There are three kinds of dependencies: (i) + types that need forward declaration; (ii) types that need a |using| + stmt; (iii) IPDL structs or unions which must be fully declared + before this struct. Some types generate multiple kinds.""" + + def __init__(self, fortype, typesToIncludes=None): + ipdl.type.TypeVisitor.__init__(self) + self.usingTypedefs = [] + self.forwardDeclStmts = [] + self.fullDeclTypes = [] + self.includeHeaders = set() + self.fortype = fortype + self.typesToIncludes = typesToIncludes + + def maybeTypedef(self, fqname, name, templateargs=[]): + assert fqname.startswith("::") + if fqname != name: + self.usingTypedefs.append(Typedef(Type(fqname), name, templateargs)) + if self.typesToIncludes is not None and fqname in self.typesToIncludes: + self.includeHeaders.add(self.typesToIncludes[fqname]) + + def visitImportedCxxType(self, t): + if t in self.visited: + return + self.visited.add(t) + self.maybeTypedef(t.fullname(), t.name()) + + def visitActorType(self, t): + if t in self.visited: + return + self.visited.add(t) + + fqname, name = t.fullname(), t.name() + + self.includeHeaders.add("mozilla/ipc/SideVariant.h") + self.maybeTypedef(_actorName(fqname, "Parent"), _actorName(name, "Parent")) + self.maybeTypedef(_actorName(fqname, "Child"), _actorName(name, "Child")) + + self.forwardDeclStmts.extend( + [ + _makeForwardDeclForActor(t.protocol, "parent"), + Whitespace.NL, + _makeForwardDeclForActor(t.protocol, "child"), + Whitespace.NL, + ] + ) + + def visitStructOrUnionType(self, su, defaultVisit): + if su in self.visited or su == self.fortype: + return + self.visited.add(su) + self.maybeTypedef(su.fullname(), su.name()) + + # Mutually recursive fields in unions are behind indirection, so we only + # need a forward decl, and don't need a full type declaration. + if isinstance(self.fortype, UnionType) and self.fortype.mutuallyRecursiveWith( + su + ): + self.forwardDeclStmts.append(_makeForwardDecl(su)) + else: + self.fullDeclTypes.append(su) + + return defaultVisit(self, su) + + def visitStructType(self, t): + return self.visitStructOrUnionType(t, TypeVisitor.visitStructType) + + def visitUnionType(self, t): + return self.visitStructOrUnionType(t, TypeVisitor.visitUnionType) + + def visitArrayType(self, t): + return TypeVisitor.visitArrayType(self, t) + + def visitMaybeType(self, m): + return TypeVisitor.visitMaybeType(self, m) + + def visitShmemType(self, s): + if s in self.visited: + return + self.visited.add(s) + self.maybeTypedef("::mozilla::ipc::Shmem", "Shmem") + + def visitByteBufType(self, s): + if s in self.visited: + return + self.visited.add(s) + self.maybeTypedef("::mozilla::ipc::ByteBuf", "ByteBuf") + + def visitFDType(self, s): + if s in self.visited: + return + self.visited.add(s) + self.maybeTypedef("::mozilla::ipc::FileDescriptor", "FileDescriptor") + + def visitEndpointType(self, s): + if s in self.visited: + return + self.visited.add(s) + self.maybeTypedef("::mozilla::ipc::Endpoint", "Endpoint", ["FooSide"]) + self.visitActorType(s.actor) + + def visitManagedEndpointType(self, s): + if s in self.visited: + return + self.visited.add(s) + self.maybeTypedef( + "::mozilla::ipc::ManagedEndpoint", "ManagedEndpoint", ["FooSide"] + ) + self.visitActorType(s.actor) + + def visitUniquePtrType(self, s): + if s in self.visited: + return + self.visited.add(s) + + def visitVoidType(self, v): + assert 0 + + def visitMessageType(self, v): + assert 0 + + def visitProtocolType(self, v): + assert 0 + + +def _fieldStaticAssertions(sd): + staticasserts = [] + for (size, fields) in itertools.groupby( + sd.fields_member_order(), lambda f: pod_size(f.ipdltype) + ): + if size == pod_size_sentinel: + continue + + fields = list(fields) + if len(fields) == 1: + continue + + staticasserts.append( + StmtCode( + """ + static_assert( + (offsetof(${struct}, ${last}) - offsetof(${struct}, ${first})) == ${expected}, + "Bad assumptions about field layout!"); + """, + struct=sd.name, + first=fields[0].memberVar(), + last=fields[-1].memberVar(), + expected=ExprLiteral.Int(size * (len(fields) - 1)), + ) + ) + + return staticasserts + + +def _generateCxxStruct(sd): + """ """ + # compute all the typedefs and forward decls we need to make + gettypedeps = _ComputeTypeDeps(sd.decl.type) + for f in sd.fields: + f.ipdltype.accept(gettypedeps) + + usingTypedefs = gettypedeps.usingTypedefs + forwarddeclstmts = gettypedeps.forwardDeclStmts + fulldecltypes = gettypedeps.fullDeclTypes + + struct = Class(sd.name, final=True) + struct.addstmts([Label.PRIVATE] + usingTypedefs + [Whitespace.NL, Label.PUBLIC]) + + constreftype = Type(sd.name, const=True, ref=True) + + # Struct() + # We want the default constructor to be declared if it is available, but + # some of our members may not be default-constructible. Silence the + # warning which clang generates in that case. + # + # Members which need value initialization will be handled by wrapping + # the member in a template type when declaring them. + struct.addcode( + """ + #ifdef __clang__ + # pragma clang diagnostic push + # if __has_warning("-Wdefaulted-function-deleted") + # pragma clang diagnostic ignored "-Wdefaulted-function-deleted" + # endif + #endif + ${name}() = default; + #ifdef __clang__ + # pragma clang diagnostic pop + #endif + + """, + name=sd.name, + ) + + # If this is an empty struct (no fields), then the default ctor + # and "create-with-fields" ctors are equivalent. + if len(sd.fields): + assert len(sd.fields) == len(sd.packed_field_order) + + # Struct(const field1& _f1, ...) + valctor = ConstructorDefn( + ConstructorDecl( + sd.name, + params=[ + Decl( + f.forceMoveType() + if _cxxTypeNeedsMoveForData(f.ipdltype) + else f.constRefType(), + f.argVar().name, + ) + for f in sd.fields_ipdl_order() + ], + force_inline=True, + ) + ) + valctor.memberinits = [] + for f in sd.fields_member_order(): + arg = f.argVar() + if _cxxTypeNeedsMoveForData(f.ipdltype): + arg = ExprMove(arg) + valctor.memberinits.append(ExprMemberInit(f.memberVar(), args=[arg])) + + struct.addstmts([valctor, Whitespace.NL]) + + # If a constructor which moves each argument would be different from the + # `const T&` version, also generate that constructor. + if not all( + _cxxTypeNeedsMoveForData(f.ipdltype) or not _cxxTypeCanMove(f.ipdltype) + for f in sd.fields_ipdl_order() + ): + # Struct(field1&& _f1, ...) + valmovector = ConstructorDefn( + ConstructorDecl( + sd.name, + params=[ + Decl( + f.forceMoveType() + if _cxxTypeCanMove(f.ipdltype) + else f.constRefType(), + f.argVar().name, + ) + for f in sd.fields_ipdl_order() + ], + force_inline=True, + ) + ) + + valmovector.memberinits = [] + for f in sd.fields_member_order(): + arg = f.argVar() + if _cxxTypeCanMove(f.ipdltype): + arg = ExprMove(arg) + valmovector.memberinits.append( + ExprMemberInit(f.memberVar(), args=[arg]) + ) + + struct.addstmts([valmovector, Whitespace.NL]) + + # The default copy, move, and assignment constructors, and the default + # destructor, will do the right thing. + + if "Comparable" in sd.attributes: + # bool operator==(const Struct& _o) + ovar = ExprVar("_o") + opeqeq = MethodDefn( + MethodDecl( + "operator==", + params=[Decl(constreftype, ovar.name)], + ret=Type.BOOL, + const=True, + ) + ) + for f in sd.fields_ipdl_order(): + ifneq = StmtIf( + ExprNot( + ExprBinary( + ExprCall(f.getMethod()), "==", ExprCall(f.getMethod(ovar)) + ) + ) + ) + ifneq.addifstmt(StmtReturn.FALSE) + opeqeq.addstmt(ifneq) + opeqeq.addstmt(StmtReturn.TRUE) + struct.addstmts([opeqeq, Whitespace.NL]) + + # bool operator!=(const Struct& _o) + opneq = MethodDefn( + MethodDecl( + "operator!=", + params=[Decl(constreftype, ovar.name)], + ret=Type.BOOL, + const=True, + ) + ) + opneq.addstmt(StmtReturn(ExprNot(ExprCall(ExprVar("operator=="), args=[ovar])))) + struct.addstmts([opneq, Whitespace.NL]) + + # field1& f1() + # const field1& f1() const + for f in sd.fields_ipdl_order(): + get = MethodDefn( + MethodDecl( + f.getMethod().name, params=[], ret=f.refType(), force_inline=True + ) + ) + get.addstmt(StmtReturn(f.refExpr())) + + getconstdecl = deepcopy(get.decl) + getconstdecl.ret = f.constRefType() + getconstdecl.const = True + getconst = MethodDefn(getconstdecl) + getconst.addstmt(StmtReturn(f.constRefExpr())) + + struct.addstmts([get, getconst, Whitespace.NL]) + + # private: + struct.addstmt(Label.PRIVATE) + + # Static assertions to ensure our assumptions about field layout match + # what the compiler is actually producing. We define this as a member + # function, rather than throwing the assertions in the constructor or + # similar, because we don't want to evaluate the static assertions every + # time the header file containing the structure is included. + staticasserts = _fieldStaticAssertions(sd) + if staticasserts: + method = MethodDefn( + MethodDecl("StaticAssertions", params=[], ret=Type.VOID, const=True) + ) + method.addstmts(staticasserts) + struct.addstmts([method]) + + # members + struct.addstmts( + [ + StmtDecl(Decl(_effectiveMemberType(f), f.memberVar().name)) + for f in sd.fields_member_order() + ] + ) + + return forwarddeclstmts, fulldecltypes, struct + + +def _effectiveMemberType(f): + effective_type = f.bareType() + # Structs must be copyable for backwards compatibility reasons, so we use + # CopyableTArray<T> as their member type for arrays. This is not exposed + # in the method signatures, these keep using nsTArray<T>, which is a base + # class of CopyableTArray<T>. + if effective_type.name == "nsTArray": + effective_type.name = "CopyableTArray" + return Type("::mozilla::ipc::IPDLStructMember", T=[effective_type]) + + +# -------------------------------------------------- + + +def _generateCxxUnion(ud): + # This Union class basically consists of a type (enum) and a + # union for storage. The union can contain POD and non-POD + # types. Each type needs a copy/move ctor, assignment operators, + # and dtor. + # + # Rather than templating this class and only providing + # specializations for the types we support, which is slightly + # "unsafe" in that C++ code can add additional specializations + # without the IPDL compiler's knowledge, we instead explicitly + # implement non-templated methods for each supported type. + # + # The one complication that arises is that C++, for arcane + # reasons, does not allow the placement destructor of a + # builtin type, like int, to be directly invoked. So we need + # to hack around this by internally typedef'ing all + # constituent types. Sigh. + # + # So, for each type, this "Union" class needs: + # (private) + # - entry in the type enum + # - entry in the storage union + # - [type]ptr() method to get a type* from the underlying union + # - same as above to get a const type* + # - typedef to hack around placement delete limitations + # (public) + # - placement delete case for dtor + # - copy ctor + # - move ctor + # - case in generic copy ctor + # - copy operator= impl + # - move operator= impl + # - case in generic operator= + # - operator [type&] + # - operator [const type&] const + # - [type&] get_[type]() + # - [const type&] get_[type]() const + # + cls = Class(ud.name, final=True) + # const Union&, i.e., Union type with inparam semantics + inClsType = Type(ud.name, const=True, ref=True) + refClsType = Type(ud.name, ref=True) + rvalueRefClsType = Type(ud.name, rvalref=True) + typetype = Type("Type") + valuetype = Type("Value") + mtypevar = ExprVar("mType") + mvaluevar = ExprVar("mValue") + maybedtorvar = ExprVar("MaybeDestroy") + assertsanityvar = ExprVar("AssertSanity") + tnonevar = ExprVar("T__None") + tlastvar = ExprVar("T__Last") + + def callAssertSanity(uvar=None, expectTypeVar=None): + func = assertsanityvar + args = [] + if uvar is not None: + func = ExprSelect(uvar, ".", assertsanityvar.name) + if expectTypeVar is not None: + args.append(expectTypeVar) + return ExprCall(func, args=args) + + def maybeDestroy(): + return StmtExpr(ExprCall(maybedtorvar)) + + # compute all the typedefs and forward decls we need to make + gettypedeps = _ComputeTypeDeps(ud.decl.type) + for c in ud.components: + c.ipdltype.accept(gettypedeps) + + usingTypedefs = gettypedeps.usingTypedefs + forwarddeclstmts = gettypedeps.forwardDeclStmts + fulldecltypes = gettypedeps.fullDeclTypes + + # the |Type| enum, used to switch on the discunion's real type + cls.addstmt(Label.PUBLIC) + typeenum = TypeEnum(typetype.name) + typeenum.addId(tnonevar.name, 0) + firstid = ud.components[0].enum() + typeenum.addId(firstid, 1) + for c in ud.components[1:]: + typeenum.addId(c.enum()) + typeenum.addId(tlastvar.name, ud.components[-1].enum()) + cls.addstmts([StmtDecl(Decl(typeenum, "")), Whitespace.NL]) + + cls.addstmt(Label.PRIVATE) + cls.addstmts( + usingTypedefs + # hacky typedef's that allow placement dtors of builtins + + [Typedef(c.internalType(), c.typedef()) for c in ud.components] + ) + cls.addstmt(Whitespace.NL) + + # the C++ union the discunion use for storage + valueunion = TypeUnion(valuetype.name) + for c in ud.components: + valueunion.addComponent(c.unionType(), c.name) + cls.addstmts([StmtDecl(Decl(valueunion, "")), Whitespace.NL]) + + # for each constituent type T, add private accessors that + # return a pointer to the Value union storage casted to |T*| + # and |const T*| + for c in ud.components: + getptr = MethodDefn( + MethodDecl( + c.getPtrName(), params=[], ret=c.ptrToInternalType(), force_inline=True + ) + ) + getptr.addstmt(StmtReturn(c.ptrToSelfExpr())) + + getptrconst = MethodDefn( + MethodDecl( + c.getConstPtrName(), + params=[], + ret=c.constPtrToType(), + const=True, + force_inline=True, + ) + ) + getptrconst.addstmt(StmtReturn(c.constptrToSelfExpr())) + + cls.addstmts([getptr, getptrconst]) + cls.addstmt(Whitespace.NL) + + # add a helper method that invokes the placement dtor on the + # current underlying value, only if |aNewType| is different + # than the current type, and returns true if the underlying + # value needs to be re-constructed + maybedtor = MethodDefn(MethodDecl(maybedtorvar.name, ret=Type.VOID)) + # wasn't /actually/ dtor'd, but it needs to be re-constructed + ifnone = StmtIf(ExprBinary(mtypevar, "==", tnonevar)) + ifnone.addifstmt(StmtReturn()) + # need to destroy. switch on underlying type + dtorswitch = StmtSwitch(mtypevar) + for c in ud.components: + dtorswitch.addcase( + CaseLabel(c.enum()), StmtBlock([StmtExpr(c.callDtor()), StmtBreak()]) + ) + dtorswitch.addcase( + DefaultLabel(), StmtBlock([_logicError("not reached"), StmtBreak()]) + ) + maybedtor.addstmts([ifnone, dtorswitch]) + cls.addstmts([maybedtor, Whitespace.NL]) + + # add helper methods that ensure the discunion has a + # valid type + sanity = MethodDefn( + MethodDecl(assertsanityvar.name, ret=Type.VOID, const=True, force_inline=True) + ) + sanity.addstmts( + [ + _abortIfFalse(ExprBinary(tnonevar, "<=", mtypevar), "invalid type tag"), + _abortIfFalse(ExprBinary(mtypevar, "<=", tlastvar), "invalid type tag"), + ] + ) + cls.addstmt(sanity) + + atypevar = ExprVar("aType") + sanity2 = MethodDefn( + MethodDecl( + assertsanityvar.name, + params=[Decl(typetype, atypevar.name)], + ret=Type.VOID, + const=True, + force_inline=True, + ) + ) + sanity2.addstmts( + [ + StmtExpr(ExprCall(assertsanityvar)), + _abortIfFalse(ExprBinary(mtypevar, "==", atypevar), "unexpected type tag"), + ] + ) + cls.addstmts([sanity2, Whitespace.NL]) + + # ---- begin public methods ----- + + # Union() default ctor + cls.addstmts( + [ + Label.PUBLIC, + ConstructorDefn( + ConstructorDecl(ud.name, force_inline=True), + memberinits=[ExprMemberInit(mtypevar, [tnonevar])], + ), + Whitespace.NL, + ] + ) + + # Union(const T&) copy & Union(T&&) move ctors + othervar = ExprVar("aOther") + for c in ud.components: + if not _cxxTypeNeedsMoveForData(c.ipdltype): + copyctor = ConstructorDefn( + ConstructorDecl(ud.name, params=[Decl(c.constRefType(), othervar.name)]) + ) + copyctor.addstmts( + [ + StmtExpr(c.callCtor(othervar)), + StmtExpr(ExprAssn(mtypevar, c.enumvar())), + ] + ) + cls.addstmts([copyctor, Whitespace.NL]) + + if not _cxxTypeCanMove(c.ipdltype): + continue + movector = ConstructorDefn( + ConstructorDecl(ud.name, params=[Decl(c.forceMoveType(), othervar.name)]) + ) + movector.addstmts( + [ + StmtExpr(c.callCtor(ExprMove(othervar))), + StmtExpr(ExprAssn(mtypevar, c.enumvar())), + ] + ) + cls.addstmts([movector, Whitespace.NL]) + + unionNeedsMove = any(_cxxTypeNeedsMoveForData(c.ipdltype) for c in ud.components) + + # Union(const Union&) copy ctor + if not unionNeedsMove: + copyctor = ConstructorDefn( + ConstructorDecl(ud.name, params=[Decl(inClsType, othervar.name)]) + ) + othertype = ud.callType(othervar) + copyswitch = StmtSwitch(othertype) + for c in ud.components: + copyswitch.addcase( + CaseLabel(c.enum()), + StmtBlock( + [ + StmtExpr( + c.callCtor( + ExprCall( + ExprSelect(othervar, ".", c.getConstTypeName()) + ) + ) + ), + StmtBreak(), + ] + ), + ) + copyswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()])) + copyswitch.addcase( + DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()]) + ) + copyctor.addstmts( + [ + StmtExpr(callAssertSanity(uvar=othervar)), + copyswitch, + StmtExpr(ExprAssn(mtypevar, othertype)), + ] + ) + cls.addstmts([copyctor, Whitespace.NL]) + + # Union(Union&&) move ctor + movector = ConstructorDefn( + ConstructorDecl(ud.name, params=[Decl(rvalueRefClsType, othervar.name)]) + ) + othertypevar = ExprVar("t") + moveswitch = StmtSwitch(othertypevar) + for c in ud.components: + case = StmtBlock() + if c.recursive: + # This is sound as we set othervar.mTypeVar to T__None after the + # switch. The pointer in the union will be left dangling. + case.addstmts( + [ + # ptr_C() = other.ptr_C() + StmtExpr( + ExprAssn( + c.callGetPtr(), + ExprCall( + ExprSelect(othervar, ".", ExprVar(c.getPtrName())) + ), + ) + ) + ] + ) + else: + case.addstmts( + [ + # new ... (Move(other.get_C())) + StmtExpr( + c.callCtor( + ExprMove( + ExprCall(ExprSelect(othervar, ".", c.getTypeName())) + ) + ) + ), + # other.MaybeDestroy(T__None) + StmtExpr(ExprCall(ExprSelect(othervar, ".", maybedtorvar))), + ] + ) + case.addstmts([StmtBreak()]) + moveswitch.addcase(CaseLabel(c.enum()), case) + moveswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()])) + moveswitch.addcase( + DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()]) + ) + movector.addstmts( + [ + StmtExpr(callAssertSanity(uvar=othervar)), + StmtDecl(Decl(typetype, othertypevar.name), init=ud.callType(othervar)), + moveswitch, + StmtExpr(ExprAssn(ExprSelect(othervar, ".", mtypevar), tnonevar)), + StmtExpr(ExprAssn(mtypevar, othertypevar)), + ] + ) + cls.addstmts([movector, Whitespace.NL]) + + # ~Union() + dtor = DestructorDefn(DestructorDecl(ud.name)) + dtor.addstmt(maybeDestroy()) + cls.addstmts([dtor, Whitespace.NL]) + + # type() + typemeth = MethodDefn( + MethodDecl("type", ret=typetype, const=True, force_inline=True) + ) + typemeth.addstmt(StmtReturn(mtypevar)) + cls.addstmts([typemeth, Whitespace.NL]) + + # Union& operator= methods + rhsvar = ExprVar("aRhs") + for c in ud.components: + + def opeqBody(rhs): + return [ + # might need to placement-delete old value first + maybeDestroy(), + StmtExpr(c.callCtor(rhs)), + StmtExpr(ExprAssn(mtypevar, c.enumvar())), + StmtReturn(ExprDeref(ExprVar.THIS)), + ] + + if not _cxxTypeNeedsMoveForData(c.ipdltype): + # Union& operator=(const T&) + opeq = MethodDefn( + MethodDecl( + "operator=", + params=[Decl(c.constRefType(), rhsvar.name)], + ret=refClsType, + ) + ) + opeq.addstmts(opeqBody(rhsvar)) + cls.addstmts([opeq, Whitespace.NL]) + + # Union& operator=(T&&) + if not _cxxTypeCanMove(c.ipdltype): + continue + + opeq = MethodDefn( + MethodDecl( + "operator=", + params=[Decl(c.forceMoveType(), rhsvar.name)], + ret=refClsType, + ) + ) + opeq.addstmts(opeqBody(ExprMove(rhsvar))) + cls.addstmts([opeq, Whitespace.NL]) + + # Union& operator=(const Union&) + if not unionNeedsMove: + opeq = MethodDefn( + MethodDecl( + "operator=", params=[Decl(inClsType, rhsvar.name)], ret=refClsType + ) + ) + rhstypevar = ExprVar("t") + opeqswitch = StmtSwitch(rhstypevar) + for c in ud.components: + case = StmtBlock() + case.addstmts( + [ + maybeDestroy(), + StmtExpr( + c.callCtor( + ExprCall(ExprSelect(rhsvar, ".", c.getConstTypeName())) + ) + ), + StmtBreak(), + ] + ) + opeqswitch.addcase(CaseLabel(c.enum()), case) + opeqswitch.addcase( + CaseLabel(tnonevar.name), + StmtBlock([maybeDestroy(), StmtBreak()]), + ) + opeqswitch.addcase( + DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()]) + ) + opeq.addstmts( + [ + StmtExpr(callAssertSanity(uvar=rhsvar)), + StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)), + opeqswitch, + StmtExpr(ExprAssn(mtypevar, rhstypevar)), + StmtReturn(ExprDeref(ExprVar.THIS)), + ] + ) + cls.addstmts([opeq, Whitespace.NL]) + + # Union& operator=(Union&&) + opeq = MethodDefn( + MethodDecl( + "operator=", params=[Decl(rvalueRefClsType, rhsvar.name)], ret=refClsType + ) + ) + rhstypevar = ExprVar("t") + opeqswitch = StmtSwitch(rhstypevar) + for c in ud.components: + case = StmtBlock() + if c.recursive: + case.addstmts( + [ + maybeDestroy(), + StmtExpr( + ExprAssn( + c.callGetPtr(), + ExprCall(ExprSelect(rhsvar, ".", ExprVar(c.getPtrName()))), + ) + ), + ] + ) + else: + case.addstmts( + [ + maybeDestroy(), + StmtExpr( + c.callCtor( + ExprMove(ExprCall(ExprSelect(rhsvar, ".", c.getTypeName()))) + ) + ), + # other.MaybeDestroy() + StmtExpr(ExprCall(ExprSelect(rhsvar, ".", maybedtorvar))), + ] + ) + case.addstmts([StmtBreak()]) + opeqswitch.addcase(CaseLabel(c.enum()), case) + opeqswitch.addcase( + CaseLabel(tnonevar.name), + StmtBlock([maybeDestroy(), StmtBreak()]), + ) + opeqswitch.addcase( + DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()]) + ) + opeq.addstmts( + [ + StmtExpr(callAssertSanity(uvar=rhsvar)), + StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)), + opeqswitch, + StmtExpr(ExprAssn(ExprSelect(rhsvar, ".", mtypevar), tnonevar)), + StmtExpr(ExprAssn(mtypevar, rhstypevar)), + StmtReturn(ExprDeref(ExprVar.THIS)), + ] + ) + cls.addstmts([opeq, Whitespace.NL]) + + if "Comparable" in ud.attributes: + # bool operator==(const T&) + for c in ud.components: + opeqeq = MethodDefn( + MethodDecl( + "operator==", + params=[Decl(c.constRefType(), rhsvar.name)], + ret=Type.BOOL, + const=True, + ) + ) + opeqeq.addstmt( + StmtReturn(ExprBinary(ExprCall(ExprVar(c.getTypeName())), "==", rhsvar)) + ) + cls.addstmts([opeqeq, Whitespace.NL]) + + # bool operator==(const Union&) + opeqeq = MethodDefn( + MethodDecl( + "operator==", + params=[Decl(inClsType, rhsvar.name)], + ret=Type.BOOL, + const=True, + ) + ) + iftypesmismatch = StmtIf(ExprBinary(ud.callType(), "!=", ud.callType(rhsvar))) + iftypesmismatch.addifstmt(StmtReturn.FALSE) + opeqeq.addstmts([iftypesmismatch, Whitespace.NL]) + + opeqeqswitch = StmtSwitch(ud.callType()) + for c in ud.components: + case = StmtBlock() + case.addstmt( + StmtReturn( + ExprBinary( + ExprCall(ExprVar(c.getTypeName())), + "==", + ExprCall(ExprSelect(rhsvar, ".", c.getTypeName())), + ) + ) + ) + opeqeqswitch.addcase(CaseLabel(c.enum()), case) + opeqeqswitch.addcase( + DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn.FALSE]) + ) + opeqeq.addstmt(opeqeqswitch) + + cls.addstmts([opeqeq, Whitespace.NL]) + + # accessors for each type: operator T&, operator const T&, + # T& get(), const T& get() + for c in ud.components: + getValueVar = ExprVar(c.getTypeName()) + getConstValueVar = ExprVar(c.getConstTypeName()) + + getvalue = MethodDefn( + MethodDecl(getValueVar.name, ret=c.refType(), force_inline=True) + ) + getvalue.addstmts( + [ + StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())), + StmtReturn(ExprDeref(c.callGetPtr())), + ] + ) + + getconstvalue = MethodDefn( + MethodDecl( + getConstValueVar.name, + ret=c.constRefType(), + const=True, + force_inline=True, + ) + ) + getconstvalue.addstmts( + [ + StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())), + StmtReturn(c.getConstValue()), + ] + ) + + cls.addstmts([getvalue, getconstvalue]) + + optype = MethodDefn(MethodDecl("", typeop=c.refType(), force_inline=True)) + optype.addstmt(StmtReturn(ExprCall(getValueVar))) + opconsttype = MethodDefn( + MethodDecl("", const=True, typeop=c.constRefType(), force_inline=True) + ) + opconsttype.addstmt(StmtReturn(ExprCall(getConstValueVar))) + + cls.addstmts([optype, opconsttype, Whitespace.NL]) + # private vars + cls.addstmts( + [ + Label.PRIVATE, + StmtDecl(Decl(valuetype, mvaluevar.name)), + StmtDecl(Decl(typetype, mtypevar.name)), + ] + ) + + return forwarddeclstmts, fulldecltypes, cls + + +# ----------------------------------------------------------------------------- + + +class _FindFriends(ipdl.ast.Visitor): + def __init__(self): + self.mytype = None # ProtocolType + self.vtype = None # ProtocolType + self.friends = set() # set<ProtocolType> + + def findFriends(self, ptype): + self.mytype = ptype + for toplvl in ptype.toplevels(): + self.walkDownTheProtocolTree(toplvl) + return self.friends + + # TODO could make this into a _iterProtocolTreeHelper ... + def walkDownTheProtocolTree(self, ptype): + if ptype != self.mytype: + # don't want to |friend| ourself! + self.visit(ptype) + for mtype in ptype.manages: + if mtype is not ptype: + self.walkDownTheProtocolTree(mtype) + + def visit(self, ptype): + # |vtype| is the type currently being visited + savedptype = self.vtype + self.vtype = ptype + ptype._ast.accept(self) + self.vtype = savedptype + + def visitMessageDecl(self, md): + for it in self.iterActorParams(md): + if it.protocol == self.mytype: + self.friends.add(self.vtype) + + def iterActorParams(self, md): + for param in md.inParams: + for actor in ipdl.type.iteractortypes(param.type): + yield actor + for ret in md.outParams: + for actor in ipdl.type.iteractortypes(ret.type): + yield actor + + +class _GenerateProtocolActorCode(ipdl.ast.Visitor): + def __init__(self, myside): + self.side = myside # "parent" or "child" + self.prettyside = myside.title() + self.clsname = None + self.protocol = None + self.hdrfile = None + self.cppfile = None + self.ns = None + self.cls = None + self.protocolCxxIncludes = [] + self.actorForwardDecls = [] + self.usingDecls = [] + self.externalIncludes = set() + self.nonForwardDeclaredHeaders = set() + self.typedefSet = set( + [ + Typedef(Type("mozilla::ipc::ActorHandle"), "ActorHandle"), + Typedef(Type("base::ProcessId"), "ProcessId"), + Typedef(Type("mozilla::ipc::ProtocolId"), "ProtocolId"), + Typedef(Type("mozilla::ipc::Endpoint"), "Endpoint", ["FooSide"]), + Typedef( + Type("mozilla::ipc::ManagedEndpoint"), + "ManagedEndpoint", + ["FooSide"], + ), + Typedef(Type("mozilla::UniquePtr"), "UniquePtr", ["T"]), + Typedef( + Type("mozilla::ipc::ResponseRejectReason"), "ResponseRejectReason" + ), + ] + ) + + def lower(self, tu, clsname, cxxHeaderFile, cxxFile): + self.clsname = clsname + self.hdrfile = cxxHeaderFile + self.cppfile = cxxFile + tu.accept(self) + + def standardTypedefs(self): + return [ + Typedef(Type("mozilla::ipc::IProtocol"), "IProtocol"), + Typedef(Type("IPC::Message"), "Message"), + Typedef(Type("base::ProcessHandle"), "ProcessHandle"), + Typedef(Type("mozilla::ipc::MessageChannel"), "MessageChannel"), + Typedef(Type("mozilla::ipc::SharedMemory"), "SharedMemory"), + ] + + def visitTranslationUnit(self, tu): + self.protocol = tu.protocol + + hf = self.hdrfile + cf = self.cppfile + + # make the C++ header + hf.addthings( + [_DISCLAIMER] + + _includeGuardStart(hf) + + [ + Whitespace.NL, + CppDirective("include", '"' + _protocolHeaderName(tu.protocol) + '.h"'), + ] + ) + + for inc in tu.includes: + inc.accept(self) + for inc in tu.cxxIncludes: + inc.accept(self) + + for using in tu.builtinUsing: + using.accept(self) + for using in tu.using: + using.accept(self) + for su in tu.structsAndUnions: + su.accept(self) + + # this generates the actor's full impl in self.cls + tu.protocol.accept(self) + + clsdecl, clsdefn = _splitClassDeclDefn(self.cls) + + # XXX damn C++ ... return types in the method defn aren't in + # class scope + for stmt in clsdefn.stmts: + if isinstance(stmt, MethodDefn): + if stmt.decl.ret and stmt.decl.ret.name == "Result": + stmt.decl.ret.name = clsdecl.name + "::" + stmt.decl.ret.name + + def setToIncludes(s): + return [CppDirective("include", '"%s"' % i) for i in sorted(iter(s))] + + def makeNamespace(p, file): + if 0 == len(p.namespaces): + return file + ns = Namespace(p.namespaces[-1].name) + outerns = _putInNamespaces(ns, p.namespaces[:-1]) + file.addthing(outerns) + return ns + + if len(self.nonForwardDeclaredHeaders) != 0: + self.hdrfile.addthings( + [ + Whitespace("// Headers for things that cannot be forward declared"), + Whitespace.NL, + ] + + setToIncludes(self.nonForwardDeclaredHeaders) + + [Whitespace.NL] + ) + self.hdrfile.addthings(self.actorForwardDecls) + self.hdrfile.addthings(self.usingDecls) + + hdrns = makeNamespace(self.protocol, self.hdrfile) + hdrns.addstmts( + [Whitespace.NL, Whitespace.NL, clsdecl, Whitespace.NL, Whitespace.NL] + ) + + actortype = ActorType(tu.protocol.decl.type) + traitsdecl, traitsdefn = _ParamTraits.actorPickling(actortype, self.side) + + self.hdrfile.addthings([traitsdecl, Whitespace.NL] + _includeGuardEnd(hf)) + + # If the implementation type is not overridden, add an implicit import + # for the default implementation header file. Explicit implementation + # types will specify their headers manually with `include`. + if self.protocol.implAttribute(self.side) is None: + assert self.protocol.name.startswith("P") + self.externalIncludes.add( + "".join(n.name + "/" for n in self.protocol.namespaces) + + self.protocol.name[1:] + + self.side.capitalize() + + ".h" + ) + + # make the .cpp file + cf.addthings( + [ + _DISCLAIMER, + Whitespace.NL, + CppDirective( + "include", + '"' + _protocolHeaderName(self.protocol, self.side) + '.h"', + ), + ] + + setToIncludes(self.externalIncludes) + ) + + cf.addthings( + ( + [Whitespace.NL] + + [ + CppDirective("include", '"%s.h"' % (inc)) + for inc in self.protocolCxxIncludes + ] + + [Whitespace.NL] + + [ + CppDirective("include", '"%s"' % filename) + for filename in ipdl.builtin.CppIncludes + ] + + [Whitespace.NL] + ) + ) + + cppns = makeNamespace(self.protocol, cf) + cppns.addstmts( + [Whitespace.NL, Whitespace.NL, clsdefn, Whitespace.NL, Whitespace.NL] + ) + + cf.addthing(traitsdefn) + + def visitUsingStmt(self, using): + if using.decl.fullname is not None: + self.typedefSet.add( + Typedef(Type(using.decl.fullname), using.decl.shortname) + ) + + if using.header is None: + return + + if using.canBeForwardDeclared(): + spec = using.type + + self.usingDecls.extend( + [ + _makeForwardDeclForQClass( + spec.baseid, + spec.quals, + cls=using.isClass(), + struct=using.isStruct(), + ), + Whitespace.NL, + ] + ) + self.externalIncludes.add(using.header) + else: + self.nonForwardDeclaredHeaders.add(using.header) + + def visitCxxInclude(self, inc): + self.externalIncludes.add(inc.file) + + def visitInclude(self, inc): + if inc.tu.filetype == "header": + # Including a header will declare any globals defined by "using" + # statements into our scope. To serialize these, we also may need + # cxx include statements, so visit them as well. + for cxxinc in inc.tu.cxxIncludes: + cxxinc.accept(self) + for using in inc.tu.using: + using.accept(self) + for su in inc.tu.structsAndUnions: + su.accept(self) + else: + # Includes for protocols only include types explicitly exported by + # those protocols. + ip = inc.tu.protocol + if ip == self.protocol: + return + + self.actorForwardDecls.extend( + [ + _makeForwardDeclForActor(ip.decl.type, self.side), + _makeForwardDeclForActor(ip.decl.type, _otherSide(self.side)), + Whitespace.NL, + ] + ) + self.protocolCxxIncludes.append(_protocolHeaderName(ip, self.side)) + + if ip.decl.fullname is not None: + self.typedefSet.add( + Typedef( + Type(_actorName(ip.decl.fullname, self.side.title())), + _actorName(ip.decl.shortname, self.side.title()), + ) + ) + + self.typedefSet.add( + Typedef( + Type( + _actorName(ip.decl.fullname, _otherSide(self.side).title()) + ), + _actorName(ip.decl.shortname, _otherSide(self.side).title()), + ) + ) + + def visitStructDecl(self, sd): + if sd.decl.fullname is not None: + self.typedefSet.add(Typedef(Type(sd.fqClassName()), sd.name)) + + def visitUnionDecl(self, ud): + if ud.decl.fullname is not None: + self.typedefSet.add(Typedef(Type(ud.fqClassName()), ud.name)) + + def visitProtocol(self, p): + self.hdrfile.addcode( + """ + #ifdef DEBUG + #include "prenv.h" + #endif // DEBUG + + #include "mozilla/Tainting.h" + #include "mozilla/ipc/MessageChannel.h" + #include "mozilla/ipc/ProtocolUtils.h" + """ + ) + + self.protocol = p + ptype = p.decl.type + toplevel = p.decl.type.toplevel() + + hasAsyncReturns = False + for md in p.messageDecls: + if md.hasAsyncReturns(): + hasAsyncReturns = True + break + + inherits = [] + if ptype.isToplevel(): + inherits.append(Inherit(p.openedProtocolInterfaceType(), viz="public")) + else: + inherits.append(Inherit(p.managerInterfaceType(), viz="public")) + + if ptype.isToplevel() and self.side == "parent": + self.hdrfile.addthings( + [_makeForwardDeclForQClass("nsIFile", []), Whitespace.NL] + ) + + self.cls = Class(self.clsname, inherits=inherits, abstract=True) + + self.cls.addstmt(Label.PRIVATE) + friends = _FindFriends().findFriends(ptype) + if ptype.isManaged(): + friends.update(ptype.managers) + + # |friend| managed actors so that they can call our Dealloc*() + friends.update(ptype.manages) + + # don't friend ourself if we're a self-managed protocol + friends.discard(ptype) + + for friend in sorted(friends, key=lambda f: f.fullname()): + self.actorForwardDecls.extend( + [_makeForwardDeclForActor(friend, self.prettyside), Whitespace.NL] + ) + self.cls.addstmt( + FriendClassDecl(_actorName(friend.fullname(), self.prettyside)) + ) + + self.cls.addstmt(Label.PROTECTED) + for typedef in sorted(self.typedefSet): + self.cls.addstmt(typedef) + + self.cls.addstmt(Whitespace.NL) + + if hasAsyncReturns: + self.cls.addstmt(Label.PUBLIC) + for md in p.messageDecls: + if self.sendsMessage(md) and md.hasAsyncReturns(): + self.cls.addstmt( + Typedef(_makePromise(md.returns, self.side), md.promiseName()) + ) + if self.receivesMessage(md) and md.hasAsyncReturns(): + self.cls.addstmt( + Typedef(_makeResolver(md.returns, self.side), md.resolverName()) + ) + self.cls.addstmt(Whitespace.NL) + + self.cls.addstmt(Label.PROTECTED) + # interface methods that the concrete subclass has to impl + for md in p.messageDecls: + isctor, isdtor = md.decl.type.isCtor(), md.decl.type.isDtor() + + if self.receivesMessage(md): + # generate Recv/Answer* interface + implicit = not isdtor + returnsems = "resolver" if md.decl.type.isAsync() else "out" + recvDecl = MethodDecl( + md.recvMethod(), + params=md.makeCxxParams( + paramsems="move", + returnsems=returnsems, + side=self.side, + implicit=implicit, + direction="recv", + ), + ret=Type("mozilla::ipc::IPCResult"), + methodspec=MethodSpec.VIRTUAL, + ) + + # These method implementations cause problems when trying to + # override them with different types in a direct call class. + # + # For the `isdtor` case there's a simple solution: it doesn't + # make much sense to specify arguments and then completely + # ignore them, and the no-arg case isn't a problem for + # overriding. + if isctor or (isdtor and not md.inParams): + defaultRecv = MethodDefn(recvDecl) + defaultRecv.addcode("return IPC_OK();\n") + self.cls.addstmt(defaultRecv) + elif self.protocol.implAttribute(self.side) == "virtual": + # If we're using virtual calls, we need the methods to be + # declared on the base class. + recvDecl.methodspec = MethodSpec.PURE + self.cls.addstmt(StmtDecl(recvDecl)) + + # If we're using virtual calls, we need the methods to be declared on + # the base class. + if self.protocol.implAttribute(self.side) == "virtual": + for md in p.messageDecls: + managed = md.decl.type.constructedType() + if not ptype.isManagerOf(managed) or md.decl.type.isDtor(): + continue + + # add the Alloc interface for managed actors + actortype = md.actorDecl().bareType(self.side) + + if managed.isRefcounted(): + if not self.receivesMessage(md): + continue + + actortype.ptr = False + actortype = _alreadyaddrefed(actortype) + + self.cls.addstmt( + StmtDecl( + MethodDecl( + _allocMethod(managed, self.side), + params=md.makeCxxParams( + side=self.side, implicit=False, direction="recv" + ), + ret=actortype, + methodspec=MethodSpec.PURE, + ) + ) + ) + + # add the Dealloc interface for all managed non-refcounted actors, + # even without ctors. This is useful for protocols which use + # ManagedEndpoint for construction. + for managed in ptype.manages: + if managed.isRefcounted(): + continue + + self.cls.addstmt( + StmtDecl( + MethodDecl( + _deallocMethod(managed, self.side), + params=[ + Decl(p.managedCxxType(managed, self.side), "aActor") + ], + ret=Type.BOOL, + methodspec=MethodSpec.PURE, + ) + ) + ) + + if ptype.isToplevel(): + # void ProcessingError(code); default to no-op + processingerror = MethodDefn( + MethodDecl( + p.processingErrorVar().name, + params=[ + Param(_Result.Type(), "aCode"), + Param(Type("char", const=True, ptr=True), "aReason"), + ], + methodspec=MethodSpec.OVERRIDE, + ) + ) + + # bool ShouldContinueFromReplyTimeout(); default to |true| + shouldcontinue = MethodDefn( + MethodDecl( + p.shouldContinueFromTimeoutVar().name, + ret=Type.BOOL, + methodspec=MethodSpec.OVERRIDE, + ) + ) + shouldcontinue.addcode("return true;\n") + + self.cls.addstmts( + [ + processingerror, + shouldcontinue, + Whitespace.NL, + ] + ) + + self.cls.addstmts(([Label.PUBLIC] + self.standardTypedefs() + [Whitespace.NL])) + + self.cls.addstmt(Label.PUBLIC) + # Actor() + ctor = ConstructorDefn(ConstructorDecl(self.clsname)) + side = ExprVar("mozilla::ipc::" + self.side.title() + "Side") + if ptype.isToplevel(): + name = ExprLiteral.String(_actorName(p.name, self.side)) + ctor.memberinits = [ + ExprMemberInit( + ExprVar("mozilla::ipc::IToplevelProtocol"), + [name, _protocolId(ptype), side], + ) + ] + else: + ctor.memberinits = [ + ExprMemberInit( + ExprVar("mozilla::ipc::IProtocol"), [_protocolId(ptype), side] + ) + ] + + ctor.addcode("MOZ_COUNT_CTOR(${clsname});\n", clsname=self.clsname) + self.cls.addstmts([ctor, Whitespace.NL]) + + # ~Actor() + dtor = DestructorDefn( + DestructorDecl(self.clsname, methodspec=MethodSpec.VIRTUAL) + ) + dtor.addcode("MOZ_COUNT_DTOR(${clsname});\n", clsname=self.clsname) + + self.cls.addstmts([dtor, Whitespace.NL]) + + if ptype.isRefcounted(): + if not ptype.isToplevel(): + self.cls.addcode( + """ + NS_INLINE_DECL_PURE_VIRTUAL_REFCOUNTING + """ + ) + self.cls.addstmt(Label.PROTECTED) + self.cls.addcode( + """ + void ActorAlloc() final { AddRef(); } + void ActorDealloc() final { Release(); } + """ + ) + + self.cls.addstmt(Label.PUBLIC) + if ptype.hasOtherPid(): + otherpidmeth = MethodDefn( + MethodDecl("OtherPid", ret=Type("::base::ProcessId"), const=True) + ) + otherpidmeth.addcode( + """ + ::base::ProcessId pid = + ::mozilla::ipc::IProtocol::ToplevelProtocol()->OtherPidMaybeInvalid(); + MOZ_RELEASE_ASSERT(pid != ::base::kInvalidProcessId); + return pid; + """ + ) + self.cls.addstmts([otherpidmeth, Whitespace.NL]) + + if not ptype.isToplevel(): + if 1 == len(p.managers): + # manager() const + managertype = p.managerActorType(self.side, ptr=True) + managermeth = MethodDefn( + MethodDecl("Manager", ret=managertype, const=True) + ) + managermeth.addcode( + """ + return static_cast<${type}>(IProtocol::Manager()); + """, + type=managertype, + ) + + self.cls.addstmts([managermeth, Whitespace.NL]) + + def actorFromIter(itervar): + return ExprCode("${iter}.Get()->GetKey()", iter=itervar) + + def forLoopOverHashtable(hashtable, itervar, const=False): + itermeth = "ConstIter" if const else "Iter" + return StmtFor( + init=ExprCode( + "auto ${itervar} = ${hashtable}.${itermeth}()", + itervar=itervar, + hashtable=hashtable, + itermeth=itermeth, + ), + cond=ExprCode("!${itervar}.Done()", itervar=itervar), + update=ExprCode("${itervar}.Next()", itervar=itervar), + ) + + # Managed[T](Array& inout) const + # const Array<T>& Managed() const + for managed in ptype.manages: + container = p.managedVar(managed, self.side) + + meth = MethodDefn( + MethodDecl( + p.managedMethod(managed, self.side).name, + params=[ + Decl( + _cxxArrayType( + p.managedCxxType(managed, self.side), ref=True + ), + "aArr", + ) + ], + const=True, + ) + ) + meth.addcode("${container}.ToArray(aArr);\n", container=container) + + refmeth = MethodDefn( + MethodDecl( + p.managedMethod(managed, self.side).name, + params=[], + ret=p.managedVarType(managed, self.side, const=True, ref=True), + const=True, + ) + ) + refmeth.addcode("return ${container};\n", container=container) + + self.cls.addstmts([meth, refmeth, Whitespace.NL]) + + # AllManagedActors(Array& inout) const + arrvar = ExprVar("arr__") + managedmeth = MethodDefn( + MethodDecl( + "AllManagedActors", + params=[ + Decl( + _cxxArrayType(_refptr(_cxxLifecycleProxyType()), ref=True), + arrvar.name, + ) + ], + methodspec=MethodSpec.OVERRIDE, + const=True, + ) + ) + + # Count the number of managed actors, and allocate space in the output array. + managedmeth.addcode( + """ + uint32_t total = 0; + """ + ) + for managed in ptype.manages: + managedmeth.addcode( + """ + total += ${container}.Count(); + """, + container=p.managedVar(managed, self.side), + ) + managedmeth.addcode( + """ + arr__.SetCapacity(total); + + """ + ) + + for managed in ptype.manages: + managedmeth.addcode( + """ + for (auto* key : ${container}) { + arr__.AppendElement(key->GetLifecycleProxy()); + } + + """, + container=p.managedVar(managed, self.side), + ) + + self.cls.addstmts([managedmeth, Whitespace.NL]) + + # OpenPEndpoint(...)/BindPEndpoint(...) + for managed in ptype.manages: + self.genManagedEndpoint(managed) + + # OnMessageReceived()/OnCallReceived() + + # save these away for use in message handler case stmts + msgvar = ExprVar("msg__") + self.msgvar = msgvar + replyvar = ExprVar("reply__") + self.replyvar = replyvar + var = ExprVar("v__") + self.var = var + # for ctor recv cases, we can't read the actor ID into a PFoo* + # because it doesn't exist on this side yet. Use a "special" + # actor handle instead + handlevar = ExprVar("handle__") + self.handlevar = handlevar + + msgtype = ExprCode("msg__.type()") + self.asyncSwitch = StmtSwitch(msgtype) + self.syncSwitch = None + self.interruptSwitch = None + if toplevel.isSync() or toplevel.isInterrupt(): + self.syncSwitch = StmtSwitch(msgtype) + if toplevel.isInterrupt(): + self.interruptSwitch = StmtSwitch(msgtype) + + # Add a handler for the MANAGED_ENDPOINT_BOUND and + # MANAGED_ENDPOINT_DROPPED message types for managed actors. + if not ptype.isToplevel(): + clearawaitingmanagedendpointbind = """ + if (!mAwaitingManagedEndpointBind) { + NS_WARNING("Unexpected managed endpoint lifecycle message after actor bound!"); + return MsgNotAllowed; + } + mAwaitingManagedEndpointBind = false; + """ + self.asyncSwitch.addcase( + CaseLabel("MANAGED_ENDPOINT_BOUND_MESSAGE_TYPE"), + StmtBlock( + [ + StmtCode(clearawaitingmanagedendpointbind), + StmtReturn(_Result.Processed), + ] + ), + ) + self.asyncSwitch.addcase( + CaseLabel("MANAGED_ENDPOINT_DROPPED_MESSAGE_TYPE"), + StmtBlock( + [ + StmtCode(clearawaitingmanagedendpointbind), + *self.destroyActor( + None, + ExprVar.THIS, + why=_DestroyReason.ManagedEndpointDropped, + ), + StmtReturn(_Result.Processed), + ] + ), + ) + + # implement Send*() methods and add dispatcher cases to + # message switch()es + for md in p.messageDecls: + self.visitMessageDecl(md) + + # add default cases + default = StmtCode( + """ + return MsgNotKnown; + """ + ) + self.asyncSwitch.addcase(DefaultLabel(), default) + if toplevel.isSync() or toplevel.isInterrupt(): + self.syncSwitch.addcase(DefaultLabel(), default) + if toplevel.isInterrupt(): + self.interruptSwitch.addcase(DefaultLabel(), default) + + self.cls.addstmts(self.implementManagerIface()) + + def makeHandlerMethod(name, switch, hasReply, dispatches=False): + params = [Decl(Type("Message", const=True, ref=True), msgvar.name)] + if hasReply: + params.append(Decl(Type("UniquePtr<Message>", ref=True), replyvar.name)) + + method = MethodDefn( + MethodDecl( + name, + methodspec=MethodSpec.OVERRIDE, + params=params, + ret=_Result.Type(), + ) + ) + + if not switch: + method.addcode( + """ + MOZ_ASSERT_UNREACHABLE("message protocol not supported"); + return MsgNotKnown; + """ + ) + return method + + if dispatches: + if hasReply: + ondeadactor = [StmtReturn(_Result.RouteError)] + else: + ondeadactor = [ + self.logMessage( + None, ExprAddrOf(msgvar), "Ignored message for dead actor" + ), + StmtReturn(_Result.Processed), + ] + + method.addcode( + """ + int32_t route__ = ${msgvar}.routing_id(); + if (MSG_ROUTING_CONTROL != route__) { + IProtocol* routed__ = Lookup(route__); + if (!routed__ || !routed__->GetLifecycleProxy()) { + $*{ondeadactor} + } + + RefPtr<mozilla::ipc::ActorLifecycleProxy> proxy__ = + routed__->GetLifecycleProxy(); + return proxy__->Get()->${name}($,{args}); + } + + """, + msgvar=msgvar, + ondeadactor=ondeadactor, + name=name, + args=[p.name for p in params], + ) + + # bug 509581: don't generate the switch stmt if there + # is only the default case; MSVC doesn't like that + if switch.nr_cases > 1: + method.addstmt(switch) + else: + method.addstmt(StmtReturn(_Result.NotKnown)) + + return method + + dispatches = ptype.isToplevel() and ptype.isManager() + self.cls.addstmts( + [ + makeHandlerMethod( + "OnMessageReceived", + self.asyncSwitch, + hasReply=False, + dispatches=dispatches, + ), + Whitespace.NL, + ] + ) + self.cls.addstmts( + [ + makeHandlerMethod( + "OnMessageReceived", + self.syncSwitch, + hasReply=True, + dispatches=dispatches, + ), + Whitespace.NL, + ] + ) + self.cls.addstmts( + [ + makeHandlerMethod( + "OnCallReceived", + self.interruptSwitch, + hasReply=True, + dispatches=dispatches, + ), + Whitespace.NL, + ] + ) + + clearsubtreevar = ExprVar("ClearSubtree") + + if ptype.isToplevel(): + # OnChannelClose() + onclose = MethodDefn( + MethodDecl("OnChannelClose", methodspec=MethodSpec.OVERRIDE) + ) + onclose.addcode( + """ + DestroySubtree(NormalShutdown); + ClearSubtree(); + DeallocShmems(); + if (GetLifecycleProxy()) { + GetLifecycleProxy()->Release(); + } + """ + ) + self.cls.addstmts([onclose, Whitespace.NL]) + + # OnChannelError() + onerror = MethodDefn( + MethodDecl("OnChannelError", methodspec=MethodSpec.OVERRIDE) + ) + onerror.addcode( + """ + DestroySubtree(AbnormalShutdown); + ClearSubtree(); + DeallocShmems(); + if (GetLifecycleProxy()) { + GetLifecycleProxy()->Release(); + } + """ + ) + self.cls.addstmts([onerror, Whitespace.NL]) + + if ptype.isToplevel() and ptype.isInterrupt(): + processnative = MethodDefn( + MethodDecl("ProcessNativeEventsInInterruptCall", ret=Type.VOID) + ) + processnative.addcode( + """ + #ifdef OS_WIN + GetIPCChannel()->ProcessNativeEventsInInterruptCall(); + #else + FatalError("This method is Windows-only"); + #endif + """ + ) + + self.cls.addstmts([processnative, Whitespace.NL]) + + # private methods + self.cls.addstmt(Label.PRIVATE) + + # ClearSubtree() + clearsubtree = MethodDefn(MethodDecl(clearsubtreevar.name)) + for managed in ptype.manages: + clearsubtree.addcode( + """ + for (auto* key : ${container}) { + key->ClearSubtree(); + } + for (auto* key : ${container}) { + // Recursively releasing ${container} kids. + auto* proxy = key->GetLifecycleProxy(); + NS_IF_RELEASE(proxy); + } + ${container}.Clear(); + + """, + container=p.managedVar(managed, self.side), + ) + + # don't release our own IPC reference: either the manager will do it, + # or we're toplevel + self.cls.addstmts([clearsubtree, Whitespace.NL]) + + if not ptype.isToplevel(): + self.cls.addstmts( + [ + StmtDecl( + Decl(Type.BOOL, "mAwaitingManagedEndpointBind"), + init=ExprLiteral.FALSE, + ), + Whitespace.NL, + ] + ) + + for managed in ptype.manages: + self.cls.addstmts( + [ + StmtDecl( + Decl( + p.managedVarType(managed, self.side), + p.managedVar(managed, self.side).name, + ) + ) + ] + ) + + def genManagedEndpoint(self, managed): + hereEp = "ManagedEndpoint<%s>" % _actorName(managed.name(), self.side) + thereEp = "ManagedEndpoint<%s>" % _actorName( + managed.name(), _otherSide(self.side) + ) + + actor = _HybridDecl(ipdl.type.ActorType(managed), "aActor") + + # ManagedEndpoint<PThere> OpenPEndpoint(PHere* aActor) + openmeth = MethodDefn( + MethodDecl( + "Open%sEndpoint" % managed.name(), + params=[ + Decl(self.protocol.managedCxxType(managed, self.side), actor.name) + ], + ret=Type(thereEp), + ) + ) + openmeth.addcode( + """ + $*{bind} + // Mark our actor as awaiting the other side to be bound. This will + // be cleared when a `MANAGED_ENDPOINT_{DROPPED,BOUND}` message is + // received. + aActor->mAwaitingManagedEndpointBind = true; + return ${thereEp}(mozilla::ipc::PrivateIPDLInterface(), aActor); + """, + bind=self.bindManagedActor(actor, errfn=ExprCall(ExprVar(thereEp))), + thereEp=thereEp, + ) + + # void BindPEndpoint(ManagedEndpoint<PHere>&& aEndpoint, PHere* aActor) + bindmeth = MethodDefn( + MethodDecl( + "Bind%sEndpoint" % managed.name(), + params=[ + Decl(Type(hereEp), "aEndpoint"), + Decl(self.protocol.managedCxxType(managed, self.side), actor.name), + ], + ret=Type.BOOL, + ) + ) + bindmeth.addcode( + """ + return aEndpoint.Bind(mozilla::ipc::PrivateIPDLInterface(), aActor, this, ${container}); + """, + container=self.protocol.managedVar(managed, self.side), + ) + + self.cls.addstmts([openmeth, bindmeth, Whitespace.NL]) + + def implementManagerIface(self): + p = self.protocol + protocolbase = Type("IProtocol", ptr=True) + + methods = [] + + if p.decl.type.isToplevel(): + # FIXME: This used to be declared conditionally based on whether + # shmem appeared somewhere in the protocol hierarchy, however that + # caused issues due to Shmem instances hidden within custom C++ + # types. + self.asyncSwitch.addcase( + CaseLabel("SHMEM_CREATED_MESSAGE_TYPE"), + self.genShmemCreatedHandler(), + ) + self.asyncSwitch.addcase( + CaseLabel("SHMEM_DESTROYED_MESSAGE_TYPE"), + self.genShmemDestroyedHandler(), + ) + + # Keep track of types created with an INOUT ctor. We need to call + # Register() or RegisterID() for them depending on the side the managee + # is created. + inoutCtorTypes = [] + for msg in p.messageDecls: + msgtype = msg.decl.type + if msgtype.isCtor() and msgtype.isInout(): + inoutCtorTypes.append(msgtype.constructedType()) + + # all protocols share the "same" RemoveManagee() implementation + pvar = ExprVar("aProtocolId") + listenervar = ExprVar("aListener") + removemanagee = MethodDefn( + MethodDecl( + p.removeManageeMethod().name, + params=[ + Decl(_protocolIdType(), pvar.name), + Decl(protocolbase, listenervar.name), + ], + methodspec=MethodSpec.OVERRIDE, + ) + ) + + if not len(p.managesStmts): + removemanagee.addcode( + """ + FatalError("unreached"); + return; + """ + ) + else: + switchontype = StmtSwitch(pvar) + for managee in p.managesStmts: + manageeipdltype = managee.decl.type + manageecxxtype = _cxxBareType( + ipdl.type.ActorType(manageeipdltype), self.side + ) + case = ExprCode( + """ + { + ${manageecxxtype} actor = static_cast<${manageecxxtype}>(aListener); + + const bool removed = ${container}.EnsureRemoved(actor); + MOZ_RELEASE_ASSERT(removed, "actor not managed by this!"); + + auto* proxy = actor->GetLifecycleProxy(); + NS_IF_RELEASE(proxy); + return; + } + """, + manageecxxtype=manageecxxtype, + container=p.managedVar(manageeipdltype, self.side), + ) + switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case) + switchontype.addcase( + DefaultLabel(), + ExprCode( + """ + FatalError("unreached"); + return; + """ + ), + ) + removemanagee.addstmt(switchontype) + + # The `DeallocManagee` method is called for managed actors to trigger + # deallocation when ActorLifecycleProxy is freed. + deallocmanagee = MethodDefn( + MethodDecl( + p.deallocManageeMethod().name, + params=[ + Decl(_protocolIdType(), pvar.name), + Decl(protocolbase, listenervar.name), + ], + methodspec=MethodSpec.OVERRIDE, + ) + ) + + if not len(p.managesStmts): + deallocmanagee.addcode( + """ + FatalError("unreached"); + return; + """ + ) + else: + switchontype = StmtSwitch(pvar) + for managee in p.managesStmts: + manageeipdltype = managee.decl.type + # Reference counted actor types don't have corresponding + # `Dealloc` methods, as they are deallocated by releasing the + # IPDL-held reference. + if manageeipdltype.isRefcounted(): + continue + + case = StmtCode( + """ + ${concrete}->${dealloc}(static_cast<${type}>(aListener)); + return; + """, + concrete=self.concreteThis(), + dealloc=_deallocMethod(manageeipdltype, self.side), + type=_cxxBareType(ipdl.type.ActorType(manageeipdltype), self.side), + ) + switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case) + switchontype.addcase( + DefaultLabel(), + StmtCode( + """ + FatalError("unreached"); + return; + """ + ), + ) + deallocmanagee.addstmt(switchontype) + + return methods + [removemanagee, deallocmanagee, Whitespace.NL] + + def genShmemCreatedHandler(self): + assert self.protocol.decl.type.isToplevel() + + return StmtCode( + """ + { + if (!ShmemCreated(${msgvar})) { + return MsgPayloadError; + } + return MsgProcessed; + } + """, + msgvar=self.msgvar, + ) + + def genShmemDestroyedHandler(self): + assert self.protocol.decl.type.isToplevel() + + return StmtCode( + """ + { + if (!ShmemDestroyed(${msgvar})) { + return MsgPayloadError; + } + return MsgProcessed; + } + """, + msgvar=self.msgvar, + ) + + # ------------------------------------------------------------------------- + # The next few functions are the crux of the IPDL code generator. + # They generate code for all the nasty work of message + # serialization/deserialization and dispatching handlers for + # received messages. + ## + + def concreteThis(self): + implAttr = self.protocol.implAttribute(self.side) + if implAttr == "virtual": + return ExprVar.THIS + + if implAttr is None: + assert self.protocol.name.startswith("P") + className = self.protocol.name[1:] + self.side.capitalize() + else: + assert isinstance(implAttr, ipdl.ast.StringLiteral) + className = implAttr.value + + return ExprCode("static_cast<${className}*>(this)", className=className) + + def thisCall(self, function, args): + return ExprCall(ExprSelect(self.concreteThis(), "->", function), args=args) + + def visitMessageDecl(self, md): + isctor = md.decl.type.isCtor() + isdtor = md.decl.type.isDtor() + decltype = md.decl.type + sendmethod = None + movesendmethod = None + promisesendmethod = None + recvlbl, recvcase = None, None + + def addRecvCase(lbl, case): + if decltype.isAsync(): + self.asyncSwitch.addcase(lbl, case) + elif decltype.isSync(): + self.syncSwitch.addcase(lbl, case) + elif decltype.isInterrupt(): + self.interruptSwitch.addcase(lbl, case) + else: + assert 0 + + if self.sendsMessage(md): + isasync = decltype.isAsync() + + # NOTE: Don't generate helper ctors for refcounted types. + # + # Safety concerns around providing your own actor to a ctor (namely + # that the return value won't be checked, and the argument will be + # `delete`-ed) are less critical with refcounted actors, due to the + # actor being held alive by the callsite. + # + # This allows refcounted actors to not implement crashing AllocPFoo + # methods on the sending side. + if isctor and not md.decl.type.constructedType().isRefcounted(): + self.cls.addstmts([self.genHelperCtor(md), Whitespace.NL]) + + if isctor and isasync: + sendmethod, (recvlbl, recvcase) = self.genAsyncCtor(md) + elif isctor: + sendmethod = self.genBlockingCtorMethod(md) + elif isdtor and isasync: + sendmethod, (recvlbl, recvcase) = self.genAsyncDtor(md) + elif isdtor: + sendmethod = self.genBlockingDtorMethod(md) + elif isasync: + ( + sendmethod, + movesendmethod, + promisesendmethod, + (recvlbl, recvcase), + ) = self.genAsyncSendMethod(md) + else: + sendmethod, movesendmethod = self.genBlockingSendMethod(md) + + # XXX figure out what to do here + if isdtor and md.decl.type.constructedType().isToplevel(): + sendmethod = None + + if sendmethod is not None: + self.cls.addstmts([sendmethod, Whitespace.NL]) + if movesendmethod is not None: + self.cls.addstmts([movesendmethod, Whitespace.NL]) + if promisesendmethod is not None: + self.cls.addstmts([promisesendmethod, Whitespace.NL]) + if recvcase is not None: + addRecvCase(recvlbl, recvcase) + recvlbl, recvcase = None, None + + if self.receivesMessage(md): + if isctor: + recvlbl, recvcase = self.genCtorRecvCase(md) + elif isdtor: + recvlbl, recvcase = self.genDtorRecvCase(md) + else: + recvlbl, recvcase = self.genRecvCase(md) + + # XXX figure out what to do here + if isdtor and md.decl.type.constructedType().isToplevel(): + return + + addRecvCase(recvlbl, recvcase) + + def genAsyncCtor(self, md): + actor = md.actorDecl() + method = MethodDefn(self.makeSendMethodDecl(md)) + + msgvar, stmts = self.makeMessage(md, errfnSendCtor) + sendok, sendstmts = self.sendAsync(md, msgvar) + + method.addcode( + """ + $*{bind} + + // Build our constructor message. + $*{stmts} + + // Notify the other side about the newly created actor. This can + // fail if our manager has already been destroyed. + // + // NOTE: If the send call fails due to toplevel channel teardown, + // the `IProtocol::ChannelSend` wrapper absorbs the error for us, + // so we don't tear down actors unexpectedly. + $*{sendstmts} + + // Warn, destroy the actor, and return null if the message failed to + // send. Otherwise, return the successfully created actor reference. + if (!${sendok}) { + NS_WARNING("Error sending ${actorname} constructor"); + $*{destroy} + return nullptr; + } + return ${actor}; + """, + bind=self.bindManagedActor(actor), + stmts=stmts, + sendstmts=sendstmts, + sendok=sendok, + destroy=self.destroyActor( + md, actor.var(), why=_DestroyReason.FailedConstructor + ), + actor=actor.var(), + actorname=actor.ipdltype.protocol.name() + self.side.capitalize(), + ) + + lbl = CaseLabel(md.pqReplyId()) + case = StmtBlock() + case.addstmt(StmtReturn(_Result.Processed)) + # TODO not really sure what to do with async ctor "replies" yet. + # destroy actor if there was an error? tricky ... + + return method, (lbl, case) + + def genBlockingCtorMethod(self, md): + actor = md.actorDecl() + method = MethodDefn(self.makeSendMethodDecl(md)) + + msgvar, stmts = self.makeMessage(md, errfnSendCtor) + + replyvar = self.replyvar + sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar) + replystmts = self.deserializeReply( + md, + replyvar, + self.side, + errfnSendCtor, + errfnSentinel(ExprLiteral.NULL), + ) + + method.addcode( + """ + $*{bind} + + // Build our constructor message. + $*{stmts} + + // Synchronously send the constructor message to the other side. If + // the send fails, e.g. due to the remote side shutting down, the + // actor will be destroyed and potentially freed. + UniquePtr<Message> ${replyvar}; + $*{sendstmts} + + if (!(${sendok})) { + // Warn, destroy the actor and return null if the message + // failed to send. + NS_WARNING("Error sending constructor"); + $*{destroy} + return nullptr; + } + + $*{replystmts} + return ${actor}; + """, + bind=self.bindManagedActor(actor), + stmts=stmts, + replyvar=replyvar, + sendstmts=sendstmts, + sendok=sendok, + destroy=self.destroyActor( + md, actor.var(), why=_DestroyReason.FailedConstructor + ), + replystmts=replystmts, + actor=actor.var(), + actorname=actor.ipdltype.protocol.name() + self.side.capitalize(), + ) + + return method + + def bindManagedActor(self, actordecl, errfn=ExprLiteral.NULL, idexpr=None): + actorproto = actordecl.ipdltype.protocol + + if idexpr is None: + setManagerArgs = [ExprVar.THIS] + else: + setManagerArgs = [ExprVar.THIS, idexpr] + + return [ + StmtCode( + """ + if (!${actor}) { + NS_WARNING("Cannot bind null ${actorname} actor"); + return ${errfn}; + } + + ${actor}->SetManagerAndRegister($,{setManagerArgs}); + ${container}.Insert(${actor}); + """, + actor=actordecl.var(), + actorname=actorproto.name() + self.side.capitalize(), + errfn=errfn, + setManagerArgs=setManagerArgs, + container=self.protocol.managedVar(actorproto, self.side), + ) + ] + + def genHelperCtor(self, md): + helperdecl = self.makeSendMethodDecl(md) + helperdecl.params = helperdecl.params[1:] + helper = MethodDefn(helperdecl) + + helper.addstmts( + [ + self.callAllocActor(md, retsems="out", side=self.side), + StmtReturn( + ExprCall( + ExprVar(helperdecl.name), args=md.makeCxxArgs(paramsems="move") + ) + ), + ] + ) + return helper + + def genAsyncDtor(self, md): + actorvar = ExprVar("actor") + method = MethodDefn(self.makeDtorMethodDecl(md, actorvar)) + + method.addstmt(self.dtorPrologue(actorvar)) + + msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar) + sendok, sendstmts = self.sendAsync(md, msgvar, actorvar) + method.addstmts( + stmts + + sendstmts + + [Whitespace.NL] + + self.dtorEpilogue(md, actorvar) + + [StmtReturn(sendok)] + ) + + lbl = CaseLabel(md.pqReplyId()) + case = StmtBlock() + case.addstmt(StmtReturn(_Result.Processed)) + # TODO if the dtor is "inherently racy", keep the actor alive + # until the other side acks + + return method, (lbl, case) + + def genBlockingDtorMethod(self, md): + actorvar = ExprVar("actor") + method = MethodDefn(self.makeDtorMethodDecl(md, actorvar)) + + method.addstmt(self.dtorPrologue(actorvar)) + + msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar) + + replyvar = self.replyvar + sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar, actorvar) + method.addstmts( + stmts + + [Whitespace.NL, StmtDecl(Decl(Type("UniquePtr<Message>"), replyvar.name))] + + sendstmts + ) + + destmts = self.deserializeReply( + md, replyvar, self.side, errfnSend, errfnSentinel(), actorvar + ) + ifsendok = StmtIf(ExprLiteral.FALSE) + ifsendok.addifstmts(destmts) + ifsendok.addifstmts( + [Whitespace.NL, StmtExpr(ExprAssn(sendok, ExprLiteral.FALSE, "&="))] + ) + + method.addstmt(ifsendok) + + method.addstmts( + self.dtorEpilogue(md, actorvar) + [Whitespace.NL, StmtReturn(sendok)] + ) + + return method + + def destroyActor(self, md, actorexpr, why=_DestroyReason.Deletion): + if md and md.decl.type.isCtor(): + destroyedType = md.decl.type.constructedType() + else: + destroyedType = self.protocol.decl.type + + return [ + StmtCode( + """ + IProtocol* mgr = ${actor}->Manager(); + ${actor}->DestroySubtree(${why}); + ${actor}->ClearSubtree(); + mgr->RemoveManagee(${protoId}, ${actor}); + """, + actor=actorexpr, + why=why, + protoId=_protocolId(destroyedType), + ) + ] + + def dtorPrologue(self, actorexpr): + return StmtCode( + """ + if (!${actor} || !${actor}->CanSend()) { + NS_WARNING("Attempt to __delete__ missing or closed actor"); + return false; + } + """, + actor=actorexpr, + ) + + def dtorEpilogue(self, md, actorexpr): + return self.destroyActor(md, actorexpr) + + def genRecvAsyncReplyCase(self, md): + lbl = CaseLabel(md.pqReplyId()) + case = StmtBlock() + resolve, reason, prologue, desrej, desstmts = self.deserializeAsyncReply( + md, self.side, errfnRecv, errfnSentinel(_Result.ValuError) + ) + + if len(md.returns) > 1: + resolvetype = _tuple([d.bareType(self.side) for d in md.returns]) + resolvearg = ExprCall( + ExprVar("std::make_tuple"), args=[ExprMove(p.var()) for p in md.returns] + ) + else: + resolvetype = md.returns[0].bareType(self.side) + resolvearg = ExprMove(md.returns[0].var()) + + case.addcode( + """ + $*{prologue} + + UniquePtr<MessageChannel::UntypedCallbackHolder> untypedCallback = + GetIPCChannel()->PopCallback(${msgvar}, Id()); + + typedef MessageChannel::CallbackHolder<${resolvetype}> CallbackHolder; + auto* callback = static_cast<CallbackHolder*>(untypedCallback.get()); + if (!callback) { + FatalError("Error unknown callback"); + return MsgProcessingError; + } + + if (${resolve}) { + $*{desstmts} + callback->Resolve(${resolvearg}); + } else { + $*{desrej} + callback->Reject(std::move(${reason})); + } + return MsgProcessed; + """, + prologue=prologue, + msgvar=self.msgvar, + resolve=resolve, + resolvetype=resolvetype, + desstmts=desstmts, + resolvearg=resolvearg, + desrej=desrej, + reason=reason, + ) + + return (lbl, case) + + def genAsyncSendMethod(self, md): + decl = self.makeSendMethodDecl(md) + if "VirtualSendImpl" in md.attributes: + decl.methodspec = MethodSpec.VIRTUAL + method = MethodDefn(decl) + msgvar, stmts = self.makeMessage(md, errfnSend) + retvar, sendstmts = self.sendAsync(md, msgvar) + + method.addstmts(stmts + [Whitespace.NL] + sendstmts + [StmtReturn(retvar)]) + + movemethod = None + + # Add the promise overload if we need one. + if md.returns: + decl = self.makeSendMethodDecl(md, promise=True) + if "VirtualSendImpl" in md.attributes: + decl.methodspec = MethodSpec.VIRTUAL + promisemethod = MethodDefn(decl) + stmts = self.sendAsyncWithPromise(md) + promisemethod.addstmts(stmts) + + (lbl, case) = self.genRecvAsyncReplyCase(md) + else: + (promisemethod, lbl, case) = (None, None, None) + + return method, movemethod, promisemethod, (lbl, case) + + def genBlockingSendMethod(self, md): + method = MethodDefn(self.makeSendMethodDecl(md)) + + msgvar, serstmts = self.makeMessage(md, errfnSend) + replyvar = self.replyvar + + sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar) + failif = StmtIf(ExprNot(sendok)) + failif.addifstmt(StmtReturn.FALSE) + + desstmts = self.deserializeReply( + md, replyvar, self.side, errfnSend, errfnSentinel() + ) + + method.addstmts( + serstmts + + [Whitespace.NL, StmtDecl(Decl(Type("UniquePtr<Message>"), replyvar.name))] + + sendstmts + + [failif] + + desstmts + + [Whitespace.NL, StmtReturn.TRUE] + ) + + movemethod = None + + return method, movemethod + + def genCtorRecvCase(self, md): + lbl = CaseLabel(md.pqMsgId()) + case = StmtBlock() + actorhandle = self.handlevar + + stmts = self.deserializeMessage( + md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError) + ) + + idvar, saveIdStmts = self.saveActorId(md) + case.addstmts( + stmts + + [ + StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[]) + for r in md.returns + ] + # alloc the actor, register it under the foreign ID + + [self.callAllocActor(md, retsems="in", side=self.side)] + + self.bindManagedActor( + md.actorDecl(), errfn=_Result.ValuError, idexpr=_actorHId(actorhandle) + ) + + [Whitespace.NL] + + saveIdStmts + + self.invokeRecvHandler(md) + + self.makeReply(md, errfnRecv, idvar) + + [Whitespace.NL, StmtReturn(_Result.Processed)] + ) + + return lbl, case + + def genDtorRecvCase(self, md): + lbl = CaseLabel(md.pqMsgId()) + case = StmtBlock() + + stmts = self.deserializeMessage( + md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError) + ) + + idvar, saveIdStmts = self.saveActorId(md) + case.addstmts( + stmts + + [ + StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[]) + for r in md.returns + ] + + self.invokeRecvHandler(md) + + [Whitespace.NL] + + saveIdStmts + + self.makeReply(md, errfnRecv, routingId=idvar) + + [Whitespace.NL] + + self.dtorEpilogue(md, ExprVar.THIS) + + [Whitespace.NL, StmtReturn(_Result.Processed)] + ) + + return lbl, case + + def genRecvCase(self, md): + lbl = CaseLabel(md.pqMsgId()) + case = StmtBlock() + + stmts = self.deserializeMessage( + md, self.side, errfn=errfnRecv, errfnSent=errfnSentinel(_Result.ValuError) + ) + + idvar, saveIdStmts = self.saveActorId(md) + declstmts = [ + StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[]) + for r in md.returns + ] + if md.decl.type.isAsync() and md.returns: + declstmts = self.makeResolver(md, errfnRecv, routingId=idvar) + case.addstmts( + stmts + + saveIdStmts + + declstmts + + self.invokeRecvHandler(md) + + [Whitespace.NL] + + self.makeReply(md, errfnRecv, routingId=idvar) + + [StmtReturn(_Result.Processed)] + ) + + return lbl, case + + # helper methods + + def makeMessage(self, md, errfn, fromActor=None): + msgvar = self.msgvar + writervar = ExprVar("writer__") + routingId = self.protocol.routingId(fromActor) + this = fromActor or ExprVar.THIS + + stmts = ( + [ + StmtDecl( + Decl(Type("UniquePtr<IPC::Message>"), msgvar.name), + init=ExprCall(ExprVar(md.pqMsgCtorFunc()), args=[routingId]), + ), + StmtDecl( + Decl(Type("IPC::MessageWriter"), writervar.name), + initargs=[ExprDeref(msgvar), this], + ), + ] + + [Whitespace.NL] + + [ + _ParamTraits.checkedWrite( + p.ipdltype, + p.var(), + ExprAddrOf(writervar), + sentinelKey=p.name, + ) + for p in md.params + ] + + [Whitespace.NL] + + self.setMessageFlags(md, msgvar) + ) + return msgvar, stmts + + def makeResolver(self, md, errfn, routingId): + if routingId is None: + routingId = self.protocol.routingId() + if not md.decl.type.isAsync() or not md.hasReply(): + return [] + + def paramValue(idx): + assert idx < len(md.returns) + if len(md.returns) > 1: + return ExprCode("std::get<${idx}>(aParam)", idx=idx) + return ExprVar("aParam") + + serializeParams = [ + _ParamTraits.checkedWrite( + p.ipdltype, + paramValue(idx), + ExprAddrOf(ExprVar("writer__")), + sentinelKey=p.name, + ) + for idx, p in enumerate(md.returns) + ] + + return [ + StmtCode( + """ + UniquePtr<IPC::Message> ${replyvar}(${replyCtor}(${routingId})); + ${replyvar}->set_seqno(${msgvar}.seqno()); + + RefPtr<mozilla::ipc::IPDLResolverInner> resolver__ = + new mozilla::ipc::IPDLResolverInner(std::move(${replyvar}), this); + + ${resolvertype} resolver = [resolver__ = std::move(resolver__)](${resolveType} aParam) { + resolver__->Resolve([&] (IPC::Message* ${replyvar}, IProtocol* self__) { + IPC::MessageWriter writer__(*${replyvar}, self__); + $*{serializeParams} + ${logSendingReply} + }); + }; + """, + msgvar=self.msgvar, + resolvertype=Type(md.resolverName()), + routingId=routingId, + resolveType=_resolveType(md.returns, self.side), + replyvar=self.replyvar, + replyCtor=ExprVar(md.pqReplyCtorFunc()), + serializeParams=serializeParams, + logSendingReply=self.logMessage( + md, + self.replyvar, + "Sending reply ", + actor=ExprVar("self__"), + ), + ) + ] + + def makeReply(self, md, errfn, routingId): + if routingId is None: + routingId = self.protocol.routingId() + # TODO special cases for async ctor/dtor replies + if not md.decl.type.hasReply(): + return [] + if md.decl.type.isAsync() and md.decl.type.hasReply(): + return [] + + replyvar = self.replyvar + return ( + [ + StmtExpr( + ExprAssn( + replyvar, + ExprCall(ExprVar(md.pqReplyCtorFunc()), args=[routingId]), + ) + ), + StmtDecl( + Decl(Type("IPC::MessageWriter"), "writer__"), + initargs=[ExprDeref(replyvar), ExprVar.THIS], + ), + Whitespace.NL, + ] + + [ + _ParamTraits.checkedWrite( + r.ipdltype, + r.var(), + ExprAddrOf(ExprVar("writer__")), + sentinelKey=r.name, + ) + for r in md.returns + ] + + self.setMessageFlags(md, replyvar) + + [self.logMessage(md, replyvar, "Sending reply ")] + ) + + def setMessageFlags(self, md, var, seqno=None): + stmts = [] + + if seqno: + stmts.append( + StmtExpr(ExprCall(ExprSelect(var, "->", "set_seqno"), args=[seqno])) + ) + + return stmts + [Whitespace.NL] + + def deserializeMessage(self, md, side, errfn, errfnSent): + msgvar = self.msgvar + msgexpr = ExprAddrOf(msgvar) + readervar = ExprVar("reader__") + isctor = md.decl.type.isCtor() + stmts = [ + self.logMessage(md, msgexpr, "Received ", receiving=True), + self.profilerLabel(md), + Whitespace.NL, + ] + + if 0 == len(md.params): + return stmts + + start, reads = 0, [] + if isctor: + # return the raw actor handle so that its ID can be used + # to construct the "real" actor + handlevar = self.handlevar + handletype = Type("ActorHandle") + reads = [ + _ParamTraits.checkedRead( + None, + handletype, + handlevar, + ExprAddrOf(readervar), + errfn, + "'%s'" % handletype.name, + sentinelKey="actor", + errfnSentinel=errfnSent, + ) + ] + start = 1 + + def maybeTainted(p, side): + if md.decl.type.tainted and "NoTaint" not in p.attributes: + return Type("Tainted", T=p.bareType(side)) + return p.bareType(side) + + reads.extend( + [ + _ParamTraits.checkedRead( + p.ipdltype, + maybeTainted(p, side), + p.var(), + ExprAddrOf(readervar), + errfn, + "'%s'" % p.ipdltype.name(), + sentinelKey=p.name, + errfnSentinel=errfnSent, + ) + for p in md.params[start:] + ] + ) + + stmts.extend( + ( + [ + StmtDecl( + Decl(Type("IPC::MessageReader"), readervar.name), + initargs=[msgvar, ExprVar.THIS], + ) + ] + + [Whitespace.NL] + + reads + + [StmtCode("${reader}.EndRead();\n", reader=readervar)] + ) + ) + + return stmts + + def deserializeAsyncReply(self, md, side, errfn, errfnSent): + msgvar = self.msgvar + readervar = ExprVar("reader__") + msgexpr = ExprAddrOf(msgvar) + isctor = md.decl.type.isCtor() + resolve = ExprVar("resolve__") + reason = ExprVar("reason__") + + # NOTE: The `resolve__` and `reason__` parameters don't have sentinels, + # as they are serialized by the IPDLResolverInner type in + # ProtocolUtils.cpp rather than by generated code. + desresolve = [ + StmtCode( + """ + bool resolve__ = false; + if (!IPC::ReadParam(&${readervar}, &resolve__)) { + FatalError("Error deserializing bool"); + return MsgValueError; + } + """, + readervar=readervar, + ), + ] + desrej = [ + StmtCode( + """ + ResponseRejectReason reason__{}; + if (!IPC::ReadParam(&${readervar}, &reason__)) { + FatalError("Error deserializing ResponseRejectReason"); + return MsgValueError; + } + ${readervar}.EndRead(); + """, + readervar=readervar, + ), + ] + prologue = [ + self.logMessage(md, msgexpr, "Received ", receiving=True), + self.profilerLabel(md), + Whitespace.NL, + ] + + if not md.returns: + return prologue + + prologue.extend( + [ + StmtDecl( + Decl(Type("IPC::MessageReader"), readervar.name), + initargs=[msgvar, ExprVar.THIS], + ) + ] + + desresolve + ) + + start, reads = 0, [] + if isctor: + # return the raw actor handle so that its ID can be used + # to construct the "real" actor + handlevar = self.handlevar + handletype = Type("ActorHandle") + reads = [ + _ParamTraits.checkedRead( + None, + handletype, + handlevar, + ExprAddrOf(readervar), + errfn, + "'%s'" % handletype.name, + sentinelKey="actor", + errfnSentinel=errfnSent, + ) + ] + start = 1 + + stmts = ( + reads + + [ + _ParamTraits.checkedRead( + p.ipdltype, + p.bareType(side), + p.var(), + ExprAddrOf(readervar), + errfn, + "'%s'" % p.ipdltype.name(), + sentinelKey=p.name, + errfnSentinel=errfnSent, + ) + for p in md.returns[start:] + ] + + [StmtCode("${reader}.EndRead();", reader=readervar)] + ) + + return resolve, reason, prologue, desrej, stmts + + def deserializeReply(self, md, replyexpr, side, errfn, errfnSentinel, actor=None): + stmts = [ + Whitespace.NL, + self.logMessage(md, replyexpr, "Received reply ", actor, receiving=True), + ] + if 0 == len(md.returns): + return stmts + + def tempvar(r): + return ExprVar(r.var().name + "__reply") + + readervar = ExprVar("reader__") + stmts.extend( + [ + Whitespace.NL, + StmtDecl( + Decl(Type("IPC::MessageReader"), readervar.name), + initargs=[ExprDeref(self.replyvar), ExprVar.THIS], + ), + ] + + [Whitespace.NL] + + [ + _ParamTraits.checkedRead( + r.ipdltype, + r.bareType(side), + tempvar(r), + ExprAddrOf(readervar), + errfn, + "'%s'" % r.ipdltype.name(), + sentinelKey=r.name, + errfnSentinel=errfnSentinel, + ) + for r in md.returns + ] + # Move-assign the values out of the variables created with + # checkedRead into outparams. + + [ + StmtExpr(ExprAssn(ExprDeref(r.var()), ExprMove(tempvar(r)))) + for r in md.returns + ] + + [StmtCode("${reader}.EndRead();", reader=readervar)] + ) + + return stmts + + def sendAsync(self, md, msgexpr, actor=None): + sendok = ExprVar("sendok__") + resolvefn = ExprVar("aResolve") + rejectfn = ExprVar("aReject") + + stmts = [ + Whitespace.NL, + self.logMessage(md, msgexpr, "Sending ", actor), + self.profilerLabel(md), + ] + stmts.append(Whitespace.NL) + + # Generate the actual call expression. + send = ExprVar("ChannelSend") + if actor is not None: + send = ExprSelect(actor, "->", send.name) + if md.returns: + stmts.append( + StmtExpr( + ExprCall( + send, + args=[ + ExprMove(msgexpr), + ExprVar(md.pqReplyId()), + ExprMove(resolvefn), + ExprMove(rejectfn), + ], + ) + ) + ) + retvar = None + else: + stmts.append( + StmtDecl( + Decl(Type.BOOL, sendok.name), + init=ExprCall(send, args=[ExprMove(msgexpr)]), + ) + ) + retvar = sendok + + return (retvar, stmts) + + def sendBlocking(self, md, msgexpr, replyexpr, actor=None): + send = ExprVar("ChannelSend") + if md.decl.type.isInterrupt(): + send = ExprVar("ChannelCall") + if actor is not None: + send = ExprSelect(actor, "->", send.name) + + sendok = ExprVar("sendok__") + self.externalIncludes.add("mozilla/ProfilerMarkers.h") + return ( + sendok, + ( + [ + Whitespace.NL, + self.logMessage(md, msgexpr, "Sending ", actor), + self.profilerLabel(md), + ] + + [ + Whitespace.NL, + StmtDecl(Decl(Type.BOOL, sendok.name), init=ExprLiteral.FALSE), + StmtBlock( + [ + StmtExpr( + ExprCall( + ExprVar("AUTO_PROFILER_TRACING_MARKER"), + [ + ExprLiteral.String("Sync IPC"), + ExprLiteral.String( + self.protocol.name + + "::" + + md.prettyMsgName() + ), + ExprVar("IPC"), + ], + ) + ), + StmtExpr( + ExprAssn( + sendok, + ExprCall( + send, + args=[ExprMove(msgexpr), ExprAddrOf(replyexpr)], + ), + ) + ), + ] + ), + ] + ), + ) + + def sendAsyncWithPromise(self, md): + # Create a new promise, and forward to the callback send overload. + promise = _makePromise(md.returns, self.side, resolver=True) + + if len(md.returns) > 1: + resolvetype = _tuple([d.bareType(self.side) for d in md.returns]) + else: + resolvetype = md.returns[0].bareType(self.side) + + resolve = ExprCode( + """ + [promise__](${resolvetype}&& aValue) { + promise__->Resolve(std::move(aValue), __func__); + } + """, + resolvetype=resolvetype, + ) + reject = ExprCode( + """ + [promise__](ResponseRejectReason&& aReason) { + promise__->Reject(std::move(aReason), __func__); + } + """, + resolvetype=resolvetype, + ) + + args = [ExprMove(p.var()) for p in md.params] + [resolve, reject] + stmt = StmtCode( + """ + RefPtr<${promise}> promise__ = new ${promise}(__func__); + promise__->UseDirectTaskDispatch(__func__); + ${send}($,{args}); + return promise__; + """, + promise=promise, + send=md.sendMethod(), + args=args, + ) + return [stmt] + + def callAllocActor(self, md, retsems, side): + actortype = md.actorDecl().bareType(self.side) + if md.decl.type.constructedType().isRefcounted(): + actortype.ptr = False + actortype = _refptr(actortype) + + callalloc = self.thisCall( + _allocMethod(md.decl.type.constructedType(), side), + args=md.makeCxxArgs(retsems=retsems, retcallsems="out", implicit=False), + ) + + return StmtDecl(Decl(actortype, md.actorDecl().var().name), init=callalloc) + + def invokeRecvHandler(self, md): + retsems = "in" + if md.decl.type.isAsync() and md.returns: + retsems = "resolver" + okdecl = StmtDecl( + Decl(Type("mozilla::ipc::IPCResult"), "__ok"), + init=self.thisCall( + md.recvMethod(), + md.makeCxxArgs( + paramsems="move", + retsems=retsems, + retcallsems="out", + ), + ), + ) + failif = StmtIf(ExprNot(ExprVar("__ok"))) + failif.addifstmts( + [ + _protocolErrorBreakpoint("Handler returned error code!"), + Whitespace( + "// Error handled in mozilla::ipc::IPCResult\n", indent=True + ), + StmtReturn(_Result.ProcessingError), + ] + ) + return [okdecl, failif] + + def makeDtorMethodDecl(self, md, actorvar): + decl = self.makeSendMethodDecl(md) + decl.params.insert( + 0, + Decl( + _cxxInType( + ipdl.type.ActorType(md.decl.type.constructedType()), + side=self.side, + direction="send", + ), + actorvar.name, + ), + ) + decl.methodspec = MethodSpec.STATIC + return decl + + def makeSendMethodDecl(self, md, promise=False, paramsems="in"): + implicit = md.decl.type.hasImplicitActorParam() + if md.decl.type.isAsync() and md.returns: + if promise: + returnsems = "promise" + rettype = _refptr(Type(md.promiseName())) + else: + returnsems = "callback" + rettype = Type.VOID + else: + assert not promise + returnsems = "out" + rettype = Type.BOOL + decl = MethodDecl( + md.sendMethod(), + params=md.makeCxxParams( + paramsems, + returnsems=returnsems, + side=self.side, + implicit=implicit, + direction="send", + ), + warn_unused=( + (self.side == "parent" and returnsems != "callback") + or (md.decl.type.isCtor() and not md.decl.type.isAsync()) + ), + ret=rettype, + ) + if md.decl.type.isCtor(): + decl.ret = md.actorDecl().bareType(self.side) + return decl + + def logMessage(self, md, msgptr, pfx, actor=None, receiving=False): + actorname = _actorName(self.protocol.name, self.side) + return StmtCode( + """ + if (mozilla::ipc::LoggingEnabledFor(${actorname})) { + mozilla::ipc::LogMessageForProtocol( + ${actorname}, + ${actor}->ToplevelProtocol()->OtherPidMaybeInvalid(), + ${pfx}, + ${msgptr}->type(), + mozilla::ipc::MessageDirection::${direction}); + } + """, + actorname=ExprLiteral.String(actorname), + actor=actor or ExprVar.THIS, + pfx=ExprLiteral.String(pfx), + msgptr=msgptr, + direction="eReceiving" if receiving else "eSending", + ) + + def profilerLabel(self, md): + self.externalIncludes.add("mozilla/ProfilerLabels.h") + return StmtCode( + """ + AUTO_PROFILER_LABEL("${name}::${msgname}", OTHER); + """, + name=self.protocol.name, + msgname=md.prettyMsgName(), + ) + + def saveActorId(self, md): + idvar = ExprVar("id__") + if md.decl.type.hasReply(): + # only save the ID if we're actually going to use it, to + # avoid unused-variable warnings + saveIdStmts = [ + StmtDecl(Decl(_actorIdType(), idvar.name), self.protocol.routingId()) + ] + else: + saveIdStmts = [] + return idvar, saveIdStmts + + +class _GenerateProtocolParentCode(_GenerateProtocolActorCode): + def __init__(self): + _GenerateProtocolActorCode.__init__(self, "parent") + + def sendsMessage(self, md): + return not md.decl.type.isIn() + + def receivesMessage(self, md): + return md.decl.type.isInout() or md.decl.type.isIn() + + +class _GenerateProtocolChildCode(_GenerateProtocolActorCode): + def __init__(self): + _GenerateProtocolActorCode.__init__(self, "child") + + def sendsMessage(self, md): + return not md.decl.type.isOut() + + def receivesMessage(self, md): + return md.decl.type.isInout() or md.decl.type.isOut() + + +# ----------------------------------------------------------------------------- +# Utility passes +## + + +def _splitClassDeclDefn(cls): + """Destructively split |cls| methods into declarations and + definitions (if |not methodDecl.force_inline|). Return classDecl, + methodDefns.""" + defns = Block() + + for i, stmt in enumerate(cls.stmts): + if isinstance(stmt, MethodDefn) and not stmt.decl.force_inline: + decl, defn = _splitMethodDeclDefn(stmt, cls) + cls.stmts[i] = StmtDecl(decl) + if defn: + defns.addstmts([defn, Whitespace.NL]) + + return cls, defns + + +def _splitMethodDeclDefn(md, cls): + # Pure methods have decls but no defns. + if md.decl.methodspec == MethodSpec.PURE: + return md.decl, None + + saveddecl = deepcopy(md.decl) + md.decl.cls = cls + # Don't emit method specifiers on method defns. + md.decl.methodspec = MethodSpec.NONE + md.decl.warn_unused = False + md.decl.only_for_definition = True + for param in md.decl.params: + if isinstance(param, Param): + param.default = None + return saveddecl, md + + +def _splitFuncDeclDefn(fun): + assert not fun.decl.force_inline + return StmtDecl(fun.decl), fun diff --git a/ipc/ipdl/ipdl/parser.py b/ipc/ipdl/ipdl/parser.py new file mode 100644 index 0000000000..1857131868 --- /dev/null +++ b/ipc/ipdl/ipdl/parser.py @@ -0,0 +1,680 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import os +from ply import lex, yacc + +from ipdl.ast import * + +# ----------------------------------------------------------------------------- + + +class ParseError(Exception): + def __init__(self, loc, fmt, *args): + self.loc = loc + self.error = ( + "%s%s: error: %s" % (Parser.includeStackString(), loc, fmt) + ) % args + + def __str__(self): + return self.error + + +def _safeLinenoValue(t): + lineno, value = 0, "???" + if hasattr(t, "lineno"): + lineno = t.lineno + if hasattr(t, "value"): + value = t.value + return lineno, value + + +def _error(loc, fmt, *args): + raise ParseError(loc, fmt, *args) + + +class Parser: + # when we reach an |include [protocol] foo;| statement, we need to + # save the current parser state and create a new one. this "stack" is + # where that state is saved + # + # there is one Parser per file + current = None + parseStack = [] + parsed = {} + + def __init__(self, type, name, debug=False): + assert type and name + self.type = type + self.debug = debug + self.filename = None + self.includedirs = None + self.loc = None # not always up to date + self.lexer = None + self.parser = None + self.tu = TranslationUnit(type, name) + self.direction = None + + def parse(self, input, filename, includedirs): + assert os.path.isabs(filename) + + if self.tu.name in Parser.parsed: + priorTU = Parser.parsed[self.tu.name].tu + if os.path.normcase(priorTU.filename) != os.path.normcase(filename): + _error( + Loc(filename), + "Trying to load `%s' from a file when we'd already seen it in file `%s'" + % (self.tu.name, priorTU.filename), + ) + + return priorTU + + self.lexer = lex.lex(debug=self.debug) + self.parser = yacc.yacc(debug=self.debug, write_tables=False) + self.filename = filename + self.includedirs = includedirs + self.tu.filename = filename + + Parser.parsed[self.tu.name] = self + Parser.parseStack.append(Parser.current) + Parser.current = self + + try: + ast = self.parser.parse(input=input, lexer=self.lexer, debug=self.debug) + finally: + Parser.current = Parser.parseStack.pop() + + return ast + + def resolveIncludePath(self, filepath): + """Return the absolute path from which the possibly partial + |filepath| should be read, or |None| if |filepath| cannot be located.""" + for incdir in self.includedirs + [""]: + realpath = os.path.join(incdir, filepath) + if os.path.isfile(realpath): + return os.path.abspath(realpath) + return None + + # returns a GCC-style string representation of the include stack. + # e.g., + # in file included from 'foo.ipdl', line 120: + # in file included from 'bar.ipd', line 12: + # which can be printed above a proper error message or warning + @staticmethod + def includeStackString(): + s = "" + for parse in Parser.parseStack[1:]: + s += " in file included from `%s', line %d:\n" % ( + parse.loc.filename, + parse.loc.lineno, + ) + return s + + +def locFromTok(p, num): + return Loc(Parser.current.filename, p.lineno(num)) + + +# ----------------------------------------------------------------------------- + +reserved = set( + ( + "async", + "both", + "child", + "class", + "from", + "include", + "intr", + "manager", + "manages", + "namespace", + "nullable", + "or", + "parent", + "protocol", + "returns", + "struct", + "sync", + "union", + "UniquePtr", + "using", + ) +) +tokens = [ + "COLONCOLON", + "ID", + "STRING", +] + [r.upper() for r in reserved] + +t_COLONCOLON = "::" + +literals = "(){}[]<>;:,?=" +t_ignore = " \f\t\v" + + +def t_linecomment(t): + r"//[^\n]*" + + +def t_multilinecomment(t): + r"/\*(\n|.)*?\*/" + t.lexer.lineno += t.value.count("\n") + + +def t_NL(t): + r"(?:\r\n|\n|\n)+" + t.lexer.lineno += len(t.value) + + +def t_ID(t): + r"[a-zA-Z_][a-zA-Z0-9_]*" + if t.value in reserved: + t.type = t.value.upper() + return t + + +def t_STRING(t): + r'"[^"\n]*"' + t.value = StringLiteral(Loc(Parser.current.filename, t.lineno), t.value[1:-1]) + return t + + +def t_error(t): + _error( + Loc(Parser.current.filename, t.lineno), + "lexically invalid characters `%s", + t.value, + ) + + +# ----------------------------------------------------------------------------- + + +def p_TranslationUnit(p): + """TranslationUnit : Preamble NamespacedStuff""" + tu = Parser.current.tu + tu.loc = Loc(tu.filename) + for stmt in p[1]: + if isinstance(stmt, CxxInclude): + tu.addCxxInclude(stmt) + elif isinstance(stmt, Include): + tu.addInclude(stmt) + elif isinstance(stmt, UsingStmt): + tu.addUsingStmt(stmt) + else: + assert 0 + + for thing in p[2]: + if isinstance(thing, StructDecl): + tu.addStructDecl(thing) + elif isinstance(thing, UnionDecl): + tu.addUnionDecl(thing) + elif isinstance(thing, Protocol): + if tu.protocol is not None: + _error(thing.loc, "only one protocol definition per file") + tu.protocol = thing + else: + assert 0 + + # The "canonical" namespace of the tu, what it's considered to be + # in for the purposes of C++: |#include "foo/bar/TU.h"| + if tu.protocol: + assert tu.filetype == "protocol" + tu.namespaces = tu.protocol.namespaces + tu.name = tu.protocol.name + else: + assert tu.filetype == "header" + # There's not really a canonical "thing" in headers. So + # somewhat arbitrarily use the namespace of the last + # interesting thing that was declared. + for thing in reversed(tu.structsAndUnions): + tu.namespaces = thing.namespaces + break + + p[0] = tu + + +# -------------------- +# Preamble + + +def p_Preamble(p): + """Preamble : Preamble PreambleStmt ';' + |""" + if 1 == len(p): + p[0] = [] + else: + p[1].append(p[2]) + p[0] = p[1] + + +def p_PreambleStmt(p): + """PreambleStmt : CxxIncludeStmt + | IncludeStmt + | UsingStmt""" + p[0] = p[1] + + +def p_CxxIncludeStmt(p): + """CxxIncludeStmt : INCLUDE STRING""" + p[0] = CxxInclude(locFromTok(p, 1), p[2].value) + + +def p_IncludeStmt(p): + """IncludeStmt : INCLUDE PROTOCOL ID + | INCLUDE ID""" + loc = locFromTok(p, 1) + + Parser.current.loc = loc + if 4 == len(p): + id = p[3] + type = "protocol" + else: + id = p[2] + type = "header" + inc = Include(loc, type, id) + + path = Parser.current.resolveIncludePath(inc.file) + if path is None: + raise ParseError(loc, "can't locate include file `%s'" % (inc.file)) + + inc.tu = Parser(type, id).parse(open(path).read(), path, Parser.current.includedirs) + p[0] = inc + + +def p_UsingKind(p): + """UsingKind : CLASS + | STRUCT + |""" + p[0] = p[1] if 2 == len(p) else None + + +def p_UsingStmt(p): + """UsingStmt : Attributes USING UsingKind CxxType FROM STRING""" + p[0] = UsingStmt( + locFromTok(p, 2), + attributes=p[1], + kind=p[3], + cxxTypeSpec=p[4], + cxxHeader=p[6].value, + ) + + +# -------------------- +# Namespaced stuff + + +def p_NamespacedStuff(p): + """NamespacedStuff : NamespacedStuff NamespaceThing + | NamespaceThing""" + if 2 == len(p): + p[0] = p[1] + else: + p[1].extend(p[2]) + p[0] = p[1] + + +def p_NamespaceThing(p): + """NamespaceThing : NAMESPACE ID '{' NamespacedStuff '}' + | StructDecl + | UnionDecl + | ProtocolDefn""" + if 2 == len(p): + p[0] = [p[1]] + else: + for thing in p[4]: + thing.addOuterNamespace(Namespace(locFromTok(p, 1), p[2])) + p[0] = p[4] + + +def p_StructDecl(p): + """StructDecl : Attributes STRUCT ID '{' StructFields '}' ';' + | Attributes STRUCT ID '{' '}' ';'""" + if 8 == len(p): + p[0] = StructDecl(locFromTok(p, 2), p[3], p[5], p[1]) + else: + p[0] = StructDecl(locFromTok(p, 2), p[3], [], p[1]) + + +def p_StructFields(p): + """StructFields : StructFields StructField ';' + | StructField ';'""" + if 3 == len(p): + p[0] = [p[1]] + else: + p[1].append(p[2]) + p[0] = p[1] + + +def p_StructField(p): + """StructField : Type ID""" + p[0] = StructField(locFromTok(p, 1), p[1], p[2]) + + +def p_UnionDecl(p): + """UnionDecl : Attributes UNION ID '{' ComponentTypes '}' ';'""" + p[0] = UnionDecl(locFromTok(p, 2), p[3], p[5], p[1]) + + +def p_ComponentTypes(p): + """ComponentTypes : ComponentTypes Type ';' + | Type ';'""" + if 3 == len(p): + p[0] = [p[1]] + else: + p[1].append(p[2]) + p[0] = p[1] + + +def p_ProtocolDefn(p): + """ProtocolDefn : Attributes OptionalSendSemantics \ + PROTOCOL ID '{' ProtocolBody '}' ';'""" + protocol = p[6] + protocol.loc = locFromTok(p, 3) + protocol.name = p[4] + protocol.attributes = p[1] + protocol.sendSemantics = p[2] + p[0] = protocol + + if Parser.current.type == "header": + _error( + protocol.loc, + "can't define a protocol in a header. Do it in a protocol spec instead.", + ) + + +def p_ProtocolBody(p): + """ProtocolBody : ManagersStmtOpt""" + p[0] = p[1] + + +# -------------------- +# manager/manages stmts + + +def p_ManagersStmtOpt(p): + """ManagersStmtOpt : ManagersStmt ManagesStmtsOpt + | ManagesStmtsOpt""" + if 2 == len(p): + p[0] = p[1] + else: + p[2].managers = p[1] + p[0] = p[2] + + +def p_ManagersStmt(p): + """ManagersStmt : MANAGER ManagerList ';'""" + if 1 == len(p): + p[0] = [] + else: + p[0] = p[2] + + +def p_ManagerList(p): + """ManagerList : ID + | ManagerList OR ID""" + if 2 == len(p): + p[0] = [Manager(locFromTok(p, 1), p[1])] + else: + p[1].append(Manager(locFromTok(p, 3), p[3])) + p[0] = p[1] + + +def p_ManagesStmtsOpt(p): + """ManagesStmtsOpt : ManagesStmt ManagesStmtsOpt + | MessageDeclsOpt""" + if 2 == len(p): + p[0] = p[1] + else: + p[2].managesStmts.insert(0, p[1]) + p[0] = p[2] + + +def p_ManagesStmt(p): + """ManagesStmt : MANAGES ID ';'""" + p[0] = ManagesStmt(locFromTok(p, 1), p[2]) + + +# -------------------- +# Message decls + + +def p_MessageDeclsOpt(p): + """MessageDeclsOpt : MessageDeclThing MessageDeclsOpt + |""" + if 1 == len(p): + # we fill in |loc| in the Protocol rule + p[0] = Protocol(None) + else: + p[2].messageDecls.insert(0, p[1]) + p[0] = p[2] + + +def p_MessageDeclThing(p): + """MessageDeclThing : MessageDirectionLabel ':' MessageDecl ';' + | MessageDecl ';'""" + if 3 == len(p): + p[0] = p[1] + else: + p[0] = p[3] + + +def p_MessageDirectionLabel(p): + """MessageDirectionLabel : PARENT + | CHILD + | BOTH""" + if p[1] == "parent": + Parser.current.direction = IN + elif p[1] == "child": + Parser.current.direction = OUT + elif p[1] == "both": + Parser.current.direction = INOUT + else: + assert 0 + + +def p_MessageDecl(p): + """MessageDecl : Attributes SendSemantics MessageBody""" + msg = p[3] + msg.attributes = p[1] + msg.sendSemantics = p[2] + + if Parser.current.direction is None: + _error(msg.loc, "missing message direction") + msg.direction = Parser.current.direction + + p[0] = msg + + +def p_MessageBody(p): + """MessageBody : ID MessageInParams MessageOutParams""" + # FIXME/cjones: need better loc info: use one of the quals + name = p[1] + msg = MessageDecl(locFromTok(p, 1)) + msg.name = name + msg.addInParams(p[2]) + msg.addOutParams(p[3]) + + p[0] = msg + + +def p_MessageInParams(p): + """MessageInParams : '(' ParamList ')'""" + p[0] = p[2] + + +def p_MessageOutParams(p): + """MessageOutParams : RETURNS '(' ParamList ')' + |""" + if 1 == len(p): + p[0] = [] + else: + p[0] = p[3] + + +# -------------------- +# Attributes +def p_Attributes(p): + """Attributes : '[' AttributeList ']' + |""" + p[0] = {} + if 4 == len(p): + for attr in p[2]: + if attr.name in p[0]: + _error(attr.loc, "Repeated extended attribute `%s'", attr.name) + p[0][attr.name] = attr + + +def p_AttributeList(p): + """AttributeList : Attribute ',' AttributeList + | Attribute""" + p[0] = [p[1]] + if 4 == len(p): + p[0] += p[3] + + +def p_Attribute(p): + """Attribute : ID AttributeValue""" + p[0] = Attribute(locFromTok(p, 1), p[1], p[2]) + + +def p_AttributeValue(p): + """AttributeValue : '=' ID + | '=' STRING + |""" + if 1 == len(p): + p[0] = None + else: + p[0] = p[2] + + +def p_SendSemantics(p): + """SendSemantics : ASYNC + | SYNC + | INTR""" + if p[1] == "async": + p[0] = ASYNC + elif p[1] == "sync": + p[0] = SYNC + else: + assert p[1] == "intr" + p[0] = INTR + + +def p_OptionalSendSemantics(p): + """OptionalSendSemantics : SendSemantics + |""" + if 2 == len(p): + p[0] = p[1] + else: + p[0] = ASYNC + + +# -------------------- +# Minor stuff + + +def p_ParamList(p): + """ParamList : ParamList ',' Param + | Param + |""" + if 1 == len(p): + p[0] = [] + elif 2 == len(p): + p[0] = [p[1]] + else: + p[1].append(p[3]) + p[0] = p[1] + + +def p_Param(p): + """Param : Attributes Type ID""" + p[0] = Param(locFromTok(p, 2), p[2], p[3], p[1]) + + +def p_Type(p): + """Type : MaybeNullable BasicType""" + # only some types are nullable; we check this in the type checker + p[2].nullable = p[1] + p[0] = p[2] + + +def p_BasicType(p): + """BasicType : CxxID + | CxxID '[' ']' + | CxxID '?' + | CxxUniquePtrInst""" + # ID == CxxType; we forbid qnames here, + # in favor of the |using| declaration + if not isinstance(p[1], TypeSpec): + assert (len(p[1]) == 2) or (len(p[1]) == 3) + if 2 == len(p[1]): + # p[1] is CxxID. isunique = False + p[1] = p[1] + (False,) + loc, id, isunique = p[1] + p[1] = TypeSpec(loc, id) + p[1].uniqueptr = isunique + if 4 == len(p): + p[1].array = True + if 3 == len(p): + p[1].maybe = True + p[0] = p[1] + + +def p_MaybeNullable(p): + """MaybeNullable : NULLABLE + |""" + p[0] = 2 == len(p) + + +# -------------------- +# C++ stuff + + +def p_CxxType(p): + """CxxType : QualifiedID + | CxxID""" + if isinstance(p[1], QualifiedId): + p[0] = p[1] + else: + loc, id = p[1] + p[0] = QualifiedId(loc, id) + + +def p_QualifiedID(p): + """QualifiedID : QualifiedID COLONCOLON CxxID + | CxxID COLONCOLON CxxID""" + if isinstance(p[1], QualifiedId): + loc, id = p[3] + p[1].qualify(id) + p[0] = p[1] + else: + loc1, id1 = p[1] + _, id2 = p[3] + p[0] = QualifiedId(loc1, id2, [id1]) + + +def p_CxxID(p): + """CxxID : ID + | CxxTemplateInst""" + if isinstance(p[1], tuple): + p[0] = p[1] + else: + p[0] = (locFromTok(p, 1), str(p[1])) + + +def p_CxxTemplateInst(p): + """CxxTemplateInst : ID '<' ID '>'""" + p[0] = (locFromTok(p, 1), str(p[1]) + "<" + str(p[3]) + ">") + + +def p_CxxUniquePtrInst(p): + """CxxUniquePtrInst : UNIQUEPTR '<' ID '>'""" + p[0] = (locFromTok(p, 1), str(p[3]), True) + + +def p_error(t): + lineno, value = _safeLinenoValue(t) + _error(Loc(Parser.current.filename, lineno), "bad syntax near `%s'", value) diff --git a/ipc/ipdl/ipdl/type.py b/ipc/ipdl/ipdl/type.py new file mode 100644 index 0000000000..76f908dd41 --- /dev/null +++ b/ipc/ipdl/ipdl/type.py @@ -0,0 +1,1748 @@ +# vim: set ts=4 sw=4 tw=99 et: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import os +import sys + +from ipdl.ast import CxxInclude, Decl, Loc, QualifiedId, StructDecl +from ipdl.ast import UnionDecl, UsingStmt, Visitor, StringLiteral +from ipdl.ast import ASYNC, SYNC, INTR +from ipdl.ast import IN, OUT, INOUT +from ipdl.ast import NOT_NESTED, INSIDE_SYNC_NESTED, INSIDE_CPOW_NESTED +from ipdl.ast import priorityList +import ipdl.builtin as builtin +from ipdl.util import hash_str + +_DELETE_MSG = "__delete__" + + +class TypeVisitor: + def __init__(self): + self.visited = set() + + def defaultVisit(self, node, *args): + raise Exception( + "INTERNAL ERROR: no visitor for node type `%s'" % (node.__class__.__name__) + ) + + def visitVoidType(self, v, *args): + pass + + def visitBuiltinCType(self, b, *args): + pass + + def visitImportedCxxType(self, t, *args): + pass + + def visitMessageType(self, m, *args): + for param in m.params: + param.accept(self, *args) + for ret in m.returns: + ret.accept(self, *args) + if m.cdtype is not None: + m.cdtype.accept(self, *args) + + def visitProtocolType(self, p, *args): + # NB: don't visit manager and manages. a naive default impl + # could result in an infinite loop + pass + + def visitActorType(self, a, *args): + a.protocol.accept(self, *args) + + def visitStructType(self, s, *args): + if s in self.visited: + return + + self.visited.add(s) + for field in s.fields: + field.accept(self, *args) + + def visitUnionType(self, u, *args): + if u in self.visited: + return + + self.visited.add(u) + for component in u.components: + component.accept(self, *args) + + def visitArrayType(self, a, *args): + a.basetype.accept(self, *args) + + def visitMaybeType(self, m, *args): + m.basetype.accept(self, *args) + + def visitUniquePtrType(self, m, *args): + m.basetype.accept(self, *args) + + def visitNotNullType(self, m, *args): + m.basetype.accept(self, *args) + + def visitShmemType(self, s, *args): + pass + + def visitByteBufType(self, s, *args): + pass + + def visitShmemChmodType(self, c, *args): + c.shmem.accept(self) + + def visitFDType(self, s, *args): + pass + + def visitEndpointType(self, s, *args): + pass + + def visitManagedEndpointType(self, s, *args): + pass + + +class Type: + def __cmp__(self, o): + return cmp(self.fullname(), o.fullname()) + + def __eq__(self, o): + return self.__class__ == o.__class__ and self.fullname() == o.fullname() + + def __hash__(self): + return hash_str(self.fullname()) + + # Is this a C++ type? + def isCxx(self): + return False + + # Is this an IPDL type? + + def isIPDL(self): + return False + + # Is this type neither compound nor an array? + + def isAtom(self): + return False + + def isRefcounted(self): + return False + + # Should this type be wrapped in `NotNull<T>` unless marked `nullable`? + + def supportsNullable(self): + return False + + def typename(self): + return self.__class__.__name__ + + def name(self): + raise NotImplementedError() + + def fullname(self): + raise NotImplementedError() + + def accept(self, visitor, *args): + visit = getattr(visitor, "visit" + self.__class__.__name__, None) + if visit is None: + return getattr(visitor, "defaultVisit")(self, *args) + return visit(self, *args) + + +class VoidType(Type): + def isCxx(self): + return True + + def isIPDL(self): + return False + + def isAtom(self): + return True + + def name(self): + return "void" + + def fullname(self): + return "void" + + +VOID = VoidType() + +# -------------------- + + +class BuiltinCType(Type): + def __init__(self, name): + self._name = name + + def isCxx(self): + return True + + def isAtom(self): + return True + + def isSendMoveOnly(self): + return False + + def isDataMoveOnly(self): + return False + + def name(self): + return self._name + + def fullname(self): + return self._name + + +class ImportedCxxType(Type): + def __init__(self, qname, refcounted, sendmoveonly, datamoveonly): + assert isinstance(qname, QualifiedId) + self.loc = qname.loc + self.qname = qname + self.refcounted = refcounted + self.sendmoveonly = sendmoveonly + self.datamoveonly = datamoveonly + + def isCxx(self): + return True + + def isAtom(self): + return True + + def isRefcounted(self): + return self.refcounted + + def supportsNullable(self): + return self.refcounted + + def isSendMoveOnly(self): + return self.sendmoveonly + + def isDataMoveOnly(self): + return self.datamoveonly + + def name(self): + return self.qname.baseid + + def fullname(self): + return str(self.qname) + + +# -------------------- + + +class IPDLType(Type): + def isIPDL(self): + return True + + def isMessage(self): + return False + + def isProtocol(self): + return False + + def isActor(self): + return False + + def isStruct(self): + return False + + def isUnion(self): + return False + + def isArray(self): + return False + + def isMaybe(self): + return False + + def isUniquePtr(self): + return False + + def isNotNull(self): + return False + + def isAtom(self): + return True + + def isCompound(self): + return False + + def isShmem(self): + return False + + def isByteBuf(self): + return False + + def isFD(self): + return False + + def isEndpoint(self): + return False + + def isManagedEndpoint(self): + return False + + def hasBaseType(self): + return False + + +class SendSemanticsType(IPDLType): + def __init__(self, nestedRange, sendSemantics): + self.nestedRange = nestedRange + self.sendSemantics = sendSemantics + + def isAsync(self): + return self.sendSemantics == ASYNC + + def isSync(self): + return self.sendSemantics == SYNC + + def isInterrupt(self): + return self.sendSemantics is INTR + + def sendSemanticsSatisfiedBy(self, greater): + def _unwrap(nr): + if isinstance(nr, dict): + return _unwrap(nr["nested"]) + elif isinstance(nr, int): + return nr + else: + raise ValueError("Got unexpected nestedRange value: %s" % nr) + + lesser = self + lnr0, gnr0, lnr1, gnr1 = ( + _unwrap(lesser.nestedRange[0]), + _unwrap(greater.nestedRange[0]), + _unwrap(lesser.nestedRange[1]), + _unwrap(greater.nestedRange[1]), + ) + if lnr0 < gnr0 or lnr1 > gnr1: + return False + + # Protocols that use intr semantics are not allowed to use + # message nesting. + if greater.isInterrupt() and lesser.nestedRange != (NOT_NESTED, NOT_NESTED): + return False + + if lesser.isAsync(): + return True + elif lesser.isSync() and not greater.isAsync(): + return True + elif greater.isInterrupt(): + return True + + return False + + +class MessageType(SendSemanticsType): + def __init__( + self, + nested, + prio, + replyPrio, + sendSemantics, + direction, + ctor=False, + dtor=False, + cdtype=None, + compress=False, + tainted=False, + lazySend=False, + ): + assert not (ctor and dtor) + assert not (ctor or dtor) or cdtype is not None + + SendSemanticsType.__init__(self, (nested, nested), sendSemantics) + self.nested = nested + self.prio = prio + self.replyPrio = replyPrio + self.direction = direction + self.params = [] + self.returns = [] + self.ctor = ctor + self.dtor = dtor + self.cdtype = cdtype + self.compress = compress + self.tainted = tainted + self.lazySend = lazySend + + def isMessage(self): + return True + + def isCtor(self): + return self.ctor + + def isDtor(self): + return self.dtor + + def constructedType(self): + return self.cdtype + + def isIn(self): + return self.direction is IN + + def isOut(self): + return self.direction is OUT + + def isInout(self): + return self.direction is INOUT + + def hasReply(self): + return len(self.returns) or self.isSync() or self.isInterrupt() + + def hasImplicitActorParam(self): + return self.isCtor() + + +class ProtocolType(SendSemanticsType): + def __init__(self, qname, nested, sendSemantics, refcounted, needsotherpid): + SendSemanticsType.__init__(self, (NOT_NESTED, nested), sendSemantics) + self.qname = qname + self.managers = [] # ProtocolType + self.manages = [] + self.hasDelete = False + self.refcounted = refcounted + self.needsotherpid = needsotherpid + + def isProtocol(self): + return True + + def isRefcounted(self): + return self.refcounted + + def hasOtherPid(self): + return all(top.needsotherpid for top in self.toplevels()) + + def name(self): + return self.qname.baseid + + def fullname(self): + return str(self.qname) + + def addManager(self, mgrtype): + assert mgrtype.isIPDL() and mgrtype.isProtocol() + self.managers.append(mgrtype) + + def managedBy(self, mgr): + self.managers = list(mgr) + + def toplevel(self): + if self.isToplevel(): + return self + for mgr in self.managers: + if mgr is not self: + return mgr.toplevel() + + def toplevels(self): + if self.isToplevel(): + return [self] + toplevels = list() + for mgr in self.managers: + if mgr is not self: + toplevels.extend(mgr.toplevels()) + return set(toplevels) + + def isManagerOf(self, pt): + for managed in self.manages: + if pt is managed: + return True + return False + + def isManagedBy(self, pt): + return pt in self.managers + + def isManager(self): + return len(self.manages) > 0 + + def isManaged(self): + return 0 < len(self.managers) + + def isToplevel(self): + return not self.isManaged() + + def manager(self): + assert 1 == len(self.managers) + for mgr in self.managers: + return mgr + + +class ActorType(IPDLType): + def __init__(self, protocol): + self.protocol = protocol + + def isActor(self): + return True + + def isRefcounted(self): + return self.protocol.isRefcounted() + + def supportsNullable(self): + return True + + def name(self): + return self.protocol.name() + + def fullname(self): + return self.protocol.fullname() + + +class _CompoundType(IPDLType): + def __init__(self): + self.defined = False # bool + self.mutualRec = set() # set(_CompoundType | ArrayType) + + def isAtom(self): + return False + + def isCompound(self): + return True + + def itercomponents(self): + raise Exception('"pure virtual" method') + + def mutuallyRecursiveWith(self, t, exploring=None): + """|self| is mutually recursive with |t| iff |self| and |t| + are in a cycle in the type graph rooted at |self|. This function + looks for such a cycle and returns True if found.""" + if exploring is None: + exploring = set() + + if t.isAtom(): + return False + elif t is self or t in self.mutualRec: + return True + elif t.hasBaseType(): + isrec = self.mutuallyRecursiveWith(t.basetype, exploring) + if isrec: + self.mutualRec.add(t) + return isrec + elif t in exploring: + return False + + exploring.add(t) + for c in t.itercomponents(): + if self.mutuallyRecursiveWith(c, exploring): + self.mutualRec.add(c) + return True + exploring.remove(t) + + return False + + +class StructType(_CompoundType): + def __init__(self, qname, fields): + _CompoundType.__init__(self) + self.qname = qname + self.fields = fields # [ Type ] + + def isStruct(self): + return True + + def itercomponents(self): + for f in self.fields: + yield f + + def name(self): + return self.qname.baseid + + def fullname(self): + return str(self.qname) + + +class UnionType(_CompoundType): + def __init__(self, qname, components): + _CompoundType.__init__(self) + self.qname = qname + self.components = components # [ Type ] + + def isUnion(self): + return True + + def itercomponents(self): + for c in self.components: + yield c + + def name(self): + return self.qname.baseid + + def fullname(self): + return str(self.qname) + + +class ArrayType(IPDLType): + def __init__(self, basetype): + self.basetype = basetype + + def isAtom(self): + return False + + def isArray(self): + return True + + def hasBaseType(self): + return True + + def name(self): + return self.basetype.name() + "[]" + + def fullname(self): + return self.basetype.fullname() + "[]" + + +class MaybeType(IPDLType): + def __init__(self, basetype): + self.basetype = basetype + + def isAtom(self): + return False + + def isMaybe(self): + return True + + def hasBaseType(self): + return True + + def name(self): + return self.basetype.name() + "?" + + def fullname(self): + return self.basetype.fullname() + "?" + + +class ShmemType(IPDLType): + def __init__(self, qname): + self.qname = qname + + def isShmem(self): + return True + + def name(self): + return self.qname.baseid + + def fullname(self): + return str(self.qname) + + +class ByteBufType(IPDLType): + def __init__(self, qname): + self.qname = qname + + def isByteBuf(self): + return True + + def name(self): + return self.qname.baseid + + def fullname(self): + return str(self.qname) + + +class FDType(IPDLType): + def __init__(self, qname): + self.qname = qname + + def isFD(self): + return True + + def name(self): + return self.qname.baseid + + def fullname(self): + return str(self.qname) + + +class EndpointType(IPDLType): + def __init__(self, qname, actor): + self.qname = qname + self.actor = actor + + def isEndpoint(self): + return True + + def name(self): + return self.qname.baseid + + def fullname(self): + return str(self.qname) + + +class ManagedEndpointType(IPDLType): + def __init__(self, qname, actor): + self.qname = qname + self.actor = actor + + def isManagedEndpoint(self): + return True + + def name(self): + return self.qname.baseid + + def fullname(self): + return str(self.qname) + + +class UniquePtrType(IPDLType): + def __init__(self, basetype): + self.basetype = basetype + + def isAtom(self): + return False + + def isUniquePtr(self): + return True + + def hasBaseType(self): + return True + + def name(self): + return "UniquePtr<" + self.basetype.name() + ">" + + def fullname(self): + return "mozilla::UniquePtr<" + self.basetype.fullname() + ">" + + +class NotNullType(IPDLType): + def __init__(self, basetype): + self.basetype = basetype + + def isAtom(self): + return False + + def isNotNull(self): + return True + + def hasBaseType(self): + return True + + def name(self): + return "NotNull<" + self.basetype.name() + ">" + + def fullname(self): + return "mozilla::NotNull<" + self.basetype.fullname() + ">" + + +def iteractortypes(t, visited=None): + """Iterate over any actor(s) buried in |type|.""" + if visited is None: + visited = set() + + # XXX |yield| semantics makes it hard to use TypeVisitor + if not t.isIPDL(): + return + elif t.isActor(): + yield t + elif t.hasBaseType(): + for actor in iteractortypes(t.basetype, visited): + yield actor + elif t.isCompound() and t not in visited: + visited.add(t) + for c in t.itercomponents(): + for actor in iteractortypes(c, visited): + yield actor + + +def hasshmem(type): + """Return true iff |type| is shmem or has it buried within.""" + + class found(BaseException): + pass + + class findShmem(TypeVisitor): + def visitShmemType(self, s): + raise found() + + try: + type.accept(findShmem()) + except found: + return True + return False + + +# -------------------- +_builtinloc = Loc("<builtin>", 0) + + +def makeBuiltinUsing(tname): + quals = tname.split("::") + base = quals.pop() + quals = quals[0:] + return UsingStmt(_builtinloc, QualifiedId(_builtinloc, base, quals)) + + +builtinUsing = [makeBuiltinUsing(t) for t in builtin.Types] +builtinHeaderIncludes = [CxxInclude(_builtinloc, f) for f in builtin.HeaderIncludes] + + +def errormsg(loc, fmt, *args): + while not isinstance(loc, Loc): + if loc is None: + loc = Loc.NONE + else: + loc = loc.loc + return "%s: error: %s" % (str(loc), fmt % args) + + +# -------------------- + + +class SymbolTable: + def __init__(self, errors): + self.errors = errors + self.scopes = [{}] # stack({}) + self.currentScope = self.scopes[0] + + def enterScope(self): + assert isinstance(self.scopes[0], dict) + assert isinstance(self.currentScope, dict) + + self.scopes.append({}) + self.currentScope = self.scopes[-1] + + def exitScope(self): + symtab = self.scopes.pop() + assert self.currentScope is symtab + + self.currentScope = self.scopes[-1] + + assert isinstance(self.scopes[0], dict) + assert isinstance(self.currentScope, dict) + + def lookup(self, sym): + # NB: since IPDL doesn't allow any aliased names of different types, + # it doesn't matter in which order we walk the scope chain to resolve + # |sym| + for scope in self.scopes: + decl = scope.get(sym, None) + if decl is not None: + return decl + return None + + def declare(self, decl): + assert decl.progname or decl.shortname or decl.fullname + assert decl.loc + assert decl.type + + def tryadd(name): + olddecl = self.lookup(name) + if olddecl is not None: + self.errors.append( + errormsg( + decl.loc, + "redeclaration of symbol `%s', first declared at %s", + name, + olddecl.loc, + ) + ) + return + self.currentScope[name] = decl + decl.scope = self.currentScope + + if decl.progname: + tryadd(decl.progname) + if decl.shortname: + tryadd(decl.shortname) + if decl.fullname: + tryadd(decl.fullname) + + +class TypeCheck: + """This pass sets the .decl attribute of AST nodes for which that is relevant; + a decl says where, with what type, and under what name(s) a node was + declared. + + With this information, it type checks the AST.""" + + def __init__(self): + # NB: no IPDL compile will EVER print a warning. A program has + # one of two attributes: it is either well typed, or not well typed. + self.errors = [] # [ string ] + + def check(self, tu, errout=sys.stderr): + def runpass(tcheckpass): + tu.accept(tcheckpass) + if len(self.errors): + self.reportErrors(errout) + return False + return True + + # tag each relevant node with "decl" information, giving type, name, + # and location of declaration + if not runpass(GatherDecls(builtinUsing, self.errors)): + return False + + # now that the nodes have decls, type checking is much easier. + if not runpass(CheckTypes(self.errors)): + return False + + return True + + def reportErrors(self, errout): + for error in self.errors: + print(error, file=errout) + + +class TcheckVisitor(Visitor): + def __init__(self, errors): + self.errors = errors + + def error(self, loc, fmt, *args): + self.errors.append(errormsg(loc, fmt, *args)) + + +class GatherDecls(TcheckVisitor): + def __init__(self, builtinUsing, errors): + TcheckVisitor.__init__(self, errors) + + # |self.symtab| is the symbol table for the translation unit + # currently being visited + self.symtab = None + self.builtinUsing = builtinUsing + + def declare( + self, loc, type, shortname=None, fullname=None, progname=None, attributes={} + ): + d = Decl(loc) + d.type = type + d.progname = progname + d.shortname = shortname + d.fullname = fullname + d.attributes = attributes + self.symtab.declare(d) + return d + + # Check that only attributes allowed by an attribute spec are present + # within the given attribute dictionary. The spec value may be either + # `None`, for a valueless attribute, a list of valid attribute values, or a + # callable which returns a truthy value if the attribute is valid. + def checkAttributes(self, attributes, spec): + for attr in attributes.values(): + if attr.name not in spec: + self.error(attr.loc, "unknown attribute `%s'", attr.name) + continue + + aspec = spec[attr.name] + if aspec is None: + if attr.value is not None: + self.error( + attr.loc, + "unexpected value for valueless attribute `%s'", + attr.name, + ) + elif isinstance(aspec, (list, tuple)): + if not any( + isinstance(attr.value, s) + if isinstance(s, type) + else attr.value == s + for s in aspec + ): + self.error( + attr.loc, + "invalid value for attribute `%s', expected one of: %s", + attr.name, + ", ".join( + s.__name__ if isinstance(s, type) else str(s) for s in aspec + ), + ) + elif callable(aspec): + if not aspec(attr.value): + self.error(attr.loc, "invalid value for attribute `%s'", attr.name) + else: + raise Exception("INTERNAL ERROR: Invalid attribute spec") + + def visitTranslationUnit(self, tu): + # all TranslationUnits declare symbols in global scope + if hasattr(tu, "visited"): + return + tu.visited = True + savedSymtab = self.symtab + self.symtab = SymbolTable(self.errors) + + # pretend like the translation unit "using"-ed these for the + # sake of type checking and C++ code generation + tu.builtinUsing = self.builtinUsing + + # for everyone's sanity, enforce that the filename and tu name + # match + basefilename = os.path.basename(tu.filename) + expectedfilename = "%s.ipdl" % (tu.name) + if not tu.protocol: + # header + expectedfilename += "h" + if basefilename != expectedfilename: + self.error( + tu.loc, + "expected file for translation unit `%s' to be named `%s'; instead it's named `%s'", # NOQA: E501 + tu.name, + expectedfilename, + basefilename, + ) + + if tu.protocol: + assert tu.name == tu.protocol.name + + p = tu.protocol + + self.checkAttributes( + p.attributes, + { + "ManualDealloc": None, + "NestedUpTo": ("not", "inside_sync", "inside_cpow"), + "NeedsOtherPid": None, + "ChildImpl": ("virtual", StringLiteral), + "ParentImpl": ("virtual", StringLiteral), + }, + ) + + # FIXME/cjones: it's a little weird and counterintuitive + # to put both the namespace and non-namespaced name in the + # global scope. try to figure out something better; maybe + # a type-neutral |using| that works for C++ and protocol + # types? + qname = p.qname() + fullname = str(qname) + p.decl = self.declare( + loc=p.loc, + type=ProtocolType( + qname, + p.nestedUpTo(), + p.sendSemantics, + "ManualDealloc" not in p.attributes, + "NeedsOtherPid" in p.attributes, + ), + shortname=p.name, + fullname=fullname, + ) + + p.parentEndpointDecl = self.declare( + loc=p.loc, + type=EndpointType( + QualifiedId( + p.loc, "Endpoint<" + fullname + "Parent>", ["mozilla", "ipc"] + ), + ActorType(p.decl.type), + ), + shortname="Endpoint<" + p.name + "Parent>", + ) + p.childEndpointDecl = self.declare( + loc=p.loc, + type=EndpointType( + QualifiedId( + p.loc, "Endpoint<" + fullname + "Child>", ["mozilla", "ipc"] + ), + ActorType(p.decl.type), + ), + shortname="Endpoint<" + p.name + "Child>", + ) + + p.parentManagedEndpointDecl = self.declare( + loc=p.loc, + type=ManagedEndpointType( + QualifiedId( + p.loc, + "ManagedEndpoint<" + fullname + "Parent>", + ["mozilla", "ipc"], + ), + ActorType(p.decl.type), + ), + shortname="ManagedEndpoint<" + p.name + "Parent>", + ) + p.childManagedEndpointDecl = self.declare( + loc=p.loc, + type=ManagedEndpointType( + QualifiedId( + p.loc, + "ManagedEndpoint<" + fullname + "Child>", + ["mozilla", "ipc"], + ), + ActorType(p.decl.type), + ), + shortname="ManagedEndpoint<" + p.name + "Child>", + ) + + # XXX ugh, this sucks. but we need this information to compute + # what friend decls we need in generated C++ + p.decl.type._ast = p + + # make sure we have decls for all dependent protocols + for pinc in tu.includes: + pinc.accept(self) + + # declare imported (and builtin) C and C++ types + for ctype in builtin.CTypes: + self.declare( + loc=_builtinloc, + type=BuiltinCType(ctype), + shortname=ctype, + ) + for using in tu.builtinUsing: + using.accept(self) + for using in tu.using: + using.accept(self) + + # first pass to "forward-declare" all structs and unions in + # order to support recursive definitions + for su in tu.structsAndUnions: + self.declareStructOrUnion(su) + + # second pass to check each definition + for su in tu.structsAndUnions: + su.accept(self) + + if tu.protocol: + # grab symbols in the protocol itself + p.accept(self) + + self.symtab = savedSymtab + + def declareStructOrUnion(self, su): + if hasattr(su, "decl"): + self.symtab.declare(su.decl) + return + + qname = su.qname() + fullname = str(qname) + + if isinstance(su, StructDecl): + sutype = StructType(qname, []) + elif isinstance(su, UnionDecl): + sutype = UnionType(qname, []) + else: + assert 0 and "unknown type" + + # XXX more suckage. this time for pickling structs/unions + # declared in headers. + sutype._ast = su + + su.decl = self.declare( + loc=su.loc, type=sutype, shortname=su.name, fullname=fullname + ) + + def visitInclude(self, inc): + if inc.tu is None: + self.error( + inc.loc, + "(type checking here will be unreliable because of an earlier error)", + ) + return + inc.tu.accept(self) + if inc.tu.protocol: + self.symtab.declare(inc.tu.protocol.decl) + self.symtab.declare(inc.tu.protocol.parentEndpointDecl) + self.symtab.declare(inc.tu.protocol.childEndpointDecl) + self.symtab.declare(inc.tu.protocol.parentManagedEndpointDecl) + self.symtab.declare(inc.tu.protocol.childManagedEndpointDecl) + else: + # This is a header. Import its "exported" globals into + # our scope. + for using in inc.tu.using: + using.accept(self) + for su in inc.tu.structsAndUnions: + self.declareStructOrUnion(su) + + def visitStructDecl(self, sd): + # If we've already processed this struct, don't do it again. + if hasattr(sd, "visited"): + return + + stype = sd.decl.type + + self.symtab.enterScope() + sd.visited = True + + self.checkAttributes(sd.attributes, {"Comparable": None}) + + for f in sd.fields: + ftypedecl = self.symtab.lookup(str(f.typespec)) + if ftypedecl is None: + self.error( + f.loc, + "field `%s' of struct `%s' has unknown type `%s'", + f.name, + sd.name, + str(f.typespec), + ) + continue + + f.decl = self.declare( + loc=f.loc, + type=self._canonicalType(ftypedecl.type, f.typespec), + shortname=f.name, + fullname=None, + ) + stype.fields.append(f.decl.type) + + self.symtab.exitScope() + + def visitUnionDecl(self, ud): + utype = ud.decl.type + + # If we've already processed this union, don't do it again. + if len(utype.components): + return + + self.checkAttributes(ud.attributes, {"Comparable": None}) + + for c in ud.components: + cdecl = self.symtab.lookup(str(c)) + if cdecl is None: + self.error( + c.loc, "unknown component type `%s' of union `%s'", str(c), ud.name + ) + continue + utype.components.append(self._canonicalType(cdecl.type, c)) + + def visitUsingStmt(self, using): + fullname = str(using.type) + + self.checkAttributes( + using.attributes, + { + "MoveOnly": (None, "data", "send"), + "RefCounted": None, + }, + ) + + if fullname == "::mozilla::ipc::Shmem": + ipdltype = ShmemType(using.type) + elif fullname == "::mozilla::ipc::ByteBuf": + ipdltype = ByteBufType(using.type) + elif fullname == "::mozilla::ipc::FileDescriptor": + ipdltype = FDType(using.type) + else: + ipdltype = ImportedCxxType( + using.type, + using.isRefcounted(), + using.isSendMoveOnly(), + using.isDataMoveOnly(), + ) + existingType = self.symtab.lookup(ipdltype.fullname()) + if existingType and existingType.fullname == ipdltype.fullname(): + if ipdltype.isRefcounted() != existingType.type.isRefcounted(): + self.error( + using.loc, + "inconsistent refcounted status of type `%s`", + str(using.type), + ) + if ( + ipdltype.isSendMoveOnly() != existingType.type.isSendMoveOnly() + or ipdltype.isDataMoveOnly() != existingType.type.isDataMoveOnly() + ): + self.error( + using.loc, + "inconsistent moveonly status of type `%s`", + str(using.type), + ) + using.decl = existingType + return + using.decl = self.declare( + loc=using.loc, + type=ipdltype, + shortname=using.type.baseid, + fullname=fullname, + ) + + def visitProtocol(self, p): + # protocol scope + self.symtab.enterScope() + + seenmgrs = set() + for mgr in p.managers: + if mgr.name in seenmgrs: + self.error(mgr.loc, "manager `%s' appears multiple times", mgr.name) + continue + + seenmgrs.add(mgr.name) + mgr.of = p + mgr.accept(self) + + for managed in p.managesStmts: + managed.manager = p + managed.accept(self) + + if not (p.managers or p.messageDecls or p.managesStmts): + self.error(p.loc, "top-level protocol `%s' cannot be empty", p.name) + + setattr(self, "currentProtocolDecl", p.decl) + for msg in p.messageDecls: + msg.accept(self) + del self.currentProtocolDecl + + p.decl.type.hasDelete = not not self.symtab.lookup(_DELETE_MSG) + if not (p.decl.type.hasDelete or p.decl.type.isToplevel()): + self.error( + p.loc, + "destructor declaration `%s(...)' required for managed protocol `%s'", + _DELETE_MSG, + p.name, + ) + + if not p.decl.type.isToplevel() and p.decl.type.needsotherpid: + self.error(p.loc, "[NeedsOtherPid] only applies to toplevel protocols") + + if p.decl.type.isToplevel() and not p.decl.type.isRefcounted(): + self.error(p.loc, "Toplevel protocols cannot be [ManualDealloc]") + + # FIXME/cjones declare all the little C++ thingies that will + # be generated. they're not relevant to IPDL itself, but + # those ("invisible") symbols can clash with others in the + # IPDL spec, and we'd like to catch those before C++ compilers + # are allowed to obfuscate the error + + self.symtab.exitScope() + + def visitManager(self, mgr): + mgrdecl = self.symtab.lookup(mgr.name) + pdecl = mgr.of.decl + assert pdecl + + pname, mgrname = pdecl.shortname, mgr.name + loc = mgr.loc + + if mgrdecl is None: + self.error( + loc, + "protocol `%s' referenced as |manager| of `%s' has not been declared", + mgrname, + pname, + ) + elif not isinstance(mgrdecl.type, ProtocolType): + self.error( + loc, + "entity `%s' referenced as |manager| of `%s' is not of `protocol' type; instead it is of type `%s'", # NOQA: E501 + mgrname, + pname, + mgrdecl.type.typename(), + ) + else: + mgr.decl = mgrdecl + pdecl.type.addManager(mgrdecl.type) + + def visitManagesStmt(self, mgs): + mgsdecl = self.symtab.lookup(mgs.name) + pdecl = mgs.manager.decl + assert pdecl + + pname, mgsname = pdecl.shortname, mgs.name + loc = mgs.loc + + if mgsdecl is None: + self.error( + loc, + "protocol `%s', managed by `%s', has not been declared", + mgsname, + pname, + ) + elif not isinstance(mgsdecl.type, ProtocolType): + self.error( + loc, + "%s declares itself managing a non-`protocol' entity `%s' of type `%s'", + pname, + mgsname, + mgsdecl.type.typename(), + ) + else: + mgs.decl = mgsdecl + pdecl.type.manages.append(mgsdecl.type) + + def visitMessageDecl(self, md): + msgname = md.name + loc = md.loc + + self.checkAttributes( + md.attributes, + { + "Tainted": None, + "Compress": (None, "all"), + "Priority": priorityList, + "ReplyPriority": priorityList, + "Nested": ("not", "inside_sync", "inside_cpow"), + "LegacyIntr": None, + "VirtualSendImpl": None, + "LazySend": None, + }, + ) + + if md.sendSemantics is INTR and "LegacyIntr" not in md.attributes: + self.error( + loc, + "intr message `%s' allowed only with [LegacyIntr]; DO NOT USE IN SHIPPING CODE", + msgname, + ) + + if md.sendSemantics is INTR and "Priority" in md.attributes: + self.error(loc, "intr message `%s' cannot specify [Priority]", msgname) + + if md.sendSemantics is INTR and "Nested" in md.attributes: + self.error(loc, "intr message `%s' cannot specify [Nested]", msgname) + + if md.sendSemantics is not ASYNC and "LazySend" in md.attributes: + self.error(loc, "non-async message `%s' cannot specify [LazySend]", msgname) + + if md.sendSemantics is not ASYNC and "ReplyPriority" in md.attributes: + self.error( + loc, "non-async message `%s' cannot specify [ReplyPriority]", msgname + ) + + if not md.outParams and "ReplyPriority" in md.attributes: + self.error( + loc, "non-returns message `%s' cannot specify [ReplyPriority]", msgname + ) + + isctor = False + isdtor = False + cdtype = None + + decl = self.symtab.lookup(msgname) + if decl is not None and decl.type.isProtocol(): + # probably a ctor. we'll check validity later. + msgname += "Constructor" + isctor = True + cdtype = decl.type + elif decl is not None: + self.error( + loc, + "message name `%s' already declared as `%s'", + msgname, + decl.type.typename(), + ) + # if we error here, no big deal; move on to find more + + if _DELETE_MSG == msgname: + isdtor = True + cdtype = self.currentProtocolDecl.type + + # enter message scope + self.symtab.enterScope() + + msgtype = MessageType( + nested=md.nested(), + prio=md.priority(), + replyPrio=md.replyPriority(), + sendSemantics=md.sendSemantics, + direction=md.direction, + ctor=isctor, + dtor=isdtor, + cdtype=cdtype, + compress=md.attributes.get("Compress"), + tainted="Tainted" in md.attributes, + lazySend="LazySend" in md.attributes, + ) + + # replace inparam Param nodes with proper Decls + def paramToDecl(param): + self.checkAttributes( + param.attributes, + { + # Passback indicates that the argument is unused by the Parent and is + # merely returned to the Child later. + # AllValid indicates that the entire span of values representable by + # the type are acceptable. e.g. 0-255 in a uint8 + "NoTaint": ("passback", "allvalid") + }, + ) + + ptname = param.typespec.basename() + ploc = param.typespec.loc + + if "NoTaint" in param.attributes and "Tainted" not in md.attributes: + self.error( + ploc, + "argument typename `%s' of message `%s' has a NoTaint attribute, but the message lacks the Tainted attribute", + ptname, + msgname, + ) + + ptdecl = self.symtab.lookup(ptname) + if ptdecl is None: + self.error( + ploc, + "argument typename `%s' of message `%s' has not been declared", + ptname, + msgname, + ) + ptype = VOID + else: + ptype = self._canonicalType(ptdecl.type, param.typespec) + return self.declare( + loc=ploc, type=ptype, progname=param.name, attributes=param.attributes + ) + + for i, inparam in enumerate(md.inParams): + pdecl = paramToDecl(inparam) + msgtype.params.append(pdecl.type) + md.inParams[i] = pdecl + for i, outparam in enumerate(md.outParams): + pdecl = paramToDecl(outparam) + msgtype.returns.append(pdecl.type) + md.outParams[i] = pdecl + + self.symtab.exitScope() + + md.decl = self.declare(loc=loc, type=msgtype, progname=msgname) + md.protocolDecl = self.currentProtocolDecl + md.decl._md = md + + def _canonicalType(self, itype, typespec): + loc = typespec.loc + if typespec.uniqueptr: + itype = UniquePtrType(itype) + + if itype.isIPDL() and itype.isProtocol(): + itype = ActorType(itype) + + if itype.supportsNullable(): + if not typespec.nullable: + itype = NotNullType(itype) + elif typespec.nullable: + self.error( + loc, "`nullable' qualifier for type `%s' is unsupported", itype.name() + ) + + if typespec.array: + itype = ArrayType(itype) + + if typespec.maybe: + itype = MaybeType(itype) + + return itype + + +# ----------------------------------------------------------------------------- + + +def checkcycles(p, stack=None): + cycles = [] + + if stack is None: + stack = [] + + for cp in p.manages: + # special case for self-managed protocols + if cp is p: + continue + + if cp in stack: + return [stack + [p, cp]] + cycles += checkcycles(cp, stack + [p]) + + return cycles + + +def formatcycles(cycles): + r = [] + for cycle in cycles: + s = " -> ".join([ptype.name() for ptype in cycle]) + r.append("`%s'" % s) + return ", ".join(r) + + +def fullyDefined(t, exploring=None): + """The rules for "full definition" of a type are + defined(atom) := true + defined(array basetype) := defined(basetype) + defined(struct f1 f2...) := defined(f1) and defined(f2) and ... + defined(union c1 c2 ...) := defined(c1) or defined(c2) or ...""" + if exploring is None: + exploring = set() + + if t.isAtom(): + return True + elif t.hasBaseType(): + return fullyDefined(t.basetype, exploring) + elif t.defined: + return True + assert t.isCompound() + + if t in exploring: + return False + + exploring.add(t) + for c in t.itercomponents(): + cdefined = fullyDefined(c, exploring) + if t.isStruct() and not cdefined: + t.defined = False + break + elif t.isUnion() and cdefined: + t.defined = True + break + else: + if t.isStruct(): + t.defined = True + elif t.isUnion(): + t.defined = False + exploring.remove(t) + + return t.defined + + +class CheckTypes(TcheckVisitor): + def __init__(self, errors): + TcheckVisitor.__init__(self, errors) + self.visited = set() + self.ptype = None + + def visitInclude(self, inc): + if inc.tu.filename in self.visited: + return + self.visited.add(inc.tu.filename) + if inc.tu.protocol: + inc.tu.protocol.accept(self) + + def visitStructDecl(self, sd): + if not fullyDefined(sd.decl.type): + self.error(sd.decl.loc, "struct `%s' is only partially defined", sd.name) + + def visitUnionDecl(self, ud): + if not fullyDefined(ud.decl.type): + self.error(ud.decl.loc, "union `%s' is only partially defined", ud.name) + + def visitProtocol(self, p): + self.ptype = p.decl.type + + # check that we require no more "power" than our manager protocols + ptype, pname = p.decl.type, p.decl.shortname + + for mgrtype in ptype.managers: + if mgrtype is not None and not ptype.sendSemanticsSatisfiedBy(mgrtype): + self.error( + p.decl.loc, + "protocol `%s' requires more powerful send semantics than its manager `%s' provides", # NOQA: E501 + pname, + mgrtype.name(), + ) + + if ptype.isInterrupt() and ptype.nestedRange != (NOT_NESTED, NOT_NESTED): + self.error( + p.decl.loc, "intr protocol `%s' cannot specify [NestedUpTo]", p.name + ) + + if ptype.isToplevel(): + cycles = checkcycles(p.decl.type) + if cycles: + self.error( + p.decl.loc, + "cycle(s) detected in manager/manages hierarchy: %s", + formatcycles(cycles), + ) + + if 1 == len(ptype.managers) and ptype is ptype.manager(): + self.error( + p.decl.loc, "top-level protocol `%s' cannot manage itself", p.name + ) + + return Visitor.visitProtocol(self, p) + + def visitManagesStmt(self, mgs): + pdecl = mgs.manager.decl + ptype, pname = pdecl.type, pdecl.shortname + + mgsdecl = mgs.decl + mgstype, mgsname = mgsdecl.type, mgsdecl.shortname + + loc = mgs.loc + + # we added this information; sanity check it + assert ptype.isManagerOf(mgstype) + + # check that the "managed" protocol agrees + if not mgstype.isManagedBy(ptype): + self.error( + loc, + "|manages| declaration in protocol `%s' does not match any |manager| declaration in protocol `%s'", # NOQA: E501 + pname, + mgsname, + ) + + def visitManager(self, mgr): + pdecl = mgr.of.decl + ptype, pname = pdecl.type, pdecl.shortname + + mgrdecl = mgr.decl + mgrtype, mgrname = mgrdecl.type, mgrdecl.shortname + + # we added this information; sanity check it + assert ptype.isManagedBy(mgrtype) + + loc = mgr.loc + + # check that the "manager" protocol agrees + if not mgrtype.isManagerOf(ptype): + self.error( + loc, + "|manager| declaration in protocol `%s' does not match any |manages| declaration in protocol `%s'", # NOQA: E501 + pname, + mgrname, + ) + + def visitMessageDecl(self, md): + mtype, mname = md.decl.type, md.decl.progname + ptype, pname = md.protocolDecl.type, md.protocolDecl.shortname + + loc = md.decl.loc + + if mtype.nested == INSIDE_SYNC_NESTED and not mtype.isSync(): + self.error( + loc, + "inside_sync nested messages must be sync (here, message `%s' in protocol `%s')", + mname, + pname, + ) + + if mtype.nested == INSIDE_CPOW_NESTED and (mtype.isOut() or mtype.isInout()): + self.error( + loc, + "inside_cpow nested parent-to-child messages are verboten (here, message `%s' in protocol `%s')", # NOQA: E501 + mname, + pname, + ) + + # We allow inside_sync messages that are themselves sync to be sent from the + # parent. Normal and inside_cpow nested messages that are sync can only come from + # the child. + if ( + mtype.isSync() + and mtype.nested == NOT_NESTED + and (mtype.isOut() or mtype.isInout()) + ): + self.error( + loc, + "sync parent-to-child messages are verboten (here, message `%s' in protocol `%s')", + mname, + pname, + ) + + if not mtype.sendSemanticsSatisfiedBy(ptype): + self.error( + loc, + "message `%s' requires more powerful send semantics than its protocol `%s' provides", # NOQA: E501 + mname, + pname, + ) + + if (mtype.isCtor() or mtype.isDtor()) and mtype.isAsync() and mtype.returns: + self.error( + loc, "asynchronous ctor/dtor message `%s' declares return values", mname + ) + + if mtype.compress and (not mtype.isAsync() or mtype.isCtor() or mtype.isDtor()): + + if mtype.isCtor() or mtype.isDtor(): + message_type = "constructor" if mtype.isCtor() else "destructor" + error_message = ( + "%s messages can't use compression (here, in protocol `%s')" + % (message_type, pname) + ) + else: + error_message = ( + "message `%s' in protocol `%s' requests compression but is not async" + % (mname, pname) # NOQA: E501 + ) + + self.error(loc, error_message) + + if mtype.isCtor() and not ptype.isManagerOf(mtype.constructedType()): + self.error( + loc, + "ctor for protocol `%s', which is not managed by protocol `%s'", + mname[: -len("constructor")], + pname, + ) diff --git a/ipc/ipdl/ipdl/util.py b/ipc/ipdl/ipdl/util.py new file mode 100644 index 0000000000..60d9c904e2 --- /dev/null +++ b/ipc/ipdl/ipdl/util.py @@ -0,0 +1,12 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import zlib + + +# The built-in hash over the str type is non-deterministic, so we need to do +# this instead. +def hash_str(s): + assert isinstance(s, str) + return zlib.adler32(s.encode("utf-8")) |