diff options
Diffstat (limited to 'ipc/ipdl/ipdl/lower.py')
-rw-r--r-- | ipc/ipdl/ipdl/lower.py | 5688 |
1 files changed, 5688 insertions, 0 deletions
diff --git a/ipc/ipdl/ipdl/lower.py b/ipc/ipdl/ipdl/lower.py new file mode 100644 index 0000000000..b8ef219b9b --- /dev/null +++ b/ipc/ipdl/ipdl/lower.py @@ -0,0 +1,5688 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +import re +from copy import deepcopy +from collections import OrderedDict +import itertools + +import ipdl.ast +import ipdl.builtin +from ipdl.cxx.ast import * +from ipdl.cxx.code import * +from ipdl.type import ActorType, UnionType, TypeVisitor, builtinHeaderIncludes +from ipdl.util import hash_str + + +# ----------------------------------------------------------------------------- +# "Public" interface to lowering +## + + +class LowerToCxx: + def lower(self, tu, segmentcapacitydict): + """returns |[ header: File ], [ cpp : File ]| representing the + lowered form of |tu|""" + # annotate the AST with IPDL/C++ IR-type stuff used later + tu.accept(_DecorateWithCxxStuff()) + + # Any modifications to the filename scheme here need corresponding + # modifications in the ipdl.py driver script. + name = tu.name + pheader, pcpp = File(name + ".h"), File(name + ".cpp") + + _GenerateProtocolCode().lower(tu, pheader, pcpp, segmentcapacitydict) + headers = [pheader] + cpps = [pcpp] + + if tu.protocol: + pname = tu.protocol.name + + parentheader, parentcpp = ( + File(pname + "Parent.h"), + File(pname + "Parent.cpp"), + ) + _GenerateProtocolParentCode().lower( + tu, pname + "Parent", parentheader, parentcpp + ) + + childheader, childcpp = File(pname + "Child.h"), File(pname + "Child.cpp") + _GenerateProtocolChildCode().lower( + tu, pname + "Child", childheader, childcpp + ) + + headers += [parentheader, childheader] + cpps += [parentcpp, childcpp] + + return headers, cpps + + +# ----------------------------------------------------------------------------- +# Helper code +## + + +def hashfunc(value): + h = hash_str(value) % 2 ** 32 + if h < 0: + h += 2 ** 32 + return h + + +_NULL_ACTOR_ID = ExprLiteral.ZERO +_FREED_ACTOR_ID = ExprLiteral.ONE + +_DISCLAIMER = Whitespace( + """// +// Automatically generated by ipdlc. +// Edit at your own risk +// + +""" +) + + +class _struct: + pass + + +def _namespacedHeaderName(name, namespaces): + pfx = "/".join([ns.name for ns in namespaces]) + if pfx: + return pfx + "/" + name + else: + return name + + +def _ipdlhHeaderName(tu): + assert tu.filetype == "header" + return _namespacedHeaderName(tu.name, tu.namespaces) + + +def _protocolHeaderName(p, side=""): + if side: + side = side.title() + base = p.name + side + return _namespacedHeaderName(base, p.namespaces) + + +def _includeGuardMacroName(headerfile): + return re.sub(r"[./]", "_", headerfile.name) + + +def _includeGuardStart(headerfile): + guard = _includeGuardMacroName(headerfile) + return [CppDirective("ifndef", guard), CppDirective("define", guard)] + + +def _includeGuardEnd(headerfile): + guard = _includeGuardMacroName(headerfile) + return [CppDirective("endif", "// ifndef " + guard)] + + +def _messageStartName(ptype): + return ptype.name() + "MsgStart" + + +def _protocolId(ptype): + return ExprVar(_messageStartName(ptype)) + + +def _protocolIdType(): + return Type.INT32 + + +def _actorName(pname, side): + """|pname| is the protocol name. |side| is 'Parent' or 'Child'.""" + tag = side + if not tag[0].isupper(): + tag = side.title() + return pname + tag + + +def _actorIdType(): + return Type.INT32 + + +def _actorTypeTagType(): + return Type.INT32 + + +def _actorId(actor=None): + if actor is not None: + return ExprCall(ExprSelect(actor, "->", "Id")) + return ExprCall(ExprVar("Id")) + + +def _actorHId(actorhandle): + return ExprSelect(actorhandle, ".", "mId") + + +def _backstagePass(): + return ExprCall(ExprVar("mozilla::ipc::PrivateIPDLInterface")) + + +def _deleteId(): + return ExprVar("Msg___delete____ID") + + +def _deleteReplyId(): + return ExprVar("Reply___delete____ID") + + +def _lookupListener(idexpr): + return ExprCall(ExprVar("Lookup"), args=[idexpr]) + + +def _makeForwardDeclForQClass(clsname, quals, cls=True, struct=False): + fd = ForwardDecl(clsname, cls=cls, struct=struct) + if 0 == len(quals): + return fd + + outerns = Namespace(quals[0]) + innerns = outerns + for ns in quals[1:]: + tmpns = Namespace(ns) + innerns.addstmt(tmpns) + innerns = tmpns + + innerns.addstmt(fd) + return outerns + + +def _makeForwardDeclForActor(ptype, side): + return _makeForwardDeclForQClass( + _actorName(ptype.qname.baseid, side), ptype.qname.quals + ) + + +def _makeForwardDecl(type): + return _makeForwardDeclForQClass(type.name(), type.qname.quals) + + +def _putInNamespaces(cxxthing, namespaces): + """|namespaces| is in order [ outer, ..., inner ]""" + if 0 == len(namespaces): + return cxxthing + + outerns = Namespace(namespaces[0].name) + innerns = outerns + for ns in namespaces[1:]: + newns = Namespace(ns.name) + innerns.addstmt(newns) + innerns = newns + innerns.addstmt(cxxthing) + return outerns + + +def _sendPrefix(msgtype): + """Prefix of the name of the C++ method that sends |msgtype|.""" + if msgtype.isInterrupt(): + return "Call" + return "Send" + + +def _recvPrefix(msgtype): + """Prefix of the name of the C++ method that handles |msgtype|.""" + if msgtype.isInterrupt(): + return "Answer" + return "Recv" + + +def _flatTypeName(ipdltype): + """Return a 'flattened' IPDL type name that can be used as an + identifier. + E.g., |Foo[]| --> |ArrayOfFoo|.""" + # NB: this logic depends heavily on what IPDL types are allowed to + # be constructed; e.g., Foo[][] is disallowed. needs to be kept in + # sync with grammar. + if ipdltype.isIPDL() and ipdltype.isArray(): + return "ArrayOf" + _flatTypeName(ipdltype.basetype) + if ipdltype.isIPDL() and ipdltype.isMaybe(): + return "Maybe" + _flatTypeName(ipdltype.basetype) + # NotNull types just assume the underlying variant name to avoid unnecessary + # noise, as a NotNull<T> and T should never exist in the same union. + if ipdltype.isIPDL() and ipdltype.isNotNull(): + return _flatTypeName(ipdltype.basetype) + return ipdltype.name() + + +def _hasVisibleActor(ipdltype): + """Return true iff a C++ decl of |ipdltype| would have an Actor* type. + For example: |Actor[]| would turn into |Array<ActorParent*>|, so this + function would return true for |Actor[]|.""" + return ipdltype.isIPDL() and ( + ipdltype.isActor() + or (ipdltype.hasBaseType() and _hasVisibleActor(ipdltype.basetype)) + ) + + +def _abortIfFalse(cond, msg): + return StmtExpr( + ExprCall(ExprVar("MOZ_RELEASE_ASSERT"), [cond, ExprLiteral.String(msg)]) + ) + + +def _refptr(T): + return Type("RefPtr", T=T) + + +def _alreadyaddrefed(T): + return Type("already_AddRefed", T=T) + + +def _tuple(types, const=False, ref=False): + return Type("std::tuple", T=types, const=const, ref=ref) + + +def _promise(resolvetype, rejecttype, tail, resolver=False): + inner = Type("Private") if resolver else None + return Type("MozPromise", T=[resolvetype, rejecttype, tail], inner=inner) + + +def _makePromise(returns, side, resolver=False): + if len(returns) > 1: + resolvetype = _tuple([d.bareType(side) for d in returns]) + else: + resolvetype = returns[0].bareType(side) + + # MozPromise is purposefully made to be exclusive only. Really, we mean it. + return _promise( + resolvetype, _ResponseRejectReason.Type(), ExprLiteral.TRUE, resolver=resolver + ) + + +def _resolveType(returns, side): + if len(returns) > 1: + return _tuple([d.inType(side, "send") for d in returns]) + return returns[0].inType(side, "send") + + +def _makeResolver(returns, side): + return TypeFunction([Decl(_resolveType(returns, side), "")]) + + +def _cxxArrayType(basetype, const=False, ref=False): + return Type("nsTArray", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False) + + +def _cxxSpanType(basetype, const=False, ref=False): + basetype = deepcopy(basetype) + basetype.rightconst = True + return Type( + "mozilla::Span", T=basetype, const=const, ref=ref, hasimplicitcopyctor=True + ) + + +def _cxxMaybeType(basetype, const=False, ref=False): + return Type( + "mozilla::Maybe", + T=basetype, + const=const, + ref=ref, + hasimplicitcopyctor=basetype.hasimplicitcopyctor, + ) + + +def _cxxReadResultType(basetype, const=False, ref=False): + return Type( + "IPC::ReadResult", + T=basetype, + const=const, + ref=ref, + hasimplicitcopyctor=basetype.hasimplicitcopyctor, + ) + + +def _cxxNotNullType(basetype, const=False, ref=False): + return Type( + "mozilla::NotNull", + T=basetype, + const=const, + ref=ref, + hasimplicitcopyctor=basetype.hasimplicitcopyctor, + ) + + +def _cxxManagedContainerType(basetype, const=False, ref=False): + return Type( + "ManagedContainer", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False + ) + + +def _cxxLifecycleProxyType(ptr=False): + return Type("mozilla::ipc::ActorLifecycleProxy", ptr=ptr) + + +def _otherSide(side): + if side == "child": + return "parent" + if side == "parent": + return "child" + assert 0 + + +def _ifLogging(topLevelProtocol, stmts): + return StmtCode( + """ + if (mozilla::ipc::LoggingEnabledFor(${proto})) { + $*{stmts} + } + """, + proto=topLevelProtocol, + stmts=stmts, + ) + + +# XXX we need to remove these and install proper error handling + + +def _printErrorMessage(msg): + if isinstance(msg, str): + msg = ExprLiteral.String(msg) + return StmtExpr(ExprCall(ExprVar("NS_ERROR"), args=[msg])) + + +def _protocolErrorBreakpoint(msg): + if isinstance(msg, str): + msg = ExprLiteral.String(msg) + return StmtExpr( + ExprCall(ExprVar("mozilla::ipc::ProtocolErrorBreakpoint"), args=[msg]) + ) + + +def _printWarningMessage(msg): + if isinstance(msg, str): + msg = ExprLiteral.String(msg) + return StmtExpr(ExprCall(ExprVar("NS_WARNING"), args=[msg])) + + +def _fatalError(msg): + return StmtExpr(ExprCall(ExprVar("FatalError"), args=[ExprLiteral.String(msg)])) + + +def _logicError(msg): + return StmtExpr( + ExprCall(ExprVar("mozilla::ipc::LogicError"), args=[ExprLiteral.String(msg)]) + ) + + +def _sentinelReadError(classname): + return StmtExpr( + ExprCall( + ExprVar("mozilla::ipc::SentinelReadError"), + args=[ExprLiteral.String(classname)], + ) + ) + + +# Results that IPDL-generated code returns back to *Channel code. +# Users never see these + + +class _Result: + @staticmethod + def Type(): + return Type("Result") + + Processed = ExprVar("MsgProcessed") + NotKnown = ExprVar("MsgNotKnown") + NotAllowed = ExprVar("MsgNotAllowed") + PayloadError = ExprVar("MsgPayloadError") + ProcessingError = ExprVar("MsgProcessingError") + RouteError = ExprVar("MsgRouteError") + ValuError = ExprVar("MsgValueError") # [sic] + + +# these |errfn*| are functions that generate code to be executed on an +# error, such as "bad actor ID". each is given a Python string +# containing a description of the error + +# used in user-facing Send*() methods + + +def errfnSend(msg, errcode=ExprLiteral.FALSE): + return [_fatalError(msg), StmtReturn(errcode)] + + +def errfnSendCtor(msg): + return errfnSend(msg, errcode=ExprLiteral.NULL) + + +# TODO should this error handling be strengthened for dtors? + + +def errfnSendDtor(msg): + return [_printErrorMessage(msg), StmtReturn.FALSE] + + +# used in |OnMessage*()| handlers that hand in-messages off to Recv*() +# interface methods + + +def errfnRecv(msg, errcode=_Result.ValuError): + return [_fatalError(msg), StmtReturn(errcode)] + + +def errfnSentinel(rvalue=ExprLiteral.FALSE): + def inner(msg): + return [_sentinelReadError(msg), StmtReturn(rvalue)] + + return inner + + +def _destroyMethod(): + return ExprVar("ActorDestroy") + + +def errfnUnreachable(msg): + return [_logicError(msg)] + + +def readResultError(): + return ExprCode("{}") + + +class _DestroyReason: + @staticmethod + def Type(): + return Type("ActorDestroyReason") + + Deletion = ExprVar("Deletion") + AncestorDeletion = ExprVar("AncestorDeletion") + NormalShutdown = ExprVar("NormalShutdown") + AbnormalShutdown = ExprVar("AbnormalShutdown") + FailedConstructor = ExprVar("FailedConstructor") + ManagedEndpointDropped = ExprVar("ManagedEndpointDropped") + + +class _ResponseRejectReason: + @staticmethod + def Type(): + return Type("ResponseRejectReason") + + SendError = ExprVar("ResponseRejectReason::SendError") + ChannelClosed = ExprVar("ResponseRejectReason::ChannelClosed") + HandlerRejected = ExprVar("ResponseRejectReason::HandlerRejected") + ActorDestroyed = ExprVar("ResponseRejectReason::ActorDestroyed") + + +# ----------------------------------------------------------------------------- +# Intermediate representation (IR) nodes used during lowering + + +class _ConvertToCxxType(TypeVisitor): + def __init__(self, side, fq): + self.side = side + self.fq = fq + + def typename(self, thing): + if self.fq: + return thing.fullname() + return thing.name() + + def visitImportedCxxType(self, t): + cxxtype = Type(self.typename(t)) + if t.isRefcounted(): + cxxtype = _refptr(cxxtype) + return cxxtype + + def visitBuiltinCType(self, b): + return Type(self.typename(b)) + + def visitActorType(self, a): + if self.side is None: + return Type( + "::mozilla::ipc::SideVariant", + T=[ + _cxxBareType(a, "parent", self.fq), + _cxxBareType(a, "child", self.fq), + ], + ) + return Type(_actorName(self.typename(a.protocol), self.side), ptr=True) + + def visitStructType(self, s): + return Type(self.typename(s)) + + def visitUnionType(self, u): + return Type(self.typename(u)) + + def visitArrayType(self, a): + basecxxtype = a.basetype.accept(self) + return _cxxArrayType(basecxxtype) + + def visitMaybeType(self, m): + basecxxtype = m.basetype.accept(self) + return _cxxMaybeType(basecxxtype) + + def visitShmemType(self, s): + return Type(self.typename(s)) + + def visitByteBufType(self, s): + return Type(self.typename(s)) + + def visitFDType(self, s): + return Type(self.typename(s)) + + def visitEndpointType(self, s): + return Type(self.typename(s)) + + def visitManagedEndpointType(self, s): + return Type(self.typename(s)) + + def visitUniquePtrType(self, s): + return Type(self.typename(s)) + + def visitNotNullType(self, n): + basecxxtype = n.basetype.accept(self) + return _cxxNotNullType(basecxxtype) + + def visitProtocolType(self, p): + assert 0 + + def visitMessageType(self, m): + assert 0 + + def visitVoidType(self, v): + assert 0 + + +def _cxxBareType(ipdltype, side, fq=False): + return ipdltype.accept(_ConvertToCxxType(side, fq)) + + +def _cxxRefType(ipdltype, side): + t = _cxxBareType(ipdltype, side) + t.ref = True + return t + + +def _cxxConstRefType(ipdltype, side): + t = _cxxBareType(ipdltype, side) + if ipdltype.isIPDL() and ipdltype.isActor(): + return t + if ipdltype.isIPDL() and ipdltype.isShmem(): + t.ref = True + return t + if ipdltype.isIPDL() and ipdltype.isNotNull(): + # If the inner type chooses to use a raw pointer, wrap that instead. + inner = _cxxConstRefType(ipdltype.basetype, side) + if inner.ptr: + t = _cxxNotNullType(inner) + return t + if ipdltype.isIPDL() and ipdltype.hasBaseType(): + # Keep same constness as inner type. + inner = _cxxConstRefType(ipdltype.basetype, side) + t.const = inner.const or not inner.ref + t.ref = True + return t + if ipdltype.isCxx() and (ipdltype.isSendMoveOnly() or ipdltype.isDataMoveOnly()): + t.const = True + t.ref = True + return t + if ipdltype.isCxx() and ipdltype.isRefcounted(): + # Use T* instead of const RefPtr<T>& + t = t.T + t.ptr = True + return t + t.const = True + t.ref = True + return t + + +def _cxxTypeNeedsMoveForSend(ipdltype, context="root", visited=None): + """Returns `True` if serializing ipdltype requires a mutable reference, e.g. + because the underlying resource represented by the value is being + transferred to another process. This is occasionally distinct from whether + the C++ type exposes a copy constructor, such as for types which are not + cheaply copiable, but are not mutated when serialized.""" + + if visited is None: + visited = set() + + visited.add(ipdltype) + + if ipdltype.isCxx(): + return ipdltype.isSendMoveOnly() + + if ipdltype.isIPDL(): + if ipdltype.hasBaseType(): + return _cxxTypeNeedsMoveForSend(ipdltype.basetype, "wrapper", visited) + if ipdltype.isStruct() or ipdltype.isUnion(): + return any( + _cxxTypeNeedsMoveForSend(t, "compound", visited) + for t in ipdltype.itercomponents() + if t not in visited + ) + + # For historical reasons, shmem is `const_cast` to a mutable reference + # when being stored in a struct or union (see + # `_StructField.constRefExpr` and `_UnionMember.getConstValue`), meaning + # that they do not cause the containing struct to require move for + # sending. + if ipdltype.isShmem(): + return context != "compound" + + return ( + ipdltype.isByteBuf() + or ipdltype.isEndpoint() + or ipdltype.isManagedEndpoint() + ) + + return False + + +def _cxxTypeNeedsMoveForData(ipdltype, context="root", visited=None): + """Returns `True` if the bare C++ type corresponding to ipdltype does not + satisfy std::is_copy_constructible_v<T>. All C++ types supported by IPDL + must support std::is_move_constructible_v<T>, so non-movable types must be + passed behind a `UniquePtr`.""" + + if visited is None: + visited = set() + + visited.add(ipdltype) + + if ipdltype.isCxx(): + return ipdltype.isDataMoveOnly() + + if ipdltype.isIPDL(): + if ipdltype.isUniquePtr(): + return True + + # When nested within a maybe or array, arrays are no longer copyable. + if context == "wrapper" and ipdltype.isArray(): + return True + if ipdltype.hasBaseType(): + return _cxxTypeNeedsMoveForData(ipdltype.basetype, "wrapper", visited) + if ipdltype.isStruct() or ipdltype.isUnion(): + return any( + _cxxTypeNeedsMoveForData(t, "compound", visited) + for t in ipdltype.itercomponents() + if t not in visited + ) + return ( + ipdltype.isByteBuf() + or ipdltype.isEndpoint() + or ipdltype.isManagedEndpoint() + ) + + return False + + +def _cxxTypeCanMove(ipdltype): + return not (ipdltype.isIPDL() and ipdltype.isActor()) + + +def _cxxForceMoveRefType(ipdltype, side): + assert _cxxTypeCanMove(ipdltype) + t = _cxxBareType(ipdltype, side) + t.rvalref = True + return t + + +def _cxxPtrToType(ipdltype, side): + t = _cxxBareType(ipdltype, side) + if ipdltype.isIPDL() and ipdltype.isActor() and side is not None: + t.ptr = False + t.ptrptr = True + return t + t.ptr = True + return t + + +def _cxxConstPtrToType(ipdltype, side): + t = _cxxBareType(ipdltype, side) + if ipdltype.isIPDL() and ipdltype.isActor() and side is not None: + t.ptr = False + t.ptrconstptr = True + return t + t.const = True + t.ptr = True + return t + + +def _cxxInType(ipdltype, side, direction): + t = _cxxBareType(ipdltype, side) + if ipdltype.isIPDL() and ipdltype.isActor(): + return t + if ipdltype.isIPDL() and ipdltype.isNotNull(): + # If the inner type chooses to use a raw pointer, wrap that instead. + inner = _cxxInType(ipdltype.basetype, side, direction) + if inner.ptr: + t = _cxxNotNullType(inner) + return t + if _cxxTypeNeedsMoveForSend(ipdltype): + t.rvalref = True + return t + if ipdltype.isCxx(): + if ipdltype.isRefcounted(): + # Use T* instead of const RefPtr<T>& + t = t.T + t.ptr = True + return t + if ipdltype.name() == "nsCString": + t = Type("nsACString") + if ipdltype.name() == "nsString": + t = Type("nsAString") + # Use Span<T const> rather than nsTArray<T> for array types which aren't + # `_cxxTypeNeedsMoveForSend`. This is only done for the "send" side, and not + # for recv signatures. + if direction == "send" and ipdltype.isIPDL() and ipdltype.isArray(): + inner = _cxxBareType(ipdltype.basetype, side) + return _cxxSpanType(inner) + + t.const = True + t.ref = True + return t + + +def _allocMethod(ptype, side): + return "Alloc" + ptype.name() + side.title() + + +def _deallocMethod(ptype, side): + return "Dealloc" + ptype.name() + side.title() + + +## +# A _HybridDecl straddles IPDL and C++ decls. It knows which C++ +# types correspond to which IPDL types, and it also knows how +# serialize and deserialize "special" IPDL C++ types. +## + + +class _HybridDecl: + """A hybrid decl stores both an IPDL type and all the C++ type + info needed by later passes, along with a basic name for the decl.""" + + def __init__(self, ipdltype, name, attributes={}): + self.ipdltype = ipdltype + self.name = name + self.attributes = attributes + + def var(self): + return ExprVar(self.name) + + def bareType(self, side, fq=False): + """Return this decl's unqualified C++ type.""" + return _cxxBareType(self.ipdltype, side, fq=fq) + + def refType(self, side): + """Return this decl's C++ type as a 'reference' type, which is not + necessarily a C++ reference.""" + return _cxxRefType(self.ipdltype, side) + + def constRefType(self, side): + """Return this decl's C++ type as a const, 'reference' type.""" + return _cxxConstRefType(self.ipdltype, side) + + def ptrToType(self, side): + return _cxxPtrToType(self.ipdltype, side) + + def constPtrToType(self, side): + return _cxxConstPtrToType(self.ipdltype, side) + + def inType(self, side, direction): + """Return this decl's C++ Type with sending inparam semantics.""" + return _cxxInType(self.ipdltype, side, direction) + + def outType(self, side): + """Return this decl's C++ Type with outparam semantics.""" + t = self.bareType(side) + if self.ipdltype.isIPDL() and self.ipdltype.isActor(): + t.ptr = False + t.ptrptr = True + return t + t.ptr = True + return t + + def forceMoveType(self, side): + """Return this decl's C++ Type with forced move semantics.""" + assert _cxxTypeCanMove(self.ipdltype) + return _cxxForceMoveRefType(self.ipdltype, side) + + +# -------------------------------------------------- + + +class HasFQName: + def fqClassName(self): + return self.decl.type.fullname() + + +class _CompoundTypeComponent(_HybridDecl): + # @override the following methods to make the side argument optional. + def bareType(self, side=None, fq=False): + return _HybridDecl.bareType(self, side, fq=fq) + + def refType(self, side=None): + return _HybridDecl.refType(self, side) + + def constRefType(self, side=None): + return _HybridDecl.constRefType(self, side) + + def ptrToType(self, side=None): + return _HybridDecl.ptrToType(self, side) + + def constPtrToType(self, side=None): + return _HybridDecl.constPtrToType(self, side) + + def forceMoveType(self, side=None): + return _HybridDecl.forceMoveType(self, side) + + +class StructDecl(ipdl.ast.StructDecl, HasFQName): + def fields_ipdl_order(self): + for f in self.fields: + yield f + + def fields_member_order(self): + assert len(self.packed_field_order) == len(self.fields) + + for i in self.packed_field_order: + yield self.fields[i] + + @staticmethod + def upgrade(structDecl): + assert isinstance(structDecl, ipdl.ast.StructDecl) + structDecl.__class__ = StructDecl + + +class _StructField(_CompoundTypeComponent): + def __init__(self, ipdltype, name, sd): + self.basename = name + + _CompoundTypeComponent.__init__(self, ipdltype, name) + + def getMethod(self, thisexpr=None, sel="."): + meth = self.var() + if thisexpr is not None: + return ExprSelect(thisexpr, sel, meth.name) + return meth + + def refExpr(self, thisexpr=None): + ref = self.memberVar() + if thisexpr is not None: + ref = ExprSelect(thisexpr, ".", ref.name) + return ref + + def constRefExpr(self, thisexpr=None): + # sigh, gross hack + refexpr = self.refExpr(thisexpr) + if "Shmem" == self.ipdltype.name(): + refexpr = ExprCast(refexpr, Type("Shmem", ref=True), const=True) + return refexpr + + def argVar(self): + return ExprVar("_" + self.name) + + def memberVar(self): + return ExprVar(self.name + "_") + + +class UnionDecl(ipdl.ast.UnionDecl, HasFQName): + def callType(self, var=None): + func = ExprVar("type") + if var is not None: + func = ExprSelect(var, ".", func.name) + return ExprCall(func) + + @staticmethod + def upgrade(unionDecl): + assert isinstance(unionDecl, ipdl.ast.UnionDecl) + unionDecl.__class__ = UnionDecl + + +class _UnionMember(_CompoundTypeComponent): + """Not in the AFL sense, but rather a member (e.g. |int;|) of an + IPDL union type.""" + + def __init__(self, ipdltype, ud): + flatname = _flatTypeName(ipdltype) + + _CompoundTypeComponent.__init__(self, ipdltype, "V" + flatname) + self.flattypename = flatname + + # To create a finite object with a mutually recursive type, a union must + # be present somewhere in the recursive loop. Because of that we only + # need to care about introducing indirections inside unions. + self.recursive = ud.decl.type.mutuallyRecursiveWith(ipdltype) + + def enum(self): + return "T" + self.flattypename + + def enumvar(self): + return ExprVar(self.enum()) + + def internalType(self): + if self.recursive: + return self.ptrToType() + else: + return self.bareType() + + def unionType(self): + """Type used for storage in generated C union decl.""" + if self.recursive: + return self.ptrToType() + else: + return Type("mozilla::AlignedStorage2", T=self.internalType()) + + def unionValue(self): + # NB: knows that Union's storage C union is named |mValue| + return ExprSelect(ExprVar("mValue"), ".", self.name) + + def typedef(self): + return self.flattypename + "__tdef" + + def callGetConstPtr(self): + """Return an expression of type self.constptrToSelfType()""" + return ExprCall(ExprVar(self.getConstPtrName())) + + def callGetPtr(self): + """Return an expression of type self.ptrToSelfType()""" + return ExprCall(ExprVar(self.getPtrName())) + + def callCtor(self, expr=None): + assert not isinstance(expr, list) + + if expr is None: + args = None + elif ( + self.ipdltype.isIPDL() + and self.ipdltype.isArray() + and not isinstance(expr, ExprMove) + ): + args = [ExprCall(ExprSelect(expr, ".", "Clone"), args=[])] + else: + args = [expr] + + if self.recursive: + return ExprAssn(self.callGetPtr(), ExprNew(self.bareType(), args=args)) + else: + return ExprNew( + self.bareType(), + args=args, + newargs=[ExprVar("mozilla::KnownNotNull"), self.callGetPtr()], + ) + + def callDtor(self): + if self.recursive: + return ExprDelete(self.callGetPtr()) + else: + return ExprCall(ExprSelect(self.callGetPtr(), "->", "~" + self.typedef())) + + def getTypeName(self): + return "get_" + self.flattypename + + def getConstTypeName(self): + return "get_" + self.flattypename + + def getOtherTypeName(self): + return "get_" + self.otherflattypename + + def getPtrName(self): + return "ptr_" + self.flattypename + + def getConstPtrName(self): + return "constptr_" + self.flattypename + + def ptrToSelfExpr(self): + """|*ptrToSelfExpr()| has type |self.bareType()|""" + v = self.unionValue() + if self.recursive: + return v + else: + return ExprCall(ExprSelect(v, ".", "addr")) + + def constptrToSelfExpr(self): + """|*constptrToSelfExpr()| has type |self.constType()|""" + v = self.unionValue() + if self.recursive: + return v + return ExprCall(ExprSelect(v, ".", "addr")) + + def ptrToInternalType(self): + t = self.ptrToType() + if self.recursive: + t.ref = True + return t + + def defaultValue(self, fq=False): + # Use the default constructor for any class that does not have an + # implicit copy constructor. + if not self.bareType().hasimplicitcopyctor: + return None + + if self.ipdltype.isIPDL() and self.ipdltype.isActor(): + return ExprLiteral.NULL + # XXX sneaky here, maybe need ExprCtor()? + return ExprCall(self.bareType(fq=fq)) + + def getConstValue(self): + v = ExprDeref(self.callGetConstPtr()) + # sigh + if "Shmem" == self.ipdltype.name(): + v = ExprCast(v, Type("Shmem", ref=True), const=True) + return v + + +# -------------------------------------------------- + + +class MessageDecl(ipdl.ast.MessageDecl): + def baseName(self): + return self.name + + def recvMethod(self): + name = _recvPrefix(self.decl.type) + self.baseName() + if self.decl.type.isCtor(): + name += "Constructor" + return name + + def sendMethod(self): + name = _sendPrefix(self.decl.type) + self.baseName() + if self.decl.type.isCtor(): + name += "Constructor" + return name + + def hasReply(self): + return ( + self.decl.type.hasReply() + or self.decl.type.isCtor() + or self.decl.type.isDtor() + ) + + def hasAsyncReturns(self): + return self.decl.type.isAsync() and self.returns + + def msgCtorFunc(self): + return "Msg_%s" % (self.decl.progname) + + def prettyMsgName(self, pfx=""): + return pfx + self.msgCtorFunc() + + def pqMsgCtorFunc(self): + return "%s::%s" % (self.namespace, self.msgCtorFunc()) + + def msgId(self): + return self.msgCtorFunc() + "__ID" + + def pqMsgId(self): + return "%s::%s" % (self.namespace, self.msgId()) + + def replyCtorFunc(self): + return "Reply_%s" % (self.decl.progname) + + def pqReplyCtorFunc(self): + return "%s::%s" % (self.namespace, self.replyCtorFunc()) + + def replyId(self): + return self.replyCtorFunc() + "__ID" + + def pqReplyId(self): + return "%s::%s" % (self.namespace, self.replyId()) + + def prettyReplyName(self, pfx=""): + return pfx + self.replyCtorFunc() + + def promiseName(self): + name = self.baseName() + if self.decl.type.isCtor(): + name += "Constructor" + name += "Promise" + return name + + def resolverName(self): + return self.baseName() + "Resolver" + + def actorDecl(self): + return self.params[0] + + def makeCxxParams( + self, paramsems="in", returnsems="out", side=None, implicit=True, direction=None + ): + """Return a list of C++ decls per the spec'd configuration. + |params| and |returns| is the C++ semantics of those: 'in', 'out', or None.""" + + def makeDecl(d, sems): + if ( + self.decl.type.tainted + and "NoTaint" not in d.attributes + and direction == "recv" + ): + # Tainted types are passed by-value, allowing the receiver to move them if desired. + assert sems != "out" + return Decl(Type("Tainted", T=d.bareType(side)), d.name) + + if sems == "in": + t = d.inType(side, direction) + # If this is the `recv` side, and we're not using "move" + # semantics, that means we're an alloc method, and cannot accept + # values by rvalue reference. Downgrade to an lvalue reference. + if direction == "recv" and t.rvalref: + t.rvalref = False + t.ref = True + return Decl(t, d.name) + elif sems == "move": + assert direction == "recv" + # For legacy reasons, use an rvalue reference when generating + # parameters for recv methods which accept arrays. + if d.ipdltype.isIPDL() and d.ipdltype.isArray(): + t = d.bareType(side) + t.rvalref = True + return Decl(t, d.name) + return Decl(d.inType(side, direction), d.name) + elif sems == "out": + return Decl(d.outType(side), d.name) + else: + assert 0 + + def makeResolverDecl(returns): + return Decl(Type(self.resolverName(), rvalref=True), "aResolve") + + def makeCallbackResolveDecl(returns): + if len(returns) > 1: + resolvetype = _tuple([d.bareType(side) for d in returns]) + else: + resolvetype = returns[0].bareType(side) + + return Decl( + Type("mozilla::ipc::ResolveCallback", T=resolvetype, rvalref=True), + "aResolve", + ) + + def makeCallbackRejectDecl(returns): + return Decl(Type("mozilla::ipc::RejectCallback", rvalref=True), "aReject") + + cxxparams = [] + if paramsems is not None: + cxxparams.extend([makeDecl(d, paramsems) for d in self.params]) + + if returnsems == "promise" and self.returns: + pass + elif returnsems == "callback" and self.returns: + cxxparams.extend( + [ + makeCallbackResolveDecl(self.returns), + makeCallbackRejectDecl(self.returns), + ] + ) + elif returnsems == "resolver" and self.returns: + cxxparams.extend([makeResolverDecl(self.returns)]) + elif returnsems is not None: + cxxparams.extend([makeDecl(r, returnsems) for r in self.returns]) + + if not implicit and self.decl.type.hasImplicitActorParam(): + cxxparams = cxxparams[1:] + + return cxxparams + + def makeCxxArgs( + self, paramsems="in", retsems="out", retcallsems="out", implicit=True + ): + assert not retcallsems or retsems # retcallsems => returnsems + cxxargs = [] + + if paramsems == "move": + # We don't std::move() RefPtr<T> types because current Recv*() + # implementors take these parameters as T*, and + # std::move(RefPtr<T>) doesn't coerce to T*. + # We also don't move NotNull, as it has no move constructor. + cxxargs.extend( + [ + p.var() + if p.ipdltype.isRefcounted() + or (p.ipdltype.isIPDL() and p.ipdltype.isNotNull()) + else ExprMove(p.var()) + for p in self.params + ] + ) + elif paramsems == "in": + cxxargs.extend([p.var() for p in self.params]) + else: + assert False + + for ret in self.returns: + if retsems == "in": + if retcallsems == "in": + cxxargs.append(ret.var()) + elif retcallsems == "out": + cxxargs.append(ExprAddrOf(ret.var())) + else: + assert 0 + elif retsems == "out": + if retcallsems == "in": + cxxargs.append(ExprDeref(ret.var())) + elif retcallsems == "out": + cxxargs.append(ret.var()) + else: + assert 0 + elif retsems == "resolver": + pass + if retsems == "resolver": + cxxargs.append(ExprMove(ExprVar("resolver"))) + + if not implicit: + assert self.decl.type.hasImplicitActorParam() + cxxargs = cxxargs[1:] + + return cxxargs + + @staticmethod + def upgrade(messageDecl): + assert isinstance(messageDecl, ipdl.ast.MessageDecl) + if messageDecl.decl.type.hasImplicitActorParam(): + messageDecl.params.insert( + 0, + _HybridDecl( + ipdl.type.ActorType(messageDecl.decl.type.constructedType()), + "actor", + ), + ) + messageDecl.__class__ = MessageDecl + + +# -------------------------------------------------- +def _usesShmem(p): + for md in p.messageDecls: + for param in md.inParams: + if ipdl.type.hasshmem(param.type): + return True + for ret in md.outParams: + if ipdl.type.hasshmem(ret.type): + return True + return False + + +def _subtreeUsesShmem(p): + if _usesShmem(p): + return True + + ptype = p.decl.type + for mgd in ptype.manages: + if ptype is not mgd: + if _subtreeUsesShmem(mgd._ast): + return True + return False + + +class Protocol(ipdl.ast.Protocol): + def managerInterfaceType(self, ptr=False): + return Type("mozilla::ipc::IProtocol", ptr=ptr) + + def openedProtocolInterfaceType(self, ptr=False): + return Type("mozilla::ipc::IToplevelProtocol", ptr=ptr) + + def _ipdlmgrtype(self): + assert 1 == len(self.decl.type.managers) + for mgr in self.decl.type.managers: + return mgr + + def managerActorType(self, side, ptr=False): + return Type(_actorName(self._ipdlmgrtype().name(), side), ptr=ptr) + + def unregisterMethod(self, actorThis=None): + if actorThis is not None: + return ExprSelect(actorThis, "->", "Unregister") + return ExprVar("Unregister") + + def removeManageeMethod(self): + return ExprVar("RemoveManagee") + + def deallocManageeMethod(self): + return ExprVar("DeallocManagee") + + def getChannelMethod(self): + return ExprVar("GetIPCChannel") + + def callGetChannel(self, actorThis=None): + fn = self.getChannelMethod() + if actorThis is not None: + fn = ExprSelect(actorThis, "->", fn.name) + return ExprCall(fn) + + def processingErrorVar(self): + assert self.decl.type.isToplevel() + return ExprVar("ProcessingError") + + def shouldContinueFromTimeoutVar(self): + assert self.decl.type.isToplevel() + return ExprVar("ShouldContinueFromReplyTimeout") + + def routingId(self, actorThis=None): + if self.decl.type.isToplevel(): + return ExprVar("MSG_ROUTING_CONTROL") + if actorThis is not None: + return ExprCall(ExprSelect(actorThis, "->", "Id")) + return ExprCall(ExprVar("Id")) + + def managerVar(self, thisexpr=None): + assert thisexpr is not None or not self.decl.type.isToplevel() + mvar = ExprCall(ExprVar("Manager"), args=[]) + if thisexpr is not None: + mvar = ExprCall(ExprSelect(thisexpr, "->", "Manager"), args=[]) + return mvar + + def managedCxxType(self, actortype, side): + assert self.decl.type.isManagerOf(actortype) + return Type(_actorName(actortype.name(), side), ptr=True) + + def managedMethod(self, actortype, side): + assert self.decl.type.isManagerOf(actortype) + return ExprVar("Managed" + _actorName(actortype.name(), side)) + + def managedVar(self, actortype, side): + assert self.decl.type.isManagerOf(actortype) + return ExprVar("mManaged" + _actorName(actortype.name(), side)) + + def managedVarType(self, actortype, side, const=False, ref=False): + assert self.decl.type.isManagerOf(actortype) + return _cxxManagedContainerType( + Type(_actorName(actortype.name(), side)), const=const, ref=ref + ) + + def subtreeUsesShmem(self): + return _subtreeUsesShmem(self) + + @staticmethod + def upgrade(protocol): + assert isinstance(protocol, ipdl.ast.Protocol) + protocol.__class__ = Protocol + + +class TranslationUnit(ipdl.ast.TranslationUnit): + @staticmethod + def upgrade(tu): + assert isinstance(tu, ipdl.ast.TranslationUnit) + tu.__class__ = TranslationUnit + + +# ----------------------------------------------------------------------------- + +pod_types = { + "::int8_t": 1, + "::uint8_t": 1, + "::int16_t": 2, + "::uint16_t": 2, + "::int32_t": 4, + "::uint32_t": 4, + "::int64_t": 8, + "::uint64_t": 8, + "float": 4, + "double": 8, +} +max_pod_size = max(pod_types.values()) +# We claim that all types we don't recognize are automatically "bigger" +# than pod types for ease of sorting. +pod_size_sentinel = max_pod_size * 2 + + +def pod_size(ipdltype): + if not ipdltype.isCxx(): + return pod_size_sentinel + + return pod_types.get(ipdltype.fullname(), pod_size_sentinel) + + +class _DecorateWithCxxStuff(ipdl.ast.Visitor): + """Phase 1 of lowering: decorate the IPDL AST with information + relevant to C++ code generation. + + This pass results in an AST that is a poor man's "IR"; in reality, a + "hybrid" AST mainly consisting of IPDL nodes with new C++ info along + with some new IPDL/C++ nodes that are tuned for C++ codegen.""" + + def __init__(self): + self.visitedTus = set() + self.protocolName = None + + def visitTranslationUnit(self, tu): + if tu not in self.visitedTus: + self.visitedTus.add(tu) + ipdl.ast.Visitor.visitTranslationUnit(self, tu) + if not isinstance(tu, TranslationUnit): + TranslationUnit.upgrade(tu) + + def visitInclude(self, inc): + if inc.tu.filetype == "header": + inc.tu.accept(self) + + def visitProtocol(self, pro): + self.protocolName = pro.name + Protocol.upgrade(pro) + return ipdl.ast.Visitor.visitProtocol(self, pro) + + def visitStructDecl(self, sd): + if not isinstance(sd, StructDecl): + newfields = [_StructField(f.decl.type, f.name, sd) for f in sd.fields] + + # Compute a permutation of the fields for in-memory storage such + # that the memory layout of the structure will be well-packed. + permutation = list(range(len(newfields))) + + # Note that the results of `pod_size` ensure that non-POD fields + # sort before POD ones. + def size(idx): + return pod_size(newfields[idx].ipdltype) + + permutation.sort(key=size, reverse=True) + + sd.fields = newfields + sd.packed_field_order = permutation + StructDecl.upgrade(sd) + + def visitUnionDecl(self, ud): + ud.components = [_UnionMember(ctype, ud) for ctype in ud.decl.type.components] + UnionDecl.upgrade(ud) + + def visitDecl(self, decl): + return _HybridDecl(decl.type, decl.progname, decl.attributes) + + def visitMessageDecl(self, md): + md.namespace = self.protocolName + md.params = [param.accept(self) for param in md.inParams] + md.returns = [ret.accept(self) for ret in md.outParams] + MessageDecl.upgrade(md) + + +# ----------------------------------------------------------------------------- + + +def msgenums(protocol, pretty=False): + msgenum = TypeEnum("MessageType") + msgstart = _messageStartName(protocol.decl.type) + " << 16" + msgenum.addId(protocol.name + "Start", msgstart) + + for md in protocol.messageDecls: + msgenum.addId(md.prettyMsgName() if pretty else md.msgId()) + if md.hasReply(): + msgenum.addId(md.prettyReplyName() if pretty else md.replyId()) + + msgenum.addId(protocol.name + "End") + return msgenum + + +class _GenerateProtocolCode(ipdl.ast.Visitor): + """Creates code common to both the parent and child actors.""" + + def __init__(self): + self.protocol = None # protocol we're generating a class for + self.hdrfile = None # what will become Protocol.h + self.cppfile = None # what will become Protocol.cpp + self.cppIncludeHeaders = [] + self.structUnionDefns = [] + self.funcDefns = [] + + def lower(self, tu, cxxHeaderFile, cxxFile, segmentcapacitydict): + self.protocol = tu.protocol + self.hdrfile = cxxHeaderFile + self.cppfile = cxxFile + self.segmentcapacitydict = segmentcapacitydict + tu.accept(self) + + def visitTranslationUnit(self, tu): + hf = self.hdrfile + + hf.addthing(_DISCLAIMER) + hf.addthings(_includeGuardStart(hf)) + hf.addthing(Whitespace.NL) + + for inc in builtinHeaderIncludes: + self.visitBuiltinCxxInclude(inc) + + # Compute the set of includes we need for declared structure/union + # classes for this protocol. + typesToIncludes = {} + for using in tu.using: + typestr = str(using.type) + if typestr not in typesToIncludes: + typesToIncludes[typestr] = using.header + else: + assert typesToIncludes[typestr] == using.header + + aggregateTypeIncludes = set() + for su in tu.structsAndUnions: + typedeps = _ComputeTypeDeps(su.decl.type, typesToIncludes) + if isinstance(su, ipdl.ast.StructDecl): + aggregateTypeIncludes.add("mozilla/ipc/IPDLStructMember.h") + for f in su.fields: + f.ipdltype.accept(typedeps) + elif isinstance(su, ipdl.ast.UnionDecl): + for c in su.components: + c.ipdltype.accept(typedeps) + + aggregateTypeIncludes.update(typedeps.includeHeaders) + + if len(aggregateTypeIncludes) != 0: + hf.addthing(Whitespace.NL) + hf.addthings([Whitespace("// Headers for typedefs"), Whitespace.NL]) + + for headername in sorted(iter(aggregateTypeIncludes)): + hf.addthing(CppDirective("include", '"' + headername + '"')) + + # Manually run Visitor.visitTranslationUnit. For dependency resolution + # we need to handle structs and unions separately. + for cxxInc in tu.cxxIncludes: + cxxInc.accept(self) + for inc in tu.includes: + inc.accept(self) + self.generateStructsAndUnions(tu) + for using in tu.builtinUsing: + using.accept(self) + for using in tu.using: + using.accept(self) + if tu.protocol: + tu.protocol.accept(self) + + if tu.filetype == "header": + self.cppIncludeHeaders.append(_ipdlhHeaderName(tu) + ".h") + + hf.addthing(Whitespace.NL) + hf.addthings(_includeGuardEnd(hf)) + + cf = self.cppfile + cf.addthings( + ( + [_DISCLAIMER, Whitespace.NL] + + [ + CppDirective("include", '"' + h + '"') + for h in self.cppIncludeHeaders + ] + + [Whitespace.NL] + + [ + CppDirective("include", '"%s"' % filename) + for filename in ipdl.builtin.CppIncludes + ] + + [Whitespace.NL] + ) + ) + + if self.protocol: + # construct the namespace into which we'll stick all our defns + ns = Namespace(self.protocol.name) + cf.addthing(_putInNamespaces(ns, self.protocol.namespaces)) + ns.addstmts(([Whitespace.NL] + self.funcDefns + [Whitespace.NL])) + + cf.addthings(self.structUnionDefns) + + def visitBuiltinCxxInclude(self, inc): + self.hdrfile.addthing(CppDirective("include", '"' + inc.file + '"')) + + def visitCxxInclude(self, inc): + self.cppIncludeHeaders.append(inc.file) + + def visitInclude(self, inc): + if inc.tu.filetype == "header": + self.hdrfile.addthing( + CppDirective("include", '"' + _ipdlhHeaderName(inc.tu) + '.h"') + ) + # Inherit cpp includes defined by imported header files, as they may + # be required to serialize an imported `using` type. + for cxxinc in inc.tu.cxxIncludes: + cxxinc.accept(self) + else: + self.cppIncludeHeaders += [ + _protocolHeaderName(inc.tu.protocol, "parent") + ".h", + _protocolHeaderName(inc.tu.protocol, "child") + ".h", + ] + + def generateStructsAndUnions(self, tu): + """Generate the definitions for all structs and unions. This will + re-order the declarations if needed in the C++ code such that + dependencies have already been defined.""" + decls = OrderedDict() + for su in tu.structsAndUnions: + if isinstance(su, StructDecl): + which = "struct" + forwarddecls, fulldecltypes, cls = _generateCxxStruct(su) + traitsdecl, traitsdefns = _ParamTraits.structPickling(su.decl.type) + else: + assert isinstance(su, UnionDecl) + which = "union" + forwarddecls, fulldecltypes, cls = _generateCxxUnion(su) + traitsdecl, traitsdefns = _ParamTraits.unionPickling(su.decl.type) + + clsdecl, methoddefns = _splitClassDeclDefn(cls) + + # Store the declarations in the decls map so we can emit in + # dependency order. + decls[su.decl.type] = ( + fulldecltypes, + [Whitespace.NL] + + forwarddecls + + [ + Whitespace( + """ +//----------------------------------------------------------------------------- +// Declaration of the IPDL type |%s %s| +// +""" + % (which, su.name) + ), + _putInNamespaces(clsdecl, su.namespaces), + ] + + [Whitespace.NL, traitsdecl], + ) + + self.structUnionDefns.extend( + [ + Whitespace( + """ +//----------------------------------------------------------------------------- +// Method definitions for the IPDL type |%s %s| +// +""" + % (which, su.name) + ), + _putInNamespaces(methoddefns, su.namespaces), + Whitespace.NL, + traitsdefns, + ] + ) + + # Generate the declarations structs in dependency order. + def gen_struct(deps, defn): + for dep in deps: + if dep in decls: + d, t = decls[dep] + del decls[dep] + gen_struct(d, t) + self.hdrfile.addthings(defn) + + while len(decls) > 0: + _, (d, t) = decls.popitem(False) + gen_struct(d, t) + + def visitProtocol(self, p): + self.cppIncludeHeaders.append(_protocolHeaderName(self.protocol, "") + ".h") + self.cppIncludeHeaders.append( + _protocolHeaderName(self.protocol, "Parent") + ".h" + ) + self.cppIncludeHeaders.append( + _protocolHeaderName(self.protocol, "Child") + ".h" + ) + + # Forward declare our own actors. + self.hdrfile.addthings( + [ + Whitespace.NL, + _makeForwardDeclForActor(p.decl.type, "Parent"), + _makeForwardDeclForActor(p.decl.type, "Child"), + ] + ) + + self.hdrfile.addthing( + Whitespace( + """ +//----------------------------------------------------------------------------- +// Code common to %sChild and %sParent +// +""" + % (p.name, p.name) + ) + ) + + # construct the namespace into which we'll stick all our decls + ns = Namespace(self.protocol.name) + self.hdrfile.addthing(_putInNamespaces(ns, p.namespaces)) + ns.addstmt(Whitespace.NL) + + for func in self.genEndpointFuncs(): + edecl, edefn = _splitFuncDeclDefn(func) + ns.addstmts([edecl, Whitespace.NL]) + self.funcDefns.append(edefn) + + # spit out message type enum and classes + msgenum = msgenums(self.protocol) + ns.addstmts([StmtDecl(Decl(msgenum, "")), Whitespace.NL]) + + for md in p.messageDecls: + decls = [] + + # Look up the segment capacity used for serializing this + # message. If the capacity is not specified, use '0' for + # the default capacity (defined in ipc_message.cc) + name = "%s::%s" % (md.namespace, md.decl.progname) + segmentcapacity = self.segmentcapacitydict.get(name, 0) + + mfDecl, mfDefn = _splitFuncDeclDefn( + _generateMessageConstructor(md, segmentcapacity, p, forReply=False) + ) + decls.append(mfDecl) + self.funcDefns.append(mfDefn) + + if md.hasReply(): + rfDecl, rfDefn = _splitFuncDeclDefn( + _generateMessageConstructor(md, 0, p, forReply=True) + ) + decls.append(rfDecl) + self.funcDefns.append(rfDefn) + + decls.append(Whitespace.NL) + ns.addstmts(decls) + + ns.addstmts([Whitespace.NL, Whitespace.NL]) + + # Generate code for PFoo::CreateEndpoints. + def genEndpointFuncs(self): + p = self.protocol.decl.type + tparent = _cxxBareType(ActorType(p), "Parent", fq=True) + tchild = _cxxBareType(ActorType(p), "Child", fq=True) + + def mkOverload(includepids): + params = [] + if includepids: + params = [ + Decl(Type("base::ProcessId"), "aParentDestPid"), + Decl(Type("base::ProcessId"), "aChildDestPid"), + ] + params += [ + Decl( + Type("mozilla::ipc::Endpoint<" + tparent.name + ">", ptr=True), + "aParent", + ), + Decl( + Type("mozilla::ipc::Endpoint<" + tchild.name + ">", ptr=True), + "aChild", + ), + ] + openfunc = MethodDefn( + MethodDecl("CreateEndpoints", params=params, ret=Type.NSRESULT) + ) + openfunc.addcode( + """ + return mozilla::ipc::CreateEndpoints( + mozilla::ipc::PrivateIPDLInterface(), + $,{args}); + """, + args=[ExprVar(d.name) for d in params], + ) + return openfunc + + funcs = [mkOverload(True)] + if not p.hasOtherPid(): + funcs.append(mkOverload(False)) + return funcs + + +# -------------------------------------------------- + +cppPriorityList = list( + map(lambda src: src.upper() + "_PRIORITY", ipdl.ast.priorityList) +) + + +def _generateMessageConstructor(md, segmentSize, protocol, forReply=False): + if forReply: + clsname = md.replyCtorFunc() + msgid = md.replyId() + replyEnum = "REPLY" + prioEnum = cppPriorityList[md.decl.type.replyPrio] + else: + clsname = md.msgCtorFunc() + msgid = md.msgId() + replyEnum = "NOT_REPLY" + prioEnum = cppPriorityList[md.decl.type.prio] + + nested = md.decl.type.nested + compress = md.decl.type.compress + lazySend = md.decl.type.lazySend + + routingId = ExprVar("routingId") + + func = FunctionDefn( + FunctionDecl( + clsname, + params=[Decl(Type("int32_t"), routingId.name)], + ret=Type("mozilla::UniquePtr<IPC::Message>"), + ) + ) + + if not compress: + compression = "COMPRESSION_NONE" + elif compress.value == "all": + compression = "COMPRESSION_ALL" + else: + assert compress.value is None + compression = "COMPRESSION_ENABLED" + + if lazySend: + lazySendEnum = "LAZY_SEND" + else: + lazySendEnum = "EAGER_SEND" + + if nested == ipdl.ast.NOT_NESTED: + nestedEnum = "NOT_NESTED" + elif nested == ipdl.ast.INSIDE_SYNC_NESTED: + nestedEnum = "NESTED_INSIDE_SYNC" + else: + assert nested == ipdl.ast.INSIDE_CPOW_NESTED + nestedEnum = "NESTED_INSIDE_CPOW" + + if md.decl.type.isSync(): + syncEnum = "SYNC" + else: + syncEnum = "ASYNC" + + # FIXME(bug ???) - remove support for interrupt messages from the IPDL compiler. + if md.decl.type.isInterrupt(): + func.addcode( + """ + static_assert( + false, + "runtime support for intr messages has been removed from IPDL"); + """ + ) + + if md.decl.type.isCtor(): + ctorEnum = "CONSTRUCTOR" + else: + ctorEnum = "NOT_CONSTRUCTOR" + + def messageEnum(valname): + return ExprVar("IPC::Message::" + valname) + + flags = ExprCall( + ExprVar("IPC::Message::HeaderFlags"), + args=[ + messageEnum(nestedEnum), + messageEnum(prioEnum), + messageEnum(compression), + messageEnum(lazySendEnum), + messageEnum(ctorEnum), + messageEnum(syncEnum), + messageEnum(replyEnum), + ], + ) + + segmentSize = int(segmentSize) + if not segmentSize: + segmentSize = 0 + func.addstmt( + StmtReturn( + ExprCall( + ExprVar("IPC::Message::IPDLMessage"), + args=[ + routingId, + ExprVar(msgid), + ExprLiteral.Int(int(segmentSize)), + flags, + ], + ) + ) + ) + + return func + + +# -------------------------------------------------- + + +class _ParamTraits: + var = ExprVar("aVar") + writervar = ExprVar("aWriter") + readervar = ExprVar("aReader") + + @classmethod + def ifsideis(cls, rdrwtr, side, then, els=None): + cxxside = ExprVar("mozilla::ipc::ChildSide") + if side == "parent": + cxxside = ExprVar("mozilla::ipc::ParentSide") + + ifstmt = StmtIf( + ExprBinary( + cxxside, + "==", + ExprCode("${rdrwtr}->GetActor()->GetSide()", rdrwtr=rdrwtr), + ) + ) + ifstmt.addifstmt(then) + if els is not None: + ifstmt.addelsestmt(els) + return ifstmt + + @classmethod + def fatalError(cls, rdrwtr, reason): + return StmtCode( + "${rdrwtr}->FatalError(${reason});", + rdrwtr=rdrwtr, + reason=ExprLiteral.String(reason), + ) + + @classmethod + def writeSentinel(cls, writervar, sentinelKey): + return [ + Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True), + StmtExpr( + ExprCall( + ExprSelect(writervar, "->", "WriteSentinel"), + args=[ExprLiteral.Int(hashfunc(sentinelKey))], + ) + ), + ] + + @classmethod + def readSentinel(cls, readervar, sentinelKey, sentinelFail): + # Read the sentinel + read = ExprCall( + ExprSelect(readervar, "->", "ReadSentinel"), + args=[ExprLiteral.Int(hashfunc(sentinelKey))], + ) + ifsentinel = StmtIf(ExprNot(read)) + ifsentinel.addifstmts(sentinelFail) + + return [ + Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True), + ifsentinel, + ] + + @classmethod + def write(cls, var, writervar, ipdltype=None): + if ipdltype and _cxxTypeNeedsMoveForSend(ipdltype): + var = ExprMove(var) + return ExprCall(ExprVar("IPC::WriteParam"), args=[writervar, var]) + + @classmethod + def checkedWrite(cls, ipdltype, var, writervar, sentinelKey): + assert sentinelKey + block = Block() + + block.addstmts( + [ + StmtExpr(cls.write(var, writervar, ipdltype)), + ] + ) + block.addstmts(cls.writeSentinel(writervar, sentinelKey)) + return block + + @classmethod + def bulkSentinelKey(cls, fields): + return " | ".join(f.basename for f in fields) + + @classmethod + def checkedBulkWrite(cls, var, size, fields): + block = Block() + first = fields[0] + + block.addstmts( + [ + StmtExpr( + ExprCall( + ExprSelect(cls.writervar, "->", "WriteBytes"), + args=[ + ExprAddrOf( + ExprCall(first.getMethod(thisexpr=var, sel=".")) + ), + ExprLiteral.Int(size * len(fields)), + ], + ) + ) + ] + ) + block.addstmts(cls.writeSentinel(cls.writervar, cls.bulkSentinelKey(fields))) + + return block + + @classmethod + def checkedBulkRead(cls, var, size, fields): + block = Block() + first = fields[0] + + readbytes = ExprCall( + ExprSelect(cls.readervar, "->", "ReadBytesInto"), + args=[ + ExprAddrOf(ExprCall(first.getMethod(thisexpr=var, sel="->"))), + ExprLiteral.Int(size * len(fields)), + ], + ) + ifbad = StmtIf(ExprNot(readbytes)) + errmsg = "Error bulk reading fields from %s" % first.ipdltype.name() + ifbad.addifstmts( + [cls.fatalError(cls.readervar, errmsg), StmtReturn(readResultError())] + ) + block.addstmt(ifbad) + block.addstmts( + cls.readSentinel( + cls.readervar, + cls.bulkSentinelKey(fields), + errfnSentinel(readResultError())(errmsg), + ) + ) + + return block + + @classmethod + def checkedRead( + cls, + ipdltype, + cxxtype, + var, + readervar, + errfn, + paramtype, + sentinelKey, + errfnSentinel, + ): + assert isinstance(var, ExprVar) + + if not isinstance(paramtype, list): + paramtype = ["Error deserializing " + paramtype] + + block = Block() + + # Read the data + block.addcode( + """ + auto ${maybevar} = IPC::ReadParam<${ty}>(${reader}); + if (!${maybevar}) { + $*{errfn} + } + auto& ${var} = *${maybevar}; + """, + maybevar=ExprVar("maybe__" + var.name), + ty=cxxtype, + reader=readervar, + errfn=errfn(*paramtype), + var=var, + ) + + block.addstmts( + cls.readSentinel(readervar, sentinelKey, errfnSentinel(*paramtype)) + ) + + return block + + # Helper wrapper for checkedRead for use within _ParamTraits + @classmethod + def _checkedRead(cls, ipdltype, cxxtype, var, sentinelKey, what): + def errfn(msg): + return [cls.fatalError(cls.readervar, msg), StmtReturn(readResultError())] + + return cls.checkedRead( + ipdltype, + cxxtype, + var, + cls.readervar, + errfn=errfn, + paramtype=what, + sentinelKey=sentinelKey, + errfnSentinel=errfnSentinel(readResultError()), + ) + + @classmethod + def generateDecl(cls, fortype, write, read, needsmove=False): + # ParamTraits impls are selected ignoring constness, and references. + pt = Class( + "ParamTraits", + specializes=Type( + fortype.name, T=fortype.T, inner=fortype.inner, ptr=fortype.ptr + ), + struct=True, + ) + + # typedef T paramType; + pt.addstmt(Typedef(fortype, "paramType")) + + # static void Write(Message*, const T&); + if needsmove: + intype = Type("paramType", rvalref=True) + else: + intype = Type("paramType", ref=True, const=True) + writemthd = MethodDefn( + MethodDecl( + "Write", + params=[ + Decl(Type("IPC::MessageWriter", ptr=True), cls.writervar.name), + Decl(intype, cls.var.name), + ], + methodspec=MethodSpec.STATIC, + ) + ) + writemthd.addstmts(write) + pt.addstmt(writemthd) + + # static ReadResult<T> Read(MessageReader*); + readmthd = MethodDefn( + MethodDecl( + "Read", + params=[ + Decl(Type("IPC::MessageReader", ptr=True), cls.readervar.name), + ], + ret=Type("IPC::ReadResult<paramType>"), + methodspec=MethodSpec.STATIC, + ) + ) + readmthd.addstmts(read) + pt.addstmt(readmthd) + + # Split the class into declaration and definition + clsdecl, methoddefns = _splitClassDeclDefn(pt) + + namespaces = [Namespace("IPC")] + clsns = _putInNamespaces(clsdecl, namespaces) + defns = _putInNamespaces(methoddefns, namespaces) + return clsns, defns + + @classmethod + def actorPickling(cls, actortype, side): + """Generates pickling for IPDL actors. This is a |nullable| deserializer. + Write and read callers will perform nullability validation.""" + + cxxtype = _cxxBareType(actortype, side, fq=True) + + write = StmtCode( + """ + MOZ_RELEASE_ASSERT( + ${writervar}->GetActor(), + "Cannot serialize managed actors without an actor"); + + int32_t id; + if (!${var}) { + id = 0; // kNullActorId + } else { + id = ${var}->Id(); + if (id == 1) { // kFreedActorId + ${var}->FatalError("Actor has been |delete|d"); + } + MOZ_RELEASE_ASSERT( + ${writervar}->GetActor()->GetIPCChannel() == ${var}->GetIPCChannel(), + "Actor must be from the same channel as the" + " actor it's being sent over"); + MOZ_RELEASE_ASSERT( + ${var}->CanSend(), + "Actor must still be open when sending"); + } + + ${write}; + """, + var=cls.var, + writervar=cls.writervar, + write=cls.write(ExprVar("id"), cls.writervar), + ) + + # bool Read(..) impl + read = StmtCode( + """ + MOZ_RELEASE_ASSERT( + ${readervar}->GetActor(), + "Cannot deserialize managed actors without an actor"); + mozilla::Maybe<mozilla::ipc::IProtocol*> actor = ${readervar}->GetActor() + ->ReadActor(${readervar}, true, ${actortype}, ${protocolid}); + if (actor.isSome()) { + return static_cast<${cxxtype}>(actor.ref()); + } + return {}; + """, + readervar=cls.readervar, + actortype=ExprLiteral.String(actortype.name()), + protocolid=_protocolId(actortype), + cxxtype=cxxtype, + ) + + return cls.generateDecl(cxxtype, [write], [read]) + + @classmethod + def structPickling(cls, structtype): + sd = structtype._ast + # NOTE: Not using _cxxBareType here as we don't have a side + cxxtype = Type(structtype.fullname()) + + write = [] + read = [] + + # First serialize/deserialize all non-pod data in IPDL order. These need + # to be read/written first because they'll be used to invoke the IPDL + # struct's constructor. + ctorargs = [] + for f in sd.fields_ipdl_order(): + if pod_size(f.ipdltype) == pod_size_sentinel: + write.append( + cls.checkedWrite( + f.ipdltype, + ExprCall(f.getMethod(thisexpr=cls.var, sel=".")), + cls.writervar, + sentinelKey=f.basename, + ) + ) + read.append( + cls._checkedRead( + f.ipdltype, + f.bareType(fq=True), + f.argVar(), + f.basename, + "'" + + f.getMethod().name + + "' " + + "(" + + f.ipdltype.name() + + ") member of " + + "'" + + structtype.name() + + "'", + ) + ) + if _cxxTypeCanMove(f.ipdltype): + ctorargs.append(ExprMove(f.argVar())) + else: + ctorargs.append(f.argVar()) + else: + # We're going to bulk-read in this value later, so we'll just + # zero-initialize it for now. + ctorargs.append(ExprCode("${type}{0}", type=f.bareType(fq=True))) + + resultvar = ExprVar("result__") + read.append( + StmtDecl( + Decl(_cxxReadResultType(Type("paramType")), resultvar.name), + initargs=[ExprVar("std::in_place")] + ctorargs, + ) + ) + + # After non-pod data, bulk read/write pod data in member order. This has + # to be done after the result has been constructed, so that we have + # somewhere to read into. + for (size, fields) in itertools.groupby( + sd.fields_member_order(), lambda f: pod_size(f.ipdltype) + ): + if size != pod_size_sentinel: + fields = list(fields) + write.append(cls.checkedBulkWrite(cls.var, size, fields)) + read.append(cls.checkedBulkRead(resultvar, size, fields)) + + read.append(StmtReturn(resultvar)) + + return cls.generateDecl( + cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(structtype) + ) + + @classmethod + def unionPickling(cls, uniontype): + # NOTE: Not using _cxxBareType here as we don't have a side + cxxtype = Type(uniontype.fullname()) + ud = uniontype._ast + + # Use typedef to set up an alias so it's easier to reference the struct type. + alias = "union__" + typevar = ExprVar("type") + + prelude = [ + Typedef(cxxtype, alias), + ] + + writeswitch = StmtSwitch(typevar) + write = prelude + [ + StmtDecl(Decl(Type.INT, typevar.name), init=ud.callType(cls.var)), + cls.checkedWrite( + None, typevar, cls.writervar, sentinelKey=uniontype.name() + ), + Whitespace.NL, + writeswitch, + ] + + readswitch = StmtSwitch(typevar) + read = prelude + [ + cls._checkedRead( + None, + Type.INT, + typevar, + uniontype.name(), + "type of union " + uniontype.name(), + ), + Whitespace.NL, + readswitch, + ] + + for c in ud.components: + caselabel = CaseLabel(alias + "::" + c.enum()) + origenum = c.enum() + + writecase = StmtBlock() + wstmt = cls.checkedWrite( + c.ipdltype, + ExprCall(ExprSelect(cls.var, ".", c.getTypeName())), + cls.writervar, + sentinelKey=c.enum(), + ) + writecase.addstmts([wstmt, StmtReturn()]) + writeswitch.addcase(caselabel, writecase) + + readcase = StmtBlock() + tmpvar = ExprVar("tmp") + readcase.addstmts( + [ + cls._checkedRead( + c.ipdltype, + c.bareType(fq=True), + tmpvar, + origenum, + "variant " + origenum + " of union " + uniontype.name(), + ), + StmtReturn(ExprMove(tmpvar)), + ] + ) + readswitch.addcase(caselabel, readcase) + + # Add the error default case + writeswitch.addcase( + DefaultLabel(), + StmtBlock( + [ + cls.fatalError( + cls.writervar, "unknown variant of union " + uniontype.name() + ), + StmtReturn(), + ] + ), + ) + readswitch.addcase( + DefaultLabel(), + StmtBlock( + [ + cls.fatalError( + cls.readervar, "unknown variant of union " + uniontype.name() + ), + StmtReturn(readResultError()), + ] + ), + ) + + return cls.generateDecl( + cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(uniontype) + ) + + +# -------------------------------------------------- + + +class _ComputeTypeDeps(TypeVisitor): + """Pass that gathers the C++ types that a particular IPDL type + (recursively) depends on. There are three kinds of dependencies: (i) + types that need forward declaration; (ii) types that need a |using| + stmt; (iii) IPDL structs or unions which must be fully declared + before this struct. Some types generate multiple kinds.""" + + def __init__(self, fortype, typesToIncludes=None): + ipdl.type.TypeVisitor.__init__(self) + self.usingTypedefs = [] + self.forwardDeclStmts = [] + self.fullDeclTypes = [] + self.includeHeaders = set() + self.fortype = fortype + self.typesToIncludes = typesToIncludes + + def maybeTypedef(self, fqname, name, templateargs=[]): + assert fqname.startswith("::") + if fqname != name: + self.usingTypedefs.append(Typedef(Type(fqname), name, templateargs)) + if self.typesToIncludes is not None and fqname in self.typesToIncludes: + self.includeHeaders.add(self.typesToIncludes[fqname]) + + def visitImportedCxxType(self, t): + if t in self.visited: + return + self.visited.add(t) + self.maybeTypedef(t.fullname(), t.name()) + + def visitActorType(self, t): + if t in self.visited: + return + self.visited.add(t) + + fqname, name = t.fullname(), t.name() + + self.includeHeaders.add("mozilla/ipc/SideVariant.h") + self.maybeTypedef(_actorName(fqname, "Parent"), _actorName(name, "Parent")) + self.maybeTypedef(_actorName(fqname, "Child"), _actorName(name, "Child")) + + self.forwardDeclStmts.extend( + [ + _makeForwardDeclForActor(t.protocol, "parent"), + Whitespace.NL, + _makeForwardDeclForActor(t.protocol, "child"), + Whitespace.NL, + ] + ) + + def visitStructOrUnionType(self, su, defaultVisit): + if su in self.visited or su == self.fortype: + return + self.visited.add(su) + self.maybeTypedef(su.fullname(), su.name()) + + # Mutually recursive fields in unions are behind indirection, so we only + # need a forward decl, and don't need a full type declaration. + if isinstance(self.fortype, UnionType) and self.fortype.mutuallyRecursiveWith( + su + ): + self.forwardDeclStmts.append(_makeForwardDecl(su)) + else: + self.fullDeclTypes.append(su) + + return defaultVisit(self, su) + + def visitStructType(self, t): + return self.visitStructOrUnionType(t, TypeVisitor.visitStructType) + + def visitUnionType(self, t): + return self.visitStructOrUnionType(t, TypeVisitor.visitUnionType) + + def visitArrayType(self, t): + return TypeVisitor.visitArrayType(self, t) + + def visitMaybeType(self, m): + return TypeVisitor.visitMaybeType(self, m) + + def visitShmemType(self, s): + if s in self.visited: + return + self.visited.add(s) + self.maybeTypedef("::mozilla::ipc::Shmem", "Shmem") + + def visitByteBufType(self, s): + if s in self.visited: + return + self.visited.add(s) + self.maybeTypedef("::mozilla::ipc::ByteBuf", "ByteBuf") + + def visitFDType(self, s): + if s in self.visited: + return + self.visited.add(s) + self.maybeTypedef("::mozilla::ipc::FileDescriptor", "FileDescriptor") + + def visitEndpointType(self, s): + if s in self.visited: + return + self.visited.add(s) + self.maybeTypedef("::mozilla::ipc::Endpoint", "Endpoint", ["FooSide"]) + self.visitActorType(s.actor) + + def visitManagedEndpointType(self, s): + if s in self.visited: + return + self.visited.add(s) + self.maybeTypedef( + "::mozilla::ipc::ManagedEndpoint", "ManagedEndpoint", ["FooSide"] + ) + self.visitActorType(s.actor) + + def visitUniquePtrType(self, s): + if s in self.visited: + return + self.visited.add(s) + + def visitVoidType(self, v): + assert 0 + + def visitMessageType(self, v): + assert 0 + + def visitProtocolType(self, v): + assert 0 + + +def _fieldStaticAssertions(sd): + staticasserts = [] + for (size, fields) in itertools.groupby( + sd.fields_member_order(), lambda f: pod_size(f.ipdltype) + ): + if size == pod_size_sentinel: + continue + + fields = list(fields) + if len(fields) == 1: + continue + + staticasserts.append( + StmtCode( + """ + static_assert( + (offsetof(${struct}, ${last}) - offsetof(${struct}, ${first})) == ${expected}, + "Bad assumptions about field layout!"); + """, + struct=sd.name, + first=fields[0].memberVar(), + last=fields[-1].memberVar(), + expected=ExprLiteral.Int(size * (len(fields) - 1)), + ) + ) + + return staticasserts + + +def _generateCxxStruct(sd): + """ """ + # compute all the typedefs and forward decls we need to make + gettypedeps = _ComputeTypeDeps(sd.decl.type) + for f in sd.fields: + f.ipdltype.accept(gettypedeps) + + usingTypedefs = gettypedeps.usingTypedefs + forwarddeclstmts = gettypedeps.forwardDeclStmts + fulldecltypes = gettypedeps.fullDeclTypes + + struct = Class(sd.name, final=True) + struct.addstmts([Label.PRIVATE] + usingTypedefs + [Whitespace.NL, Label.PUBLIC]) + + constreftype = Type(sd.name, const=True, ref=True) + + # Struct() + # We want the default constructor to be declared if it is available, but + # some of our members may not be default-constructible. Silence the + # warning which clang generates in that case. + # + # Members which need value initialization will be handled by wrapping + # the member in a template type when declaring them. + struct.addcode( + """ + #ifdef __clang__ + # pragma clang diagnostic push + # if __has_warning("-Wdefaulted-function-deleted") + # pragma clang diagnostic ignored "-Wdefaulted-function-deleted" + # endif + #endif + ${name}() = default; + #ifdef __clang__ + # pragma clang diagnostic pop + #endif + + """, + name=sd.name, + ) + + # If this is an empty struct (no fields), then the default ctor + # and "create-with-fields" ctors are equivalent. + if len(sd.fields): + assert len(sd.fields) == len(sd.packed_field_order) + + # Struct(const field1& _f1, ...) + valctor = ConstructorDefn( + ConstructorDecl( + sd.name, + params=[ + Decl( + f.forceMoveType() + if _cxxTypeNeedsMoveForData(f.ipdltype) + else f.constRefType(), + f.argVar().name, + ) + for f in sd.fields_ipdl_order() + ], + force_inline=True, + ) + ) + valctor.memberinits = [] + for f in sd.fields_member_order(): + arg = f.argVar() + if _cxxTypeNeedsMoveForData(f.ipdltype): + arg = ExprMove(arg) + valctor.memberinits.append(ExprMemberInit(f.memberVar(), args=[arg])) + + struct.addstmts([valctor, Whitespace.NL]) + + # If a constructor which moves each argument would be different from the + # `const T&` version, also generate that constructor. + if not all( + _cxxTypeNeedsMoveForData(f.ipdltype) or not _cxxTypeCanMove(f.ipdltype) + for f in sd.fields_ipdl_order() + ): + # Struct(field1&& _f1, ...) + valmovector = ConstructorDefn( + ConstructorDecl( + sd.name, + params=[ + Decl( + f.forceMoveType() + if _cxxTypeCanMove(f.ipdltype) + else f.constRefType(), + f.argVar().name, + ) + for f in sd.fields_ipdl_order() + ], + force_inline=True, + ) + ) + + valmovector.memberinits = [] + for f in sd.fields_member_order(): + arg = f.argVar() + if _cxxTypeCanMove(f.ipdltype): + arg = ExprMove(arg) + valmovector.memberinits.append( + ExprMemberInit(f.memberVar(), args=[arg]) + ) + + struct.addstmts([valmovector, Whitespace.NL]) + + # The default copy, move, and assignment constructors, and the default + # destructor, will do the right thing. + + if "Comparable" in sd.attributes: + # bool operator==(const Struct& _o) + ovar = ExprVar("_o") + opeqeq = MethodDefn( + MethodDecl( + "operator==", + params=[Decl(constreftype, ovar.name)], + ret=Type.BOOL, + const=True, + ) + ) + for f in sd.fields_ipdl_order(): + ifneq = StmtIf( + ExprNot( + ExprBinary( + ExprCall(f.getMethod()), "==", ExprCall(f.getMethod(ovar)) + ) + ) + ) + ifneq.addifstmt(StmtReturn.FALSE) + opeqeq.addstmt(ifneq) + opeqeq.addstmt(StmtReturn.TRUE) + struct.addstmts([opeqeq, Whitespace.NL]) + + # bool operator!=(const Struct& _o) + opneq = MethodDefn( + MethodDecl( + "operator!=", + params=[Decl(constreftype, ovar.name)], + ret=Type.BOOL, + const=True, + ) + ) + opneq.addstmt(StmtReturn(ExprNot(ExprCall(ExprVar("operator=="), args=[ovar])))) + struct.addstmts([opneq, Whitespace.NL]) + + # field1& f1() + # const field1& f1() const + for f in sd.fields_ipdl_order(): + get = MethodDefn( + MethodDecl( + f.getMethod().name, params=[], ret=f.refType(), force_inline=True + ) + ) + get.addstmt(StmtReturn(f.refExpr())) + + getconstdecl = deepcopy(get.decl) + getconstdecl.ret = f.constRefType() + getconstdecl.const = True + getconst = MethodDefn(getconstdecl) + getconst.addstmt(StmtReturn(f.constRefExpr())) + + struct.addstmts([get, getconst, Whitespace.NL]) + + # private: + struct.addstmt(Label.PRIVATE) + + # Static assertions to ensure our assumptions about field layout match + # what the compiler is actually producing. We define this as a member + # function, rather than throwing the assertions in the constructor or + # similar, because we don't want to evaluate the static assertions every + # time the header file containing the structure is included. + staticasserts = _fieldStaticAssertions(sd) + if staticasserts: + method = MethodDefn( + MethodDecl("StaticAssertions", params=[], ret=Type.VOID, const=True) + ) + method.addstmts(staticasserts) + struct.addstmts([method]) + + # members + struct.addstmts( + [ + StmtDecl(Decl(_effectiveMemberType(f), f.memberVar().name)) + for f in sd.fields_member_order() + ] + ) + + return forwarddeclstmts, fulldecltypes, struct + + +def _effectiveMemberType(f): + effective_type = f.bareType() + # Structs must be copyable for backwards compatibility reasons, so we use + # CopyableTArray<T> as their member type for arrays. This is not exposed + # in the method signatures, these keep using nsTArray<T>, which is a base + # class of CopyableTArray<T>. + if effective_type.name == "nsTArray": + effective_type.name = "CopyableTArray" + return Type("::mozilla::ipc::IPDLStructMember", T=[effective_type]) + + +# -------------------------------------------------- + + +def _generateCxxUnion(ud): + # This Union class basically consists of a type (enum) and a + # union for storage. The union can contain POD and non-POD + # types. Each type needs a copy/move ctor, assignment operators, + # and dtor. + # + # Rather than templating this class and only providing + # specializations for the types we support, which is slightly + # "unsafe" in that C++ code can add additional specializations + # without the IPDL compiler's knowledge, we instead explicitly + # implement non-templated methods for each supported type. + # + # The one complication that arises is that C++, for arcane + # reasons, does not allow the placement destructor of a + # builtin type, like int, to be directly invoked. So we need + # to hack around this by internally typedef'ing all + # constituent types. Sigh. + # + # So, for each type, this "Union" class needs: + # (private) + # - entry in the type enum + # - entry in the storage union + # - [type]ptr() method to get a type* from the underlying union + # - same as above to get a const type* + # - typedef to hack around placement delete limitations + # (public) + # - placement delete case for dtor + # - copy ctor + # - move ctor + # - case in generic copy ctor + # - copy operator= impl + # - move operator= impl + # - case in generic operator= + # - operator [type&] + # - operator [const type&] const + # - [type&] get_[type]() + # - [const type&] get_[type]() const + # + cls = Class(ud.name, final=True) + # const Union&, i.e., Union type with inparam semantics + inClsType = Type(ud.name, const=True, ref=True) + refClsType = Type(ud.name, ref=True) + rvalueRefClsType = Type(ud.name, rvalref=True) + typetype = Type("Type") + valuetype = Type("Value") + mtypevar = ExprVar("mType") + mvaluevar = ExprVar("mValue") + maybedtorvar = ExprVar("MaybeDestroy") + assertsanityvar = ExprVar("AssertSanity") + tnonevar = ExprVar("T__None") + tlastvar = ExprVar("T__Last") + + def callAssertSanity(uvar=None, expectTypeVar=None): + func = assertsanityvar + args = [] + if uvar is not None: + func = ExprSelect(uvar, ".", assertsanityvar.name) + if expectTypeVar is not None: + args.append(expectTypeVar) + return ExprCall(func, args=args) + + def maybeDestroy(): + return StmtExpr(ExprCall(maybedtorvar)) + + # compute all the typedefs and forward decls we need to make + gettypedeps = _ComputeTypeDeps(ud.decl.type) + for c in ud.components: + c.ipdltype.accept(gettypedeps) + + usingTypedefs = gettypedeps.usingTypedefs + forwarddeclstmts = gettypedeps.forwardDeclStmts + fulldecltypes = gettypedeps.fullDeclTypes + + # the |Type| enum, used to switch on the discunion's real type + cls.addstmt(Label.PUBLIC) + typeenum = TypeEnum(typetype.name) + typeenum.addId(tnonevar.name, 0) + firstid = ud.components[0].enum() + typeenum.addId(firstid, 1) + for c in ud.components[1:]: + typeenum.addId(c.enum()) + typeenum.addId(tlastvar.name, ud.components[-1].enum()) + cls.addstmts([StmtDecl(Decl(typeenum, "")), Whitespace.NL]) + + cls.addstmt(Label.PRIVATE) + cls.addstmts( + usingTypedefs + # hacky typedef's that allow placement dtors of builtins + + [Typedef(c.internalType(), c.typedef()) for c in ud.components] + ) + cls.addstmt(Whitespace.NL) + + # the C++ union the discunion use for storage + valueunion = TypeUnion(valuetype.name) + for c in ud.components: + valueunion.addComponent(c.unionType(), c.name) + cls.addstmts([StmtDecl(Decl(valueunion, "")), Whitespace.NL]) + + # for each constituent type T, add private accessors that + # return a pointer to the Value union storage casted to |T*| + # and |const T*| + for c in ud.components: + getptr = MethodDefn( + MethodDecl( + c.getPtrName(), params=[], ret=c.ptrToInternalType(), force_inline=True + ) + ) + getptr.addstmt(StmtReturn(c.ptrToSelfExpr())) + + getptrconst = MethodDefn( + MethodDecl( + c.getConstPtrName(), + params=[], + ret=c.constPtrToType(), + const=True, + force_inline=True, + ) + ) + getptrconst.addstmt(StmtReturn(c.constptrToSelfExpr())) + + cls.addstmts([getptr, getptrconst]) + cls.addstmt(Whitespace.NL) + + # add a helper method that invokes the placement dtor on the + # current underlying value, only if |aNewType| is different + # than the current type, and returns true if the underlying + # value needs to be re-constructed + maybedtor = MethodDefn(MethodDecl(maybedtorvar.name, ret=Type.VOID)) + # wasn't /actually/ dtor'd, but it needs to be re-constructed + ifnone = StmtIf(ExprBinary(mtypevar, "==", tnonevar)) + ifnone.addifstmt(StmtReturn()) + # need to destroy. switch on underlying type + dtorswitch = StmtSwitch(mtypevar) + for c in ud.components: + dtorswitch.addcase( + CaseLabel(c.enum()), StmtBlock([StmtExpr(c.callDtor()), StmtBreak()]) + ) + dtorswitch.addcase( + DefaultLabel(), StmtBlock([_logicError("not reached"), StmtBreak()]) + ) + maybedtor.addstmts([ifnone, dtorswitch]) + cls.addstmts([maybedtor, Whitespace.NL]) + + # add helper methods that ensure the discunion has a + # valid type + sanity = MethodDefn( + MethodDecl(assertsanityvar.name, ret=Type.VOID, const=True, force_inline=True) + ) + sanity.addstmts( + [ + _abortIfFalse(ExprBinary(tnonevar, "<=", mtypevar), "invalid type tag"), + _abortIfFalse(ExprBinary(mtypevar, "<=", tlastvar), "invalid type tag"), + ] + ) + cls.addstmt(sanity) + + atypevar = ExprVar("aType") + sanity2 = MethodDefn( + MethodDecl( + assertsanityvar.name, + params=[Decl(typetype, atypevar.name)], + ret=Type.VOID, + const=True, + force_inline=True, + ) + ) + sanity2.addstmts( + [ + StmtExpr(ExprCall(assertsanityvar)), + _abortIfFalse(ExprBinary(mtypevar, "==", atypevar), "unexpected type tag"), + ] + ) + cls.addstmts([sanity2, Whitespace.NL]) + + # ---- begin public methods ----- + + # Union() default ctor + cls.addstmts( + [ + Label.PUBLIC, + ConstructorDefn( + ConstructorDecl(ud.name, force_inline=True), + memberinits=[ExprMemberInit(mtypevar, [tnonevar])], + ), + Whitespace.NL, + ] + ) + + # Union(const T&) copy & Union(T&&) move ctors + othervar = ExprVar("aOther") + for c in ud.components: + if not _cxxTypeNeedsMoveForData(c.ipdltype): + copyctor = ConstructorDefn( + ConstructorDecl(ud.name, params=[Decl(c.constRefType(), othervar.name)]) + ) + copyctor.addstmts( + [ + StmtExpr(c.callCtor(othervar)), + StmtExpr(ExprAssn(mtypevar, c.enumvar())), + ] + ) + cls.addstmts([copyctor, Whitespace.NL]) + + if not _cxxTypeCanMove(c.ipdltype): + continue + movector = ConstructorDefn( + ConstructorDecl(ud.name, params=[Decl(c.forceMoveType(), othervar.name)]) + ) + movector.addstmts( + [ + StmtExpr(c.callCtor(ExprMove(othervar))), + StmtExpr(ExprAssn(mtypevar, c.enumvar())), + ] + ) + cls.addstmts([movector, Whitespace.NL]) + + unionNeedsMove = any(_cxxTypeNeedsMoveForData(c.ipdltype) for c in ud.components) + + # Union(const Union&) copy ctor + if not unionNeedsMove: + copyctor = ConstructorDefn( + ConstructorDecl(ud.name, params=[Decl(inClsType, othervar.name)]) + ) + othertype = ud.callType(othervar) + copyswitch = StmtSwitch(othertype) + for c in ud.components: + copyswitch.addcase( + CaseLabel(c.enum()), + StmtBlock( + [ + StmtExpr( + c.callCtor( + ExprCall( + ExprSelect(othervar, ".", c.getConstTypeName()) + ) + ) + ), + StmtBreak(), + ] + ), + ) + copyswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()])) + copyswitch.addcase( + DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()]) + ) + copyctor.addstmts( + [ + StmtExpr(callAssertSanity(uvar=othervar)), + copyswitch, + StmtExpr(ExprAssn(mtypevar, othertype)), + ] + ) + cls.addstmts([copyctor, Whitespace.NL]) + + # Union(Union&&) move ctor + movector = ConstructorDefn( + ConstructorDecl(ud.name, params=[Decl(rvalueRefClsType, othervar.name)]) + ) + othertypevar = ExprVar("t") + moveswitch = StmtSwitch(othertypevar) + for c in ud.components: + case = StmtBlock() + if c.recursive: + # This is sound as we set othervar.mTypeVar to T__None after the + # switch. The pointer in the union will be left dangling. + case.addstmts( + [ + # ptr_C() = other.ptr_C() + StmtExpr( + ExprAssn( + c.callGetPtr(), + ExprCall( + ExprSelect(othervar, ".", ExprVar(c.getPtrName())) + ), + ) + ) + ] + ) + else: + case.addstmts( + [ + # new ... (Move(other.get_C())) + StmtExpr( + c.callCtor( + ExprMove( + ExprCall(ExprSelect(othervar, ".", c.getTypeName())) + ) + ) + ), + # other.MaybeDestroy(T__None) + StmtExpr(ExprCall(ExprSelect(othervar, ".", maybedtorvar))), + ] + ) + case.addstmts([StmtBreak()]) + moveswitch.addcase(CaseLabel(c.enum()), case) + moveswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()])) + moveswitch.addcase( + DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()]) + ) + movector.addstmts( + [ + StmtExpr(callAssertSanity(uvar=othervar)), + StmtDecl(Decl(typetype, othertypevar.name), init=ud.callType(othervar)), + moveswitch, + StmtExpr(ExprAssn(ExprSelect(othervar, ".", mtypevar), tnonevar)), + StmtExpr(ExprAssn(mtypevar, othertypevar)), + ] + ) + cls.addstmts([movector, Whitespace.NL]) + + # ~Union() + dtor = DestructorDefn(DestructorDecl(ud.name)) + dtor.addstmt(maybeDestroy()) + cls.addstmts([dtor, Whitespace.NL]) + + # type() + typemeth = MethodDefn( + MethodDecl("type", ret=typetype, const=True, force_inline=True) + ) + typemeth.addstmt(StmtReturn(mtypevar)) + cls.addstmts([typemeth, Whitespace.NL]) + + # Union& operator= methods + rhsvar = ExprVar("aRhs") + for c in ud.components: + + def opeqBody(rhs): + return [ + # might need to placement-delete old value first + maybeDestroy(), + StmtExpr(c.callCtor(rhs)), + StmtExpr(ExprAssn(mtypevar, c.enumvar())), + StmtReturn(ExprDeref(ExprVar.THIS)), + ] + + if not _cxxTypeNeedsMoveForData(c.ipdltype): + # Union& operator=(const T&) + opeq = MethodDefn( + MethodDecl( + "operator=", + params=[Decl(c.constRefType(), rhsvar.name)], + ret=refClsType, + ) + ) + opeq.addstmts(opeqBody(rhsvar)) + cls.addstmts([opeq, Whitespace.NL]) + + # Union& operator=(T&&) + if not _cxxTypeCanMove(c.ipdltype): + continue + + opeq = MethodDefn( + MethodDecl( + "operator=", + params=[Decl(c.forceMoveType(), rhsvar.name)], + ret=refClsType, + ) + ) + opeq.addstmts(opeqBody(ExprMove(rhsvar))) + cls.addstmts([opeq, Whitespace.NL]) + + # Union& operator=(const Union&) + if not unionNeedsMove: + opeq = MethodDefn( + MethodDecl( + "operator=", params=[Decl(inClsType, rhsvar.name)], ret=refClsType + ) + ) + rhstypevar = ExprVar("t") + opeqswitch = StmtSwitch(rhstypevar) + for c in ud.components: + case = StmtBlock() + case.addstmts( + [ + maybeDestroy(), + StmtExpr( + c.callCtor( + ExprCall(ExprSelect(rhsvar, ".", c.getConstTypeName())) + ) + ), + StmtBreak(), + ] + ) + opeqswitch.addcase(CaseLabel(c.enum()), case) + opeqswitch.addcase( + CaseLabel(tnonevar.name), + StmtBlock([maybeDestroy(), StmtBreak()]), + ) + opeqswitch.addcase( + DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()]) + ) + opeq.addstmts( + [ + StmtExpr(callAssertSanity(uvar=rhsvar)), + StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)), + opeqswitch, + StmtExpr(ExprAssn(mtypevar, rhstypevar)), + StmtReturn(ExprDeref(ExprVar.THIS)), + ] + ) + cls.addstmts([opeq, Whitespace.NL]) + + # Union& operator=(Union&&) + opeq = MethodDefn( + MethodDecl( + "operator=", params=[Decl(rvalueRefClsType, rhsvar.name)], ret=refClsType + ) + ) + rhstypevar = ExprVar("t") + opeqswitch = StmtSwitch(rhstypevar) + for c in ud.components: + case = StmtBlock() + if c.recursive: + case.addstmts( + [ + maybeDestroy(), + StmtExpr( + ExprAssn( + c.callGetPtr(), + ExprCall(ExprSelect(rhsvar, ".", ExprVar(c.getPtrName()))), + ) + ), + ] + ) + else: + case.addstmts( + [ + maybeDestroy(), + StmtExpr( + c.callCtor( + ExprMove(ExprCall(ExprSelect(rhsvar, ".", c.getTypeName()))) + ) + ), + # other.MaybeDestroy() + StmtExpr(ExprCall(ExprSelect(rhsvar, ".", maybedtorvar))), + ] + ) + case.addstmts([StmtBreak()]) + opeqswitch.addcase(CaseLabel(c.enum()), case) + opeqswitch.addcase( + CaseLabel(tnonevar.name), + StmtBlock([maybeDestroy(), StmtBreak()]), + ) + opeqswitch.addcase( + DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()]) + ) + opeq.addstmts( + [ + StmtExpr(callAssertSanity(uvar=rhsvar)), + StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)), + opeqswitch, + StmtExpr(ExprAssn(ExprSelect(rhsvar, ".", mtypevar), tnonevar)), + StmtExpr(ExprAssn(mtypevar, rhstypevar)), + StmtReturn(ExprDeref(ExprVar.THIS)), + ] + ) + cls.addstmts([opeq, Whitespace.NL]) + + if "Comparable" in ud.attributes: + # bool operator==(const T&) + for c in ud.components: + opeqeq = MethodDefn( + MethodDecl( + "operator==", + params=[Decl(c.constRefType(), rhsvar.name)], + ret=Type.BOOL, + const=True, + ) + ) + opeqeq.addstmt( + StmtReturn(ExprBinary(ExprCall(ExprVar(c.getTypeName())), "==", rhsvar)) + ) + cls.addstmts([opeqeq, Whitespace.NL]) + + # bool operator==(const Union&) + opeqeq = MethodDefn( + MethodDecl( + "operator==", + params=[Decl(inClsType, rhsvar.name)], + ret=Type.BOOL, + const=True, + ) + ) + iftypesmismatch = StmtIf(ExprBinary(ud.callType(), "!=", ud.callType(rhsvar))) + iftypesmismatch.addifstmt(StmtReturn.FALSE) + opeqeq.addstmts([iftypesmismatch, Whitespace.NL]) + + opeqeqswitch = StmtSwitch(ud.callType()) + for c in ud.components: + case = StmtBlock() + case.addstmt( + StmtReturn( + ExprBinary( + ExprCall(ExprVar(c.getTypeName())), + "==", + ExprCall(ExprSelect(rhsvar, ".", c.getTypeName())), + ) + ) + ) + opeqeqswitch.addcase(CaseLabel(c.enum()), case) + opeqeqswitch.addcase( + DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn.FALSE]) + ) + opeqeq.addstmt(opeqeqswitch) + + cls.addstmts([opeqeq, Whitespace.NL]) + + # accessors for each type: operator T&, operator const T&, + # T& get(), const T& get() + for c in ud.components: + getValueVar = ExprVar(c.getTypeName()) + getConstValueVar = ExprVar(c.getConstTypeName()) + + getvalue = MethodDefn( + MethodDecl(getValueVar.name, ret=c.refType(), force_inline=True) + ) + getvalue.addstmts( + [ + StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())), + StmtReturn(ExprDeref(c.callGetPtr())), + ] + ) + + getconstvalue = MethodDefn( + MethodDecl( + getConstValueVar.name, + ret=c.constRefType(), + const=True, + force_inline=True, + ) + ) + getconstvalue.addstmts( + [ + StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())), + StmtReturn(c.getConstValue()), + ] + ) + + cls.addstmts([getvalue, getconstvalue]) + + optype = MethodDefn(MethodDecl("", typeop=c.refType(), force_inline=True)) + optype.addstmt(StmtReturn(ExprCall(getValueVar))) + opconsttype = MethodDefn( + MethodDecl("", const=True, typeop=c.constRefType(), force_inline=True) + ) + opconsttype.addstmt(StmtReturn(ExprCall(getConstValueVar))) + + cls.addstmts([optype, opconsttype, Whitespace.NL]) + # private vars + cls.addstmts( + [ + Label.PRIVATE, + StmtDecl(Decl(valuetype, mvaluevar.name)), + StmtDecl(Decl(typetype, mtypevar.name)), + ] + ) + + return forwarddeclstmts, fulldecltypes, cls + + +# ----------------------------------------------------------------------------- + + +class _FindFriends(ipdl.ast.Visitor): + def __init__(self): + self.mytype = None # ProtocolType + self.vtype = None # ProtocolType + self.friends = set() # set<ProtocolType> + + def findFriends(self, ptype): + self.mytype = ptype + for toplvl in ptype.toplevels(): + self.walkDownTheProtocolTree(toplvl) + return self.friends + + # TODO could make this into a _iterProtocolTreeHelper ... + def walkDownTheProtocolTree(self, ptype): + if ptype != self.mytype: + # don't want to |friend| ourself! + self.visit(ptype) + for mtype in ptype.manages: + if mtype is not ptype: + self.walkDownTheProtocolTree(mtype) + + def visit(self, ptype): + # |vtype| is the type currently being visited + savedptype = self.vtype + self.vtype = ptype + ptype._ast.accept(self) + self.vtype = savedptype + + def visitMessageDecl(self, md): + for it in self.iterActorParams(md): + if it.protocol == self.mytype: + self.friends.add(self.vtype) + + def iterActorParams(self, md): + for param in md.inParams: + for actor in ipdl.type.iteractortypes(param.type): + yield actor + for ret in md.outParams: + for actor in ipdl.type.iteractortypes(ret.type): + yield actor + + +class _GenerateProtocolActorCode(ipdl.ast.Visitor): + def __init__(self, myside): + self.side = myside # "parent" or "child" + self.prettyside = myside.title() + self.clsname = None + self.protocol = None + self.hdrfile = None + self.cppfile = None + self.ns = None + self.cls = None + self.protocolCxxIncludes = [] + self.actorForwardDecls = [] + self.usingDecls = [] + self.externalIncludes = set() + self.nonForwardDeclaredHeaders = set() + self.typedefSet = set( + [ + Typedef(Type("mozilla::ipc::ActorHandle"), "ActorHandle"), + Typedef(Type("base::ProcessId"), "ProcessId"), + Typedef(Type("mozilla::ipc::ProtocolId"), "ProtocolId"), + Typedef(Type("mozilla::ipc::Endpoint"), "Endpoint", ["FooSide"]), + Typedef( + Type("mozilla::ipc::ManagedEndpoint"), + "ManagedEndpoint", + ["FooSide"], + ), + Typedef(Type("mozilla::UniquePtr"), "UniquePtr", ["T"]), + Typedef( + Type("mozilla::ipc::ResponseRejectReason"), "ResponseRejectReason" + ), + ] + ) + + def lower(self, tu, clsname, cxxHeaderFile, cxxFile): + self.clsname = clsname + self.hdrfile = cxxHeaderFile + self.cppfile = cxxFile + tu.accept(self) + + def standardTypedefs(self): + return [ + Typedef(Type("mozilla::ipc::IProtocol"), "IProtocol"), + Typedef(Type("IPC::Message"), "Message"), + Typedef(Type("base::ProcessHandle"), "ProcessHandle"), + Typedef(Type("mozilla::ipc::MessageChannel"), "MessageChannel"), + Typedef(Type("mozilla::ipc::SharedMemory"), "SharedMemory"), + ] + + def visitTranslationUnit(self, tu): + self.protocol = tu.protocol + + hf = self.hdrfile + cf = self.cppfile + + # make the C++ header + hf.addthings( + [_DISCLAIMER] + + _includeGuardStart(hf) + + [ + Whitespace.NL, + CppDirective("include", '"' + _protocolHeaderName(tu.protocol) + '.h"'), + ] + ) + + for inc in tu.includes: + inc.accept(self) + for inc in tu.cxxIncludes: + inc.accept(self) + + for using in tu.builtinUsing: + using.accept(self) + for using in tu.using: + using.accept(self) + for su in tu.structsAndUnions: + su.accept(self) + + # this generates the actor's full impl in self.cls + tu.protocol.accept(self) + + clsdecl, clsdefn = _splitClassDeclDefn(self.cls) + + # XXX damn C++ ... return types in the method defn aren't in + # class scope + for stmt in clsdefn.stmts: + if isinstance(stmt, MethodDefn): + if stmt.decl.ret and stmt.decl.ret.name == "Result": + stmt.decl.ret.name = clsdecl.name + "::" + stmt.decl.ret.name + + def setToIncludes(s): + return [CppDirective("include", '"%s"' % i) for i in sorted(iter(s))] + + def makeNamespace(p, file): + if 0 == len(p.namespaces): + return file + ns = Namespace(p.namespaces[-1].name) + outerns = _putInNamespaces(ns, p.namespaces[:-1]) + file.addthing(outerns) + return ns + + if len(self.nonForwardDeclaredHeaders) != 0: + self.hdrfile.addthings( + [ + Whitespace("// Headers for things that cannot be forward declared"), + Whitespace.NL, + ] + + setToIncludes(self.nonForwardDeclaredHeaders) + + [Whitespace.NL] + ) + self.hdrfile.addthings(self.actorForwardDecls) + self.hdrfile.addthings(self.usingDecls) + + hdrns = makeNamespace(self.protocol, self.hdrfile) + hdrns.addstmts( + [Whitespace.NL, Whitespace.NL, clsdecl, Whitespace.NL, Whitespace.NL] + ) + + actortype = ActorType(tu.protocol.decl.type) + traitsdecl, traitsdefn = _ParamTraits.actorPickling(actortype, self.side) + + self.hdrfile.addthings([traitsdecl, Whitespace.NL] + _includeGuardEnd(hf)) + + # If the implementation type is not overridden, add an implicit import + # for the default implementation header file. Explicit implementation + # types will specify their headers manually with `include`. + if self.protocol.implAttribute(self.side) is None: + assert self.protocol.name.startswith("P") + self.externalIncludes.add( + "".join(n.name + "/" for n in self.protocol.namespaces) + + self.protocol.name[1:] + + self.side.capitalize() + + ".h" + ) + + # make the .cpp file + cf.addthings( + [ + _DISCLAIMER, + Whitespace.NL, + CppDirective( + "include", + '"' + _protocolHeaderName(self.protocol, self.side) + '.h"', + ), + ] + + setToIncludes(self.externalIncludes) + ) + + cf.addthings( + ( + [Whitespace.NL] + + [ + CppDirective("include", '"%s.h"' % (inc)) + for inc in self.protocolCxxIncludes + ] + + [Whitespace.NL] + + [ + CppDirective("include", '"%s"' % filename) + for filename in ipdl.builtin.CppIncludes + ] + + [Whitespace.NL] + ) + ) + + cppns = makeNamespace(self.protocol, cf) + cppns.addstmts( + [Whitespace.NL, Whitespace.NL, clsdefn, Whitespace.NL, Whitespace.NL] + ) + + cf.addthing(traitsdefn) + + def visitUsingStmt(self, using): + if using.decl.fullname is not None: + self.typedefSet.add( + Typedef(Type(using.decl.fullname), using.decl.shortname) + ) + + if using.header is None: + return + + if using.canBeForwardDeclared(): + spec = using.type + + self.usingDecls.extend( + [ + _makeForwardDeclForQClass( + spec.baseid, + spec.quals, + cls=using.isClass(), + struct=using.isStruct(), + ), + Whitespace.NL, + ] + ) + self.externalIncludes.add(using.header) + else: + self.nonForwardDeclaredHeaders.add(using.header) + + def visitCxxInclude(self, inc): + self.externalIncludes.add(inc.file) + + def visitInclude(self, inc): + if inc.tu.filetype == "header": + # Including a header will declare any globals defined by "using" + # statements into our scope. To serialize these, we also may need + # cxx include statements, so visit them as well. + for cxxinc in inc.tu.cxxIncludes: + cxxinc.accept(self) + for using in inc.tu.using: + using.accept(self) + for su in inc.tu.structsAndUnions: + su.accept(self) + else: + # Includes for protocols only include types explicitly exported by + # those protocols. + ip = inc.tu.protocol + if ip == self.protocol: + return + + self.actorForwardDecls.extend( + [ + _makeForwardDeclForActor(ip.decl.type, self.side), + _makeForwardDeclForActor(ip.decl.type, _otherSide(self.side)), + Whitespace.NL, + ] + ) + self.protocolCxxIncludes.append(_protocolHeaderName(ip, self.side)) + + if ip.decl.fullname is not None: + self.typedefSet.add( + Typedef( + Type(_actorName(ip.decl.fullname, self.side.title())), + _actorName(ip.decl.shortname, self.side.title()), + ) + ) + + self.typedefSet.add( + Typedef( + Type( + _actorName(ip.decl.fullname, _otherSide(self.side).title()) + ), + _actorName(ip.decl.shortname, _otherSide(self.side).title()), + ) + ) + + def visitStructDecl(self, sd): + if sd.decl.fullname is not None: + self.typedefSet.add(Typedef(Type(sd.fqClassName()), sd.name)) + + def visitUnionDecl(self, ud): + if ud.decl.fullname is not None: + self.typedefSet.add(Typedef(Type(ud.fqClassName()), ud.name)) + + def visitProtocol(self, p): + self.hdrfile.addcode( + """ + #ifdef DEBUG + #include "prenv.h" + #endif // DEBUG + + #include "mozilla/Tainting.h" + #include "mozilla/ipc/MessageChannel.h" + #include "mozilla/ipc/ProtocolUtils.h" + """ + ) + + self.protocol = p + ptype = p.decl.type + toplevel = p.decl.type.toplevel() + + hasAsyncReturns = False + for md in p.messageDecls: + if md.hasAsyncReturns(): + hasAsyncReturns = True + break + + inherits = [] + if ptype.isToplevel(): + inherits.append(Inherit(p.openedProtocolInterfaceType(), viz="public")) + else: + inherits.append(Inherit(p.managerInterfaceType(), viz="public")) + + if ptype.isToplevel() and self.side == "parent": + self.hdrfile.addthings( + [_makeForwardDeclForQClass("nsIFile", []), Whitespace.NL] + ) + + self.cls = Class(self.clsname, inherits=inherits, abstract=True) + + self.cls.addstmt(Label.PRIVATE) + friends = _FindFriends().findFriends(ptype) + if ptype.isManaged(): + friends.update(ptype.managers) + + # |friend| managed actors so that they can call our Dealloc*() + friends.update(ptype.manages) + + # don't friend ourself if we're a self-managed protocol + friends.discard(ptype) + + for friend in sorted(friends, key=lambda f: f.fullname()): + self.actorForwardDecls.extend( + [_makeForwardDeclForActor(friend, self.prettyside), Whitespace.NL] + ) + self.cls.addstmt( + FriendClassDecl(_actorName(friend.fullname(), self.prettyside)) + ) + + self.cls.addstmt(Label.PROTECTED) + for typedef in sorted(self.typedefSet): + self.cls.addstmt(typedef) + + self.cls.addstmt(Whitespace.NL) + + if hasAsyncReturns: + self.cls.addstmt(Label.PUBLIC) + for md in p.messageDecls: + if self.sendsMessage(md) and md.hasAsyncReturns(): + self.cls.addstmt( + Typedef(_makePromise(md.returns, self.side), md.promiseName()) + ) + if self.receivesMessage(md) and md.hasAsyncReturns(): + self.cls.addstmt( + Typedef(_makeResolver(md.returns, self.side), md.resolverName()) + ) + self.cls.addstmt(Whitespace.NL) + + self.cls.addstmt(Label.PROTECTED) + # interface methods that the concrete subclass has to impl + for md in p.messageDecls: + isctor, isdtor = md.decl.type.isCtor(), md.decl.type.isDtor() + + if self.receivesMessage(md): + # generate Recv/Answer* interface + implicit = not isdtor + returnsems = "resolver" if md.decl.type.isAsync() else "out" + recvDecl = MethodDecl( + md.recvMethod(), + params=md.makeCxxParams( + paramsems="move", + returnsems=returnsems, + side=self.side, + implicit=implicit, + direction="recv", + ), + ret=Type("mozilla::ipc::IPCResult"), + methodspec=MethodSpec.VIRTUAL, + ) + + # These method implementations cause problems when trying to + # override them with different types in a direct call class. + # + # For the `isdtor` case there's a simple solution: it doesn't + # make much sense to specify arguments and then completely + # ignore them, and the no-arg case isn't a problem for + # overriding. + if isctor or (isdtor and not md.inParams): + defaultRecv = MethodDefn(recvDecl) + defaultRecv.addcode("return IPC_OK();\n") + self.cls.addstmt(defaultRecv) + elif self.protocol.implAttribute(self.side) == "virtual": + # If we're using virtual calls, we need the methods to be + # declared on the base class. + recvDecl.methodspec = MethodSpec.PURE + self.cls.addstmt(StmtDecl(recvDecl)) + + # If we're using virtual calls, we need the methods to be declared on + # the base class. + if self.protocol.implAttribute(self.side) == "virtual": + for md in p.messageDecls: + managed = md.decl.type.constructedType() + if not ptype.isManagerOf(managed) or md.decl.type.isDtor(): + continue + + # add the Alloc interface for managed actors + actortype = md.actorDecl().bareType(self.side) + + if managed.isRefcounted(): + if not self.receivesMessage(md): + continue + + actortype.ptr = False + actortype = _alreadyaddrefed(actortype) + + self.cls.addstmt( + StmtDecl( + MethodDecl( + _allocMethod(managed, self.side), + params=md.makeCxxParams( + side=self.side, implicit=False, direction="recv" + ), + ret=actortype, + methodspec=MethodSpec.PURE, + ) + ) + ) + + # add the Dealloc interface for all managed non-refcounted actors, + # even without ctors. This is useful for protocols which use + # ManagedEndpoint for construction. + for managed in ptype.manages: + if managed.isRefcounted(): + continue + + self.cls.addstmt( + StmtDecl( + MethodDecl( + _deallocMethod(managed, self.side), + params=[ + Decl(p.managedCxxType(managed, self.side), "aActor") + ], + ret=Type.BOOL, + methodspec=MethodSpec.PURE, + ) + ) + ) + + if ptype.isToplevel(): + # void ProcessingError(code); default to no-op + processingerror = MethodDefn( + MethodDecl( + p.processingErrorVar().name, + params=[ + Param(_Result.Type(), "aCode"), + Param(Type("char", const=True, ptr=True), "aReason"), + ], + methodspec=MethodSpec.OVERRIDE, + ) + ) + + # bool ShouldContinueFromReplyTimeout(); default to |true| + shouldcontinue = MethodDefn( + MethodDecl( + p.shouldContinueFromTimeoutVar().name, + ret=Type.BOOL, + methodspec=MethodSpec.OVERRIDE, + ) + ) + shouldcontinue.addcode("return true;\n") + + self.cls.addstmts( + [ + processingerror, + shouldcontinue, + Whitespace.NL, + ] + ) + + self.cls.addstmts(([Label.PUBLIC] + self.standardTypedefs() + [Whitespace.NL])) + + self.cls.addstmt(Label.PUBLIC) + # Actor() + ctor = ConstructorDefn(ConstructorDecl(self.clsname)) + side = ExprVar("mozilla::ipc::" + self.side.title() + "Side") + if ptype.isToplevel(): + name = ExprLiteral.String(_actorName(p.name, self.side)) + ctor.memberinits = [ + ExprMemberInit( + ExprVar("mozilla::ipc::IToplevelProtocol"), + [name, _protocolId(ptype), side], + ) + ] + else: + ctor.memberinits = [ + ExprMemberInit( + ExprVar("mozilla::ipc::IProtocol"), [_protocolId(ptype), side] + ) + ] + + ctor.addcode("MOZ_COUNT_CTOR(${clsname});\n", clsname=self.clsname) + self.cls.addstmts([ctor, Whitespace.NL]) + + # ~Actor() + dtor = DestructorDefn( + DestructorDecl(self.clsname, methodspec=MethodSpec.VIRTUAL) + ) + dtor.addcode("MOZ_COUNT_DTOR(${clsname});\n", clsname=self.clsname) + + self.cls.addstmts([dtor, Whitespace.NL]) + + if ptype.isRefcounted(): + if not ptype.isToplevel(): + self.cls.addcode( + """ + NS_INLINE_DECL_PURE_VIRTUAL_REFCOUNTING + """ + ) + self.cls.addstmt(Label.PROTECTED) + self.cls.addcode( + """ + void ActorAlloc() final { AddRef(); } + void ActorDealloc() final { Release(); } + """ + ) + + self.cls.addstmt(Label.PUBLIC) + if ptype.hasOtherPid(): + otherpidmeth = MethodDefn( + MethodDecl("OtherPid", ret=Type("::base::ProcessId"), const=True) + ) + otherpidmeth.addcode( + """ + ::base::ProcessId pid = + ::mozilla::ipc::IProtocol::ToplevelProtocol()->OtherPidMaybeInvalid(); + MOZ_RELEASE_ASSERT(pid != ::base::kInvalidProcessId); + return pid; + """ + ) + self.cls.addstmts([otherpidmeth, Whitespace.NL]) + + if not ptype.isToplevel(): + if 1 == len(p.managers): + # manager() const + managertype = p.managerActorType(self.side, ptr=True) + managermeth = MethodDefn( + MethodDecl("Manager", ret=managertype, const=True) + ) + managermeth.addcode( + """ + return static_cast<${type}>(IProtocol::Manager()); + """, + type=managertype, + ) + + self.cls.addstmts([managermeth, Whitespace.NL]) + + def actorFromIter(itervar): + return ExprCode("${iter}.Get()->GetKey()", iter=itervar) + + def forLoopOverHashtable(hashtable, itervar, const=False): + itermeth = "ConstIter" if const else "Iter" + return StmtFor( + init=ExprCode( + "auto ${itervar} = ${hashtable}.${itermeth}()", + itervar=itervar, + hashtable=hashtable, + itermeth=itermeth, + ), + cond=ExprCode("!${itervar}.Done()", itervar=itervar), + update=ExprCode("${itervar}.Next()", itervar=itervar), + ) + + # Managed[T](Array& inout) const + # const Array<T>& Managed() const + for managed in ptype.manages: + container = p.managedVar(managed, self.side) + + meth = MethodDefn( + MethodDecl( + p.managedMethod(managed, self.side).name, + params=[ + Decl( + _cxxArrayType( + p.managedCxxType(managed, self.side), ref=True + ), + "aArr", + ) + ], + const=True, + ) + ) + meth.addcode("${container}.ToArray(aArr);\n", container=container) + + refmeth = MethodDefn( + MethodDecl( + p.managedMethod(managed, self.side).name, + params=[], + ret=p.managedVarType(managed, self.side, const=True, ref=True), + const=True, + ) + ) + refmeth.addcode("return ${container};\n", container=container) + + self.cls.addstmts([meth, refmeth, Whitespace.NL]) + + # AllManagedActors(Array& inout) const + arrvar = ExprVar("arr__") + managedmeth = MethodDefn( + MethodDecl( + "AllManagedActors", + params=[ + Decl( + _cxxArrayType(_refptr(_cxxLifecycleProxyType()), ref=True), + arrvar.name, + ) + ], + methodspec=MethodSpec.OVERRIDE, + const=True, + ) + ) + + # Count the number of managed actors, and allocate space in the output array. + managedmeth.addcode( + """ + uint32_t total = 0; + """ + ) + for managed in ptype.manages: + managedmeth.addcode( + """ + total += ${container}.Count(); + """, + container=p.managedVar(managed, self.side), + ) + managedmeth.addcode( + """ + arr__.SetCapacity(total); + + """ + ) + + for managed in ptype.manages: + managedmeth.addcode( + """ + for (auto* key : ${container}) { + arr__.AppendElement(key->GetLifecycleProxy()); + } + + """, + container=p.managedVar(managed, self.side), + ) + + self.cls.addstmts([managedmeth, Whitespace.NL]) + + # OpenPEndpoint(...)/BindPEndpoint(...) + for managed in ptype.manages: + self.genManagedEndpoint(managed) + + # OnMessageReceived()/OnCallReceived() + + # save these away for use in message handler case stmts + msgvar = ExprVar("msg__") + self.msgvar = msgvar + replyvar = ExprVar("reply__") + self.replyvar = replyvar + var = ExprVar("v__") + self.var = var + # for ctor recv cases, we can't read the actor ID into a PFoo* + # because it doesn't exist on this side yet. Use a "special" + # actor handle instead + handlevar = ExprVar("handle__") + self.handlevar = handlevar + + msgtype = ExprCode("msg__.type()") + self.asyncSwitch = StmtSwitch(msgtype) + self.syncSwitch = None + self.interruptSwitch = None + if toplevel.isSync() or toplevel.isInterrupt(): + self.syncSwitch = StmtSwitch(msgtype) + if toplevel.isInterrupt(): + self.interruptSwitch = StmtSwitch(msgtype) + + # Add a handler for the MANAGED_ENDPOINT_BOUND and + # MANAGED_ENDPOINT_DROPPED message types for managed actors. + if not ptype.isToplevel(): + clearawaitingmanagedendpointbind = """ + if (!mAwaitingManagedEndpointBind) { + NS_WARNING("Unexpected managed endpoint lifecycle message after actor bound!"); + return MsgNotAllowed; + } + mAwaitingManagedEndpointBind = false; + """ + self.asyncSwitch.addcase( + CaseLabel("MANAGED_ENDPOINT_BOUND_MESSAGE_TYPE"), + StmtBlock( + [ + StmtCode(clearawaitingmanagedendpointbind), + StmtReturn(_Result.Processed), + ] + ), + ) + self.asyncSwitch.addcase( + CaseLabel("MANAGED_ENDPOINT_DROPPED_MESSAGE_TYPE"), + StmtBlock( + [ + StmtCode(clearawaitingmanagedendpointbind), + *self.destroyActor( + None, + ExprVar.THIS, + why=_DestroyReason.ManagedEndpointDropped, + ), + StmtReturn(_Result.Processed), + ] + ), + ) + + # implement Send*() methods and add dispatcher cases to + # message switch()es + for md in p.messageDecls: + self.visitMessageDecl(md) + + # add default cases + default = StmtCode( + """ + return MsgNotKnown; + """ + ) + self.asyncSwitch.addcase(DefaultLabel(), default) + if toplevel.isSync() or toplevel.isInterrupt(): + self.syncSwitch.addcase(DefaultLabel(), default) + if toplevel.isInterrupt(): + self.interruptSwitch.addcase(DefaultLabel(), default) + + self.cls.addstmts(self.implementManagerIface()) + + def makeHandlerMethod(name, switch, hasReply, dispatches=False): + params = [Decl(Type("Message", const=True, ref=True), msgvar.name)] + if hasReply: + params.append(Decl(Type("UniquePtr<Message>", ref=True), replyvar.name)) + + method = MethodDefn( + MethodDecl( + name, + methodspec=MethodSpec.OVERRIDE, + params=params, + ret=_Result.Type(), + ) + ) + + if not switch: + method.addcode( + """ + MOZ_ASSERT_UNREACHABLE("message protocol not supported"); + return MsgNotKnown; + """ + ) + return method + + if dispatches: + if hasReply: + ondeadactor = [StmtReturn(_Result.RouteError)] + else: + ondeadactor = [ + self.logMessage( + None, ExprAddrOf(msgvar), "Ignored message for dead actor" + ), + StmtReturn(_Result.Processed), + ] + + method.addcode( + """ + int32_t route__ = ${msgvar}.routing_id(); + if (MSG_ROUTING_CONTROL != route__) { + IProtocol* routed__ = Lookup(route__); + if (!routed__ || !routed__->GetLifecycleProxy()) { + $*{ondeadactor} + } + + RefPtr<mozilla::ipc::ActorLifecycleProxy> proxy__ = + routed__->GetLifecycleProxy(); + return proxy__->Get()->${name}($,{args}); + } + + """, + msgvar=msgvar, + ondeadactor=ondeadactor, + name=name, + args=[p.name for p in params], + ) + + # bug 509581: don't generate the switch stmt if there + # is only the default case; MSVC doesn't like that + if switch.nr_cases > 1: + method.addstmt(switch) + else: + method.addstmt(StmtReturn(_Result.NotKnown)) + + return method + + dispatches = ptype.isToplevel() and ptype.isManager() + self.cls.addstmts( + [ + makeHandlerMethod( + "OnMessageReceived", + self.asyncSwitch, + hasReply=False, + dispatches=dispatches, + ), + Whitespace.NL, + ] + ) + self.cls.addstmts( + [ + makeHandlerMethod( + "OnMessageReceived", + self.syncSwitch, + hasReply=True, + dispatches=dispatches, + ), + Whitespace.NL, + ] + ) + self.cls.addstmts( + [ + makeHandlerMethod( + "OnCallReceived", + self.interruptSwitch, + hasReply=True, + dispatches=dispatches, + ), + Whitespace.NL, + ] + ) + + clearsubtreevar = ExprVar("ClearSubtree") + + if ptype.isToplevel(): + # OnChannelClose() + onclose = MethodDefn( + MethodDecl("OnChannelClose", methodspec=MethodSpec.OVERRIDE) + ) + onclose.addcode( + """ + DestroySubtree(NormalShutdown); + ClearSubtree(); + DeallocShmems(); + if (GetLifecycleProxy()) { + GetLifecycleProxy()->Release(); + } + """ + ) + self.cls.addstmts([onclose, Whitespace.NL]) + + # OnChannelError() + onerror = MethodDefn( + MethodDecl("OnChannelError", methodspec=MethodSpec.OVERRIDE) + ) + onerror.addcode( + """ + DestroySubtree(AbnormalShutdown); + ClearSubtree(); + DeallocShmems(); + if (GetLifecycleProxy()) { + GetLifecycleProxy()->Release(); + } + """ + ) + self.cls.addstmts([onerror, Whitespace.NL]) + + if ptype.isToplevel() and ptype.isInterrupt(): + processnative = MethodDefn( + MethodDecl("ProcessNativeEventsInInterruptCall", ret=Type.VOID) + ) + processnative.addcode( + """ + #ifdef OS_WIN + GetIPCChannel()->ProcessNativeEventsInInterruptCall(); + #else + FatalError("This method is Windows-only"); + #endif + """ + ) + + self.cls.addstmts([processnative, Whitespace.NL]) + + # private methods + self.cls.addstmt(Label.PRIVATE) + + # ClearSubtree() + clearsubtree = MethodDefn(MethodDecl(clearsubtreevar.name)) + for managed in ptype.manages: + clearsubtree.addcode( + """ + for (auto* key : ${container}) { + key->ClearSubtree(); + } + for (auto* key : ${container}) { + // Recursively releasing ${container} kids. + auto* proxy = key->GetLifecycleProxy(); + NS_IF_RELEASE(proxy); + } + ${container}.Clear(); + + """, + container=p.managedVar(managed, self.side), + ) + + # don't release our own IPC reference: either the manager will do it, + # or we're toplevel + self.cls.addstmts([clearsubtree, Whitespace.NL]) + + if not ptype.isToplevel(): + self.cls.addstmts( + [ + StmtDecl( + Decl(Type.BOOL, "mAwaitingManagedEndpointBind"), + init=ExprLiteral.FALSE, + ), + Whitespace.NL, + ] + ) + + for managed in ptype.manages: + self.cls.addstmts( + [ + StmtDecl( + Decl( + p.managedVarType(managed, self.side), + p.managedVar(managed, self.side).name, + ) + ) + ] + ) + + def genManagedEndpoint(self, managed): + hereEp = "ManagedEndpoint<%s>" % _actorName(managed.name(), self.side) + thereEp = "ManagedEndpoint<%s>" % _actorName( + managed.name(), _otherSide(self.side) + ) + + actor = _HybridDecl(ipdl.type.ActorType(managed), "aActor") + + # ManagedEndpoint<PThere> OpenPEndpoint(PHere* aActor) + openmeth = MethodDefn( + MethodDecl( + "Open%sEndpoint" % managed.name(), + params=[ + Decl(self.protocol.managedCxxType(managed, self.side), actor.name) + ], + ret=Type(thereEp), + ) + ) + openmeth.addcode( + """ + $*{bind} + // Mark our actor as awaiting the other side to be bound. This will + // be cleared when a `MANAGED_ENDPOINT_{DROPPED,BOUND}` message is + // received. + aActor->mAwaitingManagedEndpointBind = true; + return ${thereEp}(mozilla::ipc::PrivateIPDLInterface(), aActor); + """, + bind=self.bindManagedActor(actor, errfn=ExprCall(ExprVar(thereEp))), + thereEp=thereEp, + ) + + # void BindPEndpoint(ManagedEndpoint<PHere>&& aEndpoint, PHere* aActor) + bindmeth = MethodDefn( + MethodDecl( + "Bind%sEndpoint" % managed.name(), + params=[ + Decl(Type(hereEp), "aEndpoint"), + Decl(self.protocol.managedCxxType(managed, self.side), actor.name), + ], + ret=Type.BOOL, + ) + ) + bindmeth.addcode( + """ + return aEndpoint.Bind(mozilla::ipc::PrivateIPDLInterface(), aActor, this, ${container}); + """, + container=self.protocol.managedVar(managed, self.side), + ) + + self.cls.addstmts([openmeth, bindmeth, Whitespace.NL]) + + def implementManagerIface(self): + p = self.protocol + protocolbase = Type("IProtocol", ptr=True) + + methods = [] + + if p.decl.type.isToplevel(): + # FIXME: This used to be declared conditionally based on whether + # shmem appeared somewhere in the protocol hierarchy, however that + # caused issues due to Shmem instances hidden within custom C++ + # types. + self.asyncSwitch.addcase( + CaseLabel("SHMEM_CREATED_MESSAGE_TYPE"), + self.genShmemCreatedHandler(), + ) + self.asyncSwitch.addcase( + CaseLabel("SHMEM_DESTROYED_MESSAGE_TYPE"), + self.genShmemDestroyedHandler(), + ) + + # Keep track of types created with an INOUT ctor. We need to call + # Register() or RegisterID() for them depending on the side the managee + # is created. + inoutCtorTypes = [] + for msg in p.messageDecls: + msgtype = msg.decl.type + if msgtype.isCtor() and msgtype.isInout(): + inoutCtorTypes.append(msgtype.constructedType()) + + # all protocols share the "same" RemoveManagee() implementation + pvar = ExprVar("aProtocolId") + listenervar = ExprVar("aListener") + removemanagee = MethodDefn( + MethodDecl( + p.removeManageeMethod().name, + params=[ + Decl(_protocolIdType(), pvar.name), + Decl(protocolbase, listenervar.name), + ], + methodspec=MethodSpec.OVERRIDE, + ) + ) + + if not len(p.managesStmts): + removemanagee.addcode( + """ + FatalError("unreached"); + return; + """ + ) + else: + switchontype = StmtSwitch(pvar) + for managee in p.managesStmts: + manageeipdltype = managee.decl.type + manageecxxtype = _cxxBareType( + ipdl.type.ActorType(manageeipdltype), self.side + ) + case = ExprCode( + """ + { + ${manageecxxtype} actor = static_cast<${manageecxxtype}>(aListener); + + const bool removed = ${container}.EnsureRemoved(actor); + MOZ_RELEASE_ASSERT(removed, "actor not managed by this!"); + + auto* proxy = actor->GetLifecycleProxy(); + NS_IF_RELEASE(proxy); + return; + } + """, + manageecxxtype=manageecxxtype, + container=p.managedVar(manageeipdltype, self.side), + ) + switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case) + switchontype.addcase( + DefaultLabel(), + ExprCode( + """ + FatalError("unreached"); + return; + """ + ), + ) + removemanagee.addstmt(switchontype) + + # The `DeallocManagee` method is called for managed actors to trigger + # deallocation when ActorLifecycleProxy is freed. + deallocmanagee = MethodDefn( + MethodDecl( + p.deallocManageeMethod().name, + params=[ + Decl(_protocolIdType(), pvar.name), + Decl(protocolbase, listenervar.name), + ], + methodspec=MethodSpec.OVERRIDE, + ) + ) + + if not len(p.managesStmts): + deallocmanagee.addcode( + """ + FatalError("unreached"); + return; + """ + ) + else: + switchontype = StmtSwitch(pvar) + for managee in p.managesStmts: + manageeipdltype = managee.decl.type + # Reference counted actor types don't have corresponding + # `Dealloc` methods, as they are deallocated by releasing the + # IPDL-held reference. + if manageeipdltype.isRefcounted(): + continue + + case = StmtCode( + """ + ${concrete}->${dealloc}(static_cast<${type}>(aListener)); + return; + """, + concrete=self.concreteThis(), + dealloc=_deallocMethod(manageeipdltype, self.side), + type=_cxxBareType(ipdl.type.ActorType(manageeipdltype), self.side), + ) + switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case) + switchontype.addcase( + DefaultLabel(), + StmtCode( + """ + FatalError("unreached"); + return; + """ + ), + ) + deallocmanagee.addstmt(switchontype) + + return methods + [removemanagee, deallocmanagee, Whitespace.NL] + + def genShmemCreatedHandler(self): + assert self.protocol.decl.type.isToplevel() + + return StmtCode( + """ + { + if (!ShmemCreated(${msgvar})) { + return MsgPayloadError; + } + return MsgProcessed; + } + """, + msgvar=self.msgvar, + ) + + def genShmemDestroyedHandler(self): + assert self.protocol.decl.type.isToplevel() + + return StmtCode( + """ + { + if (!ShmemDestroyed(${msgvar})) { + return MsgPayloadError; + } + return MsgProcessed; + } + """, + msgvar=self.msgvar, + ) + + # ------------------------------------------------------------------------- + # The next few functions are the crux of the IPDL code generator. + # They generate code for all the nasty work of message + # serialization/deserialization and dispatching handlers for + # received messages. + ## + + def concreteThis(self): + implAttr = self.protocol.implAttribute(self.side) + if implAttr == "virtual": + return ExprVar.THIS + + if implAttr is None: + assert self.protocol.name.startswith("P") + className = self.protocol.name[1:] + self.side.capitalize() + else: + assert isinstance(implAttr, ipdl.ast.StringLiteral) + className = implAttr.value + + return ExprCode("static_cast<${className}*>(this)", className=className) + + def thisCall(self, function, args): + return ExprCall(ExprSelect(self.concreteThis(), "->", function), args=args) + + def visitMessageDecl(self, md): + isctor = md.decl.type.isCtor() + isdtor = md.decl.type.isDtor() + decltype = md.decl.type + sendmethod = None + movesendmethod = None + promisesendmethod = None + recvlbl, recvcase = None, None + + def addRecvCase(lbl, case): + if decltype.isAsync(): + self.asyncSwitch.addcase(lbl, case) + elif decltype.isSync(): + self.syncSwitch.addcase(lbl, case) + elif decltype.isInterrupt(): + self.interruptSwitch.addcase(lbl, case) + else: + assert 0 + + if self.sendsMessage(md): + isasync = decltype.isAsync() + + # NOTE: Don't generate helper ctors for refcounted types. + # + # Safety concerns around providing your own actor to a ctor (namely + # that the return value won't be checked, and the argument will be + # `delete`-ed) are less critical with refcounted actors, due to the + # actor being held alive by the callsite. + # + # This allows refcounted actors to not implement crashing AllocPFoo + # methods on the sending side. + if isctor and not md.decl.type.constructedType().isRefcounted(): + self.cls.addstmts([self.genHelperCtor(md), Whitespace.NL]) + + if isctor and isasync: + sendmethod, (recvlbl, recvcase) = self.genAsyncCtor(md) + elif isctor: + sendmethod = self.genBlockingCtorMethod(md) + elif isdtor and isasync: + sendmethod, (recvlbl, recvcase) = self.genAsyncDtor(md) + elif isdtor: + sendmethod = self.genBlockingDtorMethod(md) + elif isasync: + ( + sendmethod, + movesendmethod, + promisesendmethod, + (recvlbl, recvcase), + ) = self.genAsyncSendMethod(md) + else: + sendmethod, movesendmethod = self.genBlockingSendMethod(md) + + # XXX figure out what to do here + if isdtor and md.decl.type.constructedType().isToplevel(): + sendmethod = None + + if sendmethod is not None: + self.cls.addstmts([sendmethod, Whitespace.NL]) + if movesendmethod is not None: + self.cls.addstmts([movesendmethod, Whitespace.NL]) + if promisesendmethod is not None: + self.cls.addstmts([promisesendmethod, Whitespace.NL]) + if recvcase is not None: + addRecvCase(recvlbl, recvcase) + recvlbl, recvcase = None, None + + if self.receivesMessage(md): + if isctor: + recvlbl, recvcase = self.genCtorRecvCase(md) + elif isdtor: + recvlbl, recvcase = self.genDtorRecvCase(md) + else: + recvlbl, recvcase = self.genRecvCase(md) + + # XXX figure out what to do here + if isdtor and md.decl.type.constructedType().isToplevel(): + return + + addRecvCase(recvlbl, recvcase) + + def genAsyncCtor(self, md): + actor = md.actorDecl() + method = MethodDefn(self.makeSendMethodDecl(md)) + + msgvar, stmts = self.makeMessage(md, errfnSendCtor) + sendok, sendstmts = self.sendAsync(md, msgvar) + + method.addcode( + """ + $*{bind} + + // Build our constructor message. + $*{stmts} + + // Notify the other side about the newly created actor. This can + // fail if our manager has already been destroyed. + // + // NOTE: If the send call fails due to toplevel channel teardown, + // the `IProtocol::ChannelSend` wrapper absorbs the error for us, + // so we don't tear down actors unexpectedly. + $*{sendstmts} + + // Warn, destroy the actor, and return null if the message failed to + // send. Otherwise, return the successfully created actor reference. + if (!${sendok}) { + NS_WARNING("Error sending ${actorname} constructor"); + $*{destroy} + return nullptr; + } + return ${actor}; + """, + bind=self.bindManagedActor(actor), + stmts=stmts, + sendstmts=sendstmts, + sendok=sendok, + destroy=self.destroyActor( + md, actor.var(), why=_DestroyReason.FailedConstructor + ), + actor=actor.var(), + actorname=actor.ipdltype.protocol.name() + self.side.capitalize(), + ) + + lbl = CaseLabel(md.pqReplyId()) + case = StmtBlock() + case.addstmt(StmtReturn(_Result.Processed)) + # TODO not really sure what to do with async ctor "replies" yet. + # destroy actor if there was an error? tricky ... + + return method, (lbl, case) + + def genBlockingCtorMethod(self, md): + actor = md.actorDecl() + method = MethodDefn(self.makeSendMethodDecl(md)) + + msgvar, stmts = self.makeMessage(md, errfnSendCtor) + + replyvar = self.replyvar + sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar) + replystmts = self.deserializeReply( + md, + replyvar, + self.side, + errfnSendCtor, + errfnSentinel(ExprLiteral.NULL), + ) + + method.addcode( + """ + $*{bind} + + // Build our constructor message. + $*{stmts} + + // Synchronously send the constructor message to the other side. If + // the send fails, e.g. due to the remote side shutting down, the + // actor will be destroyed and potentially freed. + UniquePtr<Message> ${replyvar}; + $*{sendstmts} + + if (!(${sendok})) { + // Warn, destroy the actor and return null if the message + // failed to send. + NS_WARNING("Error sending constructor"); + $*{destroy} + return nullptr; + } + + $*{replystmts} + return ${actor}; + """, + bind=self.bindManagedActor(actor), + stmts=stmts, + replyvar=replyvar, + sendstmts=sendstmts, + sendok=sendok, + destroy=self.destroyActor( + md, actor.var(), why=_DestroyReason.FailedConstructor + ), + replystmts=replystmts, + actor=actor.var(), + actorname=actor.ipdltype.protocol.name() + self.side.capitalize(), + ) + + return method + + def bindManagedActor(self, actordecl, errfn=ExprLiteral.NULL, idexpr=None): + actorproto = actordecl.ipdltype.protocol + + if idexpr is None: + setManagerArgs = [ExprVar.THIS] + else: + setManagerArgs = [ExprVar.THIS, idexpr] + + return [ + StmtCode( + """ + if (!${actor}) { + NS_WARNING("Cannot bind null ${actorname} actor"); + return ${errfn}; + } + + ${actor}->SetManagerAndRegister($,{setManagerArgs}); + ${container}.Insert(${actor}); + """, + actor=actordecl.var(), + actorname=actorproto.name() + self.side.capitalize(), + errfn=errfn, + setManagerArgs=setManagerArgs, + container=self.protocol.managedVar(actorproto, self.side), + ) + ] + + def genHelperCtor(self, md): + helperdecl = self.makeSendMethodDecl(md) + helperdecl.params = helperdecl.params[1:] + helper = MethodDefn(helperdecl) + + helper.addstmts( + [ + self.callAllocActor(md, retsems="out", side=self.side), + StmtReturn( + ExprCall( + ExprVar(helperdecl.name), args=md.makeCxxArgs(paramsems="move") + ) + ), + ] + ) + return helper + + def genAsyncDtor(self, md): + actorvar = ExprVar("actor") + method = MethodDefn(self.makeDtorMethodDecl(md, actorvar)) + + method.addstmt(self.dtorPrologue(actorvar)) + + msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar) + sendok, sendstmts = self.sendAsync(md, msgvar, actorvar) + method.addstmts( + stmts + + sendstmts + + [Whitespace.NL] + + self.dtorEpilogue(md, actorvar) + + [StmtReturn(sendok)] + ) + + lbl = CaseLabel(md.pqReplyId()) + case = StmtBlock() + case.addstmt(StmtReturn(_Result.Processed)) + # TODO if the dtor is "inherently racy", keep the actor alive + # until the other side acks + + return method, (lbl, case) + + def genBlockingDtorMethod(self, md): + actorvar = ExprVar("actor") + method = MethodDefn(self.makeDtorMethodDecl(md, actorvar)) + + method.addstmt(self.dtorPrologue(actorvar)) + + msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar) + + replyvar = self.replyvar + sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar, actorvar) + method.addstmts( + stmts + + [Whitespace.NL, StmtDecl(Decl(Type("UniquePtr<Message>"), replyvar.name))] + + sendstmts + ) + + destmts = self.deserializeReply( + md, replyvar, self.side, errfnSend, errfnSentinel(), actorvar + ) + ifsendok = StmtIf(ExprLiteral.FALSE) + ifsendok.addifstmts(destmts) + ifsendok.addifstmts( + [Whitespace.NL, StmtExpr(ExprAssn(sendok, ExprLiteral.FALSE, "&="))] + ) + + method.addstmt(ifsendok) + + method.addstmts( + self.dtorEpilogue(md, actorvar) + [Whitespace.NL, StmtReturn(sendok)] + ) + + return method + + def destroyActor(self, md, actorexpr, why=_DestroyReason.Deletion): + if md and md.decl.type.isCtor(): + destroyedType = md.decl.type.constructedType() + else: + destroyedType = self.protocol.decl.type + + return [ + StmtCode( + """ + IProtocol* mgr = ${actor}->Manager(); + ${actor}->DestroySubtree(${why}); + ${actor}->ClearSubtree(); + mgr->RemoveManagee(${protoId}, ${actor}); + """, + actor=actorexpr, + why=why, + protoId=_protocolId(destroyedType), + ) + ] + + def dtorPrologue(self, actorexpr): + return StmtCode( + """ + if (!${actor} || !${actor}->CanSend()) { + NS_WARNING("Attempt to __delete__ missing or closed actor"); + return false; + } + """, + actor=actorexpr, + ) + + def dtorEpilogue(self, md, actorexpr): + return self.destroyActor(md, actorexpr) + + def genRecvAsyncReplyCase(self, md): + lbl = CaseLabel(md.pqReplyId()) + case = StmtBlock() + resolve, reason, prologue, desrej, desstmts = self.deserializeAsyncReply( + md, self.side, errfnRecv, errfnSentinel(_Result.ValuError) + ) + + if len(md.returns) > 1: + resolvetype = _tuple([d.bareType(self.side) for d in md.returns]) + resolvearg = ExprCall( + ExprVar("std::make_tuple"), args=[ExprMove(p.var()) for p in md.returns] + ) + else: + resolvetype = md.returns[0].bareType(self.side) + resolvearg = ExprMove(md.returns[0].var()) + + case.addcode( + """ + $*{prologue} + + UniquePtr<MessageChannel::UntypedCallbackHolder> untypedCallback = + GetIPCChannel()->PopCallback(${msgvar}, Id()); + + typedef MessageChannel::CallbackHolder<${resolvetype}> CallbackHolder; + auto* callback = static_cast<CallbackHolder*>(untypedCallback.get()); + if (!callback) { + FatalError("Error unknown callback"); + return MsgProcessingError; + } + + if (${resolve}) { + $*{desstmts} + callback->Resolve(${resolvearg}); + } else { + $*{desrej} + callback->Reject(std::move(${reason})); + } + return MsgProcessed; + """, + prologue=prologue, + msgvar=self.msgvar, + resolve=resolve, + resolvetype=resolvetype, + desstmts=desstmts, + resolvearg=resolvearg, + desrej=desrej, + reason=reason, + ) + + return (lbl, case) + + def genAsyncSendMethod(self, md): + decl = self.makeSendMethodDecl(md) + if "VirtualSendImpl" in md.attributes: + decl.methodspec = MethodSpec.VIRTUAL + method = MethodDefn(decl) + msgvar, stmts = self.makeMessage(md, errfnSend) + retvar, sendstmts = self.sendAsync(md, msgvar) + + method.addstmts(stmts + [Whitespace.NL] + sendstmts + [StmtReturn(retvar)]) + + movemethod = None + + # Add the promise overload if we need one. + if md.returns: + decl = self.makeSendMethodDecl(md, promise=True) + if "VirtualSendImpl" in md.attributes: + decl.methodspec = MethodSpec.VIRTUAL + promisemethod = MethodDefn(decl) + stmts = self.sendAsyncWithPromise(md) + promisemethod.addstmts(stmts) + + (lbl, case) = self.genRecvAsyncReplyCase(md) + else: + (promisemethod, lbl, case) = (None, None, None) + + return method, movemethod, promisemethod, (lbl, case) + + def genBlockingSendMethod(self, md): + method = MethodDefn(self.makeSendMethodDecl(md)) + + msgvar, serstmts = self.makeMessage(md, errfnSend) + replyvar = self.replyvar + + sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar) + failif = StmtIf(ExprNot(sendok)) + failif.addifstmt(StmtReturn.FALSE) + + desstmts = self.deserializeReply( + md, replyvar, self.side, errfnSend, errfnSentinel() + ) + + method.addstmts( + serstmts + + [Whitespace.NL, StmtDecl(Decl(Type("UniquePtr<Message>"), replyvar.name))] + + sendstmts + + [failif] + + desstmts + + [Whitespace.NL, StmtReturn.TRUE] + ) + + movemethod = None + + return method, movemethod + + def genCtorRecvCase(self, md): + lbl = CaseLabel(md.pqMsgId()) + case = StmtBlock() + actorhandle = self.handlevar + + stmts = self.deserializeMessage( + md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError) + ) + + idvar, saveIdStmts = self.saveActorId(md) + case.addstmts( + stmts + + [ + StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[]) + for r in md.returns + ] + # alloc the actor, register it under the foreign ID + + [self.callAllocActor(md, retsems="in", side=self.side)] + + self.bindManagedActor( + md.actorDecl(), errfn=_Result.ValuError, idexpr=_actorHId(actorhandle) + ) + + [Whitespace.NL] + + saveIdStmts + + self.invokeRecvHandler(md) + + self.makeReply(md, errfnRecv, idvar) + + [Whitespace.NL, StmtReturn(_Result.Processed)] + ) + + return lbl, case + + def genDtorRecvCase(self, md): + lbl = CaseLabel(md.pqMsgId()) + case = StmtBlock() + + stmts = self.deserializeMessage( + md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError) + ) + + idvar, saveIdStmts = self.saveActorId(md) + case.addstmts( + stmts + + [ + StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[]) + for r in md.returns + ] + + self.invokeRecvHandler(md) + + [Whitespace.NL] + + saveIdStmts + + self.makeReply(md, errfnRecv, routingId=idvar) + + [Whitespace.NL] + + self.dtorEpilogue(md, ExprVar.THIS) + + [Whitespace.NL, StmtReturn(_Result.Processed)] + ) + + return lbl, case + + def genRecvCase(self, md): + lbl = CaseLabel(md.pqMsgId()) + case = StmtBlock() + + stmts = self.deserializeMessage( + md, self.side, errfn=errfnRecv, errfnSent=errfnSentinel(_Result.ValuError) + ) + + idvar, saveIdStmts = self.saveActorId(md) + declstmts = [ + StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[]) + for r in md.returns + ] + if md.decl.type.isAsync() and md.returns: + declstmts = self.makeResolver(md, errfnRecv, routingId=idvar) + case.addstmts( + stmts + + saveIdStmts + + declstmts + + self.invokeRecvHandler(md) + + [Whitespace.NL] + + self.makeReply(md, errfnRecv, routingId=idvar) + + [StmtReturn(_Result.Processed)] + ) + + return lbl, case + + # helper methods + + def makeMessage(self, md, errfn, fromActor=None): + msgvar = self.msgvar + writervar = ExprVar("writer__") + routingId = self.protocol.routingId(fromActor) + this = fromActor or ExprVar.THIS + + stmts = ( + [ + StmtDecl( + Decl(Type("UniquePtr<IPC::Message>"), msgvar.name), + init=ExprCall(ExprVar(md.pqMsgCtorFunc()), args=[routingId]), + ), + StmtDecl( + Decl(Type("IPC::MessageWriter"), writervar.name), + initargs=[ExprDeref(msgvar), this], + ), + ] + + [Whitespace.NL] + + [ + _ParamTraits.checkedWrite( + p.ipdltype, + p.var(), + ExprAddrOf(writervar), + sentinelKey=p.name, + ) + for p in md.params + ] + + [Whitespace.NL] + + self.setMessageFlags(md, msgvar) + ) + return msgvar, stmts + + def makeResolver(self, md, errfn, routingId): + if routingId is None: + routingId = self.protocol.routingId() + if not md.decl.type.isAsync() or not md.hasReply(): + return [] + + def paramValue(idx): + assert idx < len(md.returns) + if len(md.returns) > 1: + return ExprCode("std::get<${idx}>(aParam)", idx=idx) + return ExprVar("aParam") + + serializeParams = [ + _ParamTraits.checkedWrite( + p.ipdltype, + paramValue(idx), + ExprAddrOf(ExprVar("writer__")), + sentinelKey=p.name, + ) + for idx, p in enumerate(md.returns) + ] + + return [ + StmtCode( + """ + UniquePtr<IPC::Message> ${replyvar}(${replyCtor}(${routingId})); + ${replyvar}->set_seqno(${msgvar}.seqno()); + + RefPtr<mozilla::ipc::IPDLResolverInner> resolver__ = + new mozilla::ipc::IPDLResolverInner(std::move(${replyvar}), this); + + ${resolvertype} resolver = [resolver__ = std::move(resolver__)](${resolveType} aParam) { + resolver__->Resolve([&] (IPC::Message* ${replyvar}, IProtocol* self__) { + IPC::MessageWriter writer__(*${replyvar}, self__); + $*{serializeParams} + ${logSendingReply} + }); + }; + """, + msgvar=self.msgvar, + resolvertype=Type(md.resolverName()), + routingId=routingId, + resolveType=_resolveType(md.returns, self.side), + replyvar=self.replyvar, + replyCtor=ExprVar(md.pqReplyCtorFunc()), + serializeParams=serializeParams, + logSendingReply=self.logMessage( + md, + self.replyvar, + "Sending reply ", + actor=ExprVar("self__"), + ), + ) + ] + + def makeReply(self, md, errfn, routingId): + if routingId is None: + routingId = self.protocol.routingId() + # TODO special cases for async ctor/dtor replies + if not md.decl.type.hasReply(): + return [] + if md.decl.type.isAsync() and md.decl.type.hasReply(): + return [] + + replyvar = self.replyvar + return ( + [ + StmtExpr( + ExprAssn( + replyvar, + ExprCall(ExprVar(md.pqReplyCtorFunc()), args=[routingId]), + ) + ), + StmtDecl( + Decl(Type("IPC::MessageWriter"), "writer__"), + initargs=[ExprDeref(replyvar), ExprVar.THIS], + ), + Whitespace.NL, + ] + + [ + _ParamTraits.checkedWrite( + r.ipdltype, + r.var(), + ExprAddrOf(ExprVar("writer__")), + sentinelKey=r.name, + ) + for r in md.returns + ] + + self.setMessageFlags(md, replyvar) + + [self.logMessage(md, replyvar, "Sending reply ")] + ) + + def setMessageFlags(self, md, var, seqno=None): + stmts = [] + + if seqno: + stmts.append( + StmtExpr(ExprCall(ExprSelect(var, "->", "set_seqno"), args=[seqno])) + ) + + return stmts + [Whitespace.NL] + + def deserializeMessage(self, md, side, errfn, errfnSent): + msgvar = self.msgvar + msgexpr = ExprAddrOf(msgvar) + readervar = ExprVar("reader__") + isctor = md.decl.type.isCtor() + stmts = [ + self.logMessage(md, msgexpr, "Received ", receiving=True), + self.profilerLabel(md), + Whitespace.NL, + ] + + if 0 == len(md.params): + return stmts + + start, reads = 0, [] + if isctor: + # return the raw actor handle so that its ID can be used + # to construct the "real" actor + handlevar = self.handlevar + handletype = Type("ActorHandle") + reads = [ + _ParamTraits.checkedRead( + None, + handletype, + handlevar, + ExprAddrOf(readervar), + errfn, + "'%s'" % handletype.name, + sentinelKey="actor", + errfnSentinel=errfnSent, + ) + ] + start = 1 + + def maybeTainted(p, side): + if md.decl.type.tainted and "NoTaint" not in p.attributes: + return Type("Tainted", T=p.bareType(side)) + return p.bareType(side) + + reads.extend( + [ + _ParamTraits.checkedRead( + p.ipdltype, + maybeTainted(p, side), + p.var(), + ExprAddrOf(readervar), + errfn, + "'%s'" % p.ipdltype.name(), + sentinelKey=p.name, + errfnSentinel=errfnSent, + ) + for p in md.params[start:] + ] + ) + + stmts.extend( + ( + [ + StmtDecl( + Decl(Type("IPC::MessageReader"), readervar.name), + initargs=[msgvar, ExprVar.THIS], + ) + ] + + [Whitespace.NL] + + reads + + [StmtCode("${reader}.EndRead();\n", reader=readervar)] + ) + ) + + return stmts + + def deserializeAsyncReply(self, md, side, errfn, errfnSent): + msgvar = self.msgvar + readervar = ExprVar("reader__") + msgexpr = ExprAddrOf(msgvar) + isctor = md.decl.type.isCtor() + resolve = ExprVar("resolve__") + reason = ExprVar("reason__") + + # NOTE: The `resolve__` and `reason__` parameters don't have sentinels, + # as they are serialized by the IPDLResolverInner type in + # ProtocolUtils.cpp rather than by generated code. + desresolve = [ + StmtCode( + """ + bool resolve__ = false; + if (!IPC::ReadParam(&${readervar}, &resolve__)) { + FatalError("Error deserializing bool"); + return MsgValueError; + } + """, + readervar=readervar, + ), + ] + desrej = [ + StmtCode( + """ + ResponseRejectReason reason__{}; + if (!IPC::ReadParam(&${readervar}, &reason__)) { + FatalError("Error deserializing ResponseRejectReason"); + return MsgValueError; + } + ${readervar}.EndRead(); + """, + readervar=readervar, + ), + ] + prologue = [ + self.logMessage(md, msgexpr, "Received ", receiving=True), + self.profilerLabel(md), + Whitespace.NL, + ] + + if not md.returns: + return prologue + + prologue.extend( + [ + StmtDecl( + Decl(Type("IPC::MessageReader"), readervar.name), + initargs=[msgvar, ExprVar.THIS], + ) + ] + + desresolve + ) + + start, reads = 0, [] + if isctor: + # return the raw actor handle so that its ID can be used + # to construct the "real" actor + handlevar = self.handlevar + handletype = Type("ActorHandle") + reads = [ + _ParamTraits.checkedRead( + None, + handletype, + handlevar, + ExprAddrOf(readervar), + errfn, + "'%s'" % handletype.name, + sentinelKey="actor", + errfnSentinel=errfnSent, + ) + ] + start = 1 + + stmts = ( + reads + + [ + _ParamTraits.checkedRead( + p.ipdltype, + p.bareType(side), + p.var(), + ExprAddrOf(readervar), + errfn, + "'%s'" % p.ipdltype.name(), + sentinelKey=p.name, + errfnSentinel=errfnSent, + ) + for p in md.returns[start:] + ] + + [StmtCode("${reader}.EndRead();", reader=readervar)] + ) + + return resolve, reason, prologue, desrej, stmts + + def deserializeReply(self, md, replyexpr, side, errfn, errfnSentinel, actor=None): + stmts = [ + Whitespace.NL, + self.logMessage(md, replyexpr, "Received reply ", actor, receiving=True), + ] + if 0 == len(md.returns): + return stmts + + def tempvar(r): + return ExprVar(r.var().name + "__reply") + + readervar = ExprVar("reader__") + stmts.extend( + [ + Whitespace.NL, + StmtDecl( + Decl(Type("IPC::MessageReader"), readervar.name), + initargs=[ExprDeref(self.replyvar), ExprVar.THIS], + ), + ] + + [Whitespace.NL] + + [ + _ParamTraits.checkedRead( + r.ipdltype, + r.bareType(side), + tempvar(r), + ExprAddrOf(readervar), + errfn, + "'%s'" % r.ipdltype.name(), + sentinelKey=r.name, + errfnSentinel=errfnSentinel, + ) + for r in md.returns + ] + # Move-assign the values out of the variables created with + # checkedRead into outparams. + + [ + StmtExpr(ExprAssn(ExprDeref(r.var()), ExprMove(tempvar(r)))) + for r in md.returns + ] + + [StmtCode("${reader}.EndRead();", reader=readervar)] + ) + + return stmts + + def sendAsync(self, md, msgexpr, actor=None): + sendok = ExprVar("sendok__") + resolvefn = ExprVar("aResolve") + rejectfn = ExprVar("aReject") + + stmts = [ + Whitespace.NL, + self.logMessage(md, msgexpr, "Sending ", actor), + self.profilerLabel(md), + ] + stmts.append(Whitespace.NL) + + # Generate the actual call expression. + send = ExprVar("ChannelSend") + if actor is not None: + send = ExprSelect(actor, "->", send.name) + if md.returns: + stmts.append( + StmtExpr( + ExprCall( + send, + args=[ + ExprMove(msgexpr), + ExprVar(md.pqReplyId()), + ExprMove(resolvefn), + ExprMove(rejectfn), + ], + ) + ) + ) + retvar = None + else: + stmts.append( + StmtDecl( + Decl(Type.BOOL, sendok.name), + init=ExprCall(send, args=[ExprMove(msgexpr)]), + ) + ) + retvar = sendok + + return (retvar, stmts) + + def sendBlocking(self, md, msgexpr, replyexpr, actor=None): + send = ExprVar("ChannelSend") + if md.decl.type.isInterrupt(): + send = ExprVar("ChannelCall") + if actor is not None: + send = ExprSelect(actor, "->", send.name) + + sendok = ExprVar("sendok__") + self.externalIncludes.add("mozilla/ProfilerMarkers.h") + return ( + sendok, + ( + [ + Whitespace.NL, + self.logMessage(md, msgexpr, "Sending ", actor), + self.profilerLabel(md), + ] + + [ + Whitespace.NL, + StmtDecl(Decl(Type.BOOL, sendok.name), init=ExprLiteral.FALSE), + StmtBlock( + [ + StmtExpr( + ExprCall( + ExprVar("AUTO_PROFILER_TRACING_MARKER"), + [ + ExprLiteral.String("Sync IPC"), + ExprLiteral.String( + self.protocol.name + + "::" + + md.prettyMsgName() + ), + ExprVar("IPC"), + ], + ) + ), + StmtExpr( + ExprAssn( + sendok, + ExprCall( + send, + args=[ExprMove(msgexpr), ExprAddrOf(replyexpr)], + ), + ) + ), + ] + ), + ] + ), + ) + + def sendAsyncWithPromise(self, md): + # Create a new promise, and forward to the callback send overload. + promise = _makePromise(md.returns, self.side, resolver=True) + + if len(md.returns) > 1: + resolvetype = _tuple([d.bareType(self.side) for d in md.returns]) + else: + resolvetype = md.returns[0].bareType(self.side) + + resolve = ExprCode( + """ + [promise__](${resolvetype}&& aValue) { + promise__->Resolve(std::move(aValue), __func__); + } + """, + resolvetype=resolvetype, + ) + reject = ExprCode( + """ + [promise__](ResponseRejectReason&& aReason) { + promise__->Reject(std::move(aReason), __func__); + } + """, + resolvetype=resolvetype, + ) + + args = [ExprMove(p.var()) for p in md.params] + [resolve, reject] + stmt = StmtCode( + """ + RefPtr<${promise}> promise__ = new ${promise}(__func__); + promise__->UseDirectTaskDispatch(__func__); + ${send}($,{args}); + return promise__; + """, + promise=promise, + send=md.sendMethod(), + args=args, + ) + return [stmt] + + def callAllocActor(self, md, retsems, side): + actortype = md.actorDecl().bareType(self.side) + if md.decl.type.constructedType().isRefcounted(): + actortype.ptr = False + actortype = _refptr(actortype) + + callalloc = self.thisCall( + _allocMethod(md.decl.type.constructedType(), side), + args=md.makeCxxArgs(retsems=retsems, retcallsems="out", implicit=False), + ) + + return StmtDecl(Decl(actortype, md.actorDecl().var().name), init=callalloc) + + def invokeRecvHandler(self, md): + retsems = "in" + if md.decl.type.isAsync() and md.returns: + retsems = "resolver" + okdecl = StmtDecl( + Decl(Type("mozilla::ipc::IPCResult"), "__ok"), + init=self.thisCall( + md.recvMethod(), + md.makeCxxArgs( + paramsems="move", + retsems=retsems, + retcallsems="out", + ), + ), + ) + failif = StmtIf(ExprNot(ExprVar("__ok"))) + failif.addifstmts( + [ + _protocolErrorBreakpoint("Handler returned error code!"), + Whitespace( + "// Error handled in mozilla::ipc::IPCResult\n", indent=True + ), + StmtReturn(_Result.ProcessingError), + ] + ) + return [okdecl, failif] + + def makeDtorMethodDecl(self, md, actorvar): + decl = self.makeSendMethodDecl(md) + decl.params.insert( + 0, + Decl( + _cxxInType( + ipdl.type.ActorType(md.decl.type.constructedType()), + side=self.side, + direction="send", + ), + actorvar.name, + ), + ) + decl.methodspec = MethodSpec.STATIC + return decl + + def makeSendMethodDecl(self, md, promise=False, paramsems="in"): + implicit = md.decl.type.hasImplicitActorParam() + if md.decl.type.isAsync() and md.returns: + if promise: + returnsems = "promise" + rettype = _refptr(Type(md.promiseName())) + else: + returnsems = "callback" + rettype = Type.VOID + else: + assert not promise + returnsems = "out" + rettype = Type.BOOL + decl = MethodDecl( + md.sendMethod(), + params=md.makeCxxParams( + paramsems, + returnsems=returnsems, + side=self.side, + implicit=implicit, + direction="send", + ), + warn_unused=( + (self.side == "parent" and returnsems != "callback") + or (md.decl.type.isCtor() and not md.decl.type.isAsync()) + ), + ret=rettype, + ) + if md.decl.type.isCtor(): + decl.ret = md.actorDecl().bareType(self.side) + return decl + + def logMessage(self, md, msgptr, pfx, actor=None, receiving=False): + actorname = _actorName(self.protocol.name, self.side) + return StmtCode( + """ + if (mozilla::ipc::LoggingEnabledFor(${actorname})) { + mozilla::ipc::LogMessageForProtocol( + ${actorname}, + ${actor}->ToplevelProtocol()->OtherPidMaybeInvalid(), + ${pfx}, + ${msgptr}->type(), + mozilla::ipc::MessageDirection::${direction}); + } + """, + actorname=ExprLiteral.String(actorname), + actor=actor or ExprVar.THIS, + pfx=ExprLiteral.String(pfx), + msgptr=msgptr, + direction="eReceiving" if receiving else "eSending", + ) + + def profilerLabel(self, md): + self.externalIncludes.add("mozilla/ProfilerLabels.h") + return StmtCode( + """ + AUTO_PROFILER_LABEL("${name}::${msgname}", OTHER); + """, + name=self.protocol.name, + msgname=md.prettyMsgName(), + ) + + def saveActorId(self, md): + idvar = ExprVar("id__") + if md.decl.type.hasReply(): + # only save the ID if we're actually going to use it, to + # avoid unused-variable warnings + saveIdStmts = [ + StmtDecl(Decl(_actorIdType(), idvar.name), self.protocol.routingId()) + ] + else: + saveIdStmts = [] + return idvar, saveIdStmts + + +class _GenerateProtocolParentCode(_GenerateProtocolActorCode): + def __init__(self): + _GenerateProtocolActorCode.__init__(self, "parent") + + def sendsMessage(self, md): + return not md.decl.type.isIn() + + def receivesMessage(self, md): + return md.decl.type.isInout() or md.decl.type.isIn() + + +class _GenerateProtocolChildCode(_GenerateProtocolActorCode): + def __init__(self): + _GenerateProtocolActorCode.__init__(self, "child") + + def sendsMessage(self, md): + return not md.decl.type.isOut() + + def receivesMessage(self, md): + return md.decl.type.isInout() or md.decl.type.isOut() + + +# ----------------------------------------------------------------------------- +# Utility passes +## + + +def _splitClassDeclDefn(cls): + """Destructively split |cls| methods into declarations and + definitions (if |not methodDecl.force_inline|). Return classDecl, + methodDefns.""" + defns = Block() + + for i, stmt in enumerate(cls.stmts): + if isinstance(stmt, MethodDefn) and not stmt.decl.force_inline: + decl, defn = _splitMethodDeclDefn(stmt, cls) + cls.stmts[i] = StmtDecl(decl) + if defn: + defns.addstmts([defn, Whitespace.NL]) + + return cls, defns + + +def _splitMethodDeclDefn(md, cls): + # Pure methods have decls but no defns. + if md.decl.methodspec == MethodSpec.PURE: + return md.decl, None + + saveddecl = deepcopy(md.decl) + md.decl.cls = cls + # Don't emit method specifiers on method defns. + md.decl.methodspec = MethodSpec.NONE + md.decl.warn_unused = False + md.decl.only_for_definition = True + for param in md.decl.params: + if isinstance(param, Param): + param.default = None + return saveddecl, md + + +def _splitFuncDeclDefn(fun): + assert not fun.decl.force_inline + return StmtDecl(fun.decl), fun |