# This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. import re from copy import deepcopy from collections import OrderedDict import itertools import ipdl.ast import ipdl.builtin from ipdl.cxx.ast import * from ipdl.cxx.code import * from ipdl.type import ActorType, UnionType, TypeVisitor, builtinHeaderIncludes from ipdl.util import hash_str # ----------------------------------------------------------------------------- # "Public" interface to lowering ## class LowerToCxx: def lower(self, tu, segmentcapacitydict): """returns |[ header: File ], [ cpp : File ]| representing the lowered form of |tu|""" # annotate the AST with IPDL/C++ IR-type stuff used later tu.accept(_DecorateWithCxxStuff()) # Any modifications to the filename scheme here need corresponding # modifications in the ipdl.py driver script. name = tu.name pheader, pcpp = File(name + ".h"), File(name + ".cpp") _GenerateProtocolCode().lower(tu, pheader, pcpp, segmentcapacitydict) headers = [pheader] cpps = [pcpp] if tu.protocol: pname = tu.protocol.name parentheader, parentcpp = ( File(pname + "Parent.h"), File(pname + "Parent.cpp"), ) _GenerateProtocolParentCode().lower( tu, pname + "Parent", parentheader, parentcpp ) childheader, childcpp = File(pname + "Child.h"), File(pname + "Child.cpp") _GenerateProtocolChildCode().lower( tu, pname + "Child", childheader, childcpp ) headers += [parentheader, childheader] cpps += [parentcpp, childcpp] return headers, cpps # ----------------------------------------------------------------------------- # Helper code ## def hashfunc(value): h = hash_str(value) % 2 ** 32 if h < 0: h += 2 ** 32 return h _NULL_ACTOR_ID = ExprLiteral.ZERO _FREED_ACTOR_ID = ExprLiteral.ONE _DISCLAIMER = Whitespace( """// // Automatically generated by ipdlc. // Edit at your own risk // """ ) class _struct: pass def _namespacedHeaderName(name, namespaces): pfx = "/".join([ns.name for ns in namespaces]) if pfx: return pfx + "/" + name else: return name def _ipdlhHeaderName(tu): assert tu.filetype == "header" return _namespacedHeaderName(tu.name, tu.namespaces) def _protocolHeaderName(p, side=""): if side: side = side.title() base = p.name + side return _namespacedHeaderName(base, p.namespaces) def _includeGuardMacroName(headerfile): return re.sub(r"[./]", "_", headerfile.name) def _includeGuardStart(headerfile): guard = _includeGuardMacroName(headerfile) return [CppDirective("ifndef", guard), CppDirective("define", guard)] def _includeGuardEnd(headerfile): guard = _includeGuardMacroName(headerfile) return [CppDirective("endif", "// ifndef " + guard)] def _messageStartName(ptype): return ptype.name() + "MsgStart" def _protocolId(ptype): return ExprVar(_messageStartName(ptype)) def _protocolIdType(): return Type.INT32 def _actorName(pname, side): """|pname| is the protocol name. |side| is 'Parent' or 'Child'.""" tag = side if not tag[0].isupper(): tag = side.title() return pname + tag def _actorIdType(): return Type.INT32 def _actorTypeTagType(): return Type.INT32 def _actorId(actor=None): if actor is not None: return ExprCall(ExprSelect(actor, "->", "Id")) return ExprCall(ExprVar("Id")) def _actorHId(actorhandle): return ExprSelect(actorhandle, ".", "mId") def _backstagePass(): return ExprCall(ExprVar("mozilla::ipc::PrivateIPDLInterface")) def _deleteId(): return ExprVar("Msg___delete____ID") def _deleteReplyId(): return ExprVar("Reply___delete____ID") def _lookupListener(idexpr): return ExprCall(ExprVar("Lookup"), args=[idexpr]) def _makeForwardDeclForQClass(clsname, quals, cls=True, struct=False): fd = ForwardDecl(clsname, cls=cls, struct=struct) if 0 == len(quals): return fd outerns = Namespace(quals[0]) innerns = outerns for ns in quals[1:]: tmpns = Namespace(ns) innerns.addstmt(tmpns) innerns = tmpns innerns.addstmt(fd) return outerns def _makeForwardDeclForActor(ptype, side): return _makeForwardDeclForQClass( _actorName(ptype.qname.baseid, side), ptype.qname.quals ) def _makeForwardDecl(type): return _makeForwardDeclForQClass(type.name(), type.qname.quals) def _putInNamespaces(cxxthing, namespaces): """|namespaces| is in order [ outer, ..., inner ]""" if 0 == len(namespaces): return cxxthing outerns = Namespace(namespaces[0].name) innerns = outerns for ns in namespaces[1:]: newns = Namespace(ns.name) innerns.addstmt(newns) innerns = newns innerns.addstmt(cxxthing) return outerns def _sendPrefix(msgtype): """Prefix of the name of the C++ method that sends |msgtype|.""" if msgtype.isInterrupt(): return "Call" return "Send" def _recvPrefix(msgtype): """Prefix of the name of the C++ method that handles |msgtype|.""" if msgtype.isInterrupt(): return "Answer" return "Recv" def _flatTypeName(ipdltype): """Return a 'flattened' IPDL type name that can be used as an identifier. E.g., |Foo[]| --> |ArrayOfFoo|.""" # NB: this logic depends heavily on what IPDL types are allowed to # be constructed; e.g., Foo[][] is disallowed. needs to be kept in # sync with grammar. if ipdltype.isIPDL() and ipdltype.isArray(): return "ArrayOf" + _flatTypeName(ipdltype.basetype) if ipdltype.isIPDL() and ipdltype.isMaybe(): return "Maybe" + _flatTypeName(ipdltype.basetype) # NotNull types just assume the underlying variant name to avoid unnecessary # noise, as a NotNull and T should never exist in the same union. if ipdltype.isIPDL() and ipdltype.isNotNull(): return _flatTypeName(ipdltype.basetype) return ipdltype.name() def _hasVisibleActor(ipdltype): """Return true iff a C++ decl of |ipdltype| would have an Actor* type. For example: |Actor[]| would turn into |Array|, so this function would return true for |Actor[]|.""" return ipdltype.isIPDL() and ( ipdltype.isActor() or (ipdltype.hasBaseType() and _hasVisibleActor(ipdltype.basetype)) ) def _abortIfFalse(cond, msg): return StmtExpr( ExprCall(ExprVar("MOZ_RELEASE_ASSERT"), [cond, ExprLiteral.String(msg)]) ) def _refptr(T): return Type("RefPtr", T=T) def _alreadyaddrefed(T): return Type("already_AddRefed", T=T) def _tuple(types, const=False, ref=False): return Type("std::tuple", T=types, const=const, ref=ref) def _promise(resolvetype, rejecttype, tail, resolver=False): inner = Type("Private") if resolver else None return Type("MozPromise", T=[resolvetype, rejecttype, tail], inner=inner) def _makePromise(returns, side, resolver=False): if len(returns) > 1: resolvetype = _tuple([d.bareType(side) for d in returns]) else: resolvetype = returns[0].bareType(side) # MozPromise is purposefully made to be exclusive only. Really, we mean it. return _promise( resolvetype, _ResponseRejectReason.Type(), ExprLiteral.TRUE, resolver=resolver ) def _resolveType(returns, side): if len(returns) > 1: return _tuple([d.inType(side, "send") for d in returns]) return returns[0].inType(side, "send") def _makeResolver(returns, side): return TypeFunction([Decl(_resolveType(returns, side), "")]) def _cxxArrayType(basetype, const=False, ref=False): return Type("nsTArray", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False) def _cxxSpanType(basetype, const=False, ref=False): basetype = deepcopy(basetype) basetype.rightconst = True return Type( "mozilla::Span", T=basetype, const=const, ref=ref, hasimplicitcopyctor=True ) def _cxxMaybeType(basetype, const=False, ref=False): return Type( "mozilla::Maybe", T=basetype, const=const, ref=ref, hasimplicitcopyctor=basetype.hasimplicitcopyctor, ) def _cxxReadResultType(basetype, const=False, ref=False): return Type( "IPC::ReadResult", T=basetype, const=const, ref=ref, hasimplicitcopyctor=basetype.hasimplicitcopyctor, ) def _cxxNotNullType(basetype, const=False, ref=False): return Type( "mozilla::NotNull", T=basetype, const=const, ref=ref, hasimplicitcopyctor=basetype.hasimplicitcopyctor, ) def _cxxManagedContainerType(basetype, const=False, ref=False): return Type( "ManagedContainer", T=basetype, const=const, ref=ref, hasimplicitcopyctor=False ) def _cxxLifecycleProxyType(ptr=False): return Type("mozilla::ipc::ActorLifecycleProxy", ptr=ptr) def _otherSide(side): if side == "child": return "parent" if side == "parent": return "child" assert 0 def _ifLogging(topLevelProtocol, stmts): return StmtCode( """ if (mozilla::ipc::LoggingEnabledFor(${proto})) { $*{stmts} } """, proto=topLevelProtocol, stmts=stmts, ) # XXX we need to remove these and install proper error handling def _printErrorMessage(msg): if isinstance(msg, str): msg = ExprLiteral.String(msg) return StmtExpr(ExprCall(ExprVar("NS_ERROR"), args=[msg])) def _protocolErrorBreakpoint(msg): if isinstance(msg, str): msg = ExprLiteral.String(msg) return StmtExpr( ExprCall(ExprVar("mozilla::ipc::ProtocolErrorBreakpoint"), args=[msg]) ) def _printWarningMessage(msg): if isinstance(msg, str): msg = ExprLiteral.String(msg) return StmtExpr(ExprCall(ExprVar("NS_WARNING"), args=[msg])) def _fatalError(msg): return StmtExpr(ExprCall(ExprVar("FatalError"), args=[ExprLiteral.String(msg)])) def _logicError(msg): return StmtExpr( ExprCall(ExprVar("mozilla::ipc::LogicError"), args=[ExprLiteral.String(msg)]) ) def _sentinelReadError(classname): return StmtExpr( ExprCall( ExprVar("mozilla::ipc::SentinelReadError"), args=[ExprLiteral.String(classname)], ) ) # Results that IPDL-generated code returns back to *Channel code. # Users never see these class _Result: @staticmethod def Type(): return Type("Result") Processed = ExprVar("MsgProcessed") NotKnown = ExprVar("MsgNotKnown") NotAllowed = ExprVar("MsgNotAllowed") PayloadError = ExprVar("MsgPayloadError") ProcessingError = ExprVar("MsgProcessingError") RouteError = ExprVar("MsgRouteError") ValuError = ExprVar("MsgValueError") # [sic] # these |errfn*| are functions that generate code to be executed on an # error, such as "bad actor ID". each is given a Python string # containing a description of the error # used in user-facing Send*() methods def errfnSend(msg, errcode=ExprLiteral.FALSE): return [_fatalError(msg), StmtReturn(errcode)] def errfnSendCtor(msg): return errfnSend(msg, errcode=ExprLiteral.NULL) # TODO should this error handling be strengthened for dtors? def errfnSendDtor(msg): return [_printErrorMessage(msg), StmtReturn.FALSE] # used in |OnMessage*()| handlers that hand in-messages off to Recv*() # interface methods def errfnRecv(msg, errcode=_Result.ValuError): return [_fatalError(msg), StmtReturn(errcode)] def errfnSentinel(rvalue=ExprLiteral.FALSE): def inner(msg): return [_sentinelReadError(msg), StmtReturn(rvalue)] return inner def _destroyMethod(): return ExprVar("ActorDestroy") def errfnUnreachable(msg): return [_logicError(msg)] def readResultError(): return ExprCode("{}") class _DestroyReason: @staticmethod def Type(): return Type("ActorDestroyReason") Deletion = ExprVar("Deletion") AncestorDeletion = ExprVar("AncestorDeletion") NormalShutdown = ExprVar("NormalShutdown") AbnormalShutdown = ExprVar("AbnormalShutdown") FailedConstructor = ExprVar("FailedConstructor") ManagedEndpointDropped = ExprVar("ManagedEndpointDropped") class _ResponseRejectReason: @staticmethod def Type(): return Type("ResponseRejectReason") SendError = ExprVar("ResponseRejectReason::SendError") ChannelClosed = ExprVar("ResponseRejectReason::ChannelClosed") HandlerRejected = ExprVar("ResponseRejectReason::HandlerRejected") ActorDestroyed = ExprVar("ResponseRejectReason::ActorDestroyed") # ----------------------------------------------------------------------------- # Intermediate representation (IR) nodes used during lowering class _ConvertToCxxType(TypeVisitor): def __init__(self, side, fq): self.side = side self.fq = fq def typename(self, thing): if self.fq: return thing.fullname() return thing.name() def visitImportedCxxType(self, t): cxxtype = Type(self.typename(t)) if t.isRefcounted(): cxxtype = _refptr(cxxtype) return cxxtype def visitBuiltinCType(self, b): return Type(self.typename(b)) def visitActorType(self, a): if self.side is None: return Type( "::mozilla::ipc::SideVariant", T=[ _cxxBareType(a, "parent", self.fq), _cxxBareType(a, "child", self.fq), ], ) return Type(_actorName(self.typename(a.protocol), self.side), ptr=True) def visitStructType(self, s): return Type(self.typename(s)) def visitUnionType(self, u): return Type(self.typename(u)) def visitArrayType(self, a): basecxxtype = a.basetype.accept(self) return _cxxArrayType(basecxxtype) def visitMaybeType(self, m): basecxxtype = m.basetype.accept(self) return _cxxMaybeType(basecxxtype) def visitShmemType(self, s): return Type(self.typename(s)) def visitByteBufType(self, s): return Type(self.typename(s)) def visitFDType(self, s): return Type(self.typename(s)) def visitEndpointType(self, s): return Type(self.typename(s)) def visitManagedEndpointType(self, s): return Type(self.typename(s)) def visitUniquePtrType(self, s): return Type(self.typename(s)) def visitNotNullType(self, n): basecxxtype = n.basetype.accept(self) return _cxxNotNullType(basecxxtype) def visitProtocolType(self, p): assert 0 def visitMessageType(self, m): assert 0 def visitVoidType(self, v): assert 0 def _cxxBareType(ipdltype, side, fq=False): return ipdltype.accept(_ConvertToCxxType(side, fq)) def _cxxRefType(ipdltype, side): t = _cxxBareType(ipdltype, side) t.ref = True return t def _cxxConstRefType(ipdltype, side): t = _cxxBareType(ipdltype, side) if ipdltype.isIPDL() and ipdltype.isActor(): return t if ipdltype.isIPDL() and ipdltype.isShmem(): t.ref = True return t if ipdltype.isIPDL() and ipdltype.isNotNull(): # If the inner type chooses to use a raw pointer, wrap that instead. inner = _cxxConstRefType(ipdltype.basetype, side) if inner.ptr: t = _cxxNotNullType(inner) return t if ipdltype.isIPDL() and ipdltype.hasBaseType(): # Keep same constness as inner type. inner = _cxxConstRefType(ipdltype.basetype, side) t.const = inner.const or not inner.ref t.ref = True return t if ipdltype.isCxx() and (ipdltype.isSendMoveOnly() or ipdltype.isDataMoveOnly()): t.const = True t.ref = True return t if ipdltype.isCxx() and ipdltype.isRefcounted(): # Use T* instead of const RefPtr& t = t.T t.ptr = True return t t.const = True t.ref = True return t def _cxxTypeNeedsMoveForSend(ipdltype, context="root", visited=None): """Returns `True` if serializing ipdltype requires a mutable reference, e.g. because the underlying resource represented by the value is being transferred to another process. This is occasionally distinct from whether the C++ type exposes a copy constructor, such as for types which are not cheaply copiable, but are not mutated when serialized.""" if visited is None: visited = set() visited.add(ipdltype) if ipdltype.isCxx(): return ipdltype.isSendMoveOnly() if ipdltype.isIPDL(): if ipdltype.hasBaseType(): return _cxxTypeNeedsMoveForSend(ipdltype.basetype, "wrapper", visited) if ipdltype.isStruct() or ipdltype.isUnion(): return any( _cxxTypeNeedsMoveForSend(t, "compound", visited) for t in ipdltype.itercomponents() if t not in visited ) # For historical reasons, shmem is `const_cast` to a mutable reference # when being stored in a struct or union (see # `_StructField.constRefExpr` and `_UnionMember.getConstValue`), meaning # that they do not cause the containing struct to require move for # sending. if ipdltype.isShmem(): return context != "compound" return ( ipdltype.isByteBuf() or ipdltype.isEndpoint() or ipdltype.isManagedEndpoint() ) return False def _cxxTypeNeedsMoveForData(ipdltype, context="root", visited=None): """Returns `True` if the bare C++ type corresponding to ipdltype does not satisfy std::is_copy_constructible_v. All C++ types supported by IPDL must support std::is_move_constructible_v, so non-movable types must be passed behind a `UniquePtr`.""" if visited is None: visited = set() visited.add(ipdltype) if ipdltype.isCxx(): return ipdltype.isDataMoveOnly() if ipdltype.isIPDL(): if ipdltype.isUniquePtr(): return True # When nested within a maybe or array, arrays are no longer copyable. if context == "wrapper" and ipdltype.isArray(): return True if ipdltype.hasBaseType(): return _cxxTypeNeedsMoveForData(ipdltype.basetype, "wrapper", visited) if ipdltype.isStruct() or ipdltype.isUnion(): return any( _cxxTypeNeedsMoveForData(t, "compound", visited) for t in ipdltype.itercomponents() if t not in visited ) return ( ipdltype.isByteBuf() or ipdltype.isEndpoint() or ipdltype.isManagedEndpoint() ) return False def _cxxTypeCanMove(ipdltype): return not (ipdltype.isIPDL() and ipdltype.isActor()) def _cxxForceMoveRefType(ipdltype, side): assert _cxxTypeCanMove(ipdltype) t = _cxxBareType(ipdltype, side) t.rvalref = True return t def _cxxPtrToType(ipdltype, side): t = _cxxBareType(ipdltype, side) if ipdltype.isIPDL() and ipdltype.isActor() and side is not None: t.ptr = False t.ptrptr = True return t t.ptr = True return t def _cxxConstPtrToType(ipdltype, side): t = _cxxBareType(ipdltype, side) if ipdltype.isIPDL() and ipdltype.isActor() and side is not None: t.ptr = False t.ptrconstptr = True return t t.const = True t.ptr = True return t def _cxxInType(ipdltype, side, direction): t = _cxxBareType(ipdltype, side) if ipdltype.isIPDL() and ipdltype.isActor(): return t if ipdltype.isIPDL() and ipdltype.isNotNull(): # If the inner type chooses to use a raw pointer, wrap that instead. inner = _cxxInType(ipdltype.basetype, side, direction) if inner.ptr: t = _cxxNotNullType(inner) return t if _cxxTypeNeedsMoveForSend(ipdltype): t.rvalref = True return t if ipdltype.isCxx(): if ipdltype.isRefcounted(): # Use T* instead of const RefPtr& t = t.T t.ptr = True return t if ipdltype.name() == "nsCString": t = Type("nsACString") if ipdltype.name() == "nsString": t = Type("nsAString") # Use Span rather than nsTArray for array types which aren't # `_cxxTypeNeedsMoveForSend`. This is only done for the "send" side, and not # for recv signatures. if direction == "send" and ipdltype.isIPDL() and ipdltype.isArray(): inner = _cxxBareType(ipdltype.basetype, side) return _cxxSpanType(inner) t.const = True t.ref = True return t def _allocMethod(ptype, side): return "Alloc" + ptype.name() + side.title() def _deallocMethod(ptype, side): return "Dealloc" + ptype.name() + side.title() ## # A _HybridDecl straddles IPDL and C++ decls. It knows which C++ # types correspond to which IPDL types, and it also knows how # serialize and deserialize "special" IPDL C++ types. ## class _HybridDecl: """A hybrid decl stores both an IPDL type and all the C++ type info needed by later passes, along with a basic name for the decl.""" def __init__(self, ipdltype, name, attributes={}): self.ipdltype = ipdltype self.name = name self.attributes = attributes def var(self): return ExprVar(self.name) def bareType(self, side, fq=False): """Return this decl's unqualified C++ type.""" return _cxxBareType(self.ipdltype, side, fq=fq) def refType(self, side): """Return this decl's C++ type as a 'reference' type, which is not necessarily a C++ reference.""" return _cxxRefType(self.ipdltype, side) def constRefType(self, side): """Return this decl's C++ type as a const, 'reference' type.""" return _cxxConstRefType(self.ipdltype, side) def ptrToType(self, side): return _cxxPtrToType(self.ipdltype, side) def constPtrToType(self, side): return _cxxConstPtrToType(self.ipdltype, side) def inType(self, side, direction): """Return this decl's C++ Type with sending inparam semantics.""" return _cxxInType(self.ipdltype, side, direction) def outType(self, side): """Return this decl's C++ Type with outparam semantics.""" t = self.bareType(side) if self.ipdltype.isIPDL() and self.ipdltype.isActor(): t.ptr = False t.ptrptr = True return t t.ptr = True return t def forceMoveType(self, side): """Return this decl's C++ Type with forced move semantics.""" assert _cxxTypeCanMove(self.ipdltype) return _cxxForceMoveRefType(self.ipdltype, side) # -------------------------------------------------- class HasFQName: def fqClassName(self): return self.decl.type.fullname() class _CompoundTypeComponent(_HybridDecl): # @override the following methods to make the side argument optional. def bareType(self, side=None, fq=False): return _HybridDecl.bareType(self, side, fq=fq) def refType(self, side=None): return _HybridDecl.refType(self, side) def constRefType(self, side=None): return _HybridDecl.constRefType(self, side) def ptrToType(self, side=None): return _HybridDecl.ptrToType(self, side) def constPtrToType(self, side=None): return _HybridDecl.constPtrToType(self, side) def forceMoveType(self, side=None): return _HybridDecl.forceMoveType(self, side) class StructDecl(ipdl.ast.StructDecl, HasFQName): def fields_ipdl_order(self): for f in self.fields: yield f def fields_member_order(self): assert len(self.packed_field_order) == len(self.fields) for i in self.packed_field_order: yield self.fields[i] @staticmethod def upgrade(structDecl): assert isinstance(structDecl, ipdl.ast.StructDecl) structDecl.__class__ = StructDecl class _StructField(_CompoundTypeComponent): def __init__(self, ipdltype, name, sd): self.basename = name _CompoundTypeComponent.__init__(self, ipdltype, name) def getMethod(self, thisexpr=None, sel="."): meth = self.var() if thisexpr is not None: return ExprSelect(thisexpr, sel, meth.name) return meth def refExpr(self, thisexpr=None): ref = self.memberVar() if thisexpr is not None: ref = ExprSelect(thisexpr, ".", ref.name) return ref def constRefExpr(self, thisexpr=None): # sigh, gross hack refexpr = self.refExpr(thisexpr) if "Shmem" == self.ipdltype.name(): refexpr = ExprCast(refexpr, Type("Shmem", ref=True), const=True) return refexpr def argVar(self): return ExprVar("_" + self.name) def memberVar(self): return ExprVar(self.name + "_") class UnionDecl(ipdl.ast.UnionDecl, HasFQName): def callType(self, var=None): func = ExprVar("type") if var is not None: func = ExprSelect(var, ".", func.name) return ExprCall(func) @staticmethod def upgrade(unionDecl): assert isinstance(unionDecl, ipdl.ast.UnionDecl) unionDecl.__class__ = UnionDecl class _UnionMember(_CompoundTypeComponent): """Not in the AFL sense, but rather a member (e.g. |int;|) of an IPDL union type.""" def __init__(self, ipdltype, ud): flatname = _flatTypeName(ipdltype) _CompoundTypeComponent.__init__(self, ipdltype, "V" + flatname) self.flattypename = flatname # To create a finite object with a mutually recursive type, a union must # be present somewhere in the recursive loop. Because of that we only # need to care about introducing indirections inside unions. self.recursive = ud.decl.type.mutuallyRecursiveWith(ipdltype) def enum(self): return "T" + self.flattypename def enumvar(self): return ExprVar(self.enum()) def internalType(self): if self.recursive: return self.ptrToType() else: return self.bareType() def unionType(self): """Type used for storage in generated C union decl.""" if self.recursive: return self.ptrToType() else: return Type("mozilla::AlignedStorage2", T=self.internalType()) def unionValue(self): # NB: knows that Union's storage C union is named |mValue| return ExprSelect(ExprVar("mValue"), ".", self.name) def typedef(self): return self.flattypename + "__tdef" def callGetConstPtr(self): """Return an expression of type self.constptrToSelfType()""" return ExprCall(ExprVar(self.getConstPtrName())) def callGetPtr(self): """Return an expression of type self.ptrToSelfType()""" return ExprCall(ExprVar(self.getPtrName())) def callCtor(self, expr=None): assert not isinstance(expr, list) if expr is None: args = None elif ( self.ipdltype.isIPDL() and self.ipdltype.isArray() and not isinstance(expr, ExprMove) ): args = [ExprCall(ExprSelect(expr, ".", "Clone"), args=[])] else: args = [expr] if self.recursive: return ExprAssn(self.callGetPtr(), ExprNew(self.bareType(), args=args)) else: return ExprNew( self.bareType(), args=args, newargs=[ExprVar("mozilla::KnownNotNull"), self.callGetPtr()], ) def callDtor(self): if self.recursive: return ExprDelete(self.callGetPtr()) else: return ExprCall(ExprSelect(self.callGetPtr(), "->", "~" + self.typedef())) def getTypeName(self): return "get_" + self.flattypename def getConstTypeName(self): return "get_" + self.flattypename def getOtherTypeName(self): return "get_" + self.otherflattypename def getPtrName(self): return "ptr_" + self.flattypename def getConstPtrName(self): return "constptr_" + self.flattypename def ptrToSelfExpr(self): """|*ptrToSelfExpr()| has type |self.bareType()|""" v = self.unionValue() if self.recursive: return v else: return ExprCall(ExprSelect(v, ".", "addr")) def constptrToSelfExpr(self): """|*constptrToSelfExpr()| has type |self.constType()|""" v = self.unionValue() if self.recursive: return v return ExprCall(ExprSelect(v, ".", "addr")) def ptrToInternalType(self): t = self.ptrToType() if self.recursive: t.ref = True return t def defaultValue(self, fq=False): # Use the default constructor for any class that does not have an # implicit copy constructor. if not self.bareType().hasimplicitcopyctor: return None if self.ipdltype.isIPDL() and self.ipdltype.isActor(): return ExprLiteral.NULL # XXX sneaky here, maybe need ExprCtor()? return ExprCall(self.bareType(fq=fq)) def getConstValue(self): v = ExprDeref(self.callGetConstPtr()) # sigh if "Shmem" == self.ipdltype.name(): v = ExprCast(v, Type("Shmem", ref=True), const=True) return v # -------------------------------------------------- class MessageDecl(ipdl.ast.MessageDecl): def baseName(self): return self.name def recvMethod(self): name = _recvPrefix(self.decl.type) + self.baseName() if self.decl.type.isCtor(): name += "Constructor" return name def sendMethod(self): name = _sendPrefix(self.decl.type) + self.baseName() if self.decl.type.isCtor(): name += "Constructor" return name def hasReply(self): return ( self.decl.type.hasReply() or self.decl.type.isCtor() or self.decl.type.isDtor() ) def hasAsyncReturns(self): return self.decl.type.isAsync() and self.returns def msgCtorFunc(self): return "Msg_%s" % (self.decl.progname) def prettyMsgName(self, pfx=""): return pfx + self.msgCtorFunc() def pqMsgCtorFunc(self): return "%s::%s" % (self.namespace, self.msgCtorFunc()) def msgId(self): return self.msgCtorFunc() + "__ID" def pqMsgId(self): return "%s::%s" % (self.namespace, self.msgId()) def replyCtorFunc(self): return "Reply_%s" % (self.decl.progname) def pqReplyCtorFunc(self): return "%s::%s" % (self.namespace, self.replyCtorFunc()) def replyId(self): return self.replyCtorFunc() + "__ID" def pqReplyId(self): return "%s::%s" % (self.namespace, self.replyId()) def prettyReplyName(self, pfx=""): return pfx + self.replyCtorFunc() def promiseName(self): name = self.baseName() if self.decl.type.isCtor(): name += "Constructor" name += "Promise" return name def resolverName(self): return self.baseName() + "Resolver" def actorDecl(self): return self.params[0] def makeCxxParams( self, paramsems="in", returnsems="out", side=None, implicit=True, direction=None ): """Return a list of C++ decls per the spec'd configuration. |params| and |returns| is the C++ semantics of those: 'in', 'out', or None.""" def makeDecl(d, sems): if ( self.decl.type.tainted and "NoTaint" not in d.attributes and direction == "recv" ): # Tainted types are passed by-value, allowing the receiver to move them if desired. assert sems != "out" return Decl(Type("Tainted", T=d.bareType(side)), d.name) if sems == "in": t = d.inType(side, direction) # If this is the `recv` side, and we're not using "move" # semantics, that means we're an alloc method, and cannot accept # values by rvalue reference. Downgrade to an lvalue reference. if direction == "recv" and t.rvalref: t.rvalref = False t.ref = True return Decl(t, d.name) elif sems == "move": assert direction == "recv" # For legacy reasons, use an rvalue reference when generating # parameters for recv methods which accept arrays. if d.ipdltype.isIPDL() and d.ipdltype.isArray(): t = d.bareType(side) t.rvalref = True return Decl(t, d.name) return Decl(d.inType(side, direction), d.name) elif sems == "out": return Decl(d.outType(side), d.name) else: assert 0 def makeResolverDecl(returns): return Decl(Type(self.resolverName(), rvalref=True), "aResolve") def makeCallbackResolveDecl(returns): if len(returns) > 1: resolvetype = _tuple([d.bareType(side) for d in returns]) else: resolvetype = returns[0].bareType(side) return Decl( Type("mozilla::ipc::ResolveCallback", T=resolvetype, rvalref=True), "aResolve", ) def makeCallbackRejectDecl(returns): return Decl(Type("mozilla::ipc::RejectCallback", rvalref=True), "aReject") cxxparams = [] if paramsems is not None: cxxparams.extend([makeDecl(d, paramsems) for d in self.params]) if returnsems == "promise" and self.returns: pass elif returnsems == "callback" and self.returns: cxxparams.extend( [ makeCallbackResolveDecl(self.returns), makeCallbackRejectDecl(self.returns), ] ) elif returnsems == "resolver" and self.returns: cxxparams.extend([makeResolverDecl(self.returns)]) elif returnsems is not None: cxxparams.extend([makeDecl(r, returnsems) for r in self.returns]) if not implicit and self.decl.type.hasImplicitActorParam(): cxxparams = cxxparams[1:] return cxxparams def makeCxxArgs( self, paramsems="in", retsems="out", retcallsems="out", implicit=True ): assert not retcallsems or retsems # retcallsems => returnsems cxxargs = [] if paramsems == "move": # We don't std::move() RefPtr types because current Recv*() # implementors take these parameters as T*, and # std::move(RefPtr) doesn't coerce to T*. # We also don't move NotNull, as it has no move constructor. cxxargs.extend( [ p.var() if p.ipdltype.isRefcounted() or (p.ipdltype.isIPDL() and p.ipdltype.isNotNull()) else ExprMove(p.var()) for p in self.params ] ) elif paramsems == "in": cxxargs.extend([p.var() for p in self.params]) else: assert False for ret in self.returns: if retsems == "in": if retcallsems == "in": cxxargs.append(ret.var()) elif retcallsems == "out": cxxargs.append(ExprAddrOf(ret.var())) else: assert 0 elif retsems == "out": if retcallsems == "in": cxxargs.append(ExprDeref(ret.var())) elif retcallsems == "out": cxxargs.append(ret.var()) else: assert 0 elif retsems == "resolver": pass if retsems == "resolver": cxxargs.append(ExprMove(ExprVar("resolver"))) if not implicit: assert self.decl.type.hasImplicitActorParam() cxxargs = cxxargs[1:] return cxxargs @staticmethod def upgrade(messageDecl): assert isinstance(messageDecl, ipdl.ast.MessageDecl) if messageDecl.decl.type.hasImplicitActorParam(): messageDecl.params.insert( 0, _HybridDecl( ipdl.type.ActorType(messageDecl.decl.type.constructedType()), "actor", ), ) messageDecl.__class__ = MessageDecl # -------------------------------------------------- def _usesShmem(p): for md in p.messageDecls: for param in md.inParams: if ipdl.type.hasshmem(param.type): return True for ret in md.outParams: if ipdl.type.hasshmem(ret.type): return True return False def _subtreeUsesShmem(p): if _usesShmem(p): return True ptype = p.decl.type for mgd in ptype.manages: if ptype is not mgd: if _subtreeUsesShmem(mgd._ast): return True return False class Protocol(ipdl.ast.Protocol): def managerInterfaceType(self, ptr=False): return Type("mozilla::ipc::IProtocol", ptr=ptr) def openedProtocolInterfaceType(self, ptr=False): return Type("mozilla::ipc::IToplevelProtocol", ptr=ptr) def _ipdlmgrtype(self): assert 1 == len(self.decl.type.managers) for mgr in self.decl.type.managers: return mgr def managerActorType(self, side, ptr=False): return Type(_actorName(self._ipdlmgrtype().name(), side), ptr=ptr) def unregisterMethod(self, actorThis=None): if actorThis is not None: return ExprSelect(actorThis, "->", "Unregister") return ExprVar("Unregister") def removeManageeMethod(self): return ExprVar("RemoveManagee") def deallocManageeMethod(self): return ExprVar("DeallocManagee") def getChannelMethod(self): return ExprVar("GetIPCChannel") def callGetChannel(self, actorThis=None): fn = self.getChannelMethod() if actorThis is not None: fn = ExprSelect(actorThis, "->", fn.name) return ExprCall(fn) def processingErrorVar(self): assert self.decl.type.isToplevel() return ExprVar("ProcessingError") def shouldContinueFromTimeoutVar(self): assert self.decl.type.isToplevel() return ExprVar("ShouldContinueFromReplyTimeout") def routingId(self, actorThis=None): if self.decl.type.isToplevel(): return ExprVar("MSG_ROUTING_CONTROL") if actorThis is not None: return ExprCall(ExprSelect(actorThis, "->", "Id")) return ExprCall(ExprVar("Id")) def managerVar(self, thisexpr=None): assert thisexpr is not None or not self.decl.type.isToplevel() mvar = ExprCall(ExprVar("Manager"), args=[]) if thisexpr is not None: mvar = ExprCall(ExprSelect(thisexpr, "->", "Manager"), args=[]) return mvar def managedCxxType(self, actortype, side): assert self.decl.type.isManagerOf(actortype) return Type(_actorName(actortype.name(), side), ptr=True) def managedMethod(self, actortype, side): assert self.decl.type.isManagerOf(actortype) return ExprVar("Managed" + _actorName(actortype.name(), side)) def managedVar(self, actortype, side): assert self.decl.type.isManagerOf(actortype) return ExprVar("mManaged" + _actorName(actortype.name(), side)) def managedVarType(self, actortype, side, const=False, ref=False): assert self.decl.type.isManagerOf(actortype) return _cxxManagedContainerType( Type(_actorName(actortype.name(), side)), const=const, ref=ref ) def subtreeUsesShmem(self): return _subtreeUsesShmem(self) @staticmethod def upgrade(protocol): assert isinstance(protocol, ipdl.ast.Protocol) protocol.__class__ = Protocol class TranslationUnit(ipdl.ast.TranslationUnit): @staticmethod def upgrade(tu): assert isinstance(tu, ipdl.ast.TranslationUnit) tu.__class__ = TranslationUnit # ----------------------------------------------------------------------------- pod_types = { "::int8_t": 1, "::uint8_t": 1, "::int16_t": 2, "::uint16_t": 2, "::int32_t": 4, "::uint32_t": 4, "::int64_t": 8, "::uint64_t": 8, "float": 4, "double": 8, } max_pod_size = max(pod_types.values()) # We claim that all types we don't recognize are automatically "bigger" # than pod types for ease of sorting. pod_size_sentinel = max_pod_size * 2 def pod_size(ipdltype): if not ipdltype.isCxx(): return pod_size_sentinel return pod_types.get(ipdltype.fullname(), pod_size_sentinel) class _DecorateWithCxxStuff(ipdl.ast.Visitor): """Phase 1 of lowering: decorate the IPDL AST with information relevant to C++ code generation. This pass results in an AST that is a poor man's "IR"; in reality, a "hybrid" AST mainly consisting of IPDL nodes with new C++ info along with some new IPDL/C++ nodes that are tuned for C++ codegen.""" def __init__(self): self.visitedTus = set() self.protocolName = None def visitTranslationUnit(self, tu): if tu not in self.visitedTus: self.visitedTus.add(tu) ipdl.ast.Visitor.visitTranslationUnit(self, tu) if not isinstance(tu, TranslationUnit): TranslationUnit.upgrade(tu) def visitInclude(self, inc): if inc.tu.filetype == "header": inc.tu.accept(self) def visitProtocol(self, pro): self.protocolName = pro.name Protocol.upgrade(pro) return ipdl.ast.Visitor.visitProtocol(self, pro) def visitStructDecl(self, sd): if not isinstance(sd, StructDecl): newfields = [_StructField(f.decl.type, f.name, sd) for f in sd.fields] # Compute a permutation of the fields for in-memory storage such # that the memory layout of the structure will be well-packed. permutation = list(range(len(newfields))) # Note that the results of `pod_size` ensure that non-POD fields # sort before POD ones. def size(idx): return pod_size(newfields[idx].ipdltype) permutation.sort(key=size, reverse=True) sd.fields = newfields sd.packed_field_order = permutation StructDecl.upgrade(sd) def visitUnionDecl(self, ud): ud.components = [_UnionMember(ctype, ud) for ctype in ud.decl.type.components] UnionDecl.upgrade(ud) def visitDecl(self, decl): return _HybridDecl(decl.type, decl.progname, decl.attributes) def visitMessageDecl(self, md): md.namespace = self.protocolName md.params = [param.accept(self) for param in md.inParams] md.returns = [ret.accept(self) for ret in md.outParams] MessageDecl.upgrade(md) # ----------------------------------------------------------------------------- def msgenums(protocol, pretty=False): msgenum = TypeEnum("MessageType") msgstart = _messageStartName(protocol.decl.type) + " << 16" msgenum.addId(protocol.name + "Start", msgstart) for md in protocol.messageDecls: msgenum.addId(md.prettyMsgName() if pretty else md.msgId()) if md.hasReply(): msgenum.addId(md.prettyReplyName() if pretty else md.replyId()) msgenum.addId(protocol.name + "End") return msgenum class _GenerateProtocolCode(ipdl.ast.Visitor): """Creates code common to both the parent and child actors.""" def __init__(self): self.protocol = None # protocol we're generating a class for self.hdrfile = None # what will become Protocol.h self.cppfile = None # what will become Protocol.cpp self.cppIncludeHeaders = [] self.structUnionDefns = [] self.funcDefns = [] def lower(self, tu, cxxHeaderFile, cxxFile, segmentcapacitydict): self.protocol = tu.protocol self.hdrfile = cxxHeaderFile self.cppfile = cxxFile self.segmentcapacitydict = segmentcapacitydict tu.accept(self) def visitTranslationUnit(self, tu): hf = self.hdrfile hf.addthing(_DISCLAIMER) hf.addthings(_includeGuardStart(hf)) hf.addthing(Whitespace.NL) for inc in builtinHeaderIncludes: self.visitBuiltinCxxInclude(inc) # Compute the set of includes we need for declared structure/union # classes for this protocol. typesToIncludes = {} for using in tu.using: typestr = str(using.type) if typestr not in typesToIncludes: typesToIncludes[typestr] = using.header else: assert typesToIncludes[typestr] == using.header aggregateTypeIncludes = set() for su in tu.structsAndUnions: typedeps = _ComputeTypeDeps(su.decl.type, typesToIncludes) if isinstance(su, ipdl.ast.StructDecl): aggregateTypeIncludes.add("mozilla/ipc/IPDLStructMember.h") for f in su.fields: f.ipdltype.accept(typedeps) elif isinstance(su, ipdl.ast.UnionDecl): for c in su.components: c.ipdltype.accept(typedeps) aggregateTypeIncludes.update(typedeps.includeHeaders) if len(aggregateTypeIncludes) != 0: hf.addthing(Whitespace.NL) hf.addthings([Whitespace("// Headers for typedefs"), Whitespace.NL]) for headername in sorted(iter(aggregateTypeIncludes)): hf.addthing(CppDirective("include", '"' + headername + '"')) # Manually run Visitor.visitTranslationUnit. For dependency resolution # we need to handle structs and unions separately. for cxxInc in tu.cxxIncludes: cxxInc.accept(self) for inc in tu.includes: inc.accept(self) self.generateStructsAndUnions(tu) for using in tu.builtinUsing: using.accept(self) for using in tu.using: using.accept(self) if tu.protocol: tu.protocol.accept(self) if tu.filetype == "header": self.cppIncludeHeaders.append(_ipdlhHeaderName(tu) + ".h") hf.addthing(Whitespace.NL) hf.addthings(_includeGuardEnd(hf)) cf = self.cppfile cf.addthings( ( [_DISCLAIMER, Whitespace.NL] + [ CppDirective("include", '"' + h + '"') for h in self.cppIncludeHeaders ] + [Whitespace.NL] + [ CppDirective("include", '"%s"' % filename) for filename in ipdl.builtin.CppIncludes ] + [Whitespace.NL] ) ) if self.protocol: # construct the namespace into which we'll stick all our defns ns = Namespace(self.protocol.name) cf.addthing(_putInNamespaces(ns, self.protocol.namespaces)) ns.addstmts(([Whitespace.NL] + self.funcDefns + [Whitespace.NL])) cf.addthings(self.structUnionDefns) def visitBuiltinCxxInclude(self, inc): self.hdrfile.addthing(CppDirective("include", '"' + inc.file + '"')) def visitCxxInclude(self, inc): self.cppIncludeHeaders.append(inc.file) def visitInclude(self, inc): if inc.tu.filetype == "header": self.hdrfile.addthing( CppDirective("include", '"' + _ipdlhHeaderName(inc.tu) + '.h"') ) # Inherit cpp includes defined by imported header files, as they may # be required to serialize an imported `using` type. for cxxinc in inc.tu.cxxIncludes: cxxinc.accept(self) else: self.cppIncludeHeaders += [ _protocolHeaderName(inc.tu.protocol, "parent") + ".h", _protocolHeaderName(inc.tu.protocol, "child") + ".h", ] def generateStructsAndUnions(self, tu): """Generate the definitions for all structs and unions. This will re-order the declarations if needed in the C++ code such that dependencies have already been defined.""" decls = OrderedDict() for su in tu.structsAndUnions: if isinstance(su, StructDecl): which = "struct" forwarddecls, fulldecltypes, cls = _generateCxxStruct(su) traitsdecl, traitsdefns = _ParamTraits.structPickling(su.decl.type) else: assert isinstance(su, UnionDecl) which = "union" forwarddecls, fulldecltypes, cls = _generateCxxUnion(su) traitsdecl, traitsdefns = _ParamTraits.unionPickling(su.decl.type) clsdecl, methoddefns = _splitClassDeclDefn(cls) # Store the declarations in the decls map so we can emit in # dependency order. decls[su.decl.type] = ( fulldecltypes, [Whitespace.NL] + forwarddecls + [ Whitespace( """ //----------------------------------------------------------------------------- // Declaration of the IPDL type |%s %s| // """ % (which, su.name) ), _putInNamespaces(clsdecl, su.namespaces), ] + [Whitespace.NL, traitsdecl], ) self.structUnionDefns.extend( [ Whitespace( """ //----------------------------------------------------------------------------- // Method definitions for the IPDL type |%s %s| // """ % (which, su.name) ), _putInNamespaces(methoddefns, su.namespaces), Whitespace.NL, traitsdefns, ] ) # Generate the declarations structs in dependency order. def gen_struct(deps, defn): for dep in deps: if dep in decls: d, t = decls[dep] del decls[dep] gen_struct(d, t) self.hdrfile.addthings(defn) while len(decls) > 0: _, (d, t) = decls.popitem(False) gen_struct(d, t) def visitProtocol(self, p): self.cppIncludeHeaders.append(_protocolHeaderName(self.protocol, "") + ".h") self.cppIncludeHeaders.append( _protocolHeaderName(self.protocol, "Parent") + ".h" ) self.cppIncludeHeaders.append( _protocolHeaderName(self.protocol, "Child") + ".h" ) # Forward declare our own actors. self.hdrfile.addthings( [ Whitespace.NL, _makeForwardDeclForActor(p.decl.type, "Parent"), _makeForwardDeclForActor(p.decl.type, "Child"), ] ) self.hdrfile.addthing( Whitespace( """ //----------------------------------------------------------------------------- // Code common to %sChild and %sParent // """ % (p.name, p.name) ) ) # construct the namespace into which we'll stick all our decls ns = Namespace(self.protocol.name) self.hdrfile.addthing(_putInNamespaces(ns, p.namespaces)) ns.addstmt(Whitespace.NL) for func in self.genEndpointFuncs(): edecl, edefn = _splitFuncDeclDefn(func) ns.addstmts([edecl, Whitespace.NL]) self.funcDefns.append(edefn) # spit out message type enum and classes msgenum = msgenums(self.protocol) ns.addstmts([StmtDecl(Decl(msgenum, "")), Whitespace.NL]) for md in p.messageDecls: decls = [] # Look up the segment capacity used for serializing this # message. If the capacity is not specified, use '0' for # the default capacity (defined in ipc_message.cc) name = "%s::%s" % (md.namespace, md.decl.progname) segmentcapacity = self.segmentcapacitydict.get(name, 0) mfDecl, mfDefn = _splitFuncDeclDefn( _generateMessageConstructor(md, segmentcapacity, p, forReply=False) ) decls.append(mfDecl) self.funcDefns.append(mfDefn) if md.hasReply(): rfDecl, rfDefn = _splitFuncDeclDefn( _generateMessageConstructor(md, 0, p, forReply=True) ) decls.append(rfDecl) self.funcDefns.append(rfDefn) decls.append(Whitespace.NL) ns.addstmts(decls) ns.addstmts([Whitespace.NL, Whitespace.NL]) # Generate code for PFoo::CreateEndpoints. def genEndpointFuncs(self): p = self.protocol.decl.type tparent = _cxxBareType(ActorType(p), "Parent", fq=True) tchild = _cxxBareType(ActorType(p), "Child", fq=True) def mkOverload(includepids): params = [] if includepids: params = [ Decl(Type("base::ProcessId"), "aParentDestPid"), Decl(Type("base::ProcessId"), "aChildDestPid"), ] params += [ Decl( Type("mozilla::ipc::Endpoint<" + tparent.name + ">", ptr=True), "aParent", ), Decl( Type("mozilla::ipc::Endpoint<" + tchild.name + ">", ptr=True), "aChild", ), ] openfunc = MethodDefn( MethodDecl("CreateEndpoints", params=params, ret=Type.NSRESULT) ) openfunc.addcode( """ return mozilla::ipc::CreateEndpoints( mozilla::ipc::PrivateIPDLInterface(), $,{args}); """, args=[ExprVar(d.name) for d in params], ) return openfunc funcs = [mkOverload(True)] if not p.hasOtherPid(): funcs.append(mkOverload(False)) return funcs # -------------------------------------------------- cppPriorityList = list( map(lambda src: src.upper() + "_PRIORITY", ipdl.ast.priorityList) ) def _generateMessageConstructor(md, segmentSize, protocol, forReply=False): if forReply: clsname = md.replyCtorFunc() msgid = md.replyId() replyEnum = "REPLY" prioEnum = cppPriorityList[md.decl.type.replyPrio] else: clsname = md.msgCtorFunc() msgid = md.msgId() replyEnum = "NOT_REPLY" prioEnum = cppPriorityList[md.decl.type.prio] nested = md.decl.type.nested compress = md.decl.type.compress lazySend = md.decl.type.lazySend routingId = ExprVar("routingId") func = FunctionDefn( FunctionDecl( clsname, params=[Decl(Type("int32_t"), routingId.name)], ret=Type("mozilla::UniquePtr"), ) ) if not compress: compression = "COMPRESSION_NONE" elif compress.value == "all": compression = "COMPRESSION_ALL" else: assert compress.value is None compression = "COMPRESSION_ENABLED" if lazySend: lazySendEnum = "LAZY_SEND" else: lazySendEnum = "EAGER_SEND" if nested == ipdl.ast.NOT_NESTED: nestedEnum = "NOT_NESTED" elif nested == ipdl.ast.INSIDE_SYNC_NESTED: nestedEnum = "NESTED_INSIDE_SYNC" else: assert nested == ipdl.ast.INSIDE_CPOW_NESTED nestedEnum = "NESTED_INSIDE_CPOW" if md.decl.type.isSync(): syncEnum = "SYNC" else: syncEnum = "ASYNC" # FIXME(bug ???) - remove support for interrupt messages from the IPDL compiler. if md.decl.type.isInterrupt(): func.addcode( """ static_assert( false, "runtime support for intr messages has been removed from IPDL"); """ ) if md.decl.type.isCtor(): ctorEnum = "CONSTRUCTOR" else: ctorEnum = "NOT_CONSTRUCTOR" def messageEnum(valname): return ExprVar("IPC::Message::" + valname) flags = ExprCall( ExprVar("IPC::Message::HeaderFlags"), args=[ messageEnum(nestedEnum), messageEnum(prioEnum), messageEnum(compression), messageEnum(lazySendEnum), messageEnum(ctorEnum), messageEnum(syncEnum), messageEnum(replyEnum), ], ) segmentSize = int(segmentSize) if not segmentSize: segmentSize = 0 func.addstmt( StmtReturn( ExprCall( ExprVar("IPC::Message::IPDLMessage"), args=[ routingId, ExprVar(msgid), ExprLiteral.Int(int(segmentSize)), flags, ], ) ) ) return func # -------------------------------------------------- class _ParamTraits: var = ExprVar("aVar") writervar = ExprVar("aWriter") readervar = ExprVar("aReader") @classmethod def ifsideis(cls, rdrwtr, side, then, els=None): cxxside = ExprVar("mozilla::ipc::ChildSide") if side == "parent": cxxside = ExprVar("mozilla::ipc::ParentSide") ifstmt = StmtIf( ExprBinary( cxxside, "==", ExprCode("${rdrwtr}->GetActor()->GetSide()", rdrwtr=rdrwtr), ) ) ifstmt.addifstmt(then) if els is not None: ifstmt.addelsestmt(els) return ifstmt @classmethod def fatalError(cls, rdrwtr, reason): return StmtCode( "${rdrwtr}->FatalError(${reason});", rdrwtr=rdrwtr, reason=ExprLiteral.String(reason), ) @classmethod def writeSentinel(cls, writervar, sentinelKey): return [ Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True), StmtExpr( ExprCall( ExprSelect(writervar, "->", "WriteSentinel"), args=[ExprLiteral.Int(hashfunc(sentinelKey))], ) ), ] @classmethod def readSentinel(cls, readervar, sentinelKey, sentinelFail): # Read the sentinel read = ExprCall( ExprSelect(readervar, "->", "ReadSentinel"), args=[ExprLiteral.Int(hashfunc(sentinelKey))], ) ifsentinel = StmtIf(ExprNot(read)) ifsentinel.addifstmts(sentinelFail) return [ Whitespace("// Sentinel = " + repr(sentinelKey) + "\n", indent=True), ifsentinel, ] @classmethod def write(cls, var, writervar, ipdltype=None): if ipdltype and _cxxTypeNeedsMoveForSend(ipdltype): var = ExprMove(var) return ExprCall(ExprVar("IPC::WriteParam"), args=[writervar, var]) @classmethod def checkedWrite(cls, ipdltype, var, writervar, sentinelKey): assert sentinelKey block = Block() block.addstmts( [ StmtExpr(cls.write(var, writervar, ipdltype)), ] ) block.addstmts(cls.writeSentinel(writervar, sentinelKey)) return block @classmethod def bulkSentinelKey(cls, fields): return " | ".join(f.basename for f in fields) @classmethod def checkedBulkWrite(cls, var, size, fields): block = Block() first = fields[0] block.addstmts( [ StmtExpr( ExprCall( ExprSelect(cls.writervar, "->", "WriteBytes"), args=[ ExprAddrOf( ExprCall(first.getMethod(thisexpr=var, sel=".")) ), ExprLiteral.Int(size * len(fields)), ], ) ) ] ) block.addstmts(cls.writeSentinel(cls.writervar, cls.bulkSentinelKey(fields))) return block @classmethod def checkedBulkRead(cls, var, size, fields): block = Block() first = fields[0] readbytes = ExprCall( ExprSelect(cls.readervar, "->", "ReadBytesInto"), args=[ ExprAddrOf(ExprCall(first.getMethod(thisexpr=var, sel="->"))), ExprLiteral.Int(size * len(fields)), ], ) ifbad = StmtIf(ExprNot(readbytes)) errmsg = "Error bulk reading fields from %s" % first.ipdltype.name() ifbad.addifstmts( [cls.fatalError(cls.readervar, errmsg), StmtReturn(readResultError())] ) block.addstmt(ifbad) block.addstmts( cls.readSentinel( cls.readervar, cls.bulkSentinelKey(fields), errfnSentinel(readResultError())(errmsg), ) ) return block @classmethod def checkedRead( cls, ipdltype, cxxtype, var, readervar, errfn, paramtype, sentinelKey, errfnSentinel, ): assert isinstance(var, ExprVar) if not isinstance(paramtype, list): paramtype = ["Error deserializing " + paramtype] block = Block() # Read the data block.addcode( """ auto ${maybevar} = IPC::ReadParam<${ty}>(${reader}); if (!${maybevar}) { $*{errfn} } auto& ${var} = *${maybevar}; """, maybevar=ExprVar("maybe__" + var.name), ty=cxxtype, reader=readervar, errfn=errfn(*paramtype), var=var, ) block.addstmts( cls.readSentinel(readervar, sentinelKey, errfnSentinel(*paramtype)) ) return block # Helper wrapper for checkedRead for use within _ParamTraits @classmethod def _checkedRead(cls, ipdltype, cxxtype, var, sentinelKey, what): def errfn(msg): return [cls.fatalError(cls.readervar, msg), StmtReturn(readResultError())] return cls.checkedRead( ipdltype, cxxtype, var, cls.readervar, errfn=errfn, paramtype=what, sentinelKey=sentinelKey, errfnSentinel=errfnSentinel(readResultError()), ) @classmethod def generateDecl(cls, fortype, write, read, needsmove=False): # ParamTraits impls are selected ignoring constness, and references. pt = Class( "ParamTraits", specializes=Type( fortype.name, T=fortype.T, inner=fortype.inner, ptr=fortype.ptr ), struct=True, ) # typedef T paramType; pt.addstmt(Typedef(fortype, "paramType")) # static void Write(Message*, const T&); if needsmove: intype = Type("paramType", rvalref=True) else: intype = Type("paramType", ref=True, const=True) writemthd = MethodDefn( MethodDecl( "Write", params=[ Decl(Type("IPC::MessageWriter", ptr=True), cls.writervar.name), Decl(intype, cls.var.name), ], methodspec=MethodSpec.STATIC, ) ) writemthd.addstmts(write) pt.addstmt(writemthd) # static ReadResult Read(MessageReader*); readmthd = MethodDefn( MethodDecl( "Read", params=[ Decl(Type("IPC::MessageReader", ptr=True), cls.readervar.name), ], ret=Type("IPC::ReadResult"), methodspec=MethodSpec.STATIC, ) ) readmthd.addstmts(read) pt.addstmt(readmthd) # Split the class into declaration and definition clsdecl, methoddefns = _splitClassDeclDefn(pt) namespaces = [Namespace("IPC")] clsns = _putInNamespaces(clsdecl, namespaces) defns = _putInNamespaces(methoddefns, namespaces) return clsns, defns @classmethod def actorPickling(cls, actortype, side): """Generates pickling for IPDL actors. This is a |nullable| deserializer. Write and read callers will perform nullability validation.""" cxxtype = _cxxBareType(actortype, side, fq=True) write = StmtCode( """ MOZ_RELEASE_ASSERT( ${writervar}->GetActor(), "Cannot serialize managed actors without an actor"); int32_t id; if (!${var}) { id = 0; // kNullActorId } else { id = ${var}->Id(); if (id == 1) { // kFreedActorId ${var}->FatalError("Actor has been |delete|d"); } MOZ_RELEASE_ASSERT( ${writervar}->GetActor()->GetIPCChannel() == ${var}->GetIPCChannel(), "Actor must be from the same channel as the" " actor it's being sent over"); MOZ_RELEASE_ASSERT( ${var}->CanSend(), "Actor must still be open when sending"); } ${write}; """, var=cls.var, writervar=cls.writervar, write=cls.write(ExprVar("id"), cls.writervar), ) # bool Read(..) impl read = StmtCode( """ MOZ_RELEASE_ASSERT( ${readervar}->GetActor(), "Cannot deserialize managed actors without an actor"); mozilla::Maybe actor = ${readervar}->GetActor() ->ReadActor(${readervar}, true, ${actortype}, ${protocolid}); if (actor.isSome()) { return static_cast<${cxxtype}>(actor.ref()); } return {}; """, readervar=cls.readervar, actortype=ExprLiteral.String(actortype.name()), protocolid=_protocolId(actortype), cxxtype=cxxtype, ) return cls.generateDecl(cxxtype, [write], [read]) @classmethod def structPickling(cls, structtype): sd = structtype._ast # NOTE: Not using _cxxBareType here as we don't have a side cxxtype = Type(structtype.fullname()) write = [] read = [] # First serialize/deserialize all non-pod data in IPDL order. These need # to be read/written first because they'll be used to invoke the IPDL # struct's constructor. ctorargs = [] for f in sd.fields_ipdl_order(): if pod_size(f.ipdltype) == pod_size_sentinel: write.append( cls.checkedWrite( f.ipdltype, ExprCall(f.getMethod(thisexpr=cls.var, sel=".")), cls.writervar, sentinelKey=f.basename, ) ) read.append( cls._checkedRead( f.ipdltype, f.bareType(fq=True), f.argVar(), f.basename, "'" + f.getMethod().name + "' " + "(" + f.ipdltype.name() + ") member of " + "'" + structtype.name() + "'", ) ) if _cxxTypeCanMove(f.ipdltype): ctorargs.append(ExprMove(f.argVar())) else: ctorargs.append(f.argVar()) else: # We're going to bulk-read in this value later, so we'll just # zero-initialize it for now. ctorargs.append(ExprCode("${type}{0}", type=f.bareType(fq=True))) resultvar = ExprVar("result__") read.append( StmtDecl( Decl(_cxxReadResultType(Type("paramType")), resultvar.name), initargs=[ExprVar("std::in_place")] + ctorargs, ) ) # After non-pod data, bulk read/write pod data in member order. This has # to be done after the result has been constructed, so that we have # somewhere to read into. for (size, fields) in itertools.groupby( sd.fields_member_order(), lambda f: pod_size(f.ipdltype) ): if size != pod_size_sentinel: fields = list(fields) write.append(cls.checkedBulkWrite(cls.var, size, fields)) read.append(cls.checkedBulkRead(resultvar, size, fields)) read.append(StmtReturn(resultvar)) return cls.generateDecl( cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(structtype) ) @classmethod def unionPickling(cls, uniontype): # NOTE: Not using _cxxBareType here as we don't have a side cxxtype = Type(uniontype.fullname()) ud = uniontype._ast # Use typedef to set up an alias so it's easier to reference the struct type. alias = "union__" typevar = ExprVar("type") prelude = [ Typedef(cxxtype, alias), ] writeswitch = StmtSwitch(typevar) write = prelude + [ StmtDecl(Decl(Type.INT, typevar.name), init=ud.callType(cls.var)), cls.checkedWrite( None, typevar, cls.writervar, sentinelKey=uniontype.name() ), Whitespace.NL, writeswitch, ] readswitch = StmtSwitch(typevar) read = prelude + [ cls._checkedRead( None, Type.INT, typevar, uniontype.name(), "type of union " + uniontype.name(), ), Whitespace.NL, readswitch, ] for c in ud.components: caselabel = CaseLabel(alias + "::" + c.enum()) origenum = c.enum() writecase = StmtBlock() wstmt = cls.checkedWrite( c.ipdltype, ExprCall(ExprSelect(cls.var, ".", c.getTypeName())), cls.writervar, sentinelKey=c.enum(), ) writecase.addstmts([wstmt, StmtReturn()]) writeswitch.addcase(caselabel, writecase) readcase = StmtBlock() tmpvar = ExprVar("tmp") readcase.addstmts( [ cls._checkedRead( c.ipdltype, c.bareType(fq=True), tmpvar, origenum, "variant " + origenum + " of union " + uniontype.name(), ), StmtReturn(ExprMove(tmpvar)), ] ) readswitch.addcase(caselabel, readcase) # Add the error default case writeswitch.addcase( DefaultLabel(), StmtBlock( [ cls.fatalError( cls.writervar, "unknown variant of union " + uniontype.name() ), StmtReturn(), ] ), ) readswitch.addcase( DefaultLabel(), StmtBlock( [ cls.fatalError( cls.readervar, "unknown variant of union " + uniontype.name() ), StmtReturn(readResultError()), ] ), ) return cls.generateDecl( cxxtype, write, read, needsmove=_cxxTypeNeedsMoveForSend(uniontype) ) # -------------------------------------------------- class _ComputeTypeDeps(TypeVisitor): """Pass that gathers the C++ types that a particular IPDL type (recursively) depends on. There are three kinds of dependencies: (i) types that need forward declaration; (ii) types that need a |using| stmt; (iii) IPDL structs or unions which must be fully declared before this struct. Some types generate multiple kinds.""" def __init__(self, fortype, typesToIncludes=None): ipdl.type.TypeVisitor.__init__(self) self.usingTypedefs = [] self.forwardDeclStmts = [] self.fullDeclTypes = [] self.includeHeaders = set() self.fortype = fortype self.typesToIncludes = typesToIncludes def maybeTypedef(self, fqname, name, templateargs=[]): assert fqname.startswith("::") if fqname != name: self.usingTypedefs.append(Typedef(Type(fqname), name, templateargs)) if self.typesToIncludes is not None and fqname in self.typesToIncludes: self.includeHeaders.add(self.typesToIncludes[fqname]) def visitImportedCxxType(self, t): if t in self.visited: return self.visited.add(t) self.maybeTypedef(t.fullname(), t.name()) def visitActorType(self, t): if t in self.visited: return self.visited.add(t) fqname, name = t.fullname(), t.name() self.includeHeaders.add("mozilla/ipc/SideVariant.h") self.maybeTypedef(_actorName(fqname, "Parent"), _actorName(name, "Parent")) self.maybeTypedef(_actorName(fqname, "Child"), _actorName(name, "Child")) self.forwardDeclStmts.extend( [ _makeForwardDeclForActor(t.protocol, "parent"), Whitespace.NL, _makeForwardDeclForActor(t.protocol, "child"), Whitespace.NL, ] ) def visitStructOrUnionType(self, su, defaultVisit): if su in self.visited or su == self.fortype: return self.visited.add(su) self.maybeTypedef(su.fullname(), su.name()) # Mutually recursive fields in unions are behind indirection, so we only # need a forward decl, and don't need a full type declaration. if isinstance(self.fortype, UnionType) and self.fortype.mutuallyRecursiveWith( su ): self.forwardDeclStmts.append(_makeForwardDecl(su)) else: self.fullDeclTypes.append(su) return defaultVisit(self, su) def visitStructType(self, t): return self.visitStructOrUnionType(t, TypeVisitor.visitStructType) def visitUnionType(self, t): return self.visitStructOrUnionType(t, TypeVisitor.visitUnionType) def visitArrayType(self, t): return TypeVisitor.visitArrayType(self, t) def visitMaybeType(self, m): return TypeVisitor.visitMaybeType(self, m) def visitShmemType(self, s): if s in self.visited: return self.visited.add(s) self.maybeTypedef("::mozilla::ipc::Shmem", "Shmem") def visitByteBufType(self, s): if s in self.visited: return self.visited.add(s) self.maybeTypedef("::mozilla::ipc::ByteBuf", "ByteBuf") def visitFDType(self, s): if s in self.visited: return self.visited.add(s) self.maybeTypedef("::mozilla::ipc::FileDescriptor", "FileDescriptor") def visitEndpointType(self, s): if s in self.visited: return self.visited.add(s) self.maybeTypedef("::mozilla::ipc::Endpoint", "Endpoint", ["FooSide"]) self.visitActorType(s.actor) def visitManagedEndpointType(self, s): if s in self.visited: return self.visited.add(s) self.maybeTypedef( "::mozilla::ipc::ManagedEndpoint", "ManagedEndpoint", ["FooSide"] ) self.visitActorType(s.actor) def visitUniquePtrType(self, s): if s in self.visited: return self.visited.add(s) def visitVoidType(self, v): assert 0 def visitMessageType(self, v): assert 0 def visitProtocolType(self, v): assert 0 def _fieldStaticAssertions(sd): staticasserts = [] for (size, fields) in itertools.groupby( sd.fields_member_order(), lambda f: pod_size(f.ipdltype) ): if size == pod_size_sentinel: continue fields = list(fields) if len(fields) == 1: continue staticasserts.append( StmtCode( """ static_assert( (offsetof(${struct}, ${last}) - offsetof(${struct}, ${first})) == ${expected}, "Bad assumptions about field layout!"); """, struct=sd.name, first=fields[0].memberVar(), last=fields[-1].memberVar(), expected=ExprLiteral.Int(size * (len(fields) - 1)), ) ) return staticasserts def _generateCxxStruct(sd): """ """ # compute all the typedefs and forward decls we need to make gettypedeps = _ComputeTypeDeps(sd.decl.type) for f in sd.fields: f.ipdltype.accept(gettypedeps) usingTypedefs = gettypedeps.usingTypedefs forwarddeclstmts = gettypedeps.forwardDeclStmts fulldecltypes = gettypedeps.fullDeclTypes struct = Class(sd.name, final=True) struct.addstmts([Label.PRIVATE] + usingTypedefs + [Whitespace.NL, Label.PUBLIC]) constreftype = Type(sd.name, const=True, ref=True) # Struct() # We want the default constructor to be declared if it is available, but # some of our members may not be default-constructible. Silence the # warning which clang generates in that case. # # Members which need value initialization will be handled by wrapping # the member in a template type when declaring them. struct.addcode( """ #ifdef __clang__ # pragma clang diagnostic push # if __has_warning("-Wdefaulted-function-deleted") # pragma clang diagnostic ignored "-Wdefaulted-function-deleted" # endif #endif ${name}() = default; #ifdef __clang__ # pragma clang diagnostic pop #endif """, name=sd.name, ) # If this is an empty struct (no fields), then the default ctor # and "create-with-fields" ctors are equivalent. if len(sd.fields): assert len(sd.fields) == len(sd.packed_field_order) # Struct(const field1& _f1, ...) valctor = ConstructorDefn( ConstructorDecl( sd.name, params=[ Decl( f.forceMoveType() if _cxxTypeNeedsMoveForData(f.ipdltype) else f.constRefType(), f.argVar().name, ) for f in sd.fields_ipdl_order() ], force_inline=True, ) ) valctor.memberinits = [] for f in sd.fields_member_order(): arg = f.argVar() if _cxxTypeNeedsMoveForData(f.ipdltype): arg = ExprMove(arg) valctor.memberinits.append(ExprMemberInit(f.memberVar(), args=[arg])) struct.addstmts([valctor, Whitespace.NL]) # If a constructor which moves each argument would be different from the # `const T&` version, also generate that constructor. if not all( _cxxTypeNeedsMoveForData(f.ipdltype) or not _cxxTypeCanMove(f.ipdltype) for f in sd.fields_ipdl_order() ): # Struct(field1&& _f1, ...) valmovector = ConstructorDefn( ConstructorDecl( sd.name, params=[ Decl( f.forceMoveType() if _cxxTypeCanMove(f.ipdltype) else f.constRefType(), f.argVar().name, ) for f in sd.fields_ipdl_order() ], force_inline=True, ) ) valmovector.memberinits = [] for f in sd.fields_member_order(): arg = f.argVar() if _cxxTypeCanMove(f.ipdltype): arg = ExprMove(arg) valmovector.memberinits.append( ExprMemberInit(f.memberVar(), args=[arg]) ) struct.addstmts([valmovector, Whitespace.NL]) # The default copy, move, and assignment constructors, and the default # destructor, will do the right thing. if "Comparable" in sd.attributes: # bool operator==(const Struct& _o) ovar = ExprVar("_o") opeqeq = MethodDefn( MethodDecl( "operator==", params=[Decl(constreftype, ovar.name)], ret=Type.BOOL, const=True, ) ) for f in sd.fields_ipdl_order(): ifneq = StmtIf( ExprNot( ExprBinary( ExprCall(f.getMethod()), "==", ExprCall(f.getMethod(ovar)) ) ) ) ifneq.addifstmt(StmtReturn.FALSE) opeqeq.addstmt(ifneq) opeqeq.addstmt(StmtReturn.TRUE) struct.addstmts([opeqeq, Whitespace.NL]) # bool operator!=(const Struct& _o) opneq = MethodDefn( MethodDecl( "operator!=", params=[Decl(constreftype, ovar.name)], ret=Type.BOOL, const=True, ) ) opneq.addstmt(StmtReturn(ExprNot(ExprCall(ExprVar("operator=="), args=[ovar])))) struct.addstmts([opneq, Whitespace.NL]) # field1& f1() # const field1& f1() const for f in sd.fields_ipdl_order(): get = MethodDefn( MethodDecl( f.getMethod().name, params=[], ret=f.refType(), force_inline=True ) ) get.addstmt(StmtReturn(f.refExpr())) getconstdecl = deepcopy(get.decl) getconstdecl.ret = f.constRefType() getconstdecl.const = True getconst = MethodDefn(getconstdecl) getconst.addstmt(StmtReturn(f.constRefExpr())) struct.addstmts([get, getconst, Whitespace.NL]) # private: struct.addstmt(Label.PRIVATE) # Static assertions to ensure our assumptions about field layout match # what the compiler is actually producing. We define this as a member # function, rather than throwing the assertions in the constructor or # similar, because we don't want to evaluate the static assertions every # time the header file containing the structure is included. staticasserts = _fieldStaticAssertions(sd) if staticasserts: method = MethodDefn( MethodDecl("StaticAssertions", params=[], ret=Type.VOID, const=True) ) method.addstmts(staticasserts) struct.addstmts([method]) # members struct.addstmts( [ StmtDecl(Decl(_effectiveMemberType(f), f.memberVar().name)) for f in sd.fields_member_order() ] ) return forwarddeclstmts, fulldecltypes, struct def _effectiveMemberType(f): effective_type = f.bareType() # Structs must be copyable for backwards compatibility reasons, so we use # CopyableTArray as their member type for arrays. This is not exposed # in the method signatures, these keep using nsTArray, which is a base # class of CopyableTArray. if effective_type.name == "nsTArray": effective_type.name = "CopyableTArray" return Type("::mozilla::ipc::IPDLStructMember", T=[effective_type]) # -------------------------------------------------- def _generateCxxUnion(ud): # This Union class basically consists of a type (enum) and a # union for storage. The union can contain POD and non-POD # types. Each type needs a copy/move ctor, assignment operators, # and dtor. # # Rather than templating this class and only providing # specializations for the types we support, which is slightly # "unsafe" in that C++ code can add additional specializations # without the IPDL compiler's knowledge, we instead explicitly # implement non-templated methods for each supported type. # # The one complication that arises is that C++, for arcane # reasons, does not allow the placement destructor of a # builtin type, like int, to be directly invoked. So we need # to hack around this by internally typedef'ing all # constituent types. Sigh. # # So, for each type, this "Union" class needs: # (private) # - entry in the type enum # - entry in the storage union # - [type]ptr() method to get a type* from the underlying union # - same as above to get a const type* # - typedef to hack around placement delete limitations # (public) # - placement delete case for dtor # - copy ctor # - move ctor # - case in generic copy ctor # - copy operator= impl # - move operator= impl # - case in generic operator= # - operator [type&] # - operator [const type&] const # - [type&] get_[type]() # - [const type&] get_[type]() const # cls = Class(ud.name, final=True) # const Union&, i.e., Union type with inparam semantics inClsType = Type(ud.name, const=True, ref=True) refClsType = Type(ud.name, ref=True) rvalueRefClsType = Type(ud.name, rvalref=True) typetype = Type("Type") valuetype = Type("Value") mtypevar = ExprVar("mType") mvaluevar = ExprVar("mValue") maybedtorvar = ExprVar("MaybeDestroy") assertsanityvar = ExprVar("AssertSanity") tnonevar = ExprVar("T__None") tlastvar = ExprVar("T__Last") def callAssertSanity(uvar=None, expectTypeVar=None): func = assertsanityvar args = [] if uvar is not None: func = ExprSelect(uvar, ".", assertsanityvar.name) if expectTypeVar is not None: args.append(expectTypeVar) return ExprCall(func, args=args) def maybeDestroy(): return StmtExpr(ExprCall(maybedtorvar)) # compute all the typedefs and forward decls we need to make gettypedeps = _ComputeTypeDeps(ud.decl.type) for c in ud.components: c.ipdltype.accept(gettypedeps) usingTypedefs = gettypedeps.usingTypedefs forwarddeclstmts = gettypedeps.forwardDeclStmts fulldecltypes = gettypedeps.fullDeclTypes # the |Type| enum, used to switch on the discunion's real type cls.addstmt(Label.PUBLIC) typeenum = TypeEnum(typetype.name) typeenum.addId(tnonevar.name, 0) firstid = ud.components[0].enum() typeenum.addId(firstid, 1) for c in ud.components[1:]: typeenum.addId(c.enum()) typeenum.addId(tlastvar.name, ud.components[-1].enum()) cls.addstmts([StmtDecl(Decl(typeenum, "")), Whitespace.NL]) cls.addstmt(Label.PRIVATE) cls.addstmts( usingTypedefs # hacky typedef's that allow placement dtors of builtins + [Typedef(c.internalType(), c.typedef()) for c in ud.components] ) cls.addstmt(Whitespace.NL) # the C++ union the discunion use for storage valueunion = TypeUnion(valuetype.name) for c in ud.components: valueunion.addComponent(c.unionType(), c.name) cls.addstmts([StmtDecl(Decl(valueunion, "")), Whitespace.NL]) # for each constituent type T, add private accessors that # return a pointer to the Value union storage casted to |T*| # and |const T*| for c in ud.components: getptr = MethodDefn( MethodDecl( c.getPtrName(), params=[], ret=c.ptrToInternalType(), force_inline=True ) ) getptr.addstmt(StmtReturn(c.ptrToSelfExpr())) getptrconst = MethodDefn( MethodDecl( c.getConstPtrName(), params=[], ret=c.constPtrToType(), const=True, force_inline=True, ) ) getptrconst.addstmt(StmtReturn(c.constptrToSelfExpr())) cls.addstmts([getptr, getptrconst]) cls.addstmt(Whitespace.NL) # add a helper method that invokes the placement dtor on the # current underlying value, only if |aNewType| is different # than the current type, and returns true if the underlying # value needs to be re-constructed maybedtor = MethodDefn(MethodDecl(maybedtorvar.name, ret=Type.VOID)) # wasn't /actually/ dtor'd, but it needs to be re-constructed ifnone = StmtIf(ExprBinary(mtypevar, "==", tnonevar)) ifnone.addifstmt(StmtReturn()) # need to destroy. switch on underlying type dtorswitch = StmtSwitch(mtypevar) for c in ud.components: dtorswitch.addcase( CaseLabel(c.enum()), StmtBlock([StmtExpr(c.callDtor()), StmtBreak()]) ) dtorswitch.addcase( DefaultLabel(), StmtBlock([_logicError("not reached"), StmtBreak()]) ) maybedtor.addstmts([ifnone, dtorswitch]) cls.addstmts([maybedtor, Whitespace.NL]) # add helper methods that ensure the discunion has a # valid type sanity = MethodDefn( MethodDecl(assertsanityvar.name, ret=Type.VOID, const=True, force_inline=True) ) sanity.addstmts( [ _abortIfFalse(ExprBinary(tnonevar, "<=", mtypevar), "invalid type tag"), _abortIfFalse(ExprBinary(mtypevar, "<=", tlastvar), "invalid type tag"), ] ) cls.addstmt(sanity) atypevar = ExprVar("aType") sanity2 = MethodDefn( MethodDecl( assertsanityvar.name, params=[Decl(typetype, atypevar.name)], ret=Type.VOID, const=True, force_inline=True, ) ) sanity2.addstmts( [ StmtExpr(ExprCall(assertsanityvar)), _abortIfFalse(ExprBinary(mtypevar, "==", atypevar), "unexpected type tag"), ] ) cls.addstmts([sanity2, Whitespace.NL]) # ---- begin public methods ----- # Union() default ctor cls.addstmts( [ Label.PUBLIC, ConstructorDefn( ConstructorDecl(ud.name, force_inline=True), memberinits=[ExprMemberInit(mtypevar, [tnonevar])], ), Whitespace.NL, ] ) # Union(const T&) copy & Union(T&&) move ctors othervar = ExprVar("aOther") for c in ud.components: if not _cxxTypeNeedsMoveForData(c.ipdltype): copyctor = ConstructorDefn( ConstructorDecl(ud.name, params=[Decl(c.constRefType(), othervar.name)]) ) copyctor.addstmts( [ StmtExpr(c.callCtor(othervar)), StmtExpr(ExprAssn(mtypevar, c.enumvar())), ] ) cls.addstmts([copyctor, Whitespace.NL]) if not _cxxTypeCanMove(c.ipdltype): continue movector = ConstructorDefn( ConstructorDecl(ud.name, params=[Decl(c.forceMoveType(), othervar.name)]) ) movector.addstmts( [ StmtExpr(c.callCtor(ExprMove(othervar))), StmtExpr(ExprAssn(mtypevar, c.enumvar())), ] ) cls.addstmts([movector, Whitespace.NL]) unionNeedsMove = any(_cxxTypeNeedsMoveForData(c.ipdltype) for c in ud.components) # Union(const Union&) copy ctor if not unionNeedsMove: copyctor = ConstructorDefn( ConstructorDecl(ud.name, params=[Decl(inClsType, othervar.name)]) ) othertype = ud.callType(othervar) copyswitch = StmtSwitch(othertype) for c in ud.components: copyswitch.addcase( CaseLabel(c.enum()), StmtBlock( [ StmtExpr( c.callCtor( ExprCall( ExprSelect(othervar, ".", c.getConstTypeName()) ) ) ), StmtBreak(), ] ), ) copyswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()])) copyswitch.addcase( DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()]) ) copyctor.addstmts( [ StmtExpr(callAssertSanity(uvar=othervar)), copyswitch, StmtExpr(ExprAssn(mtypevar, othertype)), ] ) cls.addstmts([copyctor, Whitespace.NL]) # Union(Union&&) move ctor movector = ConstructorDefn( ConstructorDecl(ud.name, params=[Decl(rvalueRefClsType, othervar.name)]) ) othertypevar = ExprVar("t") moveswitch = StmtSwitch(othertypevar) for c in ud.components: case = StmtBlock() if c.recursive: # This is sound as we set othervar.mTypeVar to T__None after the # switch. The pointer in the union will be left dangling. case.addstmts( [ # ptr_C() = other.ptr_C() StmtExpr( ExprAssn( c.callGetPtr(), ExprCall( ExprSelect(othervar, ".", ExprVar(c.getPtrName())) ), ) ) ] ) else: case.addstmts( [ # new ... (Move(other.get_C())) StmtExpr( c.callCtor( ExprMove( ExprCall(ExprSelect(othervar, ".", c.getTypeName())) ) ) ), # other.MaybeDestroy(T__None) StmtExpr(ExprCall(ExprSelect(othervar, ".", maybedtorvar))), ] ) case.addstmts([StmtBreak()]) moveswitch.addcase(CaseLabel(c.enum()), case) moveswitch.addcase(CaseLabel(tnonevar.name), StmtBlock([StmtBreak()])) moveswitch.addcase( DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn()]) ) movector.addstmts( [ StmtExpr(callAssertSanity(uvar=othervar)), StmtDecl(Decl(typetype, othertypevar.name), init=ud.callType(othervar)), moveswitch, StmtExpr(ExprAssn(ExprSelect(othervar, ".", mtypevar), tnonevar)), StmtExpr(ExprAssn(mtypevar, othertypevar)), ] ) cls.addstmts([movector, Whitespace.NL]) # ~Union() dtor = DestructorDefn(DestructorDecl(ud.name)) dtor.addstmt(maybeDestroy()) cls.addstmts([dtor, Whitespace.NL]) # type() typemeth = MethodDefn( MethodDecl("type", ret=typetype, const=True, force_inline=True) ) typemeth.addstmt(StmtReturn(mtypevar)) cls.addstmts([typemeth, Whitespace.NL]) # Union& operator= methods rhsvar = ExprVar("aRhs") for c in ud.components: def opeqBody(rhs): return [ # might need to placement-delete old value first maybeDestroy(), StmtExpr(c.callCtor(rhs)), StmtExpr(ExprAssn(mtypevar, c.enumvar())), StmtReturn(ExprDeref(ExprVar.THIS)), ] if not _cxxTypeNeedsMoveForData(c.ipdltype): # Union& operator=(const T&) opeq = MethodDefn( MethodDecl( "operator=", params=[Decl(c.constRefType(), rhsvar.name)], ret=refClsType, ) ) opeq.addstmts(opeqBody(rhsvar)) cls.addstmts([opeq, Whitespace.NL]) # Union& operator=(T&&) if not _cxxTypeCanMove(c.ipdltype): continue opeq = MethodDefn( MethodDecl( "operator=", params=[Decl(c.forceMoveType(), rhsvar.name)], ret=refClsType, ) ) opeq.addstmts(opeqBody(ExprMove(rhsvar))) cls.addstmts([opeq, Whitespace.NL]) # Union& operator=(const Union&) if not unionNeedsMove: opeq = MethodDefn( MethodDecl( "operator=", params=[Decl(inClsType, rhsvar.name)], ret=refClsType ) ) rhstypevar = ExprVar("t") opeqswitch = StmtSwitch(rhstypevar) for c in ud.components: case = StmtBlock() case.addstmts( [ maybeDestroy(), StmtExpr( c.callCtor( ExprCall(ExprSelect(rhsvar, ".", c.getConstTypeName())) ) ), StmtBreak(), ] ) opeqswitch.addcase(CaseLabel(c.enum()), case) opeqswitch.addcase( CaseLabel(tnonevar.name), StmtBlock([maybeDestroy(), StmtBreak()]), ) opeqswitch.addcase( DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()]) ) opeq.addstmts( [ StmtExpr(callAssertSanity(uvar=rhsvar)), StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)), opeqswitch, StmtExpr(ExprAssn(mtypevar, rhstypevar)), StmtReturn(ExprDeref(ExprVar.THIS)), ] ) cls.addstmts([opeq, Whitespace.NL]) # Union& operator=(Union&&) opeq = MethodDefn( MethodDecl( "operator=", params=[Decl(rvalueRefClsType, rhsvar.name)], ret=refClsType ) ) rhstypevar = ExprVar("t") opeqswitch = StmtSwitch(rhstypevar) for c in ud.components: case = StmtBlock() if c.recursive: case.addstmts( [ maybeDestroy(), StmtExpr( ExprAssn( c.callGetPtr(), ExprCall(ExprSelect(rhsvar, ".", ExprVar(c.getPtrName()))), ) ), ] ) else: case.addstmts( [ maybeDestroy(), StmtExpr( c.callCtor( ExprMove(ExprCall(ExprSelect(rhsvar, ".", c.getTypeName()))) ) ), # other.MaybeDestroy() StmtExpr(ExprCall(ExprSelect(rhsvar, ".", maybedtorvar))), ] ) case.addstmts([StmtBreak()]) opeqswitch.addcase(CaseLabel(c.enum()), case) opeqswitch.addcase( CaseLabel(tnonevar.name), StmtBlock([maybeDestroy(), StmtBreak()]), ) opeqswitch.addcase( DefaultLabel(), StmtBlock([_logicError("unreached"), StmtBreak()]) ) opeq.addstmts( [ StmtExpr(callAssertSanity(uvar=rhsvar)), StmtDecl(Decl(typetype, rhstypevar.name), init=ud.callType(rhsvar)), opeqswitch, StmtExpr(ExprAssn(ExprSelect(rhsvar, ".", mtypevar), tnonevar)), StmtExpr(ExprAssn(mtypevar, rhstypevar)), StmtReturn(ExprDeref(ExprVar.THIS)), ] ) cls.addstmts([opeq, Whitespace.NL]) if "Comparable" in ud.attributes: # bool operator==(const T&) for c in ud.components: opeqeq = MethodDefn( MethodDecl( "operator==", params=[Decl(c.constRefType(), rhsvar.name)], ret=Type.BOOL, const=True, ) ) opeqeq.addstmt( StmtReturn(ExprBinary(ExprCall(ExprVar(c.getTypeName())), "==", rhsvar)) ) cls.addstmts([opeqeq, Whitespace.NL]) # bool operator==(const Union&) opeqeq = MethodDefn( MethodDecl( "operator==", params=[Decl(inClsType, rhsvar.name)], ret=Type.BOOL, const=True, ) ) iftypesmismatch = StmtIf(ExprBinary(ud.callType(), "!=", ud.callType(rhsvar))) iftypesmismatch.addifstmt(StmtReturn.FALSE) opeqeq.addstmts([iftypesmismatch, Whitespace.NL]) opeqeqswitch = StmtSwitch(ud.callType()) for c in ud.components: case = StmtBlock() case.addstmt( StmtReturn( ExprBinary( ExprCall(ExprVar(c.getTypeName())), "==", ExprCall(ExprSelect(rhsvar, ".", c.getTypeName())), ) ) ) opeqeqswitch.addcase(CaseLabel(c.enum()), case) opeqeqswitch.addcase( DefaultLabel(), StmtBlock([_logicError("unreached"), StmtReturn.FALSE]) ) opeqeq.addstmt(opeqeqswitch) cls.addstmts([opeqeq, Whitespace.NL]) # accessors for each type: operator T&, operator const T&, # T& get(), const T& get() for c in ud.components: getValueVar = ExprVar(c.getTypeName()) getConstValueVar = ExprVar(c.getConstTypeName()) getvalue = MethodDefn( MethodDecl(getValueVar.name, ret=c.refType(), force_inline=True) ) getvalue.addstmts( [ StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())), StmtReturn(ExprDeref(c.callGetPtr())), ] ) getconstvalue = MethodDefn( MethodDecl( getConstValueVar.name, ret=c.constRefType(), const=True, force_inline=True, ) ) getconstvalue.addstmts( [ StmtExpr(callAssertSanity(expectTypeVar=c.enumvar())), StmtReturn(c.getConstValue()), ] ) cls.addstmts([getvalue, getconstvalue]) optype = MethodDefn(MethodDecl("", typeop=c.refType(), force_inline=True)) optype.addstmt(StmtReturn(ExprCall(getValueVar))) opconsttype = MethodDefn( MethodDecl("", const=True, typeop=c.constRefType(), force_inline=True) ) opconsttype.addstmt(StmtReturn(ExprCall(getConstValueVar))) cls.addstmts([optype, opconsttype, Whitespace.NL]) # private vars cls.addstmts( [ Label.PRIVATE, StmtDecl(Decl(valuetype, mvaluevar.name)), StmtDecl(Decl(typetype, mtypevar.name)), ] ) return forwarddeclstmts, fulldecltypes, cls # ----------------------------------------------------------------------------- class _FindFriends(ipdl.ast.Visitor): def __init__(self): self.mytype = None # ProtocolType self.vtype = None # ProtocolType self.friends = set() # set def findFriends(self, ptype): self.mytype = ptype for toplvl in ptype.toplevels(): self.walkDownTheProtocolTree(toplvl) return self.friends # TODO could make this into a _iterProtocolTreeHelper ... def walkDownTheProtocolTree(self, ptype): if ptype != self.mytype: # don't want to |friend| ourself! self.visit(ptype) for mtype in ptype.manages: if mtype is not ptype: self.walkDownTheProtocolTree(mtype) def visit(self, ptype): # |vtype| is the type currently being visited savedptype = self.vtype self.vtype = ptype ptype._ast.accept(self) self.vtype = savedptype def visitMessageDecl(self, md): for it in self.iterActorParams(md): if it.protocol == self.mytype: self.friends.add(self.vtype) def iterActorParams(self, md): for param in md.inParams: for actor in ipdl.type.iteractortypes(param.type): yield actor for ret in md.outParams: for actor in ipdl.type.iteractortypes(ret.type): yield actor class _GenerateProtocolActorCode(ipdl.ast.Visitor): def __init__(self, myside): self.side = myside # "parent" or "child" self.prettyside = myside.title() self.clsname = None self.protocol = None self.hdrfile = None self.cppfile = None self.ns = None self.cls = None self.protocolCxxIncludes = [] self.actorForwardDecls = [] self.usingDecls = [] self.externalIncludes = set() self.nonForwardDeclaredHeaders = set() self.typedefSet = set( [ Typedef(Type("mozilla::ipc::ActorHandle"), "ActorHandle"), Typedef(Type("base::ProcessId"), "ProcessId"), Typedef(Type("mozilla::ipc::ProtocolId"), "ProtocolId"), Typedef(Type("mozilla::ipc::Endpoint"), "Endpoint", ["FooSide"]), Typedef( Type("mozilla::ipc::ManagedEndpoint"), "ManagedEndpoint", ["FooSide"], ), Typedef(Type("mozilla::UniquePtr"), "UniquePtr", ["T"]), Typedef( Type("mozilla::ipc::ResponseRejectReason"), "ResponseRejectReason" ), ] ) def lower(self, tu, clsname, cxxHeaderFile, cxxFile): self.clsname = clsname self.hdrfile = cxxHeaderFile self.cppfile = cxxFile tu.accept(self) def standardTypedefs(self): return [ Typedef(Type("mozilla::ipc::IProtocol"), "IProtocol"), Typedef(Type("IPC::Message"), "Message"), Typedef(Type("base::ProcessHandle"), "ProcessHandle"), Typedef(Type("mozilla::ipc::MessageChannel"), "MessageChannel"), Typedef(Type("mozilla::ipc::SharedMemory"), "SharedMemory"), ] def visitTranslationUnit(self, tu): self.protocol = tu.protocol hf = self.hdrfile cf = self.cppfile # make the C++ header hf.addthings( [_DISCLAIMER] + _includeGuardStart(hf) + [ Whitespace.NL, CppDirective("include", '"' + _protocolHeaderName(tu.protocol) + '.h"'), ] ) for inc in tu.includes: inc.accept(self) for inc in tu.cxxIncludes: inc.accept(self) for using in tu.builtinUsing: using.accept(self) for using in tu.using: using.accept(self) for su in tu.structsAndUnions: su.accept(self) # this generates the actor's full impl in self.cls tu.protocol.accept(self) clsdecl, clsdefn = _splitClassDeclDefn(self.cls) # XXX damn C++ ... return types in the method defn aren't in # class scope for stmt in clsdefn.stmts: if isinstance(stmt, MethodDefn): if stmt.decl.ret and stmt.decl.ret.name == "Result": stmt.decl.ret.name = clsdecl.name + "::" + stmt.decl.ret.name def setToIncludes(s): return [CppDirective("include", '"%s"' % i) for i in sorted(iter(s))] def makeNamespace(p, file): if 0 == len(p.namespaces): return file ns = Namespace(p.namespaces[-1].name) outerns = _putInNamespaces(ns, p.namespaces[:-1]) file.addthing(outerns) return ns if len(self.nonForwardDeclaredHeaders) != 0: self.hdrfile.addthings( [ Whitespace("// Headers for things that cannot be forward declared"), Whitespace.NL, ] + setToIncludes(self.nonForwardDeclaredHeaders) + [Whitespace.NL] ) self.hdrfile.addthings(self.actorForwardDecls) self.hdrfile.addthings(self.usingDecls) hdrns = makeNamespace(self.protocol, self.hdrfile) hdrns.addstmts( [Whitespace.NL, Whitespace.NL, clsdecl, Whitespace.NL, Whitespace.NL] ) actortype = ActorType(tu.protocol.decl.type) traitsdecl, traitsdefn = _ParamTraits.actorPickling(actortype, self.side) self.hdrfile.addthings([traitsdecl, Whitespace.NL] + _includeGuardEnd(hf)) # If the implementation type is not overridden, add an implicit import # for the default implementation header file. Explicit implementation # types will specify their headers manually with `include`. if self.protocol.implAttribute(self.side) is None: assert self.protocol.name.startswith("P") self.externalIncludes.add( "".join(n.name + "/" for n in self.protocol.namespaces) + self.protocol.name[1:] + self.side.capitalize() + ".h" ) # make the .cpp file cf.addthings( [ _DISCLAIMER, Whitespace.NL, CppDirective( "include", '"' + _protocolHeaderName(self.protocol, self.side) + '.h"', ), ] + setToIncludes(self.externalIncludes) ) cf.addthings( ( [Whitespace.NL] + [ CppDirective("include", '"%s.h"' % (inc)) for inc in self.protocolCxxIncludes ] + [Whitespace.NL] + [ CppDirective("include", '"%s"' % filename) for filename in ipdl.builtin.CppIncludes ] + [Whitespace.NL] ) ) cppns = makeNamespace(self.protocol, cf) cppns.addstmts( [Whitespace.NL, Whitespace.NL, clsdefn, Whitespace.NL, Whitespace.NL] ) cf.addthing(traitsdefn) def visitUsingStmt(self, using): if using.decl.fullname is not None: self.typedefSet.add( Typedef(Type(using.decl.fullname), using.decl.shortname) ) if using.header is None: return if using.canBeForwardDeclared(): spec = using.type self.usingDecls.extend( [ _makeForwardDeclForQClass( spec.baseid, spec.quals, cls=using.isClass(), struct=using.isStruct(), ), Whitespace.NL, ] ) self.externalIncludes.add(using.header) else: self.nonForwardDeclaredHeaders.add(using.header) def visitCxxInclude(self, inc): self.externalIncludes.add(inc.file) def visitInclude(self, inc): if inc.tu.filetype == "header": # Including a header will declare any globals defined by "using" # statements into our scope. To serialize these, we also may need # cxx include statements, so visit them as well. for cxxinc in inc.tu.cxxIncludes: cxxinc.accept(self) for using in inc.tu.using: using.accept(self) for su in inc.tu.structsAndUnions: su.accept(self) else: # Includes for protocols only include types explicitly exported by # those protocols. ip = inc.tu.protocol if ip == self.protocol: return self.actorForwardDecls.extend( [ _makeForwardDeclForActor(ip.decl.type, self.side), _makeForwardDeclForActor(ip.decl.type, _otherSide(self.side)), Whitespace.NL, ] ) self.protocolCxxIncludes.append(_protocolHeaderName(ip, self.side)) if ip.decl.fullname is not None: self.typedefSet.add( Typedef( Type(_actorName(ip.decl.fullname, self.side.title())), _actorName(ip.decl.shortname, self.side.title()), ) ) self.typedefSet.add( Typedef( Type( _actorName(ip.decl.fullname, _otherSide(self.side).title()) ), _actorName(ip.decl.shortname, _otherSide(self.side).title()), ) ) def visitStructDecl(self, sd): if sd.decl.fullname is not None: self.typedefSet.add(Typedef(Type(sd.fqClassName()), sd.name)) def visitUnionDecl(self, ud): if ud.decl.fullname is not None: self.typedefSet.add(Typedef(Type(ud.fqClassName()), ud.name)) def visitProtocol(self, p): self.hdrfile.addcode( """ #ifdef DEBUG #include "prenv.h" #endif // DEBUG #include "mozilla/Tainting.h" #include "mozilla/ipc/MessageChannel.h" #include "mozilla/ipc/ProtocolUtils.h" """ ) self.protocol = p ptype = p.decl.type toplevel = p.decl.type.toplevel() hasAsyncReturns = False for md in p.messageDecls: if md.hasAsyncReturns(): hasAsyncReturns = True break inherits = [] if ptype.isToplevel(): inherits.append(Inherit(p.openedProtocolInterfaceType(), viz="public")) else: inherits.append(Inherit(p.managerInterfaceType(), viz="public")) if ptype.isToplevel() and self.side == "parent": self.hdrfile.addthings( [_makeForwardDeclForQClass("nsIFile", []), Whitespace.NL] ) self.cls = Class(self.clsname, inherits=inherits, abstract=True) self.cls.addstmt(Label.PRIVATE) friends = _FindFriends().findFriends(ptype) if ptype.isManaged(): friends.update(ptype.managers) # |friend| managed actors so that they can call our Dealloc*() friends.update(ptype.manages) # don't friend ourself if we're a self-managed protocol friends.discard(ptype) for friend in sorted(friends, key=lambda f: f.fullname()): self.actorForwardDecls.extend( [_makeForwardDeclForActor(friend, self.prettyside), Whitespace.NL] ) self.cls.addstmt( FriendClassDecl(_actorName(friend.fullname(), self.prettyside)) ) self.cls.addstmt(Label.PROTECTED) for typedef in sorted(self.typedefSet): self.cls.addstmt(typedef) self.cls.addstmt(Whitespace.NL) if hasAsyncReturns: self.cls.addstmt(Label.PUBLIC) for md in p.messageDecls: if self.sendsMessage(md) and md.hasAsyncReturns(): self.cls.addstmt( Typedef(_makePromise(md.returns, self.side), md.promiseName()) ) if self.receivesMessage(md) and md.hasAsyncReturns(): self.cls.addstmt( Typedef(_makeResolver(md.returns, self.side), md.resolverName()) ) self.cls.addstmt(Whitespace.NL) self.cls.addstmt(Label.PROTECTED) # interface methods that the concrete subclass has to impl for md in p.messageDecls: isctor, isdtor = md.decl.type.isCtor(), md.decl.type.isDtor() if self.receivesMessage(md): # generate Recv/Answer* interface implicit = not isdtor returnsems = "resolver" if md.decl.type.isAsync() else "out" recvDecl = MethodDecl( md.recvMethod(), params=md.makeCxxParams( paramsems="move", returnsems=returnsems, side=self.side, implicit=implicit, direction="recv", ), ret=Type("mozilla::ipc::IPCResult"), methodspec=MethodSpec.VIRTUAL, ) # These method implementations cause problems when trying to # override them with different types in a direct call class. # # For the `isdtor` case there's a simple solution: it doesn't # make much sense to specify arguments and then completely # ignore them, and the no-arg case isn't a problem for # overriding. if isctor or (isdtor and not md.inParams): defaultRecv = MethodDefn(recvDecl) defaultRecv.addcode("return IPC_OK();\n") self.cls.addstmt(defaultRecv) elif self.protocol.implAttribute(self.side) == "virtual": # If we're using virtual calls, we need the methods to be # declared on the base class. recvDecl.methodspec = MethodSpec.PURE self.cls.addstmt(StmtDecl(recvDecl)) # If we're using virtual calls, we need the methods to be declared on # the base class. if self.protocol.implAttribute(self.side) == "virtual": for md in p.messageDecls: managed = md.decl.type.constructedType() if not ptype.isManagerOf(managed) or md.decl.type.isDtor(): continue # add the Alloc interface for managed actors actortype = md.actorDecl().bareType(self.side) if managed.isRefcounted(): if not self.receivesMessage(md): continue actortype.ptr = False actortype = _alreadyaddrefed(actortype) self.cls.addstmt( StmtDecl( MethodDecl( _allocMethod(managed, self.side), params=md.makeCxxParams( side=self.side, implicit=False, direction="recv" ), ret=actortype, methodspec=MethodSpec.PURE, ) ) ) # add the Dealloc interface for all managed non-refcounted actors, # even without ctors. This is useful for protocols which use # ManagedEndpoint for construction. for managed in ptype.manages: if managed.isRefcounted(): continue self.cls.addstmt( StmtDecl( MethodDecl( _deallocMethod(managed, self.side), params=[ Decl(p.managedCxxType(managed, self.side), "aActor") ], ret=Type.BOOL, methodspec=MethodSpec.PURE, ) ) ) if ptype.isToplevel(): # void ProcessingError(code); default to no-op processingerror = MethodDefn( MethodDecl( p.processingErrorVar().name, params=[ Param(_Result.Type(), "aCode"), Param(Type("char", const=True, ptr=True), "aReason"), ], methodspec=MethodSpec.OVERRIDE, ) ) # bool ShouldContinueFromReplyTimeout(); default to |true| shouldcontinue = MethodDefn( MethodDecl( p.shouldContinueFromTimeoutVar().name, ret=Type.BOOL, methodspec=MethodSpec.OVERRIDE, ) ) shouldcontinue.addcode("return true;\n") self.cls.addstmts( [ processingerror, shouldcontinue, Whitespace.NL, ] ) self.cls.addstmts(([Label.PUBLIC] + self.standardTypedefs() + [Whitespace.NL])) self.cls.addstmt(Label.PUBLIC) # Actor() ctor = ConstructorDefn(ConstructorDecl(self.clsname)) side = ExprVar("mozilla::ipc::" + self.side.title() + "Side") if ptype.isToplevel(): name = ExprLiteral.String(_actorName(p.name, self.side)) ctor.memberinits = [ ExprMemberInit( ExprVar("mozilla::ipc::IToplevelProtocol"), [name, _protocolId(ptype), side], ) ] else: ctor.memberinits = [ ExprMemberInit( ExprVar("mozilla::ipc::IProtocol"), [_protocolId(ptype), side] ) ] ctor.addcode("MOZ_COUNT_CTOR(${clsname});\n", clsname=self.clsname) self.cls.addstmts([ctor, Whitespace.NL]) # ~Actor() dtor = DestructorDefn( DestructorDecl(self.clsname, methodspec=MethodSpec.VIRTUAL) ) dtor.addcode("MOZ_COUNT_DTOR(${clsname});\n", clsname=self.clsname) self.cls.addstmts([dtor, Whitespace.NL]) if ptype.isRefcounted(): if not ptype.isToplevel(): self.cls.addcode( """ NS_INLINE_DECL_PURE_VIRTUAL_REFCOUNTING """ ) self.cls.addstmt(Label.PROTECTED) self.cls.addcode( """ void ActorAlloc() final { AddRef(); } void ActorDealloc() final { Release(); } """ ) self.cls.addstmt(Label.PUBLIC) if ptype.hasOtherPid(): otherpidmeth = MethodDefn( MethodDecl("OtherPid", ret=Type("::base::ProcessId"), const=True) ) otherpidmeth.addcode( """ ::base::ProcessId pid = ::mozilla::ipc::IProtocol::ToplevelProtocol()->OtherPidMaybeInvalid(); MOZ_RELEASE_ASSERT(pid != ::base::kInvalidProcessId); return pid; """ ) self.cls.addstmts([otherpidmeth, Whitespace.NL]) if not ptype.isToplevel(): if 1 == len(p.managers): # manager() const managertype = p.managerActorType(self.side, ptr=True) managermeth = MethodDefn( MethodDecl("Manager", ret=managertype, const=True) ) managermeth.addcode( """ return static_cast<${type}>(IProtocol::Manager()); """, type=managertype, ) self.cls.addstmts([managermeth, Whitespace.NL]) def actorFromIter(itervar): return ExprCode("${iter}.Get()->GetKey()", iter=itervar) def forLoopOverHashtable(hashtable, itervar, const=False): itermeth = "ConstIter" if const else "Iter" return StmtFor( init=ExprCode( "auto ${itervar} = ${hashtable}.${itermeth}()", itervar=itervar, hashtable=hashtable, itermeth=itermeth, ), cond=ExprCode("!${itervar}.Done()", itervar=itervar), update=ExprCode("${itervar}.Next()", itervar=itervar), ) # Managed[T](Array& inout) const # const Array& Managed() const for managed in ptype.manages: container = p.managedVar(managed, self.side) meth = MethodDefn( MethodDecl( p.managedMethod(managed, self.side).name, params=[ Decl( _cxxArrayType( p.managedCxxType(managed, self.side), ref=True ), "aArr", ) ], const=True, ) ) meth.addcode("${container}.ToArray(aArr);\n", container=container) refmeth = MethodDefn( MethodDecl( p.managedMethod(managed, self.side).name, params=[], ret=p.managedVarType(managed, self.side, const=True, ref=True), const=True, ) ) refmeth.addcode("return ${container};\n", container=container) self.cls.addstmts([meth, refmeth, Whitespace.NL]) # AllManagedActors(Array& inout) const arrvar = ExprVar("arr__") managedmeth = MethodDefn( MethodDecl( "AllManagedActors", params=[ Decl( _cxxArrayType(_refptr(_cxxLifecycleProxyType()), ref=True), arrvar.name, ) ], methodspec=MethodSpec.OVERRIDE, const=True, ) ) # Count the number of managed actors, and allocate space in the output array. managedmeth.addcode( """ uint32_t total = 0; """ ) for managed in ptype.manages: managedmeth.addcode( """ total += ${container}.Count(); """, container=p.managedVar(managed, self.side), ) managedmeth.addcode( """ arr__.SetCapacity(total); """ ) for managed in ptype.manages: managedmeth.addcode( """ for (auto* key : ${container}) { arr__.AppendElement(key->GetLifecycleProxy()); } """, container=p.managedVar(managed, self.side), ) self.cls.addstmts([managedmeth, Whitespace.NL]) # OpenPEndpoint(...)/BindPEndpoint(...) for managed in ptype.manages: self.genManagedEndpoint(managed) # OnMessageReceived()/OnCallReceived() # save these away for use in message handler case stmts msgvar = ExprVar("msg__") self.msgvar = msgvar replyvar = ExprVar("reply__") self.replyvar = replyvar var = ExprVar("v__") self.var = var # for ctor recv cases, we can't read the actor ID into a PFoo* # because it doesn't exist on this side yet. Use a "special" # actor handle instead handlevar = ExprVar("handle__") self.handlevar = handlevar msgtype = ExprCode("msg__.type()") self.asyncSwitch = StmtSwitch(msgtype) self.syncSwitch = None self.interruptSwitch = None if toplevel.isSync() or toplevel.isInterrupt(): self.syncSwitch = StmtSwitch(msgtype) if toplevel.isInterrupt(): self.interruptSwitch = StmtSwitch(msgtype) # Add a handler for the MANAGED_ENDPOINT_BOUND and # MANAGED_ENDPOINT_DROPPED message types for managed actors. if not ptype.isToplevel(): clearawaitingmanagedendpointbind = """ if (!mAwaitingManagedEndpointBind) { NS_WARNING("Unexpected managed endpoint lifecycle message after actor bound!"); return MsgNotAllowed; } mAwaitingManagedEndpointBind = false; """ self.asyncSwitch.addcase( CaseLabel("MANAGED_ENDPOINT_BOUND_MESSAGE_TYPE"), StmtBlock( [ StmtCode(clearawaitingmanagedendpointbind), StmtReturn(_Result.Processed), ] ), ) self.asyncSwitch.addcase( CaseLabel("MANAGED_ENDPOINT_DROPPED_MESSAGE_TYPE"), StmtBlock( [ StmtCode(clearawaitingmanagedendpointbind), *self.destroyActor( None, ExprVar.THIS, why=_DestroyReason.ManagedEndpointDropped, ), StmtReturn(_Result.Processed), ] ), ) # implement Send*() methods and add dispatcher cases to # message switch()es for md in p.messageDecls: self.visitMessageDecl(md) # add default cases default = StmtCode( """ return MsgNotKnown; """ ) self.asyncSwitch.addcase(DefaultLabel(), default) if toplevel.isSync() or toplevel.isInterrupt(): self.syncSwitch.addcase(DefaultLabel(), default) if toplevel.isInterrupt(): self.interruptSwitch.addcase(DefaultLabel(), default) self.cls.addstmts(self.implementManagerIface()) def makeHandlerMethod(name, switch, hasReply, dispatches=False): params = [Decl(Type("Message", const=True, ref=True), msgvar.name)] if hasReply: params.append(Decl(Type("UniquePtr", ref=True), replyvar.name)) method = MethodDefn( MethodDecl( name, methodspec=MethodSpec.OVERRIDE, params=params, ret=_Result.Type(), ) ) if not switch: method.addcode( """ MOZ_ASSERT_UNREACHABLE("message protocol not supported"); return MsgNotKnown; """ ) return method if dispatches: if hasReply: ondeadactor = [StmtReturn(_Result.RouteError)] else: ondeadactor = [ self.logMessage( None, ExprAddrOf(msgvar), "Ignored message for dead actor" ), StmtReturn(_Result.Processed), ] method.addcode( """ int32_t route__ = ${msgvar}.routing_id(); if (MSG_ROUTING_CONTROL != route__) { IProtocol* routed__ = Lookup(route__); if (!routed__ || !routed__->GetLifecycleProxy()) { $*{ondeadactor} } RefPtr proxy__ = routed__->GetLifecycleProxy(); return proxy__->Get()->${name}($,{args}); } """, msgvar=msgvar, ondeadactor=ondeadactor, name=name, args=[p.name for p in params], ) # bug 509581: don't generate the switch stmt if there # is only the default case; MSVC doesn't like that if switch.nr_cases > 1: method.addstmt(switch) else: method.addstmt(StmtReturn(_Result.NotKnown)) return method dispatches = ptype.isToplevel() and ptype.isManager() self.cls.addstmts( [ makeHandlerMethod( "OnMessageReceived", self.asyncSwitch, hasReply=False, dispatches=dispatches, ), Whitespace.NL, ] ) self.cls.addstmts( [ makeHandlerMethod( "OnMessageReceived", self.syncSwitch, hasReply=True, dispatches=dispatches, ), Whitespace.NL, ] ) self.cls.addstmts( [ makeHandlerMethod( "OnCallReceived", self.interruptSwitch, hasReply=True, dispatches=dispatches, ), Whitespace.NL, ] ) clearsubtreevar = ExprVar("ClearSubtree") if ptype.isToplevel(): # OnChannelClose() onclose = MethodDefn( MethodDecl("OnChannelClose", methodspec=MethodSpec.OVERRIDE) ) onclose.addcode( """ DestroySubtree(NormalShutdown); ClearSubtree(); DeallocShmems(); if (GetLifecycleProxy()) { GetLifecycleProxy()->Release(); } """ ) self.cls.addstmts([onclose, Whitespace.NL]) # OnChannelError() onerror = MethodDefn( MethodDecl("OnChannelError", methodspec=MethodSpec.OVERRIDE) ) onerror.addcode( """ DestroySubtree(AbnormalShutdown); ClearSubtree(); DeallocShmems(); if (GetLifecycleProxy()) { GetLifecycleProxy()->Release(); } """ ) self.cls.addstmts([onerror, Whitespace.NL]) if ptype.isToplevel() and ptype.isInterrupt(): processnative = MethodDefn( MethodDecl("ProcessNativeEventsInInterruptCall", ret=Type.VOID) ) processnative.addcode( """ #ifdef OS_WIN GetIPCChannel()->ProcessNativeEventsInInterruptCall(); #else FatalError("This method is Windows-only"); #endif """ ) self.cls.addstmts([processnative, Whitespace.NL]) # private methods self.cls.addstmt(Label.PRIVATE) # ClearSubtree() clearsubtree = MethodDefn(MethodDecl(clearsubtreevar.name)) for managed in ptype.manages: clearsubtree.addcode( """ for (auto* key : ${container}) { key->ClearSubtree(); } for (auto* key : ${container}) { // Recursively releasing ${container} kids. auto* proxy = key->GetLifecycleProxy(); NS_IF_RELEASE(proxy); } ${container}.Clear(); """, container=p.managedVar(managed, self.side), ) # don't release our own IPC reference: either the manager will do it, # or we're toplevel self.cls.addstmts([clearsubtree, Whitespace.NL]) if not ptype.isToplevel(): self.cls.addstmts( [ StmtDecl( Decl(Type.BOOL, "mAwaitingManagedEndpointBind"), init=ExprLiteral.FALSE, ), Whitespace.NL, ] ) for managed in ptype.manages: self.cls.addstmts( [ StmtDecl( Decl( p.managedVarType(managed, self.side), p.managedVar(managed, self.side).name, ) ) ] ) def genManagedEndpoint(self, managed): hereEp = "ManagedEndpoint<%s>" % _actorName(managed.name(), self.side) thereEp = "ManagedEndpoint<%s>" % _actorName( managed.name(), _otherSide(self.side) ) actor = _HybridDecl(ipdl.type.ActorType(managed), "aActor") # ManagedEndpoint OpenPEndpoint(PHere* aActor) openmeth = MethodDefn( MethodDecl( "Open%sEndpoint" % managed.name(), params=[ Decl(self.protocol.managedCxxType(managed, self.side), actor.name) ], ret=Type(thereEp), ) ) openmeth.addcode( """ $*{bind} // Mark our actor as awaiting the other side to be bound. This will // be cleared when a `MANAGED_ENDPOINT_{DROPPED,BOUND}` message is // received. aActor->mAwaitingManagedEndpointBind = true; return ${thereEp}(mozilla::ipc::PrivateIPDLInterface(), aActor); """, bind=self.bindManagedActor(actor, errfn=ExprCall(ExprVar(thereEp))), thereEp=thereEp, ) # void BindPEndpoint(ManagedEndpoint&& aEndpoint, PHere* aActor) bindmeth = MethodDefn( MethodDecl( "Bind%sEndpoint" % managed.name(), params=[ Decl(Type(hereEp), "aEndpoint"), Decl(self.protocol.managedCxxType(managed, self.side), actor.name), ], ret=Type.BOOL, ) ) bindmeth.addcode( """ return aEndpoint.Bind(mozilla::ipc::PrivateIPDLInterface(), aActor, this, ${container}); """, container=self.protocol.managedVar(managed, self.side), ) self.cls.addstmts([openmeth, bindmeth, Whitespace.NL]) def implementManagerIface(self): p = self.protocol protocolbase = Type("IProtocol", ptr=True) methods = [] if p.decl.type.isToplevel(): # FIXME: This used to be declared conditionally based on whether # shmem appeared somewhere in the protocol hierarchy, however that # caused issues due to Shmem instances hidden within custom C++ # types. self.asyncSwitch.addcase( CaseLabel("SHMEM_CREATED_MESSAGE_TYPE"), self.genShmemCreatedHandler(), ) self.asyncSwitch.addcase( CaseLabel("SHMEM_DESTROYED_MESSAGE_TYPE"), self.genShmemDestroyedHandler(), ) # Keep track of types created with an INOUT ctor. We need to call # Register() or RegisterID() for them depending on the side the managee # is created. inoutCtorTypes = [] for msg in p.messageDecls: msgtype = msg.decl.type if msgtype.isCtor() and msgtype.isInout(): inoutCtorTypes.append(msgtype.constructedType()) # all protocols share the "same" RemoveManagee() implementation pvar = ExprVar("aProtocolId") listenervar = ExprVar("aListener") removemanagee = MethodDefn( MethodDecl( p.removeManageeMethod().name, params=[ Decl(_protocolIdType(), pvar.name), Decl(protocolbase, listenervar.name), ], methodspec=MethodSpec.OVERRIDE, ) ) if not len(p.managesStmts): removemanagee.addcode( """ FatalError("unreached"); return; """ ) else: switchontype = StmtSwitch(pvar) for managee in p.managesStmts: manageeipdltype = managee.decl.type manageecxxtype = _cxxBareType( ipdl.type.ActorType(manageeipdltype), self.side ) case = ExprCode( """ { ${manageecxxtype} actor = static_cast<${manageecxxtype}>(aListener); const bool removed = ${container}.EnsureRemoved(actor); MOZ_RELEASE_ASSERT(removed, "actor not managed by this!"); auto* proxy = actor->GetLifecycleProxy(); NS_IF_RELEASE(proxy); return; } """, manageecxxtype=manageecxxtype, container=p.managedVar(manageeipdltype, self.side), ) switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case) switchontype.addcase( DefaultLabel(), ExprCode( """ FatalError("unreached"); return; """ ), ) removemanagee.addstmt(switchontype) # The `DeallocManagee` method is called for managed actors to trigger # deallocation when ActorLifecycleProxy is freed. deallocmanagee = MethodDefn( MethodDecl( p.deallocManageeMethod().name, params=[ Decl(_protocolIdType(), pvar.name), Decl(protocolbase, listenervar.name), ], methodspec=MethodSpec.OVERRIDE, ) ) if not len(p.managesStmts): deallocmanagee.addcode( """ FatalError("unreached"); return; """ ) else: switchontype = StmtSwitch(pvar) for managee in p.managesStmts: manageeipdltype = managee.decl.type # Reference counted actor types don't have corresponding # `Dealloc` methods, as they are deallocated by releasing the # IPDL-held reference. if manageeipdltype.isRefcounted(): continue case = StmtCode( """ ${concrete}->${dealloc}(static_cast<${type}>(aListener)); return; """, concrete=self.concreteThis(), dealloc=_deallocMethod(manageeipdltype, self.side), type=_cxxBareType(ipdl.type.ActorType(manageeipdltype), self.side), ) switchontype.addcase(CaseLabel(_protocolId(manageeipdltype).name), case) switchontype.addcase( DefaultLabel(), StmtCode( """ FatalError("unreached"); return; """ ), ) deallocmanagee.addstmt(switchontype) return methods + [removemanagee, deallocmanagee, Whitespace.NL] def genShmemCreatedHandler(self): assert self.protocol.decl.type.isToplevel() return StmtCode( """ { if (!ShmemCreated(${msgvar})) { return MsgPayloadError; } return MsgProcessed; } """, msgvar=self.msgvar, ) def genShmemDestroyedHandler(self): assert self.protocol.decl.type.isToplevel() return StmtCode( """ { if (!ShmemDestroyed(${msgvar})) { return MsgPayloadError; } return MsgProcessed; } """, msgvar=self.msgvar, ) # ------------------------------------------------------------------------- # The next few functions are the crux of the IPDL code generator. # They generate code for all the nasty work of message # serialization/deserialization and dispatching handlers for # received messages. ## def concreteThis(self): implAttr = self.protocol.implAttribute(self.side) if implAttr == "virtual": return ExprVar.THIS if implAttr is None: assert self.protocol.name.startswith("P") className = self.protocol.name[1:] + self.side.capitalize() else: assert isinstance(implAttr, ipdl.ast.StringLiteral) className = implAttr.value return ExprCode("static_cast<${className}*>(this)", className=className) def thisCall(self, function, args): return ExprCall(ExprSelect(self.concreteThis(), "->", function), args=args) def visitMessageDecl(self, md): isctor = md.decl.type.isCtor() isdtor = md.decl.type.isDtor() decltype = md.decl.type sendmethod = None movesendmethod = None promisesendmethod = None recvlbl, recvcase = None, None def addRecvCase(lbl, case): if decltype.isAsync(): self.asyncSwitch.addcase(lbl, case) elif decltype.isSync(): self.syncSwitch.addcase(lbl, case) elif decltype.isInterrupt(): self.interruptSwitch.addcase(lbl, case) else: assert 0 if self.sendsMessage(md): isasync = decltype.isAsync() # NOTE: Don't generate helper ctors for refcounted types. # # Safety concerns around providing your own actor to a ctor (namely # that the return value won't be checked, and the argument will be # `delete`-ed) are less critical with refcounted actors, due to the # actor being held alive by the callsite. # # This allows refcounted actors to not implement crashing AllocPFoo # methods on the sending side. if isctor and not md.decl.type.constructedType().isRefcounted(): self.cls.addstmts([self.genHelperCtor(md), Whitespace.NL]) if isctor and isasync: sendmethod, (recvlbl, recvcase) = self.genAsyncCtor(md) elif isctor: sendmethod = self.genBlockingCtorMethod(md) elif isdtor and isasync: sendmethod, (recvlbl, recvcase) = self.genAsyncDtor(md) elif isdtor: sendmethod = self.genBlockingDtorMethod(md) elif isasync: ( sendmethod, movesendmethod, promisesendmethod, (recvlbl, recvcase), ) = self.genAsyncSendMethod(md) else: sendmethod, movesendmethod = self.genBlockingSendMethod(md) # XXX figure out what to do here if isdtor and md.decl.type.constructedType().isToplevel(): sendmethod = None if sendmethod is not None: self.cls.addstmts([sendmethod, Whitespace.NL]) if movesendmethod is not None: self.cls.addstmts([movesendmethod, Whitespace.NL]) if promisesendmethod is not None: self.cls.addstmts([promisesendmethod, Whitespace.NL]) if recvcase is not None: addRecvCase(recvlbl, recvcase) recvlbl, recvcase = None, None if self.receivesMessage(md): if isctor: recvlbl, recvcase = self.genCtorRecvCase(md) elif isdtor: recvlbl, recvcase = self.genDtorRecvCase(md) else: recvlbl, recvcase = self.genRecvCase(md) # XXX figure out what to do here if isdtor and md.decl.type.constructedType().isToplevel(): return addRecvCase(recvlbl, recvcase) def genAsyncCtor(self, md): actor = md.actorDecl() method = MethodDefn(self.makeSendMethodDecl(md)) msgvar, stmts = self.makeMessage(md, errfnSendCtor) sendok, sendstmts = self.sendAsync(md, msgvar) method.addcode( """ $*{bind} // Build our constructor message. $*{stmts} // Notify the other side about the newly created actor. This can // fail if our manager has already been destroyed. // // NOTE: If the send call fails due to toplevel channel teardown, // the `IProtocol::ChannelSend` wrapper absorbs the error for us, // so we don't tear down actors unexpectedly. $*{sendstmts} // Warn, destroy the actor, and return null if the message failed to // send. Otherwise, return the successfully created actor reference. if (!${sendok}) { NS_WARNING("Error sending ${actorname} constructor"); $*{destroy} return nullptr; } return ${actor}; """, bind=self.bindManagedActor(actor), stmts=stmts, sendstmts=sendstmts, sendok=sendok, destroy=self.destroyActor( md, actor.var(), why=_DestroyReason.FailedConstructor ), actor=actor.var(), actorname=actor.ipdltype.protocol.name() + self.side.capitalize(), ) lbl = CaseLabel(md.pqReplyId()) case = StmtBlock() case.addstmt(StmtReturn(_Result.Processed)) # TODO not really sure what to do with async ctor "replies" yet. # destroy actor if there was an error? tricky ... return method, (lbl, case) def genBlockingCtorMethod(self, md): actor = md.actorDecl() method = MethodDefn(self.makeSendMethodDecl(md)) msgvar, stmts = self.makeMessage(md, errfnSendCtor) replyvar = self.replyvar sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar) replystmts = self.deserializeReply( md, replyvar, self.side, errfnSendCtor, errfnSentinel(ExprLiteral.NULL), ) method.addcode( """ $*{bind} // Build our constructor message. $*{stmts} // Synchronously send the constructor message to the other side. If // the send fails, e.g. due to the remote side shutting down, the // actor will be destroyed and potentially freed. UniquePtr ${replyvar}; $*{sendstmts} if (!(${sendok})) { // Warn, destroy the actor and return null if the message // failed to send. NS_WARNING("Error sending constructor"); $*{destroy} return nullptr; } $*{replystmts} return ${actor}; """, bind=self.bindManagedActor(actor), stmts=stmts, replyvar=replyvar, sendstmts=sendstmts, sendok=sendok, destroy=self.destroyActor( md, actor.var(), why=_DestroyReason.FailedConstructor ), replystmts=replystmts, actor=actor.var(), actorname=actor.ipdltype.protocol.name() + self.side.capitalize(), ) return method def bindManagedActor(self, actordecl, errfn=ExprLiteral.NULL, idexpr=None): actorproto = actordecl.ipdltype.protocol if idexpr is None: setManagerArgs = [ExprVar.THIS] else: setManagerArgs = [ExprVar.THIS, idexpr] return [ StmtCode( """ if (!${actor}) { NS_WARNING("Cannot bind null ${actorname} actor"); return ${errfn}; } ${actor}->SetManagerAndRegister($,{setManagerArgs}); ${container}.Insert(${actor}); """, actor=actordecl.var(), actorname=actorproto.name() + self.side.capitalize(), errfn=errfn, setManagerArgs=setManagerArgs, container=self.protocol.managedVar(actorproto, self.side), ) ] def genHelperCtor(self, md): helperdecl = self.makeSendMethodDecl(md) helperdecl.params = helperdecl.params[1:] helper = MethodDefn(helperdecl) helper.addstmts( [ self.callAllocActor(md, retsems="out", side=self.side), StmtReturn( ExprCall( ExprVar(helperdecl.name), args=md.makeCxxArgs(paramsems="move") ) ), ] ) return helper def genAsyncDtor(self, md): actorvar = ExprVar("actor") method = MethodDefn(self.makeDtorMethodDecl(md, actorvar)) method.addstmt(self.dtorPrologue(actorvar)) msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar) sendok, sendstmts = self.sendAsync(md, msgvar, actorvar) method.addstmts( stmts + sendstmts + [Whitespace.NL] + self.dtorEpilogue(md, actorvar) + [StmtReturn(sendok)] ) lbl = CaseLabel(md.pqReplyId()) case = StmtBlock() case.addstmt(StmtReturn(_Result.Processed)) # TODO if the dtor is "inherently racy", keep the actor alive # until the other side acks return method, (lbl, case) def genBlockingDtorMethod(self, md): actorvar = ExprVar("actor") method = MethodDefn(self.makeDtorMethodDecl(md, actorvar)) method.addstmt(self.dtorPrologue(actorvar)) msgvar, stmts = self.makeMessage(md, errfnSendDtor, actorvar) replyvar = self.replyvar sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar, actorvar) method.addstmts( stmts + [Whitespace.NL, StmtDecl(Decl(Type("UniquePtr"), replyvar.name))] + sendstmts ) destmts = self.deserializeReply( md, replyvar, self.side, errfnSend, errfnSentinel(), actorvar ) ifsendok = StmtIf(ExprLiteral.FALSE) ifsendok.addifstmts(destmts) ifsendok.addifstmts( [Whitespace.NL, StmtExpr(ExprAssn(sendok, ExprLiteral.FALSE, "&="))] ) method.addstmt(ifsendok) method.addstmts( self.dtorEpilogue(md, actorvar) + [Whitespace.NL, StmtReturn(sendok)] ) return method def destroyActor(self, md, actorexpr, why=_DestroyReason.Deletion): if md and md.decl.type.isCtor(): destroyedType = md.decl.type.constructedType() else: destroyedType = self.protocol.decl.type return [ StmtCode( """ IProtocol* mgr = ${actor}->Manager(); ${actor}->DestroySubtree(${why}); ${actor}->ClearSubtree(); mgr->RemoveManagee(${protoId}, ${actor}); """, actor=actorexpr, why=why, protoId=_protocolId(destroyedType), ) ] def dtorPrologue(self, actorexpr): return StmtCode( """ if (!${actor} || !${actor}->CanSend()) { NS_WARNING("Attempt to __delete__ missing or closed actor"); return false; } """, actor=actorexpr, ) def dtorEpilogue(self, md, actorexpr): return self.destroyActor(md, actorexpr) def genRecvAsyncReplyCase(self, md): lbl = CaseLabel(md.pqReplyId()) case = StmtBlock() resolve, reason, prologue, desrej, desstmts = self.deserializeAsyncReply( md, self.side, errfnRecv, errfnSentinel(_Result.ValuError) ) if len(md.returns) > 1: resolvetype = _tuple([d.bareType(self.side) for d in md.returns]) resolvearg = ExprCall( ExprVar("std::make_tuple"), args=[ExprMove(p.var()) for p in md.returns] ) else: resolvetype = md.returns[0].bareType(self.side) resolvearg = ExprMove(md.returns[0].var()) case.addcode( """ $*{prologue} UniquePtr untypedCallback = GetIPCChannel()->PopCallback(${msgvar}, Id()); typedef MessageChannel::CallbackHolder<${resolvetype}> CallbackHolder; auto* callback = static_cast(untypedCallback.get()); if (!callback) { FatalError("Error unknown callback"); return MsgProcessingError; } if (${resolve}) { $*{desstmts} callback->Resolve(${resolvearg}); } else { $*{desrej} callback->Reject(std::move(${reason})); } return MsgProcessed; """, prologue=prologue, msgvar=self.msgvar, resolve=resolve, resolvetype=resolvetype, desstmts=desstmts, resolvearg=resolvearg, desrej=desrej, reason=reason, ) return (lbl, case) def genAsyncSendMethod(self, md): decl = self.makeSendMethodDecl(md) if "VirtualSendImpl" in md.attributes: decl.methodspec = MethodSpec.VIRTUAL method = MethodDefn(decl) msgvar, stmts = self.makeMessage(md, errfnSend) retvar, sendstmts = self.sendAsync(md, msgvar) method.addstmts(stmts + [Whitespace.NL] + sendstmts + [StmtReturn(retvar)]) movemethod = None # Add the promise overload if we need one. if md.returns: decl = self.makeSendMethodDecl(md, promise=True) if "VirtualSendImpl" in md.attributes: decl.methodspec = MethodSpec.VIRTUAL promisemethod = MethodDefn(decl) stmts = self.sendAsyncWithPromise(md) promisemethod.addstmts(stmts) (lbl, case) = self.genRecvAsyncReplyCase(md) else: (promisemethod, lbl, case) = (None, None, None) return method, movemethod, promisemethod, (lbl, case) def genBlockingSendMethod(self, md): method = MethodDefn(self.makeSendMethodDecl(md)) msgvar, serstmts = self.makeMessage(md, errfnSend) replyvar = self.replyvar sendok, sendstmts = self.sendBlocking(md, msgvar, replyvar) failif = StmtIf(ExprNot(sendok)) failif.addifstmt(StmtReturn.FALSE) desstmts = self.deserializeReply( md, replyvar, self.side, errfnSend, errfnSentinel() ) method.addstmts( serstmts + [Whitespace.NL, StmtDecl(Decl(Type("UniquePtr"), replyvar.name))] + sendstmts + [failif] + desstmts + [Whitespace.NL, StmtReturn.TRUE] ) movemethod = None return method, movemethod def genCtorRecvCase(self, md): lbl = CaseLabel(md.pqMsgId()) case = StmtBlock() actorhandle = self.handlevar stmts = self.deserializeMessage( md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError) ) idvar, saveIdStmts = self.saveActorId(md) case.addstmts( stmts + [ StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[]) for r in md.returns ] # alloc the actor, register it under the foreign ID + [self.callAllocActor(md, retsems="in", side=self.side)] + self.bindManagedActor( md.actorDecl(), errfn=_Result.ValuError, idexpr=_actorHId(actorhandle) ) + [Whitespace.NL] + saveIdStmts + self.invokeRecvHandler(md) + self.makeReply(md, errfnRecv, idvar) + [Whitespace.NL, StmtReturn(_Result.Processed)] ) return lbl, case def genDtorRecvCase(self, md): lbl = CaseLabel(md.pqMsgId()) case = StmtBlock() stmts = self.deserializeMessage( md, self.side, errfnRecv, errfnSent=errfnSentinel(_Result.ValuError) ) idvar, saveIdStmts = self.saveActorId(md) case.addstmts( stmts + [ StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[]) for r in md.returns ] + self.invokeRecvHandler(md) + [Whitespace.NL] + saveIdStmts + self.makeReply(md, errfnRecv, routingId=idvar) + [Whitespace.NL] + self.dtorEpilogue(md, ExprVar.THIS) + [Whitespace.NL, StmtReturn(_Result.Processed)] ) return lbl, case def genRecvCase(self, md): lbl = CaseLabel(md.pqMsgId()) case = StmtBlock() stmts = self.deserializeMessage( md, self.side, errfn=errfnRecv, errfnSent=errfnSentinel(_Result.ValuError) ) idvar, saveIdStmts = self.saveActorId(md) declstmts = [ StmtDecl(Decl(r.bareType(self.side), r.var().name), initargs=[]) for r in md.returns ] if md.decl.type.isAsync() and md.returns: declstmts = self.makeResolver(md, errfnRecv, routingId=idvar) case.addstmts( stmts + saveIdStmts + declstmts + self.invokeRecvHandler(md) + [Whitespace.NL] + self.makeReply(md, errfnRecv, routingId=idvar) + [StmtReturn(_Result.Processed)] ) return lbl, case # helper methods def makeMessage(self, md, errfn, fromActor=None): msgvar = self.msgvar writervar = ExprVar("writer__") routingId = self.protocol.routingId(fromActor) this = fromActor or ExprVar.THIS stmts = ( [ StmtDecl( Decl(Type("UniquePtr"), msgvar.name), init=ExprCall(ExprVar(md.pqMsgCtorFunc()), args=[routingId]), ), StmtDecl( Decl(Type("IPC::MessageWriter"), writervar.name), initargs=[ExprDeref(msgvar), this], ), ] + [Whitespace.NL] + [ _ParamTraits.checkedWrite( p.ipdltype, p.var(), ExprAddrOf(writervar), sentinelKey=p.name, ) for p in md.params ] + [Whitespace.NL] + self.setMessageFlags(md, msgvar) ) return msgvar, stmts def makeResolver(self, md, errfn, routingId): if routingId is None: routingId = self.protocol.routingId() if not md.decl.type.isAsync() or not md.hasReply(): return [] def paramValue(idx): assert idx < len(md.returns) if len(md.returns) > 1: return ExprCode("std::get<${idx}>(aParam)", idx=idx) return ExprVar("aParam") serializeParams = [ _ParamTraits.checkedWrite( p.ipdltype, paramValue(idx), ExprAddrOf(ExprVar("writer__")), sentinelKey=p.name, ) for idx, p in enumerate(md.returns) ] return [ StmtCode( """ UniquePtr ${replyvar}(${replyCtor}(${routingId})); ${replyvar}->set_seqno(${msgvar}.seqno()); RefPtr resolver__ = new mozilla::ipc::IPDLResolverInner(std::move(${replyvar}), this); ${resolvertype} resolver = [resolver__ = std::move(resolver__)](${resolveType} aParam) { resolver__->Resolve([&] (IPC::Message* ${replyvar}, IProtocol* self__) { IPC::MessageWriter writer__(*${replyvar}, self__); $*{serializeParams} ${logSendingReply} }); }; """, msgvar=self.msgvar, resolvertype=Type(md.resolverName()), routingId=routingId, resolveType=_resolveType(md.returns, self.side), replyvar=self.replyvar, replyCtor=ExprVar(md.pqReplyCtorFunc()), serializeParams=serializeParams, logSendingReply=self.logMessage( md, self.replyvar, "Sending reply ", actor=ExprVar("self__"), ), ) ] def makeReply(self, md, errfn, routingId): if routingId is None: routingId = self.protocol.routingId() # TODO special cases for async ctor/dtor replies if not md.decl.type.hasReply(): return [] if md.decl.type.isAsync() and md.decl.type.hasReply(): return [] replyvar = self.replyvar return ( [ StmtExpr( ExprAssn( replyvar, ExprCall(ExprVar(md.pqReplyCtorFunc()), args=[routingId]), ) ), StmtDecl( Decl(Type("IPC::MessageWriter"), "writer__"), initargs=[ExprDeref(replyvar), ExprVar.THIS], ), Whitespace.NL, ] + [ _ParamTraits.checkedWrite( r.ipdltype, r.var(), ExprAddrOf(ExprVar("writer__")), sentinelKey=r.name, ) for r in md.returns ] + self.setMessageFlags(md, replyvar) + [self.logMessage(md, replyvar, "Sending reply ")] ) def setMessageFlags(self, md, var, seqno=None): stmts = [] if seqno: stmts.append( StmtExpr(ExprCall(ExprSelect(var, "->", "set_seqno"), args=[seqno])) ) return stmts + [Whitespace.NL] def deserializeMessage(self, md, side, errfn, errfnSent): msgvar = self.msgvar msgexpr = ExprAddrOf(msgvar) readervar = ExprVar("reader__") isctor = md.decl.type.isCtor() stmts = [ self.logMessage(md, msgexpr, "Received ", receiving=True), self.profilerLabel(md), Whitespace.NL, ] if 0 == len(md.params): return stmts start, reads = 0, [] if isctor: # return the raw actor handle so that its ID can be used # to construct the "real" actor handlevar = self.handlevar handletype = Type("ActorHandle") reads = [ _ParamTraits.checkedRead( None, handletype, handlevar, ExprAddrOf(readervar), errfn, "'%s'" % handletype.name, sentinelKey="actor", errfnSentinel=errfnSent, ) ] start = 1 def maybeTainted(p, side): if md.decl.type.tainted and "NoTaint" not in p.attributes: return Type("Tainted", T=p.bareType(side)) return p.bareType(side) reads.extend( [ _ParamTraits.checkedRead( p.ipdltype, maybeTainted(p, side), p.var(), ExprAddrOf(readervar), errfn, "'%s'" % p.ipdltype.name(), sentinelKey=p.name, errfnSentinel=errfnSent, ) for p in md.params[start:] ] ) stmts.extend( ( [ StmtDecl( Decl(Type("IPC::MessageReader"), readervar.name), initargs=[msgvar, ExprVar.THIS], ) ] + [Whitespace.NL] + reads + [StmtCode("${reader}.EndRead();\n", reader=readervar)] ) ) return stmts def deserializeAsyncReply(self, md, side, errfn, errfnSent): msgvar = self.msgvar readervar = ExprVar("reader__") msgexpr = ExprAddrOf(msgvar) isctor = md.decl.type.isCtor() resolve = ExprVar("resolve__") reason = ExprVar("reason__") # NOTE: The `resolve__` and `reason__` parameters don't have sentinels, # as they are serialized by the IPDLResolverInner type in # ProtocolUtils.cpp rather than by generated code. desresolve = [ StmtCode( """ bool resolve__ = false; if (!IPC::ReadParam(&${readervar}, &resolve__)) { FatalError("Error deserializing bool"); return MsgValueError; } """, readervar=readervar, ), ] desrej = [ StmtCode( """ ResponseRejectReason reason__{}; if (!IPC::ReadParam(&${readervar}, &reason__)) { FatalError("Error deserializing ResponseRejectReason"); return MsgValueError; } ${readervar}.EndRead(); """, readervar=readervar, ), ] prologue = [ self.logMessage(md, msgexpr, "Received ", receiving=True), self.profilerLabel(md), Whitespace.NL, ] if not md.returns: return prologue prologue.extend( [ StmtDecl( Decl(Type("IPC::MessageReader"), readervar.name), initargs=[msgvar, ExprVar.THIS], ) ] + desresolve ) start, reads = 0, [] if isctor: # return the raw actor handle so that its ID can be used # to construct the "real" actor handlevar = self.handlevar handletype = Type("ActorHandle") reads = [ _ParamTraits.checkedRead( None, handletype, handlevar, ExprAddrOf(readervar), errfn, "'%s'" % handletype.name, sentinelKey="actor", errfnSentinel=errfnSent, ) ] start = 1 stmts = ( reads + [ _ParamTraits.checkedRead( p.ipdltype, p.bareType(side), p.var(), ExprAddrOf(readervar), errfn, "'%s'" % p.ipdltype.name(), sentinelKey=p.name, errfnSentinel=errfnSent, ) for p in md.returns[start:] ] + [StmtCode("${reader}.EndRead();", reader=readervar)] ) return resolve, reason, prologue, desrej, stmts def deserializeReply(self, md, replyexpr, side, errfn, errfnSentinel, actor=None): stmts = [ Whitespace.NL, self.logMessage(md, replyexpr, "Received reply ", actor, receiving=True), ] if 0 == len(md.returns): return stmts def tempvar(r): return ExprVar(r.var().name + "__reply") readervar = ExprVar("reader__") stmts.extend( [ Whitespace.NL, StmtDecl( Decl(Type("IPC::MessageReader"), readervar.name), initargs=[ExprDeref(self.replyvar), ExprVar.THIS], ), ] + [Whitespace.NL] + [ _ParamTraits.checkedRead( r.ipdltype, r.bareType(side), tempvar(r), ExprAddrOf(readervar), errfn, "'%s'" % r.ipdltype.name(), sentinelKey=r.name, errfnSentinel=errfnSentinel, ) for r in md.returns ] # Move-assign the values out of the variables created with # checkedRead into outparams. + [ StmtExpr(ExprAssn(ExprDeref(r.var()), ExprMove(tempvar(r)))) for r in md.returns ] + [StmtCode("${reader}.EndRead();", reader=readervar)] ) return stmts def sendAsync(self, md, msgexpr, actor=None): sendok = ExprVar("sendok__") resolvefn = ExprVar("aResolve") rejectfn = ExprVar("aReject") stmts = [ Whitespace.NL, self.logMessage(md, msgexpr, "Sending ", actor), self.profilerLabel(md), ] stmts.append(Whitespace.NL) # Generate the actual call expression. send = ExprVar("ChannelSend") if actor is not None: send = ExprSelect(actor, "->", send.name) if md.returns: stmts.append( StmtExpr( ExprCall( send, args=[ ExprMove(msgexpr), ExprVar(md.pqReplyId()), ExprMove(resolvefn), ExprMove(rejectfn), ], ) ) ) retvar = None else: stmts.append( StmtDecl( Decl(Type.BOOL, sendok.name), init=ExprCall(send, args=[ExprMove(msgexpr)]), ) ) retvar = sendok return (retvar, stmts) def sendBlocking(self, md, msgexpr, replyexpr, actor=None): send = ExprVar("ChannelSend") if md.decl.type.isInterrupt(): send = ExprVar("ChannelCall") if actor is not None: send = ExprSelect(actor, "->", send.name) sendok = ExprVar("sendok__") self.externalIncludes.add("mozilla/ProfilerMarkers.h") return ( sendok, ( [ Whitespace.NL, self.logMessage(md, msgexpr, "Sending ", actor), self.profilerLabel(md), ] + [ Whitespace.NL, StmtDecl(Decl(Type.BOOL, sendok.name), init=ExprLiteral.FALSE), StmtBlock( [ StmtExpr( ExprCall( ExprVar("AUTO_PROFILER_TRACING_MARKER"), [ ExprLiteral.String("Sync IPC"), ExprLiteral.String( self.protocol.name + "::" + md.prettyMsgName() ), ExprVar("IPC"), ], ) ), StmtExpr( ExprAssn( sendok, ExprCall( send, args=[ExprMove(msgexpr), ExprAddrOf(replyexpr)], ), ) ), ] ), ] ), ) def sendAsyncWithPromise(self, md): # Create a new promise, and forward to the callback send overload. promise = _makePromise(md.returns, self.side, resolver=True) if len(md.returns) > 1: resolvetype = _tuple([d.bareType(self.side) for d in md.returns]) else: resolvetype = md.returns[0].bareType(self.side) resolve = ExprCode( """ [promise__](${resolvetype}&& aValue) { promise__->Resolve(std::move(aValue), __func__); } """, resolvetype=resolvetype, ) reject = ExprCode( """ [promise__](ResponseRejectReason&& aReason) { promise__->Reject(std::move(aReason), __func__); } """, resolvetype=resolvetype, ) args = [ExprMove(p.var()) for p in md.params] + [resolve, reject] stmt = StmtCode( """ RefPtr<${promise}> promise__ = new ${promise}(__func__); promise__->UseDirectTaskDispatch(__func__); ${send}($,{args}); return promise__; """, promise=promise, send=md.sendMethod(), args=args, ) return [stmt] def callAllocActor(self, md, retsems, side): actortype = md.actorDecl().bareType(self.side) if md.decl.type.constructedType().isRefcounted(): actortype.ptr = False actortype = _refptr(actortype) callalloc = self.thisCall( _allocMethod(md.decl.type.constructedType(), side), args=md.makeCxxArgs(retsems=retsems, retcallsems="out", implicit=False), ) return StmtDecl(Decl(actortype, md.actorDecl().var().name), init=callalloc) def invokeRecvHandler(self, md): retsems = "in" if md.decl.type.isAsync() and md.returns: retsems = "resolver" okdecl = StmtDecl( Decl(Type("mozilla::ipc::IPCResult"), "__ok"), init=self.thisCall( md.recvMethod(), md.makeCxxArgs( paramsems="move", retsems=retsems, retcallsems="out", ), ), ) failif = StmtIf(ExprNot(ExprVar("__ok"))) failif.addifstmts( [ _protocolErrorBreakpoint("Handler returned error code!"), Whitespace( "// Error handled in mozilla::ipc::IPCResult\n", indent=True ), StmtReturn(_Result.ProcessingError), ] ) return [okdecl, failif] def makeDtorMethodDecl(self, md, actorvar): decl = self.makeSendMethodDecl(md) decl.params.insert( 0, Decl( _cxxInType( ipdl.type.ActorType(md.decl.type.constructedType()), side=self.side, direction="send", ), actorvar.name, ), ) decl.methodspec = MethodSpec.STATIC return decl def makeSendMethodDecl(self, md, promise=False, paramsems="in"): implicit = md.decl.type.hasImplicitActorParam() if md.decl.type.isAsync() and md.returns: if promise: returnsems = "promise" rettype = _refptr(Type(md.promiseName())) else: returnsems = "callback" rettype = Type.VOID else: assert not promise returnsems = "out" rettype = Type.BOOL decl = MethodDecl( md.sendMethod(), params=md.makeCxxParams( paramsems, returnsems=returnsems, side=self.side, implicit=implicit, direction="send", ), warn_unused=( (self.side == "parent" and returnsems != "callback") or (md.decl.type.isCtor() and not md.decl.type.isAsync()) ), ret=rettype, ) if md.decl.type.isCtor(): decl.ret = md.actorDecl().bareType(self.side) return decl def logMessage(self, md, msgptr, pfx, actor=None, receiving=False): actorname = _actorName(self.protocol.name, self.side) return StmtCode( """ if (mozilla::ipc::LoggingEnabledFor(${actorname})) { mozilla::ipc::LogMessageForProtocol( ${actorname}, ${actor}->ToplevelProtocol()->OtherPidMaybeInvalid(), ${pfx}, ${msgptr}->type(), mozilla::ipc::MessageDirection::${direction}); } """, actorname=ExprLiteral.String(actorname), actor=actor or ExprVar.THIS, pfx=ExprLiteral.String(pfx), msgptr=msgptr, direction="eReceiving" if receiving else "eSending", ) def profilerLabel(self, md): self.externalIncludes.add("mozilla/ProfilerLabels.h") return StmtCode( """ AUTO_PROFILER_LABEL("${name}::${msgname}", OTHER); """, name=self.protocol.name, msgname=md.prettyMsgName(), ) def saveActorId(self, md): idvar = ExprVar("id__") if md.decl.type.hasReply(): # only save the ID if we're actually going to use it, to # avoid unused-variable warnings saveIdStmts = [ StmtDecl(Decl(_actorIdType(), idvar.name), self.protocol.routingId()) ] else: saveIdStmts = [] return idvar, saveIdStmts class _GenerateProtocolParentCode(_GenerateProtocolActorCode): def __init__(self): _GenerateProtocolActorCode.__init__(self, "parent") def sendsMessage(self, md): return not md.decl.type.isIn() def receivesMessage(self, md): return md.decl.type.isInout() or md.decl.type.isIn() class _GenerateProtocolChildCode(_GenerateProtocolActorCode): def __init__(self): _GenerateProtocolActorCode.__init__(self, "child") def sendsMessage(self, md): return not md.decl.type.isOut() def receivesMessage(self, md): return md.decl.type.isInout() or md.decl.type.isOut() # ----------------------------------------------------------------------------- # Utility passes ## def _splitClassDeclDefn(cls): """Destructively split |cls| methods into declarations and definitions (if |not methodDecl.force_inline|). Return classDecl, methodDefns.""" defns = Block() for i, stmt in enumerate(cls.stmts): if isinstance(stmt, MethodDefn) and not stmt.decl.force_inline: decl, defn = _splitMethodDeclDefn(stmt, cls) cls.stmts[i] = StmtDecl(decl) if defn: defns.addstmts([defn, Whitespace.NL]) return cls, defns def _splitMethodDeclDefn(md, cls): # Pure methods have decls but no defns. if md.decl.methodspec == MethodSpec.PURE: return md.decl, None saveddecl = deepcopy(md.decl) md.decl.cls = cls # Don't emit method specifiers on method defns. md.decl.methodspec = MethodSpec.NONE md.decl.warn_unused = False md.decl.only_for_definition = True for param in md.decl.params: if isinstance(param, Param): param.default = None return saveddecl, md def _splitFuncDeclDefn(fun): assert not fun.decl.force_inline return StmtDecl(fun.decl), fun