/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=8 sts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ #ifndef mozilla_mscom_AsyncInvoker_h #define mozilla_mscom_AsyncInvoker_h #include #include #include #include "mozilla/Assertions.h" #include "mozilla/Attributes.h" #include "mozilla/DebugOnly.h" #include "mozilla/Maybe.h" #include "mozilla/Mutex.h" #include "mozilla/mscom/Aggregation.h" #include "mozilla/mscom/Utils.h" #include "nsISerialEventTarget.h" #include "nsISupportsImpl.h" #include "nsThreadUtils.h" namespace mozilla { namespace mscom { namespace detail { template class ForgettableAsyncCall : public ISynchronize { public: explicit ForgettableAsyncCall(ICallFactory* aCallFactory) : mRefCnt(0), mAsyncCall(nullptr) { StabilizedRefCount> stabilizer(mRefCnt); HRESULT hr = aCallFactory->CreateCall(__uuidof(AsyncInterface), this, IID_IUnknown, getter_AddRefs(mInnerUnk)); if (FAILED(hr)) { return; } hr = mInnerUnk->QueryInterface(__uuidof(AsyncInterface), reinterpret_cast(&mAsyncCall)); if (SUCCEEDED(hr)) { // Don't hang onto a ref. Because mAsyncCall is aggregated, its refcount // is this->mRefCnt, so we'd create a cycle! mAsyncCall->Release(); } } AsyncInterface* GetInterface() const { return mAsyncCall; } // IUnknown STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) final { if (aIid == IID_ISynchronize || aIid == IID_IUnknown) { RefPtr ptr(this); ptr.forget(aOutInterface); return S_OK; } return mInnerUnk->QueryInterface(aIid, aOutInterface); } STDMETHODIMP_(ULONG) AddRef() final { ULONG result = ++mRefCnt; NS_LOG_ADDREF(this, result, "ForgettableAsyncCall", sizeof(*this)); return result; } STDMETHODIMP_(ULONG) Release() final { ULONG result = --mRefCnt; NS_LOG_RELEASE(this, result, "ForgettableAsyncCall"); if (!result) { delete this; } return result; } // ISynchronize STDMETHODIMP Wait(DWORD aFlags, DWORD aTimeoutMilliseconds) override { return E_NOTIMPL; } STDMETHODIMP Signal() override { // Even though this function is a no-op, we must return S_OK as opposed to // E_NOTIMPL or else COM will consider the async call to have failed. return S_OK; } STDMETHODIMP Reset() override { // Even though this function is a no-op, we must return S_OK as opposed to // E_NOTIMPL or else COM will consider the async call to have failed. return S_OK; } protected: virtual ~ForgettableAsyncCall() = default; private: Atomic mRefCnt; RefPtr mInnerUnk; AsyncInterface* mAsyncCall; // weak reference }; template class WaitableAsyncCall : public ForgettableAsyncCall { public: explicit WaitableAsyncCall(ICallFactory* aCallFactory) : ForgettableAsyncCall(aCallFactory), mEvent(::CreateEventW(nullptr, FALSE, FALSE, nullptr)) {} STDMETHODIMP Wait(DWORD aFlags, DWORD aTimeoutMilliseconds) override { const DWORD waitStart = aTimeoutMilliseconds == INFINITE ? 0 : ::GetTickCount(); DWORD flags = aFlags; if (XRE_IsContentProcess() && NS_IsMainThread()) { flags |= COWAIT_ALERTABLE; } HRESULT hr; DWORD signaledIdx; DWORD elapsed = 0; while (true) { if (aTimeoutMilliseconds != INFINITE) { elapsed = ::GetTickCount() - waitStart; } if (elapsed >= aTimeoutMilliseconds) { return RPC_S_CALLPENDING; } ::SetLastError(ERROR_SUCCESS); hr = ::CoWaitForMultipleHandles(flags, aTimeoutMilliseconds - elapsed, 1, &mEvent, &signaledIdx); if (hr == RPC_S_CALLPENDING || FAILED(hr)) { return hr; } if (hr == S_OK && signaledIdx == 0) { return hr; } } } STDMETHODIMP Signal() override { if (!::SetEvent(mEvent)) { return HRESULT_FROM_WIN32(::GetLastError()); } return S_OK; } protected: ~WaitableAsyncCall() { if (mEvent) { ::CloseHandle(mEvent); } } private: HANDLE mEvent; }; template class EventDrivenAsyncCall : public ForgettableAsyncCall { public: explicit EventDrivenAsyncCall(ICallFactory* aCallFactory) : ForgettableAsyncCall(aCallFactory) {} bool HasCompletionRunnable() const { return !!mCompletionRunnable; } void ClearCompletionRunnable() { mCompletionRunnable = nullptr; } void SetCompletionRunnable(already_AddRefed aRunnable) { nsCOMPtr innerRunnable(aRunnable); MOZ_ASSERT(!!innerRunnable); if (!innerRunnable) { return; } // We need to retain a ref to ourselves to outlive the AsyncInvoker // such that our callback can execute. RefPtr> self(this); mCompletionRunnable = NS_NewRunnableFunction( "EventDrivenAsyncCall outer completion Runnable", [innerRunnable = std::move(innerRunnable), self = std::move(self)]() { innerRunnable->Run(); }); } void SetEventTarget(nsISerialEventTarget* aTarget) { mEventTarget = aTarget; } STDMETHODIMP Signal() override { MOZ_ASSERT(!!mCompletionRunnable); if (!mCompletionRunnable) { return S_OK; } nsCOMPtr eventTarget(mEventTarget.forget()); if (!eventTarget) { eventTarget = GetMainThreadSerialEventTarget(); } DebugOnly rv = eventTarget->Dispatch(mCompletionRunnable.forget(), NS_DISPATCH_NORMAL); MOZ_ASSERT(NS_SUCCEEDED(rv)); return S_OK; } private: nsCOMPtr mCompletionRunnable; nsCOMPtr mEventTarget; }; template class FireAndForgetInvoker { protected: void OnBeginInvoke() {} void OnSyncInvoke(HRESULT aHr) {} void OnAsyncInvokeFailed() {} typedef ForgettableAsyncCall AsyncCallType; RefPtr> mAsyncCall; }; template class WaitableInvoker { public: HRESULT Wait(DWORD aTimeout = INFINITE) const { if (!mAsyncCall) { // Nothing to wait for return S_OK; } return mAsyncCall->Wait(0, aTimeout); } protected: void OnBeginInvoke() {} void OnSyncInvoke(HRESULT aHr) {} void OnAsyncInvokeFailed() {} typedef WaitableAsyncCall AsyncCallType; RefPtr> mAsyncCall; }; template class EventDrivenInvoker { public: void SetCompletionRunnable(already_AddRefed aRunnable) { if (mAsyncCall) { mAsyncCall->SetCompletionRunnable(std::move(aRunnable)); return; } mCompletionRunnable = aRunnable; } void SetAsyncEventTarget(nsISerialEventTarget* aTarget) { if (mAsyncCall) { mAsyncCall->SetEventTarget(aTarget); } } protected: void OnBeginInvoke() { MOZ_RELEASE_ASSERT( mCompletionRunnable || (mAsyncCall && mAsyncCall->HasCompletionRunnable()), "You should have called SetCompletionRunnable before invoking!"); } void OnSyncInvoke(HRESULT aHr) { nsCOMPtr completionRunnable(mCompletionRunnable.forget()); if (FAILED(aHr)) { return; } completionRunnable->Run(); } void OnAsyncInvokeFailed() { MOZ_ASSERT(!!mAsyncCall); mAsyncCall->ClearCompletionRunnable(); } typedef EventDrivenAsyncCall AsyncCallType; RefPtr> mAsyncCall; nsCOMPtr mCompletionRunnable; }; } // namespace detail /** * This class is intended for "fire-and-forget" asynchronous invocations of COM * interfaces. This requires that an interface be annotated with the * |async_uuid| attribute in midl. We also require that there be no outparams * in the desired asynchronous interface (otherwise that would break the * desired "fire-and-forget" semantics). * * For example, let us suppose we have some IDL as such: * [object, uuid(...), async_uuid(...)] * interface IFoo : IUnknown * { * HRESULT Bar([in] long baz); * } * * Then, given an IFoo, we may construct an AsyncInvoker: * * IFoo* foo = ...; * AsyncInvoker myInvoker(foo); * HRESULT hr = myInvoker.Invoke(&IFoo::Bar, &AsyncIFoo::Begin_Bar, 7); * * Alternatively you may use the ASYNC_INVOKER_FOR and ASYNC_INVOKE macros, * which automatically deduce the name of the asynchronous interface from the * name of the synchronous interface: * * ASYNC_INVOKER_FOR(IFoo) myInvoker(foo); * HRESULT hr = ASYNC_INVOKE(myInvoker, Bar, 7); * * This class may also be used when a synchronous COM call must be made that * might reenter the content process. In this case, use the WaitableAsyncInvoker * variant, or the WAITABLE_ASYNC_INVOKER_FOR macro: * * WAITABLE_ASYNC_INVOKER_FOR(Ifoo) myInvoker(foo); * HRESULT hr = ASYNC_INVOKE(myInvoker, Bar, 7); * if (SUCCEEDED(hr)) { * myInvoker.Wait(); // <-- Wait for the COM call to complete. * } * * In general you should avoid using the waitable version, but in some corner * cases it is absolutely necessary in order to preserve correctness while * avoiding deadlock. * * Finally, it is also possible to have the async invoker enqueue a runnable * to the main thread when the async operation completes: * * EVENT_DRIVEN_ASYNC_INVOKER_FOR(Ifoo) myInvoker(foo); * // myRunnable will be invoked on the main thread once the async operation * // has completed. Note that we set this *before* we do the ASYNC_INVOKE! * myInvoker.SetCompletionRunnable(myRunnable.forget()); * HRESULT hr = ASYNC_INVOKE(myInvoker, Bar, 7); * // ... */ template class WaitPolicy = detail::FireAndForgetInvoker> class MOZ_RAII AsyncInvoker final : public WaitPolicy { using Base = WaitPolicy; public: typedef SyncInterface SyncInterfaceT; typedef AsyncInterface AsyncInterfaceT; /** * @param aSyncObj The COM object on which to invoke the asynchronous event. * If this object is not a proxy to the synchronous variant * of AsyncInterface, then it will be invoked synchronously * instead (because it is an in-process virtual method call). * @param aIsProxy An optional hint as to whether or not aSyncObj is a proxy. * If not specified, AsyncInvoker will automatically detect * whether aSyncObj is a proxy, however there may be a * performance penalty associated with that. */ explicit AsyncInvoker(SyncInterface* aSyncObj, const Maybe& aIsProxy = Nothing()) { MOZ_ASSERT(aSyncObj); RefPtr callFactory; if ((aIsProxy.isSome() && !aIsProxy.value()) || FAILED(aSyncObj->QueryInterface(IID_ICallFactory, getter_AddRefs(callFactory)))) { mSyncObj = aSyncObj; return; } this->mAsyncCall = new typename Base::AsyncCallType(callFactory); } /** * @brief Invoke a method on the object. Member function pointers are provided * for both the sychronous and asynchronous variants of the interface. * If this invoker's encapsulated COM object is a proxy, then Invoke * will call the asynchronous member function. Otherwise the * synchronous version must be used, as the invocation will simply be a * virtual function call that executes in-process. * @param aSyncMethod Pointer to the method that we would like to invoke on * the synchronous interface. * @param aAsyncMethod Pointer to the method that we would like to invoke on * the asynchronous interface. */ template HRESULT Invoke(SyncMethod aSyncMethod, AsyncMethod aAsyncMethod, Args&&... aArgs) { this->OnBeginInvoke(); if (mSyncObj) { HRESULT hr = (mSyncObj->*aSyncMethod)(std::forward(aArgs)...); this->OnSyncInvoke(hr); return hr; } MOZ_ASSERT(this->mAsyncCall); if (!this->mAsyncCall) { this->OnAsyncInvokeFailed(); return E_POINTER; } AsyncInterface* asyncInterface = this->mAsyncCall->GetInterface(); MOZ_ASSERT(asyncInterface); if (!asyncInterface) { this->OnAsyncInvokeFailed(); return E_POINTER; } HRESULT hr = (asyncInterface->*aAsyncMethod)(std::forward(aArgs)...); if (FAILED(hr)) { this->OnAsyncInvokeFailed(); } return hr; } AsyncInvoker(const AsyncInvoker& aOther) = delete; AsyncInvoker(AsyncInvoker&& aOther) = delete; AsyncInvoker& operator=(const AsyncInvoker& aOther) = delete; AsyncInvoker& operator=(AsyncInvoker&& aOther) = delete; private: RefPtr mSyncObj; }; template using WaitableAsyncInvoker = AsyncInvoker; template using EventDrivenAsyncInvoker = AsyncInvoker; } // namespace mscom } // namespace mozilla #define ASYNC_INVOKER_FOR(SyncIface) \ mozilla::mscom::AsyncInvoker #define WAITABLE_ASYNC_INVOKER_FOR(SyncIface) \ mozilla::mscom::WaitableAsyncInvoker #define EVENT_DRIVEN_ASYNC_INVOKER_FOR(SyncIface) \ mozilla::mscom::EventDrivenAsyncInvoker #define ASYNC_INVOKE(InvokerObj, SyncMethodName, ...) \ InvokerObj.Invoke( \ &decltype(InvokerObj)::SyncInterfaceT::SyncMethodName, \ &decltype(InvokerObj)::AsyncInterfaceT::Begin_##SyncMethodName, \ ##__VA_ARGS__) #endif // mozilla_mscom_AsyncInvoker_h