summaryrefslogtreecommitdiffstats
path: root/ipc/mscom/AsyncInvoker.h
diff options
context:
space:
mode:
Diffstat (limited to 'ipc/mscom/AsyncInvoker.h')
-rw-r--r--ipc/mscom/AsyncInvoker.h465
1 files changed, 465 insertions, 0 deletions
diff --git a/ipc/mscom/AsyncInvoker.h b/ipc/mscom/AsyncInvoker.h
new file mode 100644
index 0000000000..ce25a4224b
--- /dev/null
+++ b/ipc/mscom/AsyncInvoker.h
@@ -0,0 +1,465 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_AsyncInvoker_h
+#define mozilla_mscom_AsyncInvoker_h
+
+#include <objidl.h>
+#include <windows.h>
+
+#include <utility>
+
+#include "mozilla/Assertions.h"
+#include "mozilla/Attributes.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/Maybe.h"
+#include "mozilla/Mutex.h"
+#include "mozilla/mscom/Aggregation.h"
+#include "mozilla/mscom/Utils.h"
+#include "nsISerialEventTarget.h"
+#include "nsISupportsImpl.h"
+#include "nsThreadUtils.h"
+
+namespace mozilla {
+namespace mscom {
+namespace detail {
+
+template <typename AsyncInterface>
+class ForgettableAsyncCall : public ISynchronize {
+ public:
+ explicit ForgettableAsyncCall(ICallFactory* aCallFactory)
+ : mRefCnt(0), mAsyncCall(nullptr) {
+ StabilizedRefCount<Atomic<ULONG>> stabilizer(mRefCnt);
+
+ HRESULT hr =
+ aCallFactory->CreateCall(__uuidof(AsyncInterface), this, IID_IUnknown,
+ getter_AddRefs(mInnerUnk));
+ if (FAILED(hr)) {
+ return;
+ }
+
+ hr = mInnerUnk->QueryInterface(__uuidof(AsyncInterface),
+ reinterpret_cast<void**>(&mAsyncCall));
+ if (SUCCEEDED(hr)) {
+ // Don't hang onto a ref. Because mAsyncCall is aggregated, its refcount
+ // is this->mRefCnt, so we'd create a cycle!
+ mAsyncCall->Release();
+ }
+ }
+
+ AsyncInterface* GetInterface() const { return mAsyncCall; }
+
+ // IUnknown
+ STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) final {
+ if (aIid == IID_ISynchronize || aIid == IID_IUnknown) {
+ RefPtr<ISynchronize> ptr(this);
+ ptr.forget(aOutInterface);
+ return S_OK;
+ }
+
+ return mInnerUnk->QueryInterface(aIid, aOutInterface);
+ }
+
+ STDMETHODIMP_(ULONG) AddRef() final {
+ ULONG result = ++mRefCnt;
+ NS_LOG_ADDREF(this, result, "ForgettableAsyncCall", sizeof(*this));
+ return result;
+ }
+
+ STDMETHODIMP_(ULONG) Release() final {
+ ULONG result = --mRefCnt;
+ NS_LOG_RELEASE(this, result, "ForgettableAsyncCall");
+ if (!result) {
+ delete this;
+ }
+ return result;
+ }
+
+ // ISynchronize
+ STDMETHODIMP Wait(DWORD aFlags, DWORD aTimeoutMilliseconds) override {
+ return E_NOTIMPL;
+ }
+
+ STDMETHODIMP Signal() override {
+ // Even though this function is a no-op, we must return S_OK as opposed to
+ // E_NOTIMPL or else COM will consider the async call to have failed.
+ return S_OK;
+ }
+
+ STDMETHODIMP Reset() override {
+ // Even though this function is a no-op, we must return S_OK as opposed to
+ // E_NOTIMPL or else COM will consider the async call to have failed.
+ return S_OK;
+ }
+
+ protected:
+ virtual ~ForgettableAsyncCall() = default;
+
+ private:
+ Atomic<ULONG> mRefCnt;
+ RefPtr<IUnknown> mInnerUnk;
+ AsyncInterface* mAsyncCall; // weak reference
+};
+
+template <typename AsyncInterface>
+class WaitableAsyncCall : public ForgettableAsyncCall<AsyncInterface> {
+ public:
+ explicit WaitableAsyncCall(ICallFactory* aCallFactory)
+ : ForgettableAsyncCall<AsyncInterface>(aCallFactory),
+ mEvent(::CreateEventW(nullptr, FALSE, FALSE, nullptr)) {}
+
+ STDMETHODIMP Wait(DWORD aFlags, DWORD aTimeoutMilliseconds) override {
+ const DWORD waitStart =
+ aTimeoutMilliseconds == INFINITE ? 0 : ::GetTickCount();
+ DWORD flags = aFlags;
+ if (XRE_IsContentProcess() && NS_IsMainThread()) {
+ flags |= COWAIT_ALERTABLE;
+ }
+
+ HRESULT hr;
+ DWORD signaledIdx;
+
+ DWORD elapsed = 0;
+
+ while (true) {
+ if (aTimeoutMilliseconds != INFINITE) {
+ elapsed = ::GetTickCount() - waitStart;
+ }
+ if (elapsed >= aTimeoutMilliseconds) {
+ return RPC_S_CALLPENDING;
+ }
+
+ ::SetLastError(ERROR_SUCCESS);
+
+ hr = ::CoWaitForMultipleHandles(flags, aTimeoutMilliseconds - elapsed, 1,
+ &mEvent, &signaledIdx);
+ if (hr == RPC_S_CALLPENDING || FAILED(hr)) {
+ return hr;
+ }
+
+ if (hr == S_OK && signaledIdx == 0) {
+ return hr;
+ }
+ }
+ }
+
+ STDMETHODIMP Signal() override {
+ if (!::SetEvent(mEvent)) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+ return S_OK;
+ }
+
+ protected:
+ ~WaitableAsyncCall() {
+ if (mEvent) {
+ ::CloseHandle(mEvent);
+ }
+ }
+
+ private:
+ HANDLE mEvent;
+};
+
+template <typename AsyncInterface>
+class EventDrivenAsyncCall : public ForgettableAsyncCall<AsyncInterface> {
+ public:
+ explicit EventDrivenAsyncCall(ICallFactory* aCallFactory)
+ : ForgettableAsyncCall<AsyncInterface>(aCallFactory) {}
+
+ bool HasCompletionRunnable() const { return !!mCompletionRunnable; }
+
+ void ClearCompletionRunnable() { mCompletionRunnable = nullptr; }
+
+ void SetCompletionRunnable(already_AddRefed<nsIRunnable> aRunnable) {
+ nsCOMPtr<nsIRunnable> innerRunnable(aRunnable);
+ MOZ_ASSERT(!!innerRunnable);
+ if (!innerRunnable) {
+ return;
+ }
+
+ // We need to retain a ref to ourselves to outlive the AsyncInvoker
+ // such that our callback can execute.
+ RefPtr<EventDrivenAsyncCall<AsyncInterface>> self(this);
+
+ mCompletionRunnable = NS_NewRunnableFunction(
+ "EventDrivenAsyncCall outer completion Runnable",
+ [innerRunnable = std::move(innerRunnable), self = std::move(self)]() {
+ innerRunnable->Run();
+ });
+ }
+
+ void SetEventTarget(nsISerialEventTarget* aTarget) { mEventTarget = aTarget; }
+
+ STDMETHODIMP Signal() override {
+ MOZ_ASSERT(!!mCompletionRunnable);
+ if (!mCompletionRunnable) {
+ return S_OK;
+ }
+
+ nsCOMPtr<nsISerialEventTarget> eventTarget(mEventTarget.forget());
+ if (!eventTarget) {
+ eventTarget = GetMainThreadSerialEventTarget();
+ }
+
+ DebugOnly<nsresult> rv =
+ eventTarget->Dispatch(mCompletionRunnable.forget(), NS_DISPATCH_NORMAL);
+ MOZ_ASSERT(NS_SUCCEEDED(rv));
+ return S_OK;
+ }
+
+ private:
+ nsCOMPtr<nsIRunnable> mCompletionRunnable;
+ nsCOMPtr<nsISerialEventTarget> mEventTarget;
+};
+
+template <typename AsyncInterface>
+class FireAndForgetInvoker {
+ protected:
+ void OnBeginInvoke() {}
+ void OnSyncInvoke(HRESULT aHr) {}
+ void OnAsyncInvokeFailed() {}
+
+ typedef ForgettableAsyncCall<AsyncInterface> AsyncCallType;
+
+ RefPtr<ForgettableAsyncCall<AsyncInterface>> mAsyncCall;
+};
+
+template <typename AsyncInterface>
+class WaitableInvoker {
+ public:
+ HRESULT Wait(DWORD aTimeout = INFINITE) const {
+ if (!mAsyncCall) {
+ // Nothing to wait for
+ return S_OK;
+ }
+
+ return mAsyncCall->Wait(0, aTimeout);
+ }
+
+ protected:
+ void OnBeginInvoke() {}
+ void OnSyncInvoke(HRESULT aHr) {}
+ void OnAsyncInvokeFailed() {}
+
+ typedef WaitableAsyncCall<AsyncInterface> AsyncCallType;
+
+ RefPtr<WaitableAsyncCall<AsyncInterface>> mAsyncCall;
+};
+
+template <typename AsyncInterface>
+class EventDrivenInvoker {
+ public:
+ void SetCompletionRunnable(already_AddRefed<nsIRunnable> aRunnable) {
+ if (mAsyncCall) {
+ mAsyncCall->SetCompletionRunnable(std::move(aRunnable));
+ return;
+ }
+
+ mCompletionRunnable = aRunnable;
+ }
+
+ void SetAsyncEventTarget(nsISerialEventTarget* aTarget) {
+ if (mAsyncCall) {
+ mAsyncCall->SetEventTarget(aTarget);
+ }
+ }
+
+ protected:
+ void OnBeginInvoke() {
+ MOZ_RELEASE_ASSERT(
+ mCompletionRunnable ||
+ (mAsyncCall && mAsyncCall->HasCompletionRunnable()),
+ "You should have called SetCompletionRunnable before invoking!");
+ }
+
+ void OnSyncInvoke(HRESULT aHr) {
+ nsCOMPtr<nsIRunnable> completionRunnable(mCompletionRunnable.forget());
+ if (FAILED(aHr)) {
+ return;
+ }
+
+ completionRunnable->Run();
+ }
+
+ void OnAsyncInvokeFailed() {
+ MOZ_ASSERT(!!mAsyncCall);
+ mAsyncCall->ClearCompletionRunnable();
+ }
+
+ typedef EventDrivenAsyncCall<AsyncInterface> AsyncCallType;
+
+ RefPtr<EventDrivenAsyncCall<AsyncInterface>> mAsyncCall;
+ nsCOMPtr<nsIRunnable> mCompletionRunnable;
+};
+
+} // namespace detail
+
+/**
+ * This class is intended for "fire-and-forget" asynchronous invocations of COM
+ * interfaces. This requires that an interface be annotated with the
+ * |async_uuid| attribute in midl. We also require that there be no outparams
+ * in the desired asynchronous interface (otherwise that would break the
+ * desired "fire-and-forget" semantics).
+ *
+ * For example, let us suppose we have some IDL as such:
+ * [object, uuid(...), async_uuid(...)]
+ * interface IFoo : IUnknown
+ * {
+ * HRESULT Bar([in] long baz);
+ * }
+ *
+ * Then, given an IFoo, we may construct an AsyncInvoker<IFoo, AsyncIFoo>:
+ *
+ * IFoo* foo = ...;
+ * AsyncInvoker<IFoo, AsyncIFoo> myInvoker(foo);
+ * HRESULT hr = myInvoker.Invoke(&IFoo::Bar, &AsyncIFoo::Begin_Bar, 7);
+ *
+ * Alternatively you may use the ASYNC_INVOKER_FOR and ASYNC_INVOKE macros,
+ * which automatically deduce the name of the asynchronous interface from the
+ * name of the synchronous interface:
+ *
+ * ASYNC_INVOKER_FOR(IFoo) myInvoker(foo);
+ * HRESULT hr = ASYNC_INVOKE(myInvoker, Bar, 7);
+ *
+ * This class may also be used when a synchronous COM call must be made that
+ * might reenter the content process. In this case, use the WaitableAsyncInvoker
+ * variant, or the WAITABLE_ASYNC_INVOKER_FOR macro:
+ *
+ * WAITABLE_ASYNC_INVOKER_FOR(Ifoo) myInvoker(foo);
+ * HRESULT hr = ASYNC_INVOKE(myInvoker, Bar, 7);
+ * if (SUCCEEDED(hr)) {
+ * myInvoker.Wait(); // <-- Wait for the COM call to complete.
+ * }
+ *
+ * In general you should avoid using the waitable version, but in some corner
+ * cases it is absolutely necessary in order to preserve correctness while
+ * avoiding deadlock.
+ *
+ * Finally, it is also possible to have the async invoker enqueue a runnable
+ * to the main thread when the async operation completes:
+ *
+ * EVENT_DRIVEN_ASYNC_INVOKER_FOR(Ifoo) myInvoker(foo);
+ * // myRunnable will be invoked on the main thread once the async operation
+ * // has completed. Note that we set this *before* we do the ASYNC_INVOKE!
+ * myInvoker.SetCompletionRunnable(myRunnable.forget());
+ * HRESULT hr = ASYNC_INVOKE(myInvoker, Bar, 7);
+ * // ...
+ */
+template <typename SyncInterface, typename AsyncInterface,
+ template <typename Iface> class WaitPolicy =
+ detail::FireAndForgetInvoker>
+class MOZ_RAII AsyncInvoker final : public WaitPolicy<AsyncInterface> {
+ using Base = WaitPolicy<AsyncInterface>;
+
+ public:
+ typedef SyncInterface SyncInterfaceT;
+ typedef AsyncInterface AsyncInterfaceT;
+
+ /**
+ * @param aSyncObj The COM object on which to invoke the asynchronous event.
+ * If this object is not a proxy to the synchronous variant
+ * of AsyncInterface, then it will be invoked synchronously
+ * instead (because it is an in-process virtual method call).
+ * @param aIsProxy An optional hint as to whether or not aSyncObj is a proxy.
+ * If not specified, AsyncInvoker will automatically detect
+ * whether aSyncObj is a proxy, however there may be a
+ * performance penalty associated with that.
+ */
+ explicit AsyncInvoker(SyncInterface* aSyncObj,
+ const Maybe<bool>& aIsProxy = Nothing()) {
+ MOZ_ASSERT(aSyncObj);
+
+ RefPtr<ICallFactory> callFactory;
+ if ((aIsProxy.isSome() && !aIsProxy.value()) ||
+ FAILED(aSyncObj->QueryInterface(IID_ICallFactory,
+ getter_AddRefs(callFactory)))) {
+ mSyncObj = aSyncObj;
+ return;
+ }
+
+ this->mAsyncCall = new typename Base::AsyncCallType(callFactory);
+ }
+
+ /**
+ * @brief Invoke a method on the object. Member function pointers are provided
+ * for both the sychronous and asynchronous variants of the interface.
+ * If this invoker's encapsulated COM object is a proxy, then Invoke
+ * will call the asynchronous member function. Otherwise the
+ * synchronous version must be used, as the invocation will simply be a
+ * virtual function call that executes in-process.
+ * @param aSyncMethod Pointer to the method that we would like to invoke on
+ * the synchronous interface.
+ * @param aAsyncMethod Pointer to the method that we would like to invoke on
+ * the asynchronous interface.
+ */
+ template <typename SyncMethod, typename AsyncMethod, typename... Args>
+ HRESULT Invoke(SyncMethod aSyncMethod, AsyncMethod aAsyncMethod,
+ Args&&... aArgs) {
+ this->OnBeginInvoke();
+ if (mSyncObj) {
+ HRESULT hr = (mSyncObj->*aSyncMethod)(std::forward<Args>(aArgs)...);
+ this->OnSyncInvoke(hr);
+ return hr;
+ }
+
+ MOZ_ASSERT(this->mAsyncCall);
+ if (!this->mAsyncCall) {
+ this->OnAsyncInvokeFailed();
+ return E_POINTER;
+ }
+
+ AsyncInterface* asyncInterface = this->mAsyncCall->GetInterface();
+ MOZ_ASSERT(asyncInterface);
+ if (!asyncInterface) {
+ this->OnAsyncInvokeFailed();
+ return E_POINTER;
+ }
+
+ HRESULT hr = (asyncInterface->*aAsyncMethod)(std::forward<Args>(aArgs)...);
+ if (FAILED(hr)) {
+ this->OnAsyncInvokeFailed();
+ }
+
+ return hr;
+ }
+
+ AsyncInvoker(const AsyncInvoker& aOther) = delete;
+ AsyncInvoker(AsyncInvoker&& aOther) = delete;
+ AsyncInvoker& operator=(const AsyncInvoker& aOther) = delete;
+ AsyncInvoker& operator=(AsyncInvoker&& aOther) = delete;
+
+ private:
+ RefPtr<SyncInterface> mSyncObj;
+};
+
+template <typename SyncInterface, typename AsyncInterface>
+using WaitableAsyncInvoker =
+ AsyncInvoker<SyncInterface, AsyncInterface, detail::WaitableInvoker>;
+
+template <typename SyncInterface, typename AsyncInterface>
+using EventDrivenAsyncInvoker =
+ AsyncInvoker<SyncInterface, AsyncInterface, detail::EventDrivenInvoker>;
+
+} // namespace mscom
+} // namespace mozilla
+
+#define ASYNC_INVOKER_FOR(SyncIface) \
+ mozilla::mscom::AsyncInvoker<SyncIface, Async##SyncIface>
+
+#define WAITABLE_ASYNC_INVOKER_FOR(SyncIface) \
+ mozilla::mscom::WaitableAsyncInvoker<SyncIface, Async##SyncIface>
+
+#define EVENT_DRIVEN_ASYNC_INVOKER_FOR(SyncIface) \
+ mozilla::mscom::EventDrivenAsyncInvoker<SyncIface, Async##SyncIface>
+
+#define ASYNC_INVOKE(InvokerObj, SyncMethodName, ...) \
+ InvokerObj.Invoke( \
+ &decltype(InvokerObj)::SyncInterfaceT::SyncMethodName, \
+ &decltype(InvokerObj)::AsyncInterfaceT::Begin_##SyncMethodName, \
+ ##__VA_ARGS__)
+
+#endif // mozilla_mscom_AsyncInvoker_h