summaryrefslogtreecommitdiffstats
path: root/ipc/mscom
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 09:22:09 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-07 09:22:09 +0000
commit43a97878ce14b72f0981164f87f2e35e14151312 (patch)
tree620249daf56c0258faa40cbdcf9cfba06de2a846 /ipc/mscom
parentInitial commit. (diff)
downloadfirefox-upstream.tar.xz
firefox-upstream.zip
Adding upstream version 110.0.1.upstream/110.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'ipc/mscom')
-rw-r--r--ipc/mscom/ActivationContext.cpp223
-rw-r--r--ipc/mscom/ActivationContext.h101
-rw-r--r--ipc/mscom/Aggregation.h83
-rw-r--r--ipc/mscom/AgileReference.cpp223
-rw-r--r--ipc/mscom/AgileReference.h143
-rw-r--r--ipc/mscom/ApartmentRegion.h93
-rw-r--r--ipc/mscom/AsyncInvoker.h465
-rw-r--r--ipc/mscom/COMPtrHolder.h201
-rw-r--r--ipc/mscom/COMWrappers.cpp101
-rw-r--r--ipc/mscom/COMWrappers.h44
-rw-r--r--ipc/mscom/DispatchForwarder.cpp156
-rw-r--r--ipc/mscom/DispatchForwarder.h79
-rw-r--r--ipc/mscom/EnsureMTA.cpp256
-rw-r--r--ipc/mscom/EnsureMTA.h191
-rw-r--r--ipc/mscom/FastMarshaler.cpp164
-rw-r--r--ipc/mscom/FastMarshaler.h66
-rw-r--r--ipc/mscom/IHandlerProvider.h51
-rw-r--r--ipc/mscom/Interceptor.cpp860
-rw-r--r--ipc/mscom/Interceptor.h199
-rw-r--r--ipc/mscom/InterceptorLog.cpp527
-rw-r--r--ipc/mscom/InterceptorLog.h36
-rw-r--r--ipc/mscom/MainThreadHandoff.cpp697
-rw-r--r--ipc/mscom/MainThreadHandoff.h105
-rw-r--r--ipc/mscom/MainThreadInvoker.cpp177
-rw-r--r--ipc/mscom/MainThreadInvoker.h56
-rw-r--r--ipc/mscom/Objref.cpp411
-rw-r--r--ipc/mscom/Objref.h53
-rw-r--r--ipc/mscom/PassthruProxy.cpp393
-rw-r--r--ipc/mscom/PassthruProxy.h127
-rw-r--r--ipc/mscom/ProcessRuntime.cpp480
-rw-r--r--ipc/mscom/ProcessRuntime.h96
-rw-r--r--ipc/mscom/ProfilerMarkers.cpp236
-rw-r--r--ipc/mscom/ProfilerMarkers.h18
-rw-r--r--ipc/mscom/ProxyStream.cpp411
-rw-r--r--ipc/mscom/ProxyStream.h88
-rw-r--r--ipc/mscom/Ptr.h306
-rw-r--r--ipc/mscom/Registration.cpp534
-rw-r--r--ipc/mscom/Registration.h142
-rw-r--r--ipc/mscom/RegistrationAnnotator.cpp385
-rw-r--r--ipc/mscom/RegistrationAnnotator.h19
-rw-r--r--ipc/mscom/SpinEvent.cpp77
-rw-r--r--ipc/mscom/SpinEvent.h40
-rw-r--r--ipc/mscom/StructStream.cpp23
-rw-r--r--ipc/mscom/StructStream.h239
-rw-r--r--ipc/mscom/Utils.cpp600
-rw-r--r--ipc/mscom/Utils.h168
-rw-r--r--ipc/mscom/VTableBuilder.c54
-rw-r--r--ipc/mscom/VTableBuilder.h37
-rw-r--r--ipc/mscom/WeakRef.cpp225
-rw-r--r--ipc/mscom/WeakRef.h142
-rw-r--r--ipc/mscom/moz.build97
-rw-r--r--ipc/mscom/mozglue/ActCtxResource.cpp239
-rw-r--r--ipc/mscom/mozglue/ActCtxResource.h40
-rw-r--r--ipc/mscom/mozglue/ProcessRuntimeShared.cpp31
-rw-r--r--ipc/mscom/mozglue/ProcessRuntimeShared.h55
-rw-r--r--ipc/mscom/mozglue/moz.build19
-rw-r--r--ipc/mscom/oop/Factory.h142
-rw-r--r--ipc/mscom/oop/Handler.cpp281
-rw-r--r--ipc/mscom/oop/Handler.h134
-rw-r--r--ipc/mscom/oop/Module.cpp292
-rw-r--r--ipc/mscom/oop/Module.h98
-rw-r--r--ipc/mscom/oop/moz.build43
62 files changed, 12072 insertions, 0 deletions
diff --git a/ipc/mscom/ActivationContext.cpp b/ipc/mscom/ActivationContext.cpp
new file mode 100644
index 0000000000..c5ee9d1419
--- /dev/null
+++ b/ipc/mscom/ActivationContext.cpp
@@ -0,0 +1,223 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/mscom/ActivationContext.h"
+
+#include "mozilla/Assertions.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/mscom/Utils.h"
+
+namespace mozilla {
+namespace mscom {
+
+ActivationContext::ActivationContext(ActCtxResource aResource)
+ : ActivationContext(aResource.mModule, aResource.mId) {}
+
+ActivationContext::ActivationContext(HMODULE aLoadFromModule, WORD aResourceId)
+ : mActCtx(INVALID_HANDLE_VALUE) {
+ ACTCTXW actCtx = {sizeof(actCtx)};
+ actCtx.dwFlags = ACTCTX_FLAG_RESOURCE_NAME_VALID | ACTCTX_FLAG_HMODULE_VALID;
+ actCtx.lpResourceName = MAKEINTRESOURCEW(aResourceId);
+ actCtx.hModule = aLoadFromModule;
+
+ Init(actCtx);
+}
+
+void ActivationContext::Init(ACTCTXW& aActCtx) {
+ MOZ_ASSERT(mActCtx == INVALID_HANDLE_VALUE);
+ mActCtx = ::CreateActCtxW(&aActCtx);
+ MOZ_ASSERT(mActCtx != INVALID_HANDLE_VALUE);
+}
+
+void ActivationContext::AddRef() {
+ if (mActCtx == INVALID_HANDLE_VALUE) {
+ return;
+ }
+ ::AddRefActCtx(mActCtx);
+}
+
+ActivationContext::ActivationContext(ActivationContext&& aOther)
+ : mActCtx(aOther.mActCtx) {
+ aOther.mActCtx = INVALID_HANDLE_VALUE;
+}
+
+ActivationContext& ActivationContext::operator=(ActivationContext&& aOther) {
+ Release();
+
+ mActCtx = aOther.mActCtx;
+ aOther.mActCtx = INVALID_HANDLE_VALUE;
+ return *this;
+}
+
+ActivationContext::ActivationContext(const ActivationContext& aOther)
+ : mActCtx(aOther.mActCtx) {
+ AddRef();
+}
+
+ActivationContext& ActivationContext::operator=(
+ const ActivationContext& aOther) {
+ Release();
+ mActCtx = aOther.mActCtx;
+ AddRef();
+ return *this;
+}
+
+void ActivationContext::Release() {
+ if (mActCtx == INVALID_HANDLE_VALUE) {
+ return;
+ }
+ ::ReleaseActCtx(mActCtx);
+ mActCtx = INVALID_HANDLE_VALUE;
+}
+
+ActivationContext::~ActivationContext() { Release(); }
+
+#if defined(MOZILLA_INTERNAL_API)
+
+/* static */ Result<uintptr_t, HRESULT> ActivationContext::GetCurrent() {
+ HANDLE actCtx;
+ if (!::GetCurrentActCtx(&actCtx)) {
+ return Result<uintptr_t, HRESULT>(HRESULT_FROM_WIN32(::GetLastError()));
+ }
+
+ return reinterpret_cast<uintptr_t>(actCtx);
+}
+
+/* static */
+HRESULT ActivationContext::GetCurrentManifestPath(nsAString& aOutManifestPath) {
+ aOutManifestPath.Truncate();
+
+ SIZE_T bytesNeeded;
+ BOOL ok = ::QueryActCtxW(QUERY_ACTCTX_FLAG_USE_ACTIVE_ACTCTX, nullptr,
+ nullptr, ActivationContextDetailedInformation,
+ nullptr, 0, &bytesNeeded);
+ if (!ok) {
+ DWORD err = ::GetLastError();
+ if (err != ERROR_INSUFFICIENT_BUFFER) {
+ return HRESULT_FROM_WIN32(err);
+ }
+ }
+
+ auto ctxBuf = MakeUnique<BYTE[]>(bytesNeeded);
+
+ ok = ::QueryActCtxW(QUERY_ACTCTX_FLAG_USE_ACTIVE_ACTCTX, nullptr, nullptr,
+ ActivationContextDetailedInformation, ctxBuf.get(),
+ bytesNeeded, nullptr);
+ if (!ok) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+
+ auto ctxInfo =
+ reinterpret_cast<ACTIVATION_CONTEXT_DETAILED_INFORMATION*>(ctxBuf.get());
+
+ // assemblyIndex is 1-based, and we want the last index, so we can just copy
+ // ctxInfo->ulAssemblyCount directly.
+ DWORD assemblyIndex = ctxInfo->ulAssemblyCount;
+ ok = ::QueryActCtxW(
+ QUERY_ACTCTX_FLAG_USE_ACTIVE_ACTCTX, nullptr, &assemblyIndex,
+ AssemblyDetailedInformationInActivationContext, nullptr, 0, &bytesNeeded);
+ if (!ok) {
+ DWORD err = ::GetLastError();
+ if (err != ERROR_INSUFFICIENT_BUFFER) {
+ return HRESULT_FROM_WIN32(err);
+ }
+ }
+
+ auto assemblyBuf = MakeUnique<BYTE[]>(bytesNeeded);
+
+ ok = ::QueryActCtxW(QUERY_ACTCTX_FLAG_USE_ACTIVE_ACTCTX, nullptr,
+ &assemblyIndex,
+ AssemblyDetailedInformationInActivationContext,
+ assemblyBuf.get(), bytesNeeded, &bytesNeeded);
+ if (!ok) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+
+ auto assemblyInfo =
+ reinterpret_cast<ACTIVATION_CONTEXT_ASSEMBLY_DETAILED_INFORMATION*>(
+ assemblyBuf.get());
+ aOutManifestPath = nsDependentString(
+ assemblyInfo->lpAssemblyManifestPath,
+ (assemblyInfo->ulManifestPathLength + 1) / sizeof(wchar_t));
+
+ return S_OK;
+}
+
+#endif // defined(MOZILLA_INTERNAL_API)
+
+ActivationContextRegion::ActivationContextRegion() : mActCookie(0) {}
+
+ActivationContextRegion::ActivationContextRegion(
+ const ActivationContext& aActCtx)
+ : mActCtx(aActCtx), mActCookie(0) {
+ Activate();
+}
+
+ActivationContextRegion& ActivationContextRegion::operator=(
+ const ActivationContext& aActCtx) {
+ Deactivate();
+ mActCtx = aActCtx;
+ Activate();
+ return *this;
+}
+
+ActivationContextRegion::ActivationContextRegion(ActivationContext&& aActCtx)
+ : mActCtx(std::move(aActCtx)), mActCookie(0) {
+ Activate();
+}
+
+ActivationContextRegion& ActivationContextRegion::operator=(
+ ActivationContext&& aActCtx) {
+ Deactivate();
+ mActCtx = std::move(aActCtx);
+ Activate();
+ return *this;
+}
+
+ActivationContextRegion::ActivationContextRegion(ActivationContextRegion&& aRgn)
+ : mActCtx(std::move(aRgn.mActCtx)), mActCookie(aRgn.mActCookie) {
+ aRgn.mActCookie = 0;
+}
+
+ActivationContextRegion& ActivationContextRegion::operator=(
+ ActivationContextRegion&& aRgn) {
+ Deactivate();
+ mActCtx = std::move(aRgn.mActCtx);
+ mActCookie = aRgn.mActCookie;
+ aRgn.mActCookie = 0;
+ return *this;
+}
+
+void ActivationContextRegion::Activate() {
+ if (mActCtx.mActCtx == INVALID_HANDLE_VALUE) {
+ return;
+ }
+
+#ifdef MOZ_DIAGNOSTIC_ASSERT_ENABLED
+ BOOL activated =
+#endif
+ ::ActivateActCtx(mActCtx.mActCtx, &mActCookie);
+ MOZ_DIAGNOSTIC_ASSERT(activated);
+}
+
+bool ActivationContextRegion::Deactivate() {
+ if (!mActCookie) {
+ return true;
+ }
+
+ BOOL deactivated = ::DeactivateActCtx(0, mActCookie);
+ MOZ_DIAGNOSTIC_ASSERT(deactivated);
+ if (deactivated) {
+ mActCookie = 0;
+ }
+
+ return !!deactivated;
+}
+
+ActivationContextRegion::~ActivationContextRegion() { Deactivate(); }
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/ActivationContext.h b/ipc/mscom/ActivationContext.h
new file mode 100644
index 0000000000..3d0144528d
--- /dev/null
+++ b/ipc/mscom/ActivationContext.h
@@ -0,0 +1,101 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_ActivationContext_h
+#define mozilla_mscom_ActivationContext_h
+
+#include <utility>
+
+#include "mozilla/Attributes.h"
+#include "mozilla/mscom/ActCtxResource.h"
+
+#if defined(MOZILLA_INTERNAL_API)
+# include "mozilla/ResultVariant.h"
+# include "nsString.h"
+#endif // defined(MOZILLA_INTERNAL_API)
+
+#include <windows.h>
+
+namespace mozilla {
+namespace mscom {
+
+class ActivationContext final {
+ public:
+ // This is the default resource ID that the Windows dynamic linker searches
+ // for when seeking a manifest while loading a DLL.
+ static constexpr WORD kDllManifestDefaultResourceId = 2;
+
+ ActivationContext() : mActCtx(INVALID_HANDLE_VALUE) {}
+
+ explicit ActivationContext(ActCtxResource aResource);
+ explicit ActivationContext(HMODULE aLoadFromModule,
+ WORD aResourceId = kDllManifestDefaultResourceId);
+
+ ActivationContext(ActivationContext&& aOther);
+ ActivationContext& operator=(ActivationContext&& aOther);
+
+ ActivationContext(const ActivationContext& aOther);
+ ActivationContext& operator=(const ActivationContext& aOther);
+
+ ~ActivationContext();
+
+ explicit operator bool() const { return mActCtx != INVALID_HANDLE_VALUE; }
+
+#if defined(MOZILLA_INTERNAL_API)
+ static Result<uintptr_t, HRESULT> GetCurrent();
+ static HRESULT GetCurrentManifestPath(nsAString& aOutManifestPath);
+#endif // defined(MOZILLA_INTERNAL_API)
+
+ private:
+ void Init(ACTCTXW& aActCtx);
+ void AddRef();
+ void Release();
+
+ private:
+ HANDLE mActCtx;
+
+ friend class ActivationContextRegion;
+};
+
+class MOZ_NON_TEMPORARY_CLASS ActivationContextRegion final {
+ public:
+ template <typename... Args>
+ explicit ActivationContextRegion(Args&&... aArgs)
+ : mActCtx(std::forward<Args>(aArgs)...), mActCookie(0) {
+ Activate();
+ }
+
+ ActivationContextRegion();
+
+ explicit ActivationContextRegion(const ActivationContext& aActCtx);
+ ActivationContextRegion& operator=(const ActivationContext& aActCtx);
+
+ explicit ActivationContextRegion(ActivationContext&& aActCtx);
+ ActivationContextRegion& operator=(ActivationContext&& aActCtx);
+
+ ActivationContextRegion(ActivationContextRegion&& aRgn);
+ ActivationContextRegion& operator=(ActivationContextRegion&& aRgn);
+
+ ~ActivationContextRegion();
+
+ explicit operator bool() const { return !!mActCookie; }
+
+ ActivationContextRegion(const ActivationContextRegion&) = delete;
+ ActivationContextRegion& operator=(const ActivationContextRegion&) = delete;
+
+ bool Deactivate();
+
+ private:
+ void Activate();
+
+ ActivationContext mActCtx;
+ ULONG_PTR mActCookie;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_ActivationContext_h
diff --git a/ipc/mscom/Aggregation.h b/ipc/mscom/Aggregation.h
new file mode 100644
index 0000000000..ca171e39cc
--- /dev/null
+++ b/ipc/mscom/Aggregation.h
@@ -0,0 +1,83 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_Aggregation_h
+#define mozilla_mscom_Aggregation_h
+
+#include "mozilla/Attributes.h"
+
+#include <stddef.h>
+#include <unknwn.h>
+
+namespace mozilla {
+namespace mscom {
+
+/**
+ * This is used for stabilizing a COM object's reference count during
+ * construction when that object aggregates other objects. Since the aggregated
+ * object(s) may AddRef() or Release(), we need to artifically boost the
+ * refcount to prevent premature destruction. Note that we increment/decrement
+ * instead of AddRef()/Release() in this class because we want to adjust the
+ * refcount without causing any other side effects (like object destruction).
+ */
+template <typename RefCntT>
+class MOZ_RAII StabilizedRefCount {
+ public:
+ explicit StabilizedRefCount(RefCntT& aRefCnt) : mRefCnt(aRefCnt) {
+ ++aRefCnt;
+ }
+
+ ~StabilizedRefCount() { --mRefCnt; }
+
+ StabilizedRefCount(const StabilizedRefCount&) = delete;
+ StabilizedRefCount(StabilizedRefCount&&) = delete;
+ StabilizedRefCount& operator=(const StabilizedRefCount&) = delete;
+ StabilizedRefCount& operator=(StabilizedRefCount&&) = delete;
+
+ private:
+ RefCntT& mRefCnt;
+};
+
+namespace detail {
+
+template <typename T>
+class InternalUnknown : public IUnknown {
+ public:
+ STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) override {
+ return This()->InternalQueryInterface(aIid, aOutInterface);
+ }
+
+ STDMETHODIMP_(ULONG) AddRef() override { return This()->InternalAddRef(); }
+
+ STDMETHODIMP_(ULONG) Release() override { return This()->InternalRelease(); }
+
+ private:
+ T* This() {
+ return reinterpret_cast<T*>(reinterpret_cast<char*>(this) -
+ offsetof(T, mInternalUnknown));
+ }
+};
+
+} // namespace detail
+} // namespace mscom
+} // namespace mozilla
+
+#define DECLARE_AGGREGATABLE(Type) \
+ public: \
+ STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override { \
+ return mOuter->QueryInterface(riid, ppv); \
+ } \
+ STDMETHODIMP_(ULONG) AddRef() override { return mOuter->AddRef(); } \
+ STDMETHODIMP_(ULONG) Release() override { return mOuter->Release(); } \
+ \
+ protected: \
+ STDMETHODIMP InternalQueryInterface(REFIID riid, void** ppv); \
+ STDMETHODIMP_(ULONG) InternalAddRef(); \
+ STDMETHODIMP_(ULONG) InternalRelease(); \
+ friend class mozilla::mscom::detail::InternalUnknown<Type>; \
+ mozilla::mscom::detail::InternalUnknown<Type> mInternalUnknown
+
+#endif // mozilla_mscom_Aggregation_h
diff --git a/ipc/mscom/AgileReference.cpp b/ipc/mscom/AgileReference.cpp
new file mode 100644
index 0000000000..7a1b94822d
--- /dev/null
+++ b/ipc/mscom/AgileReference.cpp
@@ -0,0 +1,223 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/mscom/AgileReference.h"
+
+#include <utility>
+
+#include "mozilla/Assertions.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/DynamicallyLinkedFunctionPtr.h"
+#include "mozilla/mscom/Utils.h"
+
+#if defined(MOZILLA_INTERNAL_API)
+# include "nsDebug.h"
+# include "nsPrintfCString.h"
+#endif // defined(MOZILLA_INTERNAL_API)
+
+#if NTDDI_VERSION < NTDDI_WINBLUE
+
+// Declarations from Windows SDK specific to Windows 8.1
+
+enum AgileReferenceOptions {
+ AGILEREFERENCE_DEFAULT = 0,
+ AGILEREFERENCE_DELAYEDMARSHAL = 1,
+};
+
+HRESULT WINAPI RoGetAgileReference(AgileReferenceOptions options, REFIID riid,
+ IUnknown* pUnk,
+ IAgileReference** ppAgileReference);
+
+#endif // NTDDI_VERSION < NTDDI_WINBLUE
+
+namespace mozilla {
+namespace mscom {
+namespace detail {
+
+GlobalInterfaceTableCookie::GlobalInterfaceTableCookie(IUnknown* aObject,
+ REFIID aIid,
+ HRESULT& aOutHResult)
+ : mCookie(0) {
+ IGlobalInterfaceTable* git = ObtainGit();
+ MOZ_ASSERT(git);
+ if (!git) {
+ aOutHResult = E_POINTER;
+ return;
+ }
+
+ aOutHResult = git->RegisterInterfaceInGlobal(aObject, aIid, &mCookie);
+ MOZ_ASSERT(SUCCEEDED(aOutHResult));
+}
+
+GlobalInterfaceTableCookie::~GlobalInterfaceTableCookie() {
+ IGlobalInterfaceTable* git = ObtainGit();
+ MOZ_ASSERT(git);
+ if (!git) {
+ return;
+ }
+
+ DebugOnly<HRESULT> hr = git->RevokeInterfaceFromGlobal(mCookie);
+#if defined(MOZILLA_INTERNAL_API)
+ NS_WARNING_ASSERTION(
+ SUCCEEDED(hr),
+ nsPrintfCString("IGlobalInterfaceTable::RevokeInterfaceFromGlobal failed "
+ "with HRESULT 0x%08lX",
+ ((HRESULT)hr))
+ .get());
+#else
+ MOZ_ASSERT(SUCCEEDED(hr));
+#endif // defined(MOZILLA_INTERNAL_API)
+ mCookie = 0;
+}
+
+HRESULT GlobalInterfaceTableCookie::GetInterface(REFIID aIid,
+ void** aOutInterface) const {
+ IGlobalInterfaceTable* git = ObtainGit();
+ MOZ_ASSERT(git);
+ if (!git) {
+ return E_UNEXPECTED;
+ }
+
+ MOZ_ASSERT(IsValid());
+ return git->GetInterfaceFromGlobal(mCookie, aIid, aOutInterface);
+}
+
+/* static */
+IGlobalInterfaceTable* GlobalInterfaceTableCookie::ObtainGit() {
+ // Internally to COM, the Global Interface Table is a singleton, therefore we
+ // don't worry about holding onto this reference indefinitely.
+ static IGlobalInterfaceTable* sGit = []() -> IGlobalInterfaceTable* {
+ IGlobalInterfaceTable* result = nullptr;
+ DebugOnly<HRESULT> hr = ::CoCreateInstance(
+ CLSID_StdGlobalInterfaceTable, nullptr, CLSCTX_INPROC_SERVER,
+ IID_IGlobalInterfaceTable, reinterpret_cast<void**>(&result));
+ MOZ_ASSERT(SUCCEEDED(hr));
+ return result;
+ }();
+
+ return sGit;
+}
+
+} // namespace detail
+
+AgileReference::AgileReference() : mIid(), mHResult(E_NOINTERFACE) {}
+
+AgileReference::AgileReference(REFIID aIid, IUnknown* aObject)
+ : mIid(aIid), mHResult(E_UNEXPECTED) {
+ AssignInternal(aObject);
+}
+
+AgileReference::AgileReference(AgileReference&& aOther)
+ : mIid(aOther.mIid),
+ mAgileRef(std::move(aOther.mAgileRef)),
+ mGitCookie(std::move(aOther.mGitCookie)),
+ mHResult(aOther.mHResult) {
+ aOther.mHResult = CO_E_RELEASED;
+}
+
+void AgileReference::Assign(REFIID aIid, IUnknown* aObject) {
+ Clear();
+ mIid = aIid;
+ AssignInternal(aObject);
+}
+
+void AgileReference::AssignInternal(IUnknown* aObject) {
+ // We expect mIid to already be set
+ DebugOnly<IID> zeroIid = {};
+ MOZ_ASSERT(mIid != zeroIid);
+
+ /*
+ * There are two possible techniques for creating agile references. Starting
+ * with Windows 8.1, we may use the RoGetAgileReference API, which is faster.
+ * If that API is not available, we fall back to using the Global Interface
+ * Table.
+ */
+ static const StaticDynamicallyLinkedFunctionPtr<
+ decltype(&::RoGetAgileReference)>
+ pRoGetAgileReference(L"ole32.dll", "RoGetAgileReference");
+
+ MOZ_ASSERT(aObject);
+
+ if (pRoGetAgileReference &&
+ SUCCEEDED(mHResult =
+ pRoGetAgileReference(AGILEREFERENCE_DEFAULT, mIid, aObject,
+ getter_AddRefs(mAgileRef)))) {
+ return;
+ }
+
+ mGitCookie = new detail::GlobalInterfaceTableCookie(aObject, mIid, mHResult);
+ MOZ_ASSERT(mGitCookie->IsValid());
+}
+
+AgileReference::~AgileReference() { Clear(); }
+
+void AgileReference::Clear() {
+ mIid = {};
+ mAgileRef = nullptr;
+ mGitCookie = nullptr;
+ mHResult = E_NOINTERFACE;
+}
+
+AgileReference& AgileReference::operator=(const AgileReference& aOther) {
+ Clear();
+ mIid = aOther.mIid;
+ mAgileRef = aOther.mAgileRef;
+ mGitCookie = aOther.mGitCookie;
+ mHResult = aOther.mHResult;
+ return *this;
+}
+
+AgileReference& AgileReference::operator=(AgileReference&& aOther) {
+ Clear();
+ mIid = aOther.mIid;
+ mAgileRef = std::move(aOther.mAgileRef);
+ mGitCookie = std::move(aOther.mGitCookie);
+ mHResult = aOther.mHResult;
+ aOther.mHResult = CO_E_RELEASED;
+ return *this;
+}
+
+HRESULT
+AgileReference::Resolve(REFIID aIid, void** aOutInterface) const {
+ MOZ_ASSERT(aOutInterface);
+ // This check is exclusive-OR; we should have one or the other, but not both
+ MOZ_ASSERT((mAgileRef || mGitCookie) && !(mAgileRef && mGitCookie));
+ MOZ_ASSERT(IsCOMInitializedOnCurrentThread());
+
+ if (!aOutInterface) {
+ return E_INVALIDARG;
+ }
+
+ *aOutInterface = nullptr;
+
+ if (mAgileRef) {
+ // IAgileReference lets you directly resolve the interface you want...
+ return mAgileRef->Resolve(aIid, aOutInterface);
+ }
+
+ if (!mGitCookie) {
+ return E_UNEXPECTED;
+ }
+
+ RefPtr<IUnknown> originalInterface;
+ HRESULT hr =
+ mGitCookie->GetInterface(mIid, getter_AddRefs(originalInterface));
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ if (aIid == mIid) {
+ originalInterface.forget(aOutInterface);
+ return S_OK;
+ }
+
+ // ...Whereas the GIT requires us to obtain the same interface that we
+ // requested and then QI for the desired interface afterward.
+ return originalInterface->QueryInterface(aIid, aOutInterface);
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/AgileReference.h b/ipc/mscom/AgileReference.h
new file mode 100644
index 0000000000..d39e444494
--- /dev/null
+++ b/ipc/mscom/AgileReference.h
@@ -0,0 +1,143 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_AgileReference_h
+#define mozilla_mscom_AgileReference_h
+
+#include "mozilla/Attributes.h"
+#include "mozilla/RefPtr.h"
+#include "nsISupportsImpl.h"
+
+#include <objidl.h>
+
+namespace mozilla {
+namespace mscom {
+namespace detail {
+
+class MOZ_HEAP_CLASS GlobalInterfaceTableCookie final {
+ public:
+ GlobalInterfaceTableCookie(IUnknown* aObject, REFIID aIid,
+ HRESULT& aOutHResult);
+
+ bool IsValid() const { return !!mCookie; }
+ HRESULT GetInterface(REFIID aIid, void** aOutInterface) const;
+
+ NS_INLINE_DECL_THREADSAFE_REFCOUNTING(GlobalInterfaceTableCookie)
+
+ GlobalInterfaceTableCookie(const GlobalInterfaceTableCookie&) = delete;
+ GlobalInterfaceTableCookie(GlobalInterfaceTableCookie&&) = delete;
+
+ GlobalInterfaceTableCookie& operator=(const GlobalInterfaceTableCookie&) =
+ delete;
+ GlobalInterfaceTableCookie& operator=(GlobalInterfaceTableCookie&&) = delete;
+
+ private:
+ ~GlobalInterfaceTableCookie();
+
+ private:
+ DWORD mCookie;
+
+ private:
+ static IGlobalInterfaceTable* ObtainGit();
+};
+
+} // namespace detail
+
+/**
+ * This class encapsulates an "agile reference." These are references that
+ * allow you to pass COM interfaces between apartments. When you have an
+ * interface that you would like to pass between apartments, you wrap that
+ * interface in an AgileReference and pass the agile reference instead. Then
+ * you unwrap the interface by calling AgileReference::Resolve.
+ *
+ * Sample usage:
+ *
+ * // In the multithreaded apartment, foo is an IFoo*
+ * auto myAgileRef = MakeUnique<AgileReference>(IID_IFoo, foo);
+ *
+ * // myAgileRef is passed to our main thread, which runs in a single-threaded
+ * // apartment:
+ *
+ * RefPtr<IFoo> foo;
+ * HRESULT hr = myAgileRef->Resolve(IID_IFoo, getter_AddRefs(foo));
+ * // Now foo may be called from the main thread
+ */
+class AgileReference final {
+ public:
+ AgileReference();
+
+ template <typename InterfaceT>
+ explicit AgileReference(RefPtr<InterfaceT>& aObject)
+ : AgileReference(__uuidof(InterfaceT), aObject) {}
+
+ AgileReference(REFIID aIid, IUnknown* aObject);
+
+ AgileReference(const AgileReference& aOther) = default;
+ AgileReference(AgileReference&& aOther);
+
+ ~AgileReference();
+
+ explicit operator bool() const {
+ return mAgileRef || (mGitCookie && mGitCookie->IsValid());
+ }
+
+ HRESULT GetHResult() const { return mHResult; }
+
+ template <typename T>
+ void Assign(const RefPtr<T>& aOther) {
+ Assign(__uuidof(T), aOther);
+ }
+
+ template <typename T>
+ AgileReference& operator=(const RefPtr<T>& aOther) {
+ Assign(aOther);
+ return *this;
+ }
+
+ HRESULT Resolve(REFIID aIid, void** aOutInterface) const;
+
+ AgileReference& operator=(const AgileReference& aOther);
+ AgileReference& operator=(AgileReference&& aOther);
+
+ AgileReference& operator=(decltype(nullptr)) {
+ Clear();
+ return *this;
+ }
+
+ void Clear();
+
+ private:
+ void Assign(REFIID aIid, IUnknown* aObject);
+ void AssignInternal(IUnknown* aObject);
+
+ private:
+ IID mIid;
+ RefPtr<IAgileReference> mAgileRef;
+ RefPtr<detail::GlobalInterfaceTableCookie> mGitCookie;
+ HRESULT mHResult;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+template <typename T>
+RefPtr<T>::RefPtr(const mozilla::mscom::AgileReference& aAgileRef)
+ : mRawPtr(nullptr) {
+ (*this) = aAgileRef;
+}
+
+template <typename T>
+RefPtr<T>& RefPtr<T>::operator=(
+ const mozilla::mscom::AgileReference& aAgileRef) {
+ void* newRawPtr;
+ if (FAILED(aAgileRef.Resolve(__uuidof(T), &newRawPtr))) {
+ newRawPtr = nullptr;
+ }
+ assign_assuming_AddRef(static_cast<T*>(newRawPtr));
+ return *this;
+}
+
+#endif // mozilla_mscom_AgileReference_h
diff --git a/ipc/mscom/ApartmentRegion.h b/ipc/mscom/ApartmentRegion.h
new file mode 100644
index 0000000000..41915710df
--- /dev/null
+++ b/ipc/mscom/ApartmentRegion.h
@@ -0,0 +1,93 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_ApartmentRegion_h
+#define mozilla_mscom_ApartmentRegion_h
+
+#include "mozilla/Assertions.h"
+#include "mozilla/Attributes.h"
+#include "mozilla/mscom/COMWrappers.h"
+
+namespace mozilla {
+namespace mscom {
+
+class MOZ_NON_TEMPORARY_CLASS ApartmentRegion final {
+ public:
+ /**
+ * This constructor is to be used when we want to instantiate the object but
+ * we do not yet know which type of apartment we want. Call Init() to
+ * complete initialization.
+ */
+ constexpr ApartmentRegion() : mInitResult(CO_E_NOTINITIALIZED) {}
+
+ explicit ApartmentRegion(COINIT aAptType)
+ : mInitResult(wrapped::CoInitializeEx(nullptr, aAptType)) {
+ // If this fires then we're probably mixing apartments on the same thread
+ MOZ_ASSERT(IsValid());
+ }
+
+ ~ApartmentRegion() {
+ if (IsValid()) {
+ wrapped::CoUninitialize();
+ }
+ }
+
+ explicit operator bool() const { return IsValid(); }
+
+ bool IsValidOutermost() const { return mInitResult == S_OK; }
+
+ bool IsValid() const { return SUCCEEDED(mInitResult); }
+
+ bool Init(COINIT aAptType) {
+ MOZ_ASSERT(mInitResult == CO_E_NOTINITIALIZED);
+ mInitResult = wrapped::CoInitializeEx(nullptr, aAptType);
+ MOZ_ASSERT(IsValid());
+ return IsValid();
+ }
+
+ HRESULT
+ GetHResult() const { return mInitResult; }
+
+ private:
+ ApartmentRegion(const ApartmentRegion&) = delete;
+ ApartmentRegion& operator=(const ApartmentRegion&) = delete;
+ ApartmentRegion(ApartmentRegion&&) = delete;
+ ApartmentRegion& operator=(ApartmentRegion&&) = delete;
+
+ HRESULT mInitResult;
+};
+
+template <COINIT T>
+class MOZ_NON_TEMPORARY_CLASS ApartmentRegionT final {
+ public:
+ ApartmentRegionT() : mAptRgn(T) {}
+
+ ~ApartmentRegionT() = default;
+
+ explicit operator bool() const { return mAptRgn.IsValid(); }
+
+ bool IsValidOutermost() const { return mAptRgn.IsValidOutermost(); }
+
+ bool IsValid() const { return mAptRgn.IsValid(); }
+
+ HRESULT GetHResult() const { return mAptRgn.GetHResult(); }
+
+ private:
+ ApartmentRegionT(const ApartmentRegionT&) = delete;
+ ApartmentRegionT& operator=(const ApartmentRegionT&) = delete;
+ ApartmentRegionT(ApartmentRegionT&&) = delete;
+ ApartmentRegionT& operator=(ApartmentRegionT&&) = delete;
+
+ ApartmentRegion mAptRgn;
+};
+
+typedef ApartmentRegionT<COINIT_APARTMENTTHREADED> STARegion;
+typedef ApartmentRegionT<COINIT_MULTITHREADED> MTARegion;
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_ApartmentRegion_h
diff --git a/ipc/mscom/AsyncInvoker.h b/ipc/mscom/AsyncInvoker.h
new file mode 100644
index 0000000000..ce25a4224b
--- /dev/null
+++ b/ipc/mscom/AsyncInvoker.h
@@ -0,0 +1,465 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_AsyncInvoker_h
+#define mozilla_mscom_AsyncInvoker_h
+
+#include <objidl.h>
+#include <windows.h>
+
+#include <utility>
+
+#include "mozilla/Assertions.h"
+#include "mozilla/Attributes.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/Maybe.h"
+#include "mozilla/Mutex.h"
+#include "mozilla/mscom/Aggregation.h"
+#include "mozilla/mscom/Utils.h"
+#include "nsISerialEventTarget.h"
+#include "nsISupportsImpl.h"
+#include "nsThreadUtils.h"
+
+namespace mozilla {
+namespace mscom {
+namespace detail {
+
+template <typename AsyncInterface>
+class ForgettableAsyncCall : public ISynchronize {
+ public:
+ explicit ForgettableAsyncCall(ICallFactory* aCallFactory)
+ : mRefCnt(0), mAsyncCall(nullptr) {
+ StabilizedRefCount<Atomic<ULONG>> stabilizer(mRefCnt);
+
+ HRESULT hr =
+ aCallFactory->CreateCall(__uuidof(AsyncInterface), this, IID_IUnknown,
+ getter_AddRefs(mInnerUnk));
+ if (FAILED(hr)) {
+ return;
+ }
+
+ hr = mInnerUnk->QueryInterface(__uuidof(AsyncInterface),
+ reinterpret_cast<void**>(&mAsyncCall));
+ if (SUCCEEDED(hr)) {
+ // Don't hang onto a ref. Because mAsyncCall is aggregated, its refcount
+ // is this->mRefCnt, so we'd create a cycle!
+ mAsyncCall->Release();
+ }
+ }
+
+ AsyncInterface* GetInterface() const { return mAsyncCall; }
+
+ // IUnknown
+ STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) final {
+ if (aIid == IID_ISynchronize || aIid == IID_IUnknown) {
+ RefPtr<ISynchronize> ptr(this);
+ ptr.forget(aOutInterface);
+ return S_OK;
+ }
+
+ return mInnerUnk->QueryInterface(aIid, aOutInterface);
+ }
+
+ STDMETHODIMP_(ULONG) AddRef() final {
+ ULONG result = ++mRefCnt;
+ NS_LOG_ADDREF(this, result, "ForgettableAsyncCall", sizeof(*this));
+ return result;
+ }
+
+ STDMETHODIMP_(ULONG) Release() final {
+ ULONG result = --mRefCnt;
+ NS_LOG_RELEASE(this, result, "ForgettableAsyncCall");
+ if (!result) {
+ delete this;
+ }
+ return result;
+ }
+
+ // ISynchronize
+ STDMETHODIMP Wait(DWORD aFlags, DWORD aTimeoutMilliseconds) override {
+ return E_NOTIMPL;
+ }
+
+ STDMETHODIMP Signal() override {
+ // Even though this function is a no-op, we must return S_OK as opposed to
+ // E_NOTIMPL or else COM will consider the async call to have failed.
+ return S_OK;
+ }
+
+ STDMETHODIMP Reset() override {
+ // Even though this function is a no-op, we must return S_OK as opposed to
+ // E_NOTIMPL or else COM will consider the async call to have failed.
+ return S_OK;
+ }
+
+ protected:
+ virtual ~ForgettableAsyncCall() = default;
+
+ private:
+ Atomic<ULONG> mRefCnt;
+ RefPtr<IUnknown> mInnerUnk;
+ AsyncInterface* mAsyncCall; // weak reference
+};
+
+template <typename AsyncInterface>
+class WaitableAsyncCall : public ForgettableAsyncCall<AsyncInterface> {
+ public:
+ explicit WaitableAsyncCall(ICallFactory* aCallFactory)
+ : ForgettableAsyncCall<AsyncInterface>(aCallFactory),
+ mEvent(::CreateEventW(nullptr, FALSE, FALSE, nullptr)) {}
+
+ STDMETHODIMP Wait(DWORD aFlags, DWORD aTimeoutMilliseconds) override {
+ const DWORD waitStart =
+ aTimeoutMilliseconds == INFINITE ? 0 : ::GetTickCount();
+ DWORD flags = aFlags;
+ if (XRE_IsContentProcess() && NS_IsMainThread()) {
+ flags |= COWAIT_ALERTABLE;
+ }
+
+ HRESULT hr;
+ DWORD signaledIdx;
+
+ DWORD elapsed = 0;
+
+ while (true) {
+ if (aTimeoutMilliseconds != INFINITE) {
+ elapsed = ::GetTickCount() - waitStart;
+ }
+ if (elapsed >= aTimeoutMilliseconds) {
+ return RPC_S_CALLPENDING;
+ }
+
+ ::SetLastError(ERROR_SUCCESS);
+
+ hr = ::CoWaitForMultipleHandles(flags, aTimeoutMilliseconds - elapsed, 1,
+ &mEvent, &signaledIdx);
+ if (hr == RPC_S_CALLPENDING || FAILED(hr)) {
+ return hr;
+ }
+
+ if (hr == S_OK && signaledIdx == 0) {
+ return hr;
+ }
+ }
+ }
+
+ STDMETHODIMP Signal() override {
+ if (!::SetEvent(mEvent)) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+ return S_OK;
+ }
+
+ protected:
+ ~WaitableAsyncCall() {
+ if (mEvent) {
+ ::CloseHandle(mEvent);
+ }
+ }
+
+ private:
+ HANDLE mEvent;
+};
+
+template <typename AsyncInterface>
+class EventDrivenAsyncCall : public ForgettableAsyncCall<AsyncInterface> {
+ public:
+ explicit EventDrivenAsyncCall(ICallFactory* aCallFactory)
+ : ForgettableAsyncCall<AsyncInterface>(aCallFactory) {}
+
+ bool HasCompletionRunnable() const { return !!mCompletionRunnable; }
+
+ void ClearCompletionRunnable() { mCompletionRunnable = nullptr; }
+
+ void SetCompletionRunnable(already_AddRefed<nsIRunnable> aRunnable) {
+ nsCOMPtr<nsIRunnable> innerRunnable(aRunnable);
+ MOZ_ASSERT(!!innerRunnable);
+ if (!innerRunnable) {
+ return;
+ }
+
+ // We need to retain a ref to ourselves to outlive the AsyncInvoker
+ // such that our callback can execute.
+ RefPtr<EventDrivenAsyncCall<AsyncInterface>> self(this);
+
+ mCompletionRunnable = NS_NewRunnableFunction(
+ "EventDrivenAsyncCall outer completion Runnable",
+ [innerRunnable = std::move(innerRunnable), self = std::move(self)]() {
+ innerRunnable->Run();
+ });
+ }
+
+ void SetEventTarget(nsISerialEventTarget* aTarget) { mEventTarget = aTarget; }
+
+ STDMETHODIMP Signal() override {
+ MOZ_ASSERT(!!mCompletionRunnable);
+ if (!mCompletionRunnable) {
+ return S_OK;
+ }
+
+ nsCOMPtr<nsISerialEventTarget> eventTarget(mEventTarget.forget());
+ if (!eventTarget) {
+ eventTarget = GetMainThreadSerialEventTarget();
+ }
+
+ DebugOnly<nsresult> rv =
+ eventTarget->Dispatch(mCompletionRunnable.forget(), NS_DISPATCH_NORMAL);
+ MOZ_ASSERT(NS_SUCCEEDED(rv));
+ return S_OK;
+ }
+
+ private:
+ nsCOMPtr<nsIRunnable> mCompletionRunnable;
+ nsCOMPtr<nsISerialEventTarget> mEventTarget;
+};
+
+template <typename AsyncInterface>
+class FireAndForgetInvoker {
+ protected:
+ void OnBeginInvoke() {}
+ void OnSyncInvoke(HRESULT aHr) {}
+ void OnAsyncInvokeFailed() {}
+
+ typedef ForgettableAsyncCall<AsyncInterface> AsyncCallType;
+
+ RefPtr<ForgettableAsyncCall<AsyncInterface>> mAsyncCall;
+};
+
+template <typename AsyncInterface>
+class WaitableInvoker {
+ public:
+ HRESULT Wait(DWORD aTimeout = INFINITE) const {
+ if (!mAsyncCall) {
+ // Nothing to wait for
+ return S_OK;
+ }
+
+ return mAsyncCall->Wait(0, aTimeout);
+ }
+
+ protected:
+ void OnBeginInvoke() {}
+ void OnSyncInvoke(HRESULT aHr) {}
+ void OnAsyncInvokeFailed() {}
+
+ typedef WaitableAsyncCall<AsyncInterface> AsyncCallType;
+
+ RefPtr<WaitableAsyncCall<AsyncInterface>> mAsyncCall;
+};
+
+template <typename AsyncInterface>
+class EventDrivenInvoker {
+ public:
+ void SetCompletionRunnable(already_AddRefed<nsIRunnable> aRunnable) {
+ if (mAsyncCall) {
+ mAsyncCall->SetCompletionRunnable(std::move(aRunnable));
+ return;
+ }
+
+ mCompletionRunnable = aRunnable;
+ }
+
+ void SetAsyncEventTarget(nsISerialEventTarget* aTarget) {
+ if (mAsyncCall) {
+ mAsyncCall->SetEventTarget(aTarget);
+ }
+ }
+
+ protected:
+ void OnBeginInvoke() {
+ MOZ_RELEASE_ASSERT(
+ mCompletionRunnable ||
+ (mAsyncCall && mAsyncCall->HasCompletionRunnable()),
+ "You should have called SetCompletionRunnable before invoking!");
+ }
+
+ void OnSyncInvoke(HRESULT aHr) {
+ nsCOMPtr<nsIRunnable> completionRunnable(mCompletionRunnable.forget());
+ if (FAILED(aHr)) {
+ return;
+ }
+
+ completionRunnable->Run();
+ }
+
+ void OnAsyncInvokeFailed() {
+ MOZ_ASSERT(!!mAsyncCall);
+ mAsyncCall->ClearCompletionRunnable();
+ }
+
+ typedef EventDrivenAsyncCall<AsyncInterface> AsyncCallType;
+
+ RefPtr<EventDrivenAsyncCall<AsyncInterface>> mAsyncCall;
+ nsCOMPtr<nsIRunnable> mCompletionRunnable;
+};
+
+} // namespace detail
+
+/**
+ * This class is intended for "fire-and-forget" asynchronous invocations of COM
+ * interfaces. This requires that an interface be annotated with the
+ * |async_uuid| attribute in midl. We also require that there be no outparams
+ * in the desired asynchronous interface (otherwise that would break the
+ * desired "fire-and-forget" semantics).
+ *
+ * For example, let us suppose we have some IDL as such:
+ * [object, uuid(...), async_uuid(...)]
+ * interface IFoo : IUnknown
+ * {
+ * HRESULT Bar([in] long baz);
+ * }
+ *
+ * Then, given an IFoo, we may construct an AsyncInvoker<IFoo, AsyncIFoo>:
+ *
+ * IFoo* foo = ...;
+ * AsyncInvoker<IFoo, AsyncIFoo> myInvoker(foo);
+ * HRESULT hr = myInvoker.Invoke(&IFoo::Bar, &AsyncIFoo::Begin_Bar, 7);
+ *
+ * Alternatively you may use the ASYNC_INVOKER_FOR and ASYNC_INVOKE macros,
+ * which automatically deduce the name of the asynchronous interface from the
+ * name of the synchronous interface:
+ *
+ * ASYNC_INVOKER_FOR(IFoo) myInvoker(foo);
+ * HRESULT hr = ASYNC_INVOKE(myInvoker, Bar, 7);
+ *
+ * This class may also be used when a synchronous COM call must be made that
+ * might reenter the content process. In this case, use the WaitableAsyncInvoker
+ * variant, or the WAITABLE_ASYNC_INVOKER_FOR macro:
+ *
+ * WAITABLE_ASYNC_INVOKER_FOR(Ifoo) myInvoker(foo);
+ * HRESULT hr = ASYNC_INVOKE(myInvoker, Bar, 7);
+ * if (SUCCEEDED(hr)) {
+ * myInvoker.Wait(); // <-- Wait for the COM call to complete.
+ * }
+ *
+ * In general you should avoid using the waitable version, but in some corner
+ * cases it is absolutely necessary in order to preserve correctness while
+ * avoiding deadlock.
+ *
+ * Finally, it is also possible to have the async invoker enqueue a runnable
+ * to the main thread when the async operation completes:
+ *
+ * EVENT_DRIVEN_ASYNC_INVOKER_FOR(Ifoo) myInvoker(foo);
+ * // myRunnable will be invoked on the main thread once the async operation
+ * // has completed. Note that we set this *before* we do the ASYNC_INVOKE!
+ * myInvoker.SetCompletionRunnable(myRunnable.forget());
+ * HRESULT hr = ASYNC_INVOKE(myInvoker, Bar, 7);
+ * // ...
+ */
+template <typename SyncInterface, typename AsyncInterface,
+ template <typename Iface> class WaitPolicy =
+ detail::FireAndForgetInvoker>
+class MOZ_RAII AsyncInvoker final : public WaitPolicy<AsyncInterface> {
+ using Base = WaitPolicy<AsyncInterface>;
+
+ public:
+ typedef SyncInterface SyncInterfaceT;
+ typedef AsyncInterface AsyncInterfaceT;
+
+ /**
+ * @param aSyncObj The COM object on which to invoke the asynchronous event.
+ * If this object is not a proxy to the synchronous variant
+ * of AsyncInterface, then it will be invoked synchronously
+ * instead (because it is an in-process virtual method call).
+ * @param aIsProxy An optional hint as to whether or not aSyncObj is a proxy.
+ * If not specified, AsyncInvoker will automatically detect
+ * whether aSyncObj is a proxy, however there may be a
+ * performance penalty associated with that.
+ */
+ explicit AsyncInvoker(SyncInterface* aSyncObj,
+ const Maybe<bool>& aIsProxy = Nothing()) {
+ MOZ_ASSERT(aSyncObj);
+
+ RefPtr<ICallFactory> callFactory;
+ if ((aIsProxy.isSome() && !aIsProxy.value()) ||
+ FAILED(aSyncObj->QueryInterface(IID_ICallFactory,
+ getter_AddRefs(callFactory)))) {
+ mSyncObj = aSyncObj;
+ return;
+ }
+
+ this->mAsyncCall = new typename Base::AsyncCallType(callFactory);
+ }
+
+ /**
+ * @brief Invoke a method on the object. Member function pointers are provided
+ * for both the sychronous and asynchronous variants of the interface.
+ * If this invoker's encapsulated COM object is a proxy, then Invoke
+ * will call the asynchronous member function. Otherwise the
+ * synchronous version must be used, as the invocation will simply be a
+ * virtual function call that executes in-process.
+ * @param aSyncMethod Pointer to the method that we would like to invoke on
+ * the synchronous interface.
+ * @param aAsyncMethod Pointer to the method that we would like to invoke on
+ * the asynchronous interface.
+ */
+ template <typename SyncMethod, typename AsyncMethod, typename... Args>
+ HRESULT Invoke(SyncMethod aSyncMethod, AsyncMethod aAsyncMethod,
+ Args&&... aArgs) {
+ this->OnBeginInvoke();
+ if (mSyncObj) {
+ HRESULT hr = (mSyncObj->*aSyncMethod)(std::forward<Args>(aArgs)...);
+ this->OnSyncInvoke(hr);
+ return hr;
+ }
+
+ MOZ_ASSERT(this->mAsyncCall);
+ if (!this->mAsyncCall) {
+ this->OnAsyncInvokeFailed();
+ return E_POINTER;
+ }
+
+ AsyncInterface* asyncInterface = this->mAsyncCall->GetInterface();
+ MOZ_ASSERT(asyncInterface);
+ if (!asyncInterface) {
+ this->OnAsyncInvokeFailed();
+ return E_POINTER;
+ }
+
+ HRESULT hr = (asyncInterface->*aAsyncMethod)(std::forward<Args>(aArgs)...);
+ if (FAILED(hr)) {
+ this->OnAsyncInvokeFailed();
+ }
+
+ return hr;
+ }
+
+ AsyncInvoker(const AsyncInvoker& aOther) = delete;
+ AsyncInvoker(AsyncInvoker&& aOther) = delete;
+ AsyncInvoker& operator=(const AsyncInvoker& aOther) = delete;
+ AsyncInvoker& operator=(AsyncInvoker&& aOther) = delete;
+
+ private:
+ RefPtr<SyncInterface> mSyncObj;
+};
+
+template <typename SyncInterface, typename AsyncInterface>
+using WaitableAsyncInvoker =
+ AsyncInvoker<SyncInterface, AsyncInterface, detail::WaitableInvoker>;
+
+template <typename SyncInterface, typename AsyncInterface>
+using EventDrivenAsyncInvoker =
+ AsyncInvoker<SyncInterface, AsyncInterface, detail::EventDrivenInvoker>;
+
+} // namespace mscom
+} // namespace mozilla
+
+#define ASYNC_INVOKER_FOR(SyncIface) \
+ mozilla::mscom::AsyncInvoker<SyncIface, Async##SyncIface>
+
+#define WAITABLE_ASYNC_INVOKER_FOR(SyncIface) \
+ mozilla::mscom::WaitableAsyncInvoker<SyncIface, Async##SyncIface>
+
+#define EVENT_DRIVEN_ASYNC_INVOKER_FOR(SyncIface) \
+ mozilla::mscom::EventDrivenAsyncInvoker<SyncIface, Async##SyncIface>
+
+#define ASYNC_INVOKE(InvokerObj, SyncMethodName, ...) \
+ InvokerObj.Invoke( \
+ &decltype(InvokerObj)::SyncInterfaceT::SyncMethodName, \
+ &decltype(InvokerObj)::AsyncInterfaceT::Begin_##SyncMethodName, \
+ ##__VA_ARGS__)
+
+#endif // mozilla_mscom_AsyncInvoker_h
diff --git a/ipc/mscom/COMPtrHolder.h b/ipc/mscom/COMPtrHolder.h
new file mode 100644
index 0000000000..e99698b8e7
--- /dev/null
+++ b/ipc/mscom/COMPtrHolder.h
@@ -0,0 +1,201 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_COMPtrHolder_h
+#define mozilla_mscom_COMPtrHolder_h
+
+#include <utility>
+
+#include "mozilla/Assertions.h"
+#include "mozilla/Attributes.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/mscom/ProxyStream.h"
+#include "mozilla/mscom/Ptr.h"
+#if defined(MOZ_SANDBOX)
+# include "mozilla/SandboxSettings.h"
+#endif // defined(MOZ_SANDBOX)
+#include "nsExceptionHandler.h"
+
+namespace mozilla {
+namespace mscom {
+
+template <typename Interface, const IID& _IID>
+class COMPtrHolder {
+ public:
+ typedef ProxyUniquePtr<Interface> COMPtrType;
+ typedef COMPtrHolder<Interface, _IID> ThisType;
+ typedef typename detail::EnvironmentSelector<Interface>::Type EnvType;
+
+ COMPtrHolder() {}
+
+ MOZ_IMPLICIT COMPtrHolder(decltype(nullptr)) {}
+
+ explicit COMPtrHolder(COMPtrType&& aPtr)
+ : mPtr(std::forward<COMPtrType>(aPtr)) {}
+
+ COMPtrHolder(COMPtrType&& aPtr, const ActivationContext& aActCtx)
+ : mPtr(std::forward<COMPtrType>(aPtr)), mActCtx(aActCtx) {}
+
+ Interface* Get() const { return mPtr.get(); }
+
+ [[nodiscard]] Interface* Release() { return mPtr.release(); }
+
+ void Set(COMPtrType&& aPtr) { mPtr = std::forward<COMPtrType>(aPtr); }
+
+ void SetActCtx(const ActivationContext& aActCtx) { mActCtx = aActCtx; }
+
+#if defined(MOZ_SANDBOX)
+ // This method is const because we need to call it during IPC write, where
+ // we are passed as a const argument. At higher sandboxing levels we need to
+ // save this artifact from the serialization process for later deletion.
+ void PreserveStream(PreservedStreamPtr aPtr) const {
+ MOZ_ASSERT(!mMarshaledStream);
+ mMarshaledStream = std::move(aPtr);
+ }
+
+ PreservedStreamPtr GetPreservedStream() {
+ return std::move(mMarshaledStream);
+ }
+#endif // defined(MOZ_SANDBOX)
+
+ COMPtrHolder(const COMPtrHolder& aOther) = delete;
+
+ COMPtrHolder(COMPtrHolder&& aOther)
+ : mPtr(std::move(aOther.mPtr))
+#if defined(MOZ_SANDBOX)
+ ,
+ mMarshaledStream(std::move(aOther.mMarshaledStream))
+#endif // defined(MOZ_SANDBOX)
+ {
+ }
+
+ // COMPtrHolder is eventually added as a member of a struct that is declared
+ // in IPDL. The generated C++ code for that IPDL struct includes copy
+ // constructors and assignment operators that assume that all members are
+ // copyable. I don't think that those copy constructors and operator= are
+ // actually used by any generated code, but they are made available. Since no
+ // move semantics are available, this terrible hack makes COMPtrHolder build
+ // when used as a member of an IPDL struct.
+ ThisType& operator=(const ThisType& aOther) {
+ Set(std::move(aOther.mPtr));
+
+#if defined(MOZ_SANDBOX)
+ mMarshaledStream = std::move(aOther.mMarshaledStream);
+#endif // defined(MOZ_SANDBOX)
+
+ return *this;
+ }
+
+ ThisType& operator=(ThisType&& aOther) {
+ Set(std::move(aOther.mPtr));
+
+#if defined(MOZ_SANDBOX)
+ mMarshaledStream = std::move(aOther.mMarshaledStream);
+#endif // defined(MOZ_SANDBOX)
+
+ return *this;
+ }
+
+ bool operator==(const ThisType& aOther) const { return mPtr == aOther.mPtr; }
+
+ bool IsNull() const { return !mPtr; }
+
+ private:
+ // This is mutable to facilitate the above operator= hack
+ mutable COMPtrType mPtr;
+ ActivationContext mActCtx;
+
+#if defined(MOZ_SANDBOX)
+ // This is mutable so that we may optionally store a reference to a marshaled
+ // stream to be cleaned up later via PreserveStream().
+ mutable PreservedStreamPtr mMarshaledStream;
+#endif // defined(MOZ_SANDBOX)
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+namespace IPC {
+
+template <typename Interface, const IID& _IID>
+struct ParamTraits<mozilla::mscom::COMPtrHolder<Interface, _IID>> {
+ typedef mozilla::mscom::COMPtrHolder<Interface, _IID> paramType;
+
+ static void Write(MessageWriter* aWriter, const paramType& aParam) {
+#if defined(MOZ_SANDBOX)
+ static const bool sIsStreamPreservationNeeded =
+ XRE_IsParentProcess() &&
+ mozilla::GetEffectiveContentSandboxLevel() >= 3;
+#else
+ const bool sIsStreamPreservationNeeded = false;
+#endif // defined(MOZ_SANDBOX)
+
+ typename paramType::EnvType env;
+
+ mozilla::mscom::ProxyStreamFlags flags =
+ sIsStreamPreservationNeeded
+ ? mozilla::mscom::ProxyStreamFlags::ePreservable
+ : mozilla::mscom::ProxyStreamFlags::eDefault;
+
+ mozilla::mscom::ProxyStream proxyStream(_IID, aParam.Get(), &env, flags);
+ int bufLen;
+ const BYTE* buf = proxyStream.GetBuffer(bufLen);
+ MOZ_ASSERT(buf || !bufLen);
+ aWriter->WriteInt(bufLen);
+ if (bufLen) {
+ aWriter->WriteBytes(reinterpret_cast<const char*>(buf), bufLen);
+ }
+
+#if defined(MOZ_SANDBOX)
+ if (sIsStreamPreservationNeeded) {
+ /**
+ * When we're sending a ProxyStream from parent to content and the
+ * content sandboxing level is >= 3, content is unable to communicate
+ * its releasing of its reference to the proxied object. We preserve the
+ * marshaled proxy data here and later manually release it on content's
+ * behalf.
+ */
+ aParam.PreserveStream(proxyStream.GetPreservedStream());
+ }
+#endif // defined(MOZ_SANDBOX)
+ }
+
+ static bool Read(MessageReader* aReader, paramType* aResult) {
+ int length;
+ if (!aReader->ReadLength(&length)) {
+ return false;
+ }
+
+ mozilla::UniquePtr<BYTE[]> buf;
+ if (length) {
+ buf = mozilla::MakeUnique<BYTE[]>(length);
+ if (!aReader->ReadBytesInto(buf.get(), length)) {
+ return false;
+ }
+ }
+
+ typename paramType::EnvType env;
+
+ mozilla::mscom::ProxyStream proxyStream(_IID, buf.get(), length, &env);
+ if (!proxyStream.IsValid()) {
+ CrashReporter::AnnotateCrashReport(
+ CrashReporter::Annotation::ProxyStreamValid, "false"_ns);
+ return false;
+ }
+
+ typename paramType::COMPtrType ptr;
+ if (!proxyStream.GetInterface(mozilla::mscom::getter_AddRefs(ptr))) {
+ return false;
+ }
+
+ aResult->Set(std::move(ptr));
+ return true;
+ }
+};
+
+} // namespace IPC
+
+#endif // mozilla_mscom_COMPtrHolder_h
diff --git a/ipc/mscom/COMWrappers.cpp b/ipc/mscom/COMWrappers.cpp
new file mode 100644
index 0000000000..1bc48a4927
--- /dev/null
+++ b/ipc/mscom/COMWrappers.cpp
@@ -0,0 +1,101 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/mscom/COMWrappers.h"
+
+#include <objbase.h>
+
+#include "mozilla/Assertions.h"
+#include "mozilla/DynamicallyLinkedFunctionPtr.h"
+
+namespace mozilla::mscom::wrapped {
+
+HRESULT CoInitializeEx(LPVOID pvReserved, DWORD dwCoInit) {
+ static const StaticDynamicallyLinkedFunctionPtr<decltype(&::CoInitializeEx)>
+ pCoInitializeEx(L"combase.dll", "CoInitializeEx");
+ if (!pCoInitializeEx) {
+ return ::CoInitializeEx(pvReserved, dwCoInit);
+ }
+
+ return pCoInitializeEx(pvReserved, dwCoInit);
+}
+
+void CoUninitialize() {
+ static const StaticDynamicallyLinkedFunctionPtr<decltype(&::CoUninitialize)>
+ pCoUninitialize(L"combase.dll", "CoUninitialize");
+ if (!pCoUninitialize) {
+ return ::CoUninitialize();
+ }
+
+ return pCoUninitialize();
+}
+
+HRESULT CoIncrementMTAUsage(CO_MTA_USAGE_COOKIE* pCookie) {
+ static const StaticDynamicallyLinkedFunctionPtr<
+ decltype(&::CoIncrementMTAUsage)>
+ pCoIncrementMTAUsage(L"combase.dll", "CoIncrementMTAUsage");
+ // This API is only available beginning with Windows 8.
+ if (!pCoIncrementMTAUsage) {
+ return E_NOTIMPL;
+ }
+
+ HRESULT hr = pCoIncrementMTAUsage(pCookie);
+ MOZ_ASSERT(SUCCEEDED(hr));
+ return hr;
+}
+
+HRESULT CoGetApartmentType(APTTYPE* pAptType, APTTYPEQUALIFIER* pAptQualifier) {
+ static const StaticDynamicallyLinkedFunctionPtr<
+ decltype(&::CoGetApartmentType)>
+ pCoGetApartmentType(L"combase.dll", "CoGetApartmentType");
+ if (!pCoGetApartmentType) {
+ return ::CoGetApartmentType(pAptType, pAptQualifier);
+ }
+
+ return pCoGetApartmentType(pAptType, pAptQualifier);
+}
+
+HRESULT CoInitializeSecurity(PSECURITY_DESCRIPTOR pSecDesc, LONG cAuthSvc,
+ SOLE_AUTHENTICATION_SERVICE* asAuthSvc,
+ void* pReserved1, DWORD dwAuthnLevel,
+ DWORD dwImpLevel, void* pAuthList,
+ DWORD dwCapabilities, void* pReserved3) {
+ static const StaticDynamicallyLinkedFunctionPtr<
+ decltype(&::CoInitializeSecurity)>
+ pCoInitializeSecurity(L"combase.dll", "CoInitializeSecurity");
+ if (!pCoInitializeSecurity) {
+ return ::CoInitializeSecurity(pSecDesc, cAuthSvc, asAuthSvc, pReserved1,
+ dwAuthnLevel, dwImpLevel, pAuthList,
+ dwCapabilities, pReserved3);
+ }
+
+ return pCoInitializeSecurity(pSecDesc, cAuthSvc, asAuthSvc, pReserved1,
+ dwAuthnLevel, dwImpLevel, pAuthList,
+ dwCapabilities, pReserved3);
+}
+
+HRESULT CoCreateInstance(REFCLSID rclsid, LPUNKNOWN pUnkOuter,
+ DWORD dwClsContext, REFIID riid, LPVOID* ppv) {
+ static const StaticDynamicallyLinkedFunctionPtr<decltype(&::CoCreateInstance)>
+ pCoCreateInstance(L"combase.dll", "CoCreateInstance");
+ if (!pCoCreateInstance) {
+ return ::CoCreateInstance(rclsid, pUnkOuter, dwClsContext, riid, ppv);
+ }
+
+ return pCoCreateInstance(rclsid, pUnkOuter, dwClsContext, riid, ppv);
+}
+
+HRESULT CoCreateGuid(GUID* pguid) {
+ static const StaticDynamicallyLinkedFunctionPtr<decltype(&::CoCreateGuid)>
+ pCoCreateGuid(L"combase.dll", "CoCreateGuid");
+ if (!pCoCreateGuid) {
+ return ::CoCreateGuid(pguid);
+ }
+
+ return pCoCreateGuid(pguid);
+}
+
+} // namespace mozilla::mscom::wrapped
diff --git a/ipc/mscom/COMWrappers.h b/ipc/mscom/COMWrappers.h
new file mode 100644
index 0000000000..38bef10749
--- /dev/null
+++ b/ipc/mscom/COMWrappers.h
@@ -0,0 +1,44 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_COMWrappers_h
+#define mozilla_mscom_COMWrappers_h
+
+#include <objbase.h>
+
+#if (NTDDI_VERSION < NTDDI_WIN8)
+// Win8+ API that we use very carefully
+DECLARE_HANDLE(CO_MTA_USAGE_COOKIE);
+HRESULT WINAPI CoIncrementMTAUsage(CO_MTA_USAGE_COOKIE* pCookie);
+#endif // (NTDDI_VERSION < NTDDI_WIN8)
+
+// A set of wrapped COM functions, so that we can dynamically link to the
+// functions in combase.dll on win8+. This prevents ole32.dll and many other
+// DLLs loading, which are not required when we have win32k locked down.
+namespace mozilla::mscom::wrapped {
+
+HRESULT CoInitializeEx(LPVOID pvReserved, DWORD dwCoInit);
+
+void CoUninitialize();
+
+HRESULT CoIncrementMTAUsage(CO_MTA_USAGE_COOKIE* pCookie);
+
+HRESULT CoGetApartmentType(APTTYPE* pAptType, APTTYPEQUALIFIER* pAptQualifier);
+
+HRESULT CoInitializeSecurity(PSECURITY_DESCRIPTOR pSecDesc, LONG cAuthSvc,
+ SOLE_AUTHENTICATION_SERVICE* asAuthSvc,
+ void* pReserved1, DWORD dwAuthnLevel,
+ DWORD dwImpLevel, void* pAuthList,
+ DWORD dwCapabilities, void* pReserved3);
+
+HRESULT CoCreateInstance(REFCLSID rclsid, LPUNKNOWN pUnkOuter,
+ DWORD dwClsContext, REFIID riid, LPVOID* ppv);
+
+HRESULT CoCreateGuid(GUID* pguid);
+
+} // namespace mozilla::mscom::wrapped
+
+#endif // mozilla_mscom_COMWrappers_h
diff --git a/ipc/mscom/DispatchForwarder.cpp b/ipc/mscom/DispatchForwarder.cpp
new file mode 100644
index 0000000000..e55b1316d0
--- /dev/null
+++ b/ipc/mscom/DispatchForwarder.cpp
@@ -0,0 +1,156 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/mscom/DispatchForwarder.h"
+
+#include <oleauto.h>
+
+#include <utility>
+
+#include "mozilla/mscom/MainThreadInvoker.h"
+
+namespace mozilla {
+namespace mscom {
+
+/* static */
+HRESULT DispatchForwarder::Create(IInterceptor* aInterceptor,
+ STAUniquePtr<IDispatch>& aTarget,
+ IUnknown** aOutput) {
+ MOZ_ASSERT(aInterceptor && aOutput);
+ if (!aOutput) {
+ return E_INVALIDARG;
+ }
+ *aOutput = nullptr;
+ if (!aInterceptor) {
+ return E_INVALIDARG;
+ }
+ DispatchForwarder* forwarder = new DispatchForwarder(aInterceptor, aTarget);
+ HRESULT hr = forwarder->QueryInterface(IID_IDispatch, (void**)aOutput);
+ forwarder->Release();
+ return hr;
+}
+
+DispatchForwarder::DispatchForwarder(IInterceptor* aInterceptor,
+ STAUniquePtr<IDispatch>& aTarget)
+ : mRefCnt(1), mInterceptor(aInterceptor), mTarget(std::move(aTarget)) {}
+
+DispatchForwarder::~DispatchForwarder() {}
+
+HRESULT
+DispatchForwarder::QueryInterface(REFIID riid, void** ppv) {
+ if (!ppv) {
+ return E_INVALIDARG;
+ }
+
+ // Since this class implements a tearoff, any interfaces that are not
+ // IDispatch must be routed to the original object's QueryInterface.
+ // This is especially important for IUnknown since COM uses that interface
+ // to determine object identity.
+ if (riid != IID_IDispatch) {
+ return mInterceptor->QueryInterface(riid, ppv);
+ }
+
+ IUnknown* punk = static_cast<IDispatch*>(this);
+ *ppv = punk;
+ if (!punk) {
+ return E_NOINTERFACE;
+ }
+
+ punk->AddRef();
+ return S_OK;
+}
+
+ULONG
+DispatchForwarder::AddRef() {
+ return (ULONG)InterlockedIncrement((LONG*)&mRefCnt);
+}
+
+ULONG
+DispatchForwarder::Release() {
+ ULONG newRefCnt = (ULONG)InterlockedDecrement((LONG*)&mRefCnt);
+ if (newRefCnt == 0) {
+ delete this;
+ }
+ return newRefCnt;
+}
+
+HRESULT
+DispatchForwarder::GetTypeInfoCount(UINT* pctinfo) {
+ if (!pctinfo) {
+ return E_INVALIDARG;
+ }
+ *pctinfo = 1;
+ return S_OK;
+}
+
+HRESULT
+DispatchForwarder::GetTypeInfo(UINT iTInfo, LCID lcid, ITypeInfo** ppTInfo) {
+ // ITypeInfo as implemented by COM is apartment-neutral, so we don't need
+ // to wrap it (yay!)
+ if (mTypeInfo) {
+ RefPtr<ITypeInfo> copy(mTypeInfo);
+ copy.forget(ppTInfo);
+ return S_OK;
+ }
+ HRESULT hr = E_UNEXPECTED;
+ auto fn = [&]() -> void { hr = mTarget->GetTypeInfo(iTInfo, lcid, ppTInfo); };
+ MainThreadInvoker invoker;
+ if (!invoker.Invoke(
+ NS_NewRunnableFunction("DispatchForwarder::GetTypeInfo", fn))) {
+ return E_UNEXPECTED;
+ }
+ if (FAILED(hr)) {
+ return hr;
+ }
+ mTypeInfo = *ppTInfo;
+ return hr;
+}
+
+HRESULT
+DispatchForwarder::GetIDsOfNames(REFIID riid, LPOLESTR* rgszNames, UINT cNames,
+ LCID lcid, DISPID* rgDispId) {
+ HRESULT hr = E_UNEXPECTED;
+ auto fn = [&]() -> void {
+ hr = mTarget->GetIDsOfNames(riid, rgszNames, cNames, lcid, rgDispId);
+ };
+ MainThreadInvoker invoker;
+ if (!invoker.Invoke(
+ NS_NewRunnableFunction("DispatchForwarder::GetIDsOfNames", fn))) {
+ return E_UNEXPECTED;
+ }
+ return hr;
+}
+
+HRESULT
+DispatchForwarder::Invoke(DISPID dispIdMember, REFIID riid, LCID lcid,
+ WORD wFlags, DISPPARAMS* pDispParams,
+ VARIANT* pVarResult, EXCEPINFO* pExcepInfo,
+ UINT* puArgErr) {
+ HRESULT hr;
+ if (!mInterface) {
+ if (!mTypeInfo) {
+ return E_UNEXPECTED;
+ }
+ TYPEATTR* typeAttr = nullptr;
+ hr = mTypeInfo->GetTypeAttr(&typeAttr);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ hr = mInterceptor->QueryInterface(typeAttr->guid,
+ (void**)getter_AddRefs(mInterface));
+ mTypeInfo->ReleaseTypeAttr(typeAttr);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ }
+ // We don't invoke IDispatch on the target, but rather on the interceptor!
+ hr = ::DispInvoke(mInterface.get(), mTypeInfo, dispIdMember, wFlags,
+ pDispParams, pVarResult, pExcepInfo, puArgErr);
+ return hr;
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/DispatchForwarder.h b/ipc/mscom/DispatchForwarder.h
new file mode 100644
index 0000000000..55d3fbb4f2
--- /dev/null
+++ b/ipc/mscom/DispatchForwarder.h
@@ -0,0 +1,79 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_DispatchForwarder_h
+#define mozilla_mscom_DispatchForwarder_h
+
+#include <oaidl.h>
+
+#include "mozilla/mscom/Interceptor.h"
+#include "mozilla/mscom/Ptr.h"
+
+namespace mozilla {
+namespace mscom {
+
+class DispatchForwarder final : public IDispatch {
+ public:
+ static HRESULT Create(IInterceptor* aInterceptor,
+ STAUniquePtr<IDispatch>& aTarget, IUnknown** aOutput);
+
+ // IUnknown
+ STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override;
+ STDMETHODIMP_(ULONG) AddRef() override;
+ STDMETHODIMP_(ULONG) Release() override;
+
+ // IDispatch
+ STDMETHODIMP GetTypeInfoCount(
+ /* [out] */ __RPC__out UINT* pctinfo) override;
+
+ STDMETHODIMP GetTypeInfo(
+ /* [in] */ UINT iTInfo,
+ /* [in] */ LCID lcid,
+ /* [out] */ __RPC__deref_out_opt ITypeInfo** ppTInfo) override;
+
+ STDMETHODIMP GetIDsOfNames(
+ /* [in] */ __RPC__in REFIID riid,
+ /* [size_is][in] */ __RPC__in_ecount_full(cNames) LPOLESTR* rgszNames,
+ /* [range][in] */ __RPC__in_range(0, 16384) UINT cNames,
+ /* [in] */ LCID lcid,
+ /* [size_is][out] */ __RPC__out_ecount_full(cNames) DISPID* rgDispId)
+ override;
+
+ STDMETHODIMP Invoke(
+ /* [annotation][in] */
+ _In_ DISPID dispIdMember,
+ /* [annotation][in] */
+ _In_ REFIID riid,
+ /* [annotation][in] */
+ _In_ LCID lcid,
+ /* [annotation][in] */
+ _In_ WORD wFlags,
+ /* [annotation][out][in] */
+ _In_ DISPPARAMS* pDispParams,
+ /* [annotation][out] */
+ _Out_opt_ VARIANT* pVarResult,
+ /* [annotation][out] */
+ _Out_opt_ EXCEPINFO* pExcepInfo,
+ /* [annotation][out] */
+ _Out_opt_ UINT* puArgErr) override;
+
+ private:
+ DispatchForwarder(IInterceptor* aInterceptor,
+ STAUniquePtr<IDispatch>& aTarget);
+ ~DispatchForwarder();
+
+ private:
+ ULONG mRefCnt;
+ RefPtr<IInterceptor> mInterceptor;
+ STAUniquePtr<IDispatch> mTarget;
+ RefPtr<ITypeInfo> mTypeInfo;
+ RefPtr<IUnknown> mInterface;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_DispatchForwarder_h
diff --git a/ipc/mscom/EnsureMTA.cpp b/ipc/mscom/EnsureMTA.cpp
new file mode 100644
index 0000000000..2258dd9dcd
--- /dev/null
+++ b/ipc/mscom/EnsureMTA.cpp
@@ -0,0 +1,256 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/mscom/EnsureMTA.h"
+
+#include "mozilla/Assertions.h"
+#include "mozilla/ClearOnShutdown.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/mscom/COMWrappers.h"
+#include "mozilla/mscom/Utils.h"
+#include "mozilla/SchedulerGroup.h"
+#include "mozilla/StaticLocalPtr.h"
+#include "nsThreadManager.h"
+#include "nsThreadUtils.h"
+
+#include "private/pprthred.h"
+
+namespace {
+
+class EnterMTARunnable : public mozilla::Runnable {
+ public:
+ EnterMTARunnable() : mozilla::Runnable("EnterMTARunnable") {}
+ NS_IMETHOD Run() override {
+ mozilla::DebugOnly<HRESULT> hr =
+ mozilla::mscom::wrapped::CoInitializeEx(nullptr, COINIT_MULTITHREADED);
+ MOZ_ASSERT(SUCCEEDED(hr));
+ return NS_OK;
+ }
+};
+
+class BackgroundMTAData {
+ public:
+ BackgroundMTAData() {
+ nsCOMPtr<nsIRunnable> runnable = new EnterMTARunnable();
+ mozilla::DebugOnly<nsresult> rv = NS_NewNamedThread(
+ "COM MTA", getter_AddRefs(mThread), runnable.forget());
+ NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "NS_NewNamedThread failed");
+ MOZ_ASSERT(NS_SUCCEEDED(rv));
+ }
+
+ ~BackgroundMTAData() {
+ if (mThread) {
+ mThread->Dispatch(
+ NS_NewRunnableFunction("BackgroundMTAData::~BackgroundMTAData",
+ &mozilla::mscom::wrapped::CoUninitialize),
+ NS_DISPATCH_NORMAL);
+ mThread->Shutdown();
+ }
+ }
+
+ nsCOMPtr<nsIThread> GetThread() const { return mThread; }
+
+ private:
+ nsCOMPtr<nsIThread> mThread;
+};
+
+} // anonymous namespace
+
+namespace mozilla {
+namespace mscom {
+
+EnsureMTA::EnsureMTA() {
+ MOZ_ASSERT(NS_IsMainThread());
+
+ // It is possible that we're running so early that we might need to start
+ // the thread manager ourselves. We do this here to guarantee that we have
+ // the ability to start the persistent MTA thread at any moment beyond this
+ // point.
+ nsresult rv = nsThreadManager::get().Init();
+ // We intentionally don't check rv unless we need it when
+ // CoIncremementMTAUsage is unavailable.
+
+ // Calling this function initializes the MTA without needing to explicitly
+ // create a thread and call CoInitializeEx to do it.
+ // We don't retain the cookie because once we've incremented the MTA, we
+ // leave it that way for the lifetime of the process.
+ CO_MTA_USAGE_COOKIE mtaCookie = nullptr;
+ HRESULT hr = wrapped::CoIncrementMTAUsage(&mtaCookie);
+ if (SUCCEEDED(hr)) {
+ if (NS_SUCCEEDED(rv)) {
+ // Start the persistent MTA thread (mostly) asynchronously.
+ Unused << GetPersistentMTAThread();
+ }
+
+ return;
+ }
+
+ // In the fallback case, we simply initialize our persistent MTA thread.
+
+ // Make sure thread manager init succeeded before trying to initialize the
+ // persistent MTA thread.
+ MOZ_DIAGNOSTIC_ASSERT(NS_SUCCEEDED(rv));
+ if (NS_FAILED(rv)) {
+ return;
+ }
+
+ // Before proceeding any further, pump a runnable through the persistent MTA
+ // thread to ensure that it is up and running and has finished initializing
+ // the multi-threaded apartment.
+ nsCOMPtr<nsIRunnable> runnable(NS_NewRunnableFunction(
+ "EnsureMTA::EnsureMTA()",
+ []() { MOZ_RELEASE_ASSERT(IsCurrentThreadExplicitMTA()); }));
+ SyncDispatchToPersistentThread(runnable);
+}
+
+/* static */
+RefPtr<EnsureMTA::CreateInstanceAgileRefPromise>
+EnsureMTA::CreateInstanceInternal(REFCLSID aClsid, REFIID aIid) {
+ MOZ_ASSERT(IsCurrentThreadExplicitMTA());
+
+ RefPtr<IUnknown> iface;
+ HRESULT hr = wrapped::CoCreateInstance(aClsid, nullptr, CLSCTX_INPROC_SERVER,
+ aIid, getter_AddRefs(iface));
+ if (FAILED(hr)) {
+ return CreateInstanceAgileRefPromise::CreateAndReject(hr, __func__);
+ }
+
+ // We need to use the two argument constructor for AgileReference because our
+ // RefPtr is not parameterized on the specific interface being requested.
+ AgileReference agileRef(aIid, iface);
+ if (!agileRef) {
+ return CreateInstanceAgileRefPromise::CreateAndReject(agileRef.GetHResult(),
+ __func__);
+ }
+
+ return CreateInstanceAgileRefPromise::CreateAndResolve(std::move(agileRef),
+ __func__);
+}
+
+/* static */
+RefPtr<EnsureMTA::CreateInstanceAgileRefPromise> EnsureMTA::CreateInstance(
+ REFCLSID aClsid, REFIID aIid) {
+ MOZ_ASSERT(IsCOMInitializedOnCurrentThread());
+
+ const bool isClassOk = IsClassThreadAwareInprocServer(aClsid);
+ MOZ_ASSERT(isClassOk,
+ "mozilla::mscom::EnsureMTA::CreateInstance is not "
+ "safe/performant/necessary to use with this CLSID. This CLSID "
+ "either does not support creation from within a multithreaded "
+ "apartment, or it is not an in-process server.");
+ if (!isClassOk) {
+ return CreateInstanceAgileRefPromise::CreateAndReject(CO_E_NOT_SUPPORTED,
+ __func__);
+ }
+
+ if (IsCurrentThreadExplicitMTA()) {
+ // It's safe to immediately call CreateInstanceInternal
+ return CreateInstanceInternal(aClsid, aIid);
+ }
+
+ // aClsid and aIid are references. Make local copies that we can put into the
+ // lambda in case the sources of aClsid or aIid are not static data
+ CLSID localClsid = aClsid;
+ IID localIid = aIid;
+
+ auto invoker = [localClsid,
+ localIid]() -> RefPtr<CreateInstanceAgileRefPromise> {
+ return CreateInstanceInternal(localClsid, localIid);
+ };
+
+ nsCOMPtr<nsIThread> mtaThread(GetPersistentMTAThread());
+
+ return InvokeAsync(mtaThread->SerialEventTarget(), __func__,
+ std::move(invoker));
+}
+
+/* static */
+nsCOMPtr<nsIThread> EnsureMTA::GetPersistentMTAThread() {
+ static StaticLocalAutoPtr<BackgroundMTAData> sMTAData(
+ []() -> BackgroundMTAData* {
+ BackgroundMTAData* bgData = new BackgroundMTAData();
+
+ auto setClearOnShutdown = [ptr = &sMTAData]() -> void {
+ ClearOnShutdown(ptr, ShutdownPhase::XPCOMShutdownThreads);
+ };
+
+ if (NS_IsMainThread()) {
+ setClearOnShutdown();
+ return bgData;
+ }
+
+ SchedulerGroup::Dispatch(
+ TaskCategory::Other,
+ NS_NewRunnableFunction("mscom::EnsureMTA::GetPersistentMTAThread",
+ std::move(setClearOnShutdown)));
+
+ return bgData;
+ }());
+
+ MOZ_ASSERT(sMTAData);
+
+ return sMTAData->GetThread();
+}
+
+/* static */
+void EnsureMTA::SyncDispatchToPersistentThread(nsIRunnable* aRunnable) {
+ nsCOMPtr<nsIThread> thread(GetPersistentMTAThread());
+ MOZ_ASSERT(thread);
+ if (!thread) {
+ return;
+ }
+
+ // Note that, due to APC dispatch, we might reenter this function while we
+ // wait on this event. We therefore need a unique event object for each
+ // entry into this function. If perf becomes an issue then we will want to
+ // maintain an array of events where the Nth event is unique to the Nth
+ // reentry.
+ nsAutoHandle event(::CreateEventW(nullptr, FALSE, FALSE, nullptr));
+ if (!event) {
+ return;
+ }
+
+ HANDLE eventHandle = event.get();
+ auto eventSetter = [&aRunnable, eventHandle]() -> void {
+ aRunnable->Run();
+ ::SetEvent(eventHandle);
+ };
+
+ nsresult rv = thread->Dispatch(
+ NS_NewRunnableFunction("mscom::EnsureMTA::SyncDispatchToPersistentThread",
+ std::move(eventSetter)),
+ NS_DISPATCH_NORMAL);
+ MOZ_ASSERT(NS_SUCCEEDED(rv));
+ if (NS_FAILED(rv)) {
+ return;
+ }
+
+#if defined(ACCESSIBILITY)
+ const BOOL alertable = XRE_IsContentProcess() && NS_IsMainThread();
+#else
+ const BOOL alertable = FALSE;
+#endif // defined(ACCESSIBILITY)
+
+ AUTO_PROFILER_THREAD_SLEEP;
+ DWORD waitResult;
+ while ((waitResult = ::WaitForSingleObjectEx(event, INFINITE, alertable)) ==
+ WAIT_IO_COMPLETION) {
+ }
+ MOZ_ASSERT(waitResult == WAIT_OBJECT_0);
+}
+
+/**
+ * While this function currently appears to be redundant, it may become more
+ * sophisticated in the future. For example, we could optionally dispatch to an
+ * MTA context if we wanted to utilize the MTA thread pool.
+ */
+/* static */
+void EnsureMTA::SyncDispatch(nsCOMPtr<nsIRunnable>&& aRunnable, Option aOpt) {
+ SyncDispatchToPersistentThread(aRunnable);
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/EnsureMTA.h b/ipc/mscom/EnsureMTA.h
new file mode 100644
index 0000000000..5e410d4daa
--- /dev/null
+++ b/ipc/mscom/EnsureMTA.h
@@ -0,0 +1,191 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_EnsureMTA_h
+#define mozilla_mscom_EnsureMTA_h
+
+#include "MainThreadUtils.h"
+#include "mozilla/Attributes.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/MozPromise.h"
+#include "mozilla/Unused.h"
+#include "mozilla/mscom/AgileReference.h"
+#include "mozilla/mscom/Utils.h"
+#include "mozilla/RefPtr.h"
+#include "nsCOMPtr.h"
+#include "nsIThread.h"
+#include "nsThreadUtils.h"
+#include "nsWindowsHelpers.h"
+
+#include <windows.h>
+
+namespace mozilla {
+namespace mscom {
+namespace detail {
+
+// Forward declarations
+template <typename T>
+struct MTADelete;
+
+template <typename T>
+struct MTARelease;
+
+template <typename T>
+struct MTAReleaseInChildProcess;
+
+struct PreservedStreamDeleter;
+
+} // namespace detail
+
+class ProcessRuntime;
+
+// This class is OK to use as a temporary on the stack.
+class MOZ_STACK_CLASS EnsureMTA final {
+ public:
+ enum class Option {
+ Default,
+ // Forcibly dispatch to the thread returned by GetPersistentMTAThread(),
+ // even if the current thread is already inside a MTA.
+ ForceDispatchToPersistentThread,
+ };
+
+ /**
+ * Synchronously run |aClosure| on a thread living in the COM multithreaded
+ * apartment. If the current thread lives inside the COM MTA, then it runs
+ * |aClosure| immediately unless |aOpt| ==
+ * Option::ForceDispatchToPersistentThread.
+ */
+ template <typename FuncT>
+ explicit EnsureMTA(FuncT&& aClosure, Option aOpt = Option::Default) {
+ if (aOpt != Option::ForceDispatchToPersistentThread &&
+ IsCurrentThreadMTA()) {
+ // We're already on the MTA, we can run aClosure directly
+ aClosure();
+ return;
+ }
+
+ // In this case we need to run aClosure on a background thread in the MTA
+ nsCOMPtr<nsIRunnable> runnable(
+ NS_NewRunnableFunction("EnsureMTA::EnsureMTA", std::move(aClosure)));
+ SyncDispatch(std::move(runnable), aOpt);
+ }
+
+ using CreateInstanceAgileRefPromise =
+ MozPromise<AgileReference, HRESULT, false>;
+
+ /**
+ * *** A MSCOM PEER SHOULD REVIEW ALL NEW USES OF THIS API! ***
+ *
+ * Asynchronously instantiate a new COM object from a MTA thread, unless the
+ * current thread is already living inside the multithreaded apartment, in
+ * which case the object is immediately instantiated.
+ *
+ * This function only supports the most common configurations for creating
+ * a new object, so it only supports in-process servers. Furthermore, this
+ * function does not support aggregation (ie. the |pUnkOuter| parameter to
+ * CoCreateInstance).
+ *
+ * Given that attempting to instantiate an Apartment-threaded COM object
+ * inside the MTA results in a *loss* of performance, we assert when that
+ * situation arises.
+ *
+ * The resulting promise, once resolved, provides an AgileReference that may
+ * be passed between any COM-initialized thread in the current process.
+ *
+ * *** A MSCOM PEER SHOULD REVIEW ALL NEW USES OF THIS API! ***
+ *
+ * WARNING:
+ * Some COM objects do not support creation in the multithreaded apartment,
+ * in which case this function is not available as an option. In this case,
+ * the promise will always be rejected. In debug builds we will assert.
+ *
+ * *** A MSCOM PEER SHOULD REVIEW ALL NEW USES OF THIS API! ***
+ *
+ * WARNING:
+ * Any in-process COM objects whose interfaces accept HWNDs are probably
+ * *not* safe to instantiate in the multithreaded apartment! Even if this
+ * function succeeds when creating such an object, you *MUST NOT* do so, as
+ * these failures might not become apparent until your code is running out in
+ * the wild on the release channel!
+ *
+ * *** A MSCOM PEER SHOULD REVIEW ALL NEW USES OF THIS API! ***
+ *
+ * WARNING:
+ * When you obtain an interface from the AgileReference, it may or may not be
+ * a proxy to the real object. This depends entirely on the implementation of
+ * the underlying class and the multithreading capabilities that the class
+ * declares to the COM runtime. If the interface is proxied, it might be
+ * expensive to invoke methods on that interface! *Always* test the
+ * performance of your method calls when calling interfaces that are resolved
+ * via this function!
+ *
+ * *** A MSCOM PEER SHOULD REVIEW ALL NEW USES OF THIS API! ***
+ *
+ * (Despite this myriad of warnings, it is still *much* safer to use this
+ * function to asynchronously create COM objects than it is to roll your own!)
+ *
+ * *** A MSCOM PEER SHOULD REVIEW ALL NEW USES OF THIS API! ***
+ */
+ static RefPtr<CreateInstanceAgileRefPromise> CreateInstance(REFCLSID aClsid,
+ REFIID aIid);
+
+ private:
+ static RefPtr<CreateInstanceAgileRefPromise> CreateInstanceInternal(
+ REFCLSID aClsid, REFIID aIid);
+
+ static nsCOMPtr<nsIThread> GetPersistentMTAThread();
+
+ static void SyncDispatch(nsCOMPtr<nsIRunnable>&& aRunnable, Option aOpt);
+ static void SyncDispatchToPersistentThread(nsIRunnable* aRunnable);
+
+ // The following function is private in order to force any consumers to be
+ // declared as friends of EnsureMTA. The intention is to prevent
+ // AsyncOperation from becoming some kind of free-for-all mechanism for
+ // asynchronously executing work on a background thread.
+ template <typename FuncT>
+ static void AsyncOperation(FuncT&& aClosure) {
+ if (IsCurrentThreadMTA()) {
+ aClosure();
+ return;
+ }
+
+ nsCOMPtr<nsIThread> thread(GetPersistentMTAThread());
+ MOZ_ASSERT(thread);
+ if (!thread) {
+ return;
+ }
+
+ DebugOnly<nsresult> rv = thread->Dispatch(
+ NS_NewRunnableFunction("mscom::EnsureMTA::AsyncOperation",
+ std::move(aClosure)),
+ NS_DISPATCH_NORMAL);
+ MOZ_ASSERT(NS_SUCCEEDED(rv));
+ }
+
+ /**
+ * This constructor just ensures that the MTA is up and running. This should
+ * only be called by ProcessRuntime.
+ */
+ EnsureMTA();
+
+ friend class mozilla::mscom::ProcessRuntime;
+
+ template <typename T>
+ friend struct mozilla::mscom::detail::MTADelete;
+
+ template <typename T>
+ friend struct mozilla::mscom::detail::MTARelease;
+
+ template <typename T>
+ friend struct mozilla::mscom::detail::MTAReleaseInChildProcess;
+
+ friend struct mozilla::mscom::detail::PreservedStreamDeleter;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_EnsureMTA_h
diff --git a/ipc/mscom/FastMarshaler.cpp b/ipc/mscom/FastMarshaler.cpp
new file mode 100644
index 0000000000..4ae62394a2
--- /dev/null
+++ b/ipc/mscom/FastMarshaler.cpp
@@ -0,0 +1,164 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/mscom/FastMarshaler.h"
+
+#include "mozilla/mscom/Utils.h"
+
+#include <objbase.h>
+
+namespace mozilla {
+namespace mscom {
+
+HRESULT
+FastMarshaler::Create(IUnknown* aOuter, IUnknown** aOutMarshalerUnk) {
+ MOZ_ASSERT(XRE_IsContentProcess());
+
+ if (!aOuter || !aOutMarshalerUnk) {
+ return E_INVALIDARG;
+ }
+
+ *aOutMarshalerUnk = nullptr;
+
+ HRESULT hr;
+ RefPtr<FastMarshaler> fm(new FastMarshaler(aOuter, &hr));
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ return fm->InternalQueryInterface(IID_IUnknown, (void**)aOutMarshalerUnk);
+}
+
+FastMarshaler::FastMarshaler(IUnknown* aOuter, HRESULT* aResult)
+ : mRefCnt(0), mOuter(aOuter), mStdMarshalWeak(nullptr) {
+ *aResult =
+ ::CoGetStdMarshalEx(aOuter, SMEXF_SERVER, getter_AddRefs(mStdMarshalUnk));
+ if (FAILED(*aResult)) {
+ return;
+ }
+
+ *aResult =
+ mStdMarshalUnk->QueryInterface(IID_IMarshal, (void**)&mStdMarshalWeak);
+ if (FAILED(*aResult)) {
+ return;
+ }
+
+ // mStdMarshalWeak is weak
+ mStdMarshalWeak->Release();
+}
+
+HRESULT
+FastMarshaler::InternalQueryInterface(REFIID riid, void** ppv) {
+ if (!ppv) {
+ return E_INVALIDARG;
+ }
+
+ if (riid == IID_IUnknown) {
+ RefPtr<IUnknown> punk(static_cast<IUnknown*>(&mInternalUnknown));
+ punk.forget(ppv);
+ return S_OK;
+ }
+
+ if (riid == IID_IMarshal) {
+ RefPtr<IMarshal> ptr(this);
+ ptr.forget(ppv);
+ return S_OK;
+ }
+
+ return mStdMarshalUnk->QueryInterface(riid, ppv);
+}
+
+ULONG
+FastMarshaler::InternalAddRef() { return ++mRefCnt; }
+
+ULONG
+FastMarshaler::InternalRelease() {
+ ULONG result = --mRefCnt;
+ if (!result) {
+ delete this;
+ }
+
+ return result;
+}
+
+DWORD
+FastMarshaler::GetMarshalFlags(DWORD aDestContext, DWORD aMshlFlags) {
+ // Only worry about local contexts.
+ if (aDestContext != MSHCTX_LOCAL) {
+ return aMshlFlags;
+ }
+
+ return aMshlFlags | MSHLFLAGS_NOPING;
+}
+
+HRESULT
+FastMarshaler::GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ CLSID* pCid) {
+ if (!mStdMarshalWeak) {
+ return E_POINTER;
+ }
+
+ return mStdMarshalWeak->GetUnmarshalClass(
+ riid, pv, dwDestContext, pvDestContext,
+ GetMarshalFlags(dwDestContext, mshlflags), pCid);
+}
+
+HRESULT
+FastMarshaler::GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ DWORD* pSize) {
+ if (!mStdMarshalWeak) {
+ return E_POINTER;
+ }
+
+ return mStdMarshalWeak->GetMarshalSizeMax(
+ riid, pv, dwDestContext, pvDestContext,
+ GetMarshalFlags(dwDestContext, mshlflags), pSize);
+}
+
+HRESULT
+FastMarshaler::MarshalInterface(IStream* pStm, REFIID riid, void* pv,
+ DWORD dwDestContext, void* pvDestContext,
+ DWORD mshlflags) {
+ if (!mStdMarshalWeak) {
+ return E_POINTER;
+ }
+
+ return mStdMarshalWeak->MarshalInterface(
+ pStm, riid, pv, dwDestContext, pvDestContext,
+ GetMarshalFlags(dwDestContext, mshlflags));
+}
+
+HRESULT
+FastMarshaler::UnmarshalInterface(IStream* pStm, REFIID riid, void** ppv) {
+ if (!mStdMarshalWeak) {
+ return E_POINTER;
+ }
+
+ return mStdMarshalWeak->UnmarshalInterface(pStm, riid, ppv);
+}
+
+HRESULT
+FastMarshaler::ReleaseMarshalData(IStream* pStm) {
+ if (!mStdMarshalWeak) {
+ return E_POINTER;
+ }
+
+ return mStdMarshalWeak->ReleaseMarshalData(pStm);
+}
+
+HRESULT
+FastMarshaler::DisconnectObject(DWORD dwReserved) {
+ if (!mStdMarshalWeak) {
+ return E_POINTER;
+ }
+
+ return mStdMarshalWeak->DisconnectObject(dwReserved);
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/FastMarshaler.h b/ipc/mscom/FastMarshaler.h
new file mode 100644
index 0000000000..e1f3e88801
--- /dev/null
+++ b/ipc/mscom/FastMarshaler.h
@@ -0,0 +1,66 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_FastMarshaler_h
+#define mozilla_mscom_FastMarshaler_h
+
+#include "mozilla/Atomics.h"
+#include "mozilla/mscom/Aggregation.h"
+#include "mozilla/RefPtr.h"
+
+#include <objidl.h>
+
+namespace mozilla {
+namespace mscom {
+
+/**
+ * COM ping functionality is enabled by default and is designed to free strong
+ * references held by defunct client processes. However, this incurs a
+ * significant performance penalty in a11y code due to large numbers of remote
+ * objects being created and destroyed within a short period of time. Thus, we
+ * turn off pings to improve performance.
+ * ACHTUNG! When COM pings are disabled, Release calls from remote clients are
+ * never sent to the server! If you use this marshaler, you *must* explicitly
+ * disconnect clients using CoDisconnectObject when the object is no longer
+ * relevant. Otherwise, references to the object will never be released, causing
+ * a leak.
+ */
+class FastMarshaler final : public IMarshal {
+ public:
+ static HRESULT Create(IUnknown* aOuter, IUnknown** aOutMarshalerUnk);
+
+ // IMarshal
+ STDMETHODIMP GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ CLSID* pCid) override;
+ STDMETHODIMP GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ DWORD* pSize) override;
+ STDMETHODIMP MarshalInterface(IStream* pStm, REFIID riid, void* pv,
+ DWORD dwDestContext, void* pvDestContext,
+ DWORD mshlflags) override;
+ STDMETHODIMP UnmarshalInterface(IStream* pStm, REFIID riid,
+ void** ppv) override;
+ STDMETHODIMP ReleaseMarshalData(IStream* pStm) override;
+ STDMETHODIMP DisconnectObject(DWORD dwReserved) override;
+
+ private:
+ FastMarshaler(IUnknown* aOuter, HRESULT* aResult);
+ ~FastMarshaler() = default;
+
+ static DWORD GetMarshalFlags(DWORD aDestContext, DWORD aMshlFlags);
+
+ Atomic<ULONG> mRefCnt;
+ IUnknown* mOuter;
+ RefPtr<IUnknown> mStdMarshalUnk;
+ IMarshal* mStdMarshalWeak;
+ DECLARE_AGGREGATABLE(FastMarshaler);
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_FastMarshaler_h
diff --git a/ipc/mscom/IHandlerProvider.h b/ipc/mscom/IHandlerProvider.h
new file mode 100644
index 0000000000..76963fb9f6
--- /dev/null
+++ b/ipc/mscom/IHandlerProvider.h
@@ -0,0 +1,51 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_IHandlerProvider_h
+#define mozilla_mscom_IHandlerProvider_h
+
+#include "mozilla/NotNull.h"
+#include "mozilla/mscom/Ptr.h"
+
+#include <objidl.h>
+
+namespace mozilla {
+namespace mscom {
+
+struct IInterceptor;
+
+struct HandlerProvider {
+ virtual STDMETHODIMP GetHandler(NotNull<CLSID*> aHandlerClsid) = 0;
+ virtual STDMETHODIMP GetHandlerPayloadSize(
+ NotNull<IInterceptor*> aInterceptor, NotNull<DWORD*> aOutPayloadSize) = 0;
+ virtual STDMETHODIMP WriteHandlerPayload(NotNull<IInterceptor*> aInterceptor,
+ NotNull<IStream*> aStream) = 0;
+ virtual STDMETHODIMP_(REFIID) MarshalAs(REFIID aIid) = 0;
+ virtual STDMETHODIMP DisconnectHandlerRemotes() = 0;
+
+ /**
+ * Determine whether this interface might be supported by objects using
+ * this HandlerProvider.
+ * This is used to avoid unnecessary cross-thread QueryInterface calls for
+ * interfaces known to be unsupported.
+ * Return S_OK if the interface might be supported, E_NOINTERFACE if it
+ * definitely isn't supported.
+ */
+ virtual STDMETHODIMP IsInterfaceMaybeSupported(REFIID aIid) { return S_OK; }
+};
+
+struct IHandlerProvider : public IUnknown, public HandlerProvider {
+ virtual STDMETHODIMP_(REFIID)
+ GetEffectiveOutParamIid(REFIID aCallIid, ULONG aCallMethod) = 0;
+ virtual STDMETHODIMP NewInstance(
+ REFIID aIid, InterceptorTargetPtr<IUnknown> aTarget,
+ NotNull<IHandlerProvider**> aOutNewPayload) = 0;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_IHandlerProvider_h
diff --git a/ipc/mscom/Interceptor.cpp b/ipc/mscom/Interceptor.cpp
new file mode 100644
index 0000000000..0e223bba26
--- /dev/null
+++ b/ipc/mscom/Interceptor.cpp
@@ -0,0 +1,860 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#define INITGUID
+
+#include "mozilla/mscom/Interceptor.h"
+
+#include <utility>
+
+#include "MainThreadUtils.h"
+#include "mozilla/Assertions.h"
+#include "mozilla/ThreadLocal.h"
+#include "mozilla/Unused.h"
+#include "mozilla/mscom/DispatchForwarder.h"
+#include "mozilla/mscom/FastMarshaler.h"
+#include "mozilla/mscom/InterceptorLog.h"
+#include "mozilla/mscom/MainThreadInvoker.h"
+#include "mozilla/mscom/Objref.h"
+#include "mozilla/mscom/Registration.h"
+#include "mozilla/mscom/Utils.h"
+#include "nsDirectoryServiceDefs.h"
+#include "nsDirectoryServiceUtils.h"
+#include "nsPrintfCString.h"
+#include "nsRefPtrHashtable.h"
+#include "nsThreadUtils.h"
+#include "nsXULAppAPI.h"
+
+#define ENSURE_HR_SUCCEEDED(hr) \
+ MOZ_ASSERT(SUCCEEDED((HRESULT)hr)); \
+ if (FAILED((HRESULT)hr)) { \
+ return hr; \
+ }
+
+namespace mozilla {
+namespace mscom {
+namespace detail {
+
+class MOZ_CAPABILITY("mutex") LiveSet final {
+ public:
+ LiveSet() : mMutex("mozilla::mscom::LiveSet::mMutex") {}
+
+ void Lock() MOZ_CAPABILITY_ACQUIRE(mMutex) { mMutex.Lock(); }
+
+ void Unlock() MOZ_CAPABILITY_RELEASE(mMutex) { mMutex.Unlock(); }
+
+ void Put(IUnknown* aKey, already_AddRefed<IWeakReference> aValue) {
+ mMutex.AssertCurrentThreadOwns();
+ mLiveSet.InsertOrUpdate(aKey, RefPtr<IWeakReference>{std::move(aValue)});
+ }
+
+ RefPtr<IWeakReference> Get(IUnknown* aKey) {
+ mMutex.AssertCurrentThreadOwns();
+ RefPtr<IWeakReference> result;
+ mLiveSet.Get(aKey, getter_AddRefs(result));
+ return result;
+ }
+
+ void Remove(IUnknown* aKey) {
+ mMutex.AssertCurrentThreadOwns();
+ mLiveSet.Remove(aKey);
+ }
+
+ private:
+ Mutex mMutex MOZ_UNANNOTATED;
+ nsRefPtrHashtable<nsPtrHashKey<IUnknown>, IWeakReference> mLiveSet;
+};
+
+/**
+ * We don't use the normal XPCOM BaseAutoLock because we need the ability
+ * to explicitly Unlock.
+ */
+class MOZ_RAII MOZ_SCOPED_CAPABILITY LiveSetAutoLock final {
+ public:
+ explicit LiveSetAutoLock(LiveSet& aLiveSet) MOZ_CAPABILITY_ACQUIRE(aLiveSet)
+ : mLiveSet(&aLiveSet) {
+ aLiveSet.Lock();
+ }
+
+ ~LiveSetAutoLock() MOZ_CAPABILITY_RELEASE() {
+ if (mLiveSet) {
+ mLiveSet->Unlock();
+ }
+ }
+
+ void Unlock() MOZ_CAPABILITY_RELEASE() {
+ MOZ_ASSERT(mLiveSet);
+ if (mLiveSet) {
+ mLiveSet->Unlock();
+ mLiveSet = nullptr;
+ }
+ }
+
+ LiveSetAutoLock(const LiveSetAutoLock& aOther) = delete;
+ LiveSetAutoLock(LiveSetAutoLock&& aOther) = delete;
+ LiveSetAutoLock& operator=(const LiveSetAutoLock& aOther) = delete;
+ LiveSetAutoLock& operator=(LiveSetAutoLock&& aOther) = delete;
+
+ private:
+ LiveSet* mLiveSet;
+};
+
+class MOZ_RAII ReentrySentinel final {
+ public:
+ explicit ReentrySentinel(Interceptor* aCurrent) : mCurInterceptor(aCurrent) {
+ static const bool kHasTls = tlsSentinelStackTop.init();
+ MOZ_RELEASE_ASSERT(kHasTls);
+
+ mPrevSentinel = tlsSentinelStackTop.get();
+ tlsSentinelStackTop.set(this);
+ }
+
+ ~ReentrySentinel() { tlsSentinelStackTop.set(mPrevSentinel); }
+
+ bool IsOutermost() const {
+ return !(mPrevSentinel && mPrevSentinel->IsMarshaling(mCurInterceptor));
+ }
+
+ ReentrySentinel(const ReentrySentinel&) = delete;
+ ReentrySentinel(ReentrySentinel&&) = delete;
+ ReentrySentinel& operator=(const ReentrySentinel&) = delete;
+ ReentrySentinel& operator=(ReentrySentinel&&) = delete;
+
+ private:
+ bool IsMarshaling(Interceptor* aTopInterceptor) const {
+ return aTopInterceptor == mCurInterceptor ||
+ (mPrevSentinel && mPrevSentinel->IsMarshaling(aTopInterceptor));
+ }
+
+ private:
+ Interceptor* mCurInterceptor;
+ ReentrySentinel* mPrevSentinel;
+
+ static MOZ_THREAD_LOCAL(ReentrySentinel*) tlsSentinelStackTop;
+};
+
+MOZ_THREAD_LOCAL(ReentrySentinel*) ReentrySentinel::tlsSentinelStackTop;
+
+class MOZ_RAII LoggedQIResult final {
+ public:
+ explicit LoggedQIResult(REFIID aIid)
+ : mIid(aIid),
+ mHr(E_UNEXPECTED),
+ mTarget(nullptr),
+ mInterceptor(nullptr),
+ mBegin(TimeStamp::Now()) {}
+
+ ~LoggedQIResult() {
+ if (!mTarget) {
+ return;
+ }
+
+ TimeStamp end(TimeStamp::Now());
+ TimeDuration total(end - mBegin);
+ TimeDuration overhead(total - mNonOverheadDuration);
+
+ InterceptorLog::QI(mHr, mTarget, mIid, mInterceptor, &overhead,
+ &mNonOverheadDuration);
+ }
+
+ void Log(IUnknown* aTarget, IUnknown* aInterceptor) {
+ mTarget = aTarget;
+ mInterceptor = aInterceptor;
+ }
+
+ void operator=(HRESULT aHr) { mHr = aHr; }
+
+ operator HRESULT() { return mHr; }
+
+ operator TimeDuration*() { return &mNonOverheadDuration; }
+
+ LoggedQIResult(const LoggedQIResult&) = delete;
+ LoggedQIResult(LoggedQIResult&&) = delete;
+ LoggedQIResult& operator=(const LoggedQIResult&) = delete;
+ LoggedQIResult& operator=(LoggedQIResult&&) = delete;
+
+ private:
+ REFIID mIid;
+ HRESULT mHr;
+ IUnknown* mTarget;
+ IUnknown* mInterceptor;
+ TimeDuration mNonOverheadDuration;
+ TimeStamp mBegin;
+};
+
+} // namespace detail
+
+static detail::LiveSet& GetLiveSet() {
+ static detail::LiveSet sLiveSet;
+ return sLiveSet;
+}
+
+MOZ_THREAD_LOCAL(bool) Interceptor::tlsCreatingStdMarshal;
+
+/* static */
+HRESULT Interceptor::Create(STAUniquePtr<IUnknown> aTarget,
+ IInterceptorSink* aSink, REFIID aInitialIid,
+ void** aOutInterface) {
+ MOZ_ASSERT(aOutInterface && aTarget && aSink);
+ if (!aOutInterface) {
+ return E_INVALIDARG;
+ }
+
+ detail::LiveSetAutoLock lock(GetLiveSet());
+
+ RefPtr<IWeakReference> existingWeak(GetLiveSet().Get(aTarget.get()));
+ if (existingWeak) {
+ RefPtr<IWeakReferenceSource> existingStrong;
+ if (SUCCEEDED(existingWeak->ToStrongRef(getter_AddRefs(existingStrong)))) {
+ // QI on existingStrong may touch other threads. Since we now hold a
+ // strong ref on the interceptor, we may now release the lock.
+ lock.Unlock();
+ return existingStrong->QueryInterface(aInitialIid, aOutInterface);
+ }
+ }
+
+ *aOutInterface = nullptr;
+
+ if (!aTarget || !aSink) {
+ return E_INVALIDARG;
+ }
+
+ RefPtr<Interceptor> intcpt(new Interceptor(aSink));
+ return intcpt->GetInitialInterceptorForIID(lock, aInitialIid,
+ std::move(aTarget), aOutInterface);
+}
+
+Interceptor::Interceptor(IInterceptorSink* aSink)
+ : WeakReferenceSupport(WeakReferenceSupport::Flags::eDestroyOnMainThread),
+ mEventSink(aSink),
+ mInterceptorMapMutex("mozilla::mscom::Interceptor::mInterceptorMapMutex"),
+ mStdMarshalMutex("mozilla::mscom::Interceptor::mStdMarshalMutex"),
+ mStdMarshal(nullptr) {
+ static const bool kHasTls = tlsCreatingStdMarshal.init();
+ MOZ_ASSERT(kHasTls);
+ Unused << kHasTls;
+
+ MOZ_ASSERT(aSink);
+ RefPtr<IWeakReference> weakRef;
+ if (SUCCEEDED(GetWeakReference(getter_AddRefs(weakRef)))) {
+ aSink->SetInterceptor(weakRef);
+ }
+}
+
+Interceptor::~Interceptor() {
+ { // Scope for lock
+ detail::LiveSetAutoLock lock(GetLiveSet());
+ GetLiveSet().Remove(mTarget.get());
+ }
+
+ // This needs to run on the main thread because it releases target interface
+ // reference counts which may not be thread-safe.
+ MOZ_ASSERT(NS_IsMainThread());
+ for (uint32_t index = 0, len = mInterceptorMap.Length(); index < len;
+ ++index) {
+ MapEntry& entry = mInterceptorMap[index];
+ entry.mInterceptor = nullptr;
+ entry.mTargetInterface->Release();
+ }
+}
+
+HRESULT
+Interceptor::GetClassForHandler(DWORD aDestContext, void* aDestContextPtr,
+ CLSID* aHandlerClsid) {
+ if (aDestContextPtr || !aHandlerClsid ||
+ aDestContext == MSHCTX_DIFFERENTMACHINE) {
+ return E_INVALIDARG;
+ }
+
+ MOZ_ASSERT(mEventSink);
+ return mEventSink->GetHandler(WrapNotNull(aHandlerClsid));
+}
+
+REFIID
+Interceptor::MarshalAs(REFIID aIid) const {
+#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+ return IsCallerExternalProcess() ? aIid : mEventSink->MarshalAs(aIid);
+#else
+ return mEventSink->MarshalAs(aIid);
+#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+}
+
+HRESULT
+Interceptor::GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ CLSID* pCid) {
+ return mStdMarshal->GetUnmarshalClass(MarshalAs(riid), pv, dwDestContext,
+ pvDestContext, mshlflags, pCid);
+}
+
+HRESULT
+Interceptor::GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ DWORD* pSize) {
+ detail::ReentrySentinel sentinel(this);
+
+ HRESULT hr = mStdMarshal->GetMarshalSizeMax(
+ MarshalAs(riid), pv, dwDestContext, pvDestContext, mshlflags, pSize);
+ if (FAILED(hr) || !sentinel.IsOutermost()) {
+ return hr;
+ }
+
+#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+ if (XRE_IsContentProcess() && IsCallerExternalProcess()) {
+ // The caller isn't our chrome process, so we do not provide a handler
+ // payload. Even though we're only getting the size here, calculating the
+ // payload size might actually require building the payload.
+ return hr;
+ }
+#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+
+ DWORD payloadSize = 0;
+ hr = mEventSink->GetHandlerPayloadSize(WrapNotNull(this),
+ WrapNotNull(&payloadSize));
+ if (hr == E_NOTIMPL) {
+ return S_OK;
+ }
+
+ if (SUCCEEDED(hr)) {
+ *pSize += payloadSize;
+ }
+ return hr;
+}
+
+HRESULT
+Interceptor::MarshalInterface(IStream* pStm, REFIID riid, void* pv,
+ DWORD dwDestContext, void* pvDestContext,
+ DWORD mshlflags) {
+ detail::ReentrySentinel sentinel(this);
+
+ HRESULT hr;
+
+#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+ // Save the current stream position
+ LARGE_INTEGER seekTo;
+ seekTo.QuadPart = 0;
+
+ ULARGE_INTEGER objrefPos;
+
+ hr = pStm->Seek(seekTo, STREAM_SEEK_CUR, &objrefPos);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+
+ hr = mStdMarshal->MarshalInterface(pStm, MarshalAs(riid), pv, dwDestContext,
+ pvDestContext, mshlflags);
+ if (FAILED(hr) || !sentinel.IsOutermost()) {
+ return hr;
+ }
+
+#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+ if (XRE_IsContentProcess() && IsCallerExternalProcess()) {
+ // The caller isn't our chrome process, so do not provide a handler.
+
+ // First, save the current position that marks the current end of the
+ // OBJREF in the stream.
+ ULARGE_INTEGER endPos;
+ hr = pStm->Seek(seekTo, STREAM_SEEK_CUR, &endPos);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ // Now strip out the handler.
+ if (!StripHandlerFromOBJREF(WrapNotNull(pStm), objrefPos.QuadPart,
+ endPos.QuadPart)) {
+ return E_FAIL;
+ }
+
+ return S_OK;
+ }
+#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+
+ hr = mEventSink->WriteHandlerPayload(WrapNotNull(this), WrapNotNull(pStm));
+ if (hr == E_NOTIMPL) {
+ return S_OK;
+ }
+
+ return hr;
+}
+
+HRESULT
+Interceptor::UnmarshalInterface(IStream* pStm, REFIID riid, void** ppv) {
+ return mStdMarshal->UnmarshalInterface(pStm, riid, ppv);
+}
+
+HRESULT
+Interceptor::ReleaseMarshalData(IStream* pStm) {
+ return mStdMarshal->ReleaseMarshalData(pStm);
+}
+
+HRESULT
+Interceptor::DisconnectObject(DWORD dwReserved) {
+ mEventSink->DisconnectHandlerRemotes();
+ return mStdMarshal->DisconnectObject(dwReserved);
+}
+
+Interceptor::MapEntry* Interceptor::Lookup(REFIID aIid) {
+ mInterceptorMapMutex.AssertCurrentThreadOwns();
+
+ for (uint32_t index = 0, len = mInterceptorMap.Length(); index < len;
+ ++index) {
+ if (mInterceptorMap[index].mIID == aIid) {
+ return &mInterceptorMap[index];
+ }
+ }
+ return nullptr;
+}
+
+HRESULT
+Interceptor::GetTargetForIID(REFIID aIid,
+ InterceptorTargetPtr<IUnknown>& aTarget) {
+ MutexAutoLock lock(mInterceptorMapMutex);
+ MapEntry* entry = Lookup(aIid);
+ if (entry) {
+ aTarget.reset(entry->mTargetInterface);
+ return S_OK;
+ }
+
+ return E_NOINTERFACE;
+}
+
+// CoGetInterceptor requires type metadata to be able to generate its emulated
+// vtable. If no registered metadata is available, CoGetInterceptor returns
+// kFileNotFound.
+static const HRESULT kFileNotFound = HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
+
+HRESULT
+Interceptor::CreateInterceptor(REFIID aIid, IUnknown* aOuter,
+ IUnknown** aOutput) {
+ // In order to aggregate, we *must* request IID_IUnknown as the initial
+ // interface for the interceptor, as that IUnknown is non-delegating.
+ // This is a fundamental rule for creating aggregated objects in COM.
+ HRESULT hr = ::CoGetInterceptor(aIid, aOuter, IID_IUnknown, (void**)aOutput);
+ if (hr != kFileNotFound) {
+ return hr;
+ }
+
+ // In the case that CoGetInterceptor returns kFileNotFound, we can try to
+ // explicitly load typelib data from our runtime registration facility and
+ // pass that into CoGetInterceptorFromTypeInfo.
+
+ RefPtr<ITypeInfo> typeInfo;
+ bool found = RegisteredProxy::Find(aIid, getter_AddRefs(typeInfo));
+ // If this assert fires then we have omitted registering the typelib for a
+ // required interface. To fix this, review our calls to mscom::RegisterProxy
+ // and mscom::RegisterTypelib, and add the additional typelib as necessary.
+ MOZ_ASSERT(found);
+ if (!found) {
+ return kFileNotFound;
+ }
+
+ hr = ::CoGetInterceptorFromTypeInfo(aIid, aOuter, typeInfo, IID_IUnknown,
+ (void**)aOutput);
+ // If this assert fires then the interceptor doesn't like something about
+ // the format of the typelib. One thing in particular that it doesn't like
+ // is complex types that contain unions.
+ MOZ_ASSERT(SUCCEEDED(hr));
+ return hr;
+}
+
+HRESULT
+Interceptor::PublishTarget(detail::LiveSetAutoLock& aLiveSetLock,
+ RefPtr<IUnknown> aInterceptor, REFIID aTargetIid,
+ STAUniquePtr<IUnknown> aTarget)
+ MOZ_NO_THREAD_SAFETY_ANALYSIS {
+ // Suppress thread safety analysis as this conditionally releases locks.
+ RefPtr<IWeakReference> weakRef;
+ HRESULT hr = GetWeakReference(getter_AddRefs(weakRef));
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ // mTarget is a weak reference to aTarget. This is safe because we transfer
+ // ownership of aTarget into mInterceptorMap which remains live for the
+ // lifetime of this Interceptor.
+ mTarget = ToInterceptorTargetPtr(aTarget);
+ GetLiveSet().Put(mTarget.get(), weakRef.forget());
+
+ // Now we transfer aTarget's ownership into mInterceptorMap.
+ mInterceptorMap.AppendElement(
+ MapEntry(aTargetIid, aInterceptor, aTarget.release()));
+
+ // Release the live set lock because subsequent operations may post work to
+ // the main thread, creating potential for deadlocks.
+ aLiveSetLock.Unlock();
+ return S_OK;
+}
+
+HRESULT
+Interceptor::GetInitialInterceptorForIID(
+ detail::LiveSetAutoLock& aLiveSetLock, REFIID aTargetIid,
+ STAUniquePtr<IUnknown> aTarget,
+ void** aOutInterceptor) MOZ_NO_THREAD_SAFETY_ANALYSIS {
+ // Suppress thread safety analysis as this conditionally releases locks.
+ MOZ_ASSERT(aOutInterceptor);
+ MOZ_ASSERT(aTargetIid != IID_IMarshal);
+ MOZ_ASSERT(!IsProxy(aTarget.get()));
+
+ HRESULT hr = E_UNEXPECTED;
+
+ auto hasFailed = [&hr]() -> bool { return FAILED(hr); };
+
+ MOZ_PUSH_IGNORE_THREAD_SAFETY // Avoid the lambda upsetting analysis.
+ auto cleanup = [&aLiveSetLock]() -> void { aLiveSetLock.Unlock(); };
+ MOZ_POP_THREAD_SAFETY
+
+ ExecuteWhen<decltype(hasFailed), decltype(cleanup)> onFail(hasFailed,
+ cleanup);
+
+ if (aTargetIid == IID_IUnknown) {
+ // We must lock mInterceptorMapMutex so that nothing can race with us once
+ // we have been published to the live set.
+ MutexAutoLock lock(mInterceptorMapMutex);
+
+ hr = PublishTarget(aLiveSetLock, nullptr, aTargetIid, std::move(aTarget));
+ ENSURE_HR_SUCCEEDED(hr);
+
+ hr = QueryInterface(aTargetIid, aOutInterceptor);
+ ENSURE_HR_SUCCEEDED(hr);
+ return hr;
+ }
+
+ // Raise the refcount for stabilization purposes during aggregation
+ WeakReferenceSupport::StabilizeRefCount stabilizer(*this);
+
+ RefPtr<IUnknown> unkInterceptor;
+ hr = CreateInterceptor(aTargetIid, static_cast<WeakReferenceSupport*>(this),
+ getter_AddRefs(unkInterceptor));
+ ENSURE_HR_SUCCEEDED(hr);
+
+ RefPtr<ICallInterceptor> interceptor;
+ hr = unkInterceptor->QueryInterface(IID_ICallInterceptor,
+ getter_AddRefs(interceptor));
+ ENSURE_HR_SUCCEEDED(hr);
+
+ hr = interceptor->RegisterSink(mEventSink);
+ ENSURE_HR_SUCCEEDED(hr);
+
+ // We must lock mInterceptorMapMutex so that nothing can race with us once we
+ // have been published to the live set.
+ MutexAutoLock lock(mInterceptorMapMutex);
+
+ hr = PublishTarget(aLiveSetLock, unkInterceptor, aTargetIid,
+ std::move(aTarget));
+ ENSURE_HR_SUCCEEDED(hr);
+
+ if (MarshalAs(aTargetIid) == aTargetIid) {
+ hr = unkInterceptor->QueryInterface(aTargetIid, aOutInterceptor);
+ ENSURE_HR_SUCCEEDED(hr);
+ return hr;
+ }
+
+ hr = GetInterceptorForIID(aTargetIid, aOutInterceptor, &lock);
+ ENSURE_HR_SUCCEEDED(hr);
+ return hr;
+}
+
+HRESULT
+Interceptor::GetInterceptorForIID(REFIID aIid, void** aOutInterceptor) {
+ return GetInterceptorForIID(aIid, aOutInterceptor, nullptr);
+}
+
+/**
+ * This method contains the core guts of the handling of QueryInterface calls
+ * that are delegated to us from the ICallInterceptor.
+ *
+ * @param aIid ID of the desired interface
+ * @param aOutInterceptor The resulting emulated vtable that corresponds to
+ * the interface specified by aIid.
+ * @param aAlreadyLocked Proof of an existing lock on |mInterceptorMapMutex|,
+ * if present.
+ */
+HRESULT
+Interceptor::GetInterceptorForIID(REFIID aIid, void** aOutInterceptor,
+ MutexAutoLock* aAlreadyLocked) {
+ detail::LoggedQIResult result(aIid);
+
+ if (!aOutInterceptor) {
+ return E_INVALIDARG;
+ }
+
+ if (aIid == IID_IUnknown) {
+ // Special case: When we see IUnknown, we just provide a reference to this
+ RefPtr<IInterceptor> intcpt(this);
+ intcpt.forget(aOutInterceptor);
+ return S_OK;
+ }
+
+ REFIID interceptorIid = MarshalAs(aIid);
+
+ RefPtr<IUnknown> unkInterceptor;
+ IUnknown* interfaceForQILog = nullptr;
+
+ // (1) Check to see if we already have an existing interceptor for
+ // interceptorIid.
+ auto doLookup = [&]() -> void {
+ MapEntry* entry = Lookup(interceptorIid);
+ if (entry) {
+ unkInterceptor = entry->mInterceptor;
+ interfaceForQILog = entry->mTargetInterface;
+ }
+ };
+
+ if (aAlreadyLocked) {
+ doLookup();
+ } else {
+ MutexAutoLock lock(mInterceptorMapMutex);
+ doLookup();
+ }
+
+ // (1a) A COM interceptor already exists for this interface, so all we need
+ // to do is run a QI on it.
+ if (unkInterceptor) {
+ // Technically we didn't actually execute a QI on the target interface, but
+ // for logging purposes we would like to record the fact that this interface
+ // was requested.
+ result.Log(mTarget.get(), interfaceForQILog);
+ result = unkInterceptor->QueryInterface(interceptorIid, aOutInterceptor);
+ ENSURE_HR_SUCCEEDED(result);
+ return result;
+ }
+
+ // (2) Obtain a new target interface.
+
+ // (2a) First, make sure that the target interface is available
+ // NB: We *MUST* query the correct interface! ICallEvents::Invoke casts its
+ // pvReceiver argument directly to the required interface! DO NOT assume
+ // that COM will use QI or upcast/downcast!
+ HRESULT hr;
+
+ STAUniquePtr<IUnknown> targetInterface;
+ IUnknown* rawTargetInterface = nullptr;
+ hr =
+ QueryInterfaceTarget(interceptorIid, (void**)&rawTargetInterface, result);
+ targetInterface.reset(rawTargetInterface);
+ result = hr;
+ result.Log(mTarget.get(), targetInterface.get());
+ MOZ_ASSERT(SUCCEEDED(hr) || hr == E_NOINTERFACE);
+ if (hr == E_NOINTERFACE) {
+ return hr;
+ }
+ ENSURE_HR_SUCCEEDED(hr);
+
+ // We *really* shouldn't be adding interceptors to proxies
+ MOZ_ASSERT(aIid != IID_IMarshal);
+
+ // (3) Create a new COM interceptor to that interface that delegates its
+ // IUnknown to |this|.
+
+ // Raise the refcount for stabilization purposes during aggregation
+ WeakReferenceSupport::StabilizeRefCount stabilizer(*this);
+
+ hr = CreateInterceptor(interceptorIid,
+ static_cast<WeakReferenceSupport*>(this),
+ getter_AddRefs(unkInterceptor));
+ ENSURE_HR_SUCCEEDED(hr);
+
+ // (4) Obtain the interceptor's ICallInterceptor interface and register our
+ // event sink.
+ RefPtr<ICallInterceptor> interceptor;
+ hr = unkInterceptor->QueryInterface(IID_ICallInterceptor,
+ (void**)getter_AddRefs(interceptor));
+ ENSURE_HR_SUCCEEDED(hr);
+
+ hr = interceptor->RegisterSink(mEventSink);
+ ENSURE_HR_SUCCEEDED(hr);
+
+ // (5) Now that we have this new COM interceptor, insert it into the map.
+ auto doInsertion = [&]() -> void {
+ // We might have raced with another thread, so first check that we don't
+ // already have an entry for this
+ MapEntry* entry = Lookup(interceptorIid);
+ if (entry && entry->mInterceptor) {
+ // Bug 1433046: Because of aggregation, the QI for |interceptor|
+ // AddRefed |this|, not |unkInterceptor|. Thus, releasing |unkInterceptor|
+ // will destroy the object. Before we do that, we must first release
+ // |interceptor|. Otherwise, |interceptor| would be invalidated when
+ // |unkInterceptor| is destroyed.
+ interceptor = nullptr;
+ unkInterceptor = entry->mInterceptor;
+ } else {
+ // MapEntry has a RefPtr to unkInterceptor, OTOH we must not touch the
+ // refcount for the target interface because we are just moving it into
+ // the map and its refcounting might not be thread-safe.
+ IUnknown* rawTargetInterface = targetInterface.release();
+ mInterceptorMap.AppendElement(
+ MapEntry(interceptorIid, unkInterceptor, rawTargetInterface));
+ }
+ };
+
+ if (aAlreadyLocked) {
+ doInsertion();
+ } else {
+ MutexAutoLock lock(mInterceptorMapMutex);
+ doInsertion();
+ }
+
+ hr = unkInterceptor->QueryInterface(interceptorIid, aOutInterceptor);
+ ENSURE_HR_SUCCEEDED(hr);
+ return hr;
+}
+
+HRESULT
+Interceptor::QueryInterfaceTarget(REFIID aIid, void** aOutput,
+ TimeDuration* aOutDuration) {
+ // NB: This QI needs to run on the main thread because the target object
+ // is probably Gecko code that is not thread-safe. Note that this main
+ // thread invocation is *synchronous*.
+ if (!NS_IsMainThread() && tlsCreatingStdMarshal.get()) {
+ mStdMarshalMutex.AssertCurrentThreadOwns();
+ // COM queries for special interfaces such as IFastRundown when creating a
+ // marshaler. We don't want these being dispatched to the main thread,
+ // since this would cause a deadlock on mStdMarshalMutex if the main
+ // thread is also querying for IMarshal. If we do need to respond to these
+ // special interfaces, this should be done before this point; e.g. in
+ // Interceptor::QueryInterface like we do for INoMarshal.
+ return E_NOINTERFACE;
+ }
+
+ if (mEventSink->IsInterfaceMaybeSupported(aIid) == E_NOINTERFACE) {
+ return E_NOINTERFACE;
+ }
+
+ MainThreadInvoker invoker;
+ HRESULT hr;
+ auto runOnMainThread = [&]() -> void {
+ MOZ_ASSERT(NS_IsMainThread());
+ hr = mTarget->QueryInterface(aIid, aOutput);
+ };
+ if (!invoker.Invoke(NS_NewRunnableFunction("Interceptor::QueryInterface",
+ runOnMainThread))) {
+ return E_FAIL;
+ }
+ if (aOutDuration) {
+ *aOutDuration = invoker.GetDuration();
+ }
+ return hr;
+}
+
+HRESULT
+Interceptor::QueryInterface(REFIID riid, void** ppv) {
+ if (riid == IID_INoMarshal) {
+ // This entire library is designed around marshaling, so there's no point
+ // propagating this QI request all over the place!
+ return E_NOINTERFACE;
+ }
+
+ return WeakReferenceSupport::QueryInterface(riid, ppv);
+}
+
+HRESULT
+Interceptor::WeakRefQueryInterface(REFIID aIid, IUnknown** aOutInterface) {
+ if (aIid == IID_IStdMarshalInfo) {
+ detail::ReentrySentinel sentinel(this);
+
+ if (!sentinel.IsOutermost()) {
+ return E_NOINTERFACE;
+ }
+
+ // Do not indicate that this interface is available unless we actually
+ // support it. We'll check that by looking for a successful call to
+ // IInterceptorSink::GetHandler()
+ CLSID dummy;
+ if (FAILED(mEventSink->GetHandler(WrapNotNull(&dummy)))) {
+ return E_NOINTERFACE;
+ }
+
+ RefPtr<IStdMarshalInfo> std(this);
+ std.forget(aOutInterface);
+ return S_OK;
+ }
+
+ if (aIid == IID_IMarshal) {
+ MutexAutoLock lock(mStdMarshalMutex);
+
+ HRESULT hr;
+
+ if (!mStdMarshalUnk) {
+ MOZ_ASSERT(!tlsCreatingStdMarshal.get());
+ tlsCreatingStdMarshal.set(true);
+ if (XRE_IsContentProcess()) {
+ hr = FastMarshaler::Create(static_cast<IWeakReferenceSource*>(this),
+ getter_AddRefs(mStdMarshalUnk));
+ } else {
+ hr = ::CoGetStdMarshalEx(static_cast<IWeakReferenceSource*>(this),
+ SMEXF_SERVER, getter_AddRefs(mStdMarshalUnk));
+ }
+ tlsCreatingStdMarshal.set(false);
+
+ ENSURE_HR_SUCCEEDED(hr);
+ }
+
+ if (!mStdMarshal) {
+ hr = mStdMarshalUnk->QueryInterface(IID_IMarshal, (void**)&mStdMarshal);
+ ENSURE_HR_SUCCEEDED(hr);
+
+ // mStdMarshal is weak, so drop its refcount
+ mStdMarshal->Release();
+ }
+
+ RefPtr<IMarshal> marshal(this);
+ marshal.forget(aOutInterface);
+ return S_OK;
+ }
+
+ if (aIid == IID_IInterceptor) {
+ RefPtr<IInterceptor> intcpt(this);
+ intcpt.forget(aOutInterface);
+ return S_OK;
+ }
+
+ if (aIid == IID_IDispatch) {
+ STAUniquePtr<IDispatch> disp;
+ IDispatch* rawDisp = nullptr;
+ HRESULT hr = QueryInterfaceTarget(aIid, (void**)&rawDisp);
+ ENSURE_HR_SUCCEEDED(hr);
+
+ disp.reset(rawDisp);
+ return DispatchForwarder::Create(this, disp, aOutInterface);
+ }
+
+ return GetInterceptorForIID(aIid, (void**)aOutInterface, nullptr);
+}
+
+ULONG
+Interceptor::AddRef() { return WeakReferenceSupport::AddRef(); }
+
+ULONG
+Interceptor::Release() { return WeakReferenceSupport::Release(); }
+
+/* static */
+HRESULT Interceptor::DisconnectRemotesForTarget(IUnknown* aTarget) {
+ MOZ_ASSERT(aTarget);
+
+ detail::LiveSetAutoLock lock(GetLiveSet());
+
+ // It is not an error if the interceptor doesn't exist, so we return
+ // S_FALSE instead of an error in that case.
+ RefPtr<IWeakReference> existingWeak(GetLiveSet().Get(aTarget));
+ if (!existingWeak) {
+ return S_FALSE;
+ }
+
+ RefPtr<IWeakReferenceSource> existingStrong;
+ if (FAILED(existingWeak->ToStrongRef(getter_AddRefs(existingStrong)))) {
+ return S_FALSE;
+ }
+ // Since we now hold a strong ref on the interceptor, we may now release the
+ // lock.
+ lock.Unlock();
+
+ return ::CoDisconnectObject(existingStrong, 0);
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/Interceptor.h b/ipc/mscom/Interceptor.h
new file mode 100644
index 0000000000..8ab578092b
--- /dev/null
+++ b/ipc/mscom/Interceptor.h
@@ -0,0 +1,199 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_Interceptor_h
+#define mozilla_mscom_Interceptor_h
+
+#include <callobj.h>
+#include <objidl.h>
+
+#include <utility>
+
+#include "mozilla/Mutex.h"
+#include "mozilla/RefPtr.h"
+#include "mozilla/mscom/IHandlerProvider.h"
+#include "mozilla/mscom/Ptr.h"
+#include "mozilla/mscom/WeakRef.h"
+#include "nsTArray.h"
+
+namespace mozilla {
+namespace mscom {
+namespace detail {
+
+class LiveSetAutoLock;
+
+} // namespace detail
+
+// {8831EB53-A937-42BC-9921-B3E1121FDF86}
+DEFINE_GUID(IID_IInterceptorSink, 0x8831eb53, 0xa937, 0x42bc, 0x99, 0x21, 0xb3,
+ 0xe1, 0x12, 0x1f, 0xdf, 0x86);
+
+struct IInterceptorSink : public ICallFrameEvents, public HandlerProvider {
+ virtual STDMETHODIMP SetInterceptor(IWeakReference* aInterceptor) = 0;
+};
+
+// {3710799B-ECA2-4165-B9B0-3FA1E4A9B230}
+DEFINE_GUID(IID_IInterceptor, 0x3710799b, 0xeca2, 0x4165, 0xb9, 0xb0, 0x3f,
+ 0xa1, 0xe4, 0xa9, 0xb2, 0x30);
+
+struct IInterceptor : public IUnknown {
+ virtual STDMETHODIMP GetTargetForIID(
+ REFIID aIid, InterceptorTargetPtr<IUnknown>& aTarget) = 0;
+ virtual STDMETHODIMP GetInterceptorForIID(REFIID aIid,
+ void** aOutInterceptor) = 0;
+ virtual STDMETHODIMP GetEventSink(IInterceptorSink** aSink) = 0;
+};
+
+/**
+ * The COM interceptor is the core functionality in mscom that allows us to
+ * redirect method calls to different threads. It emulates the vtable of a
+ * target interface. When a call is made on this emulated vtable, the call is
+ * packaged up into an instance of the ICallFrame interface which may be passed
+ * to other contexts for execution.
+ *
+ * In order to accomplish this, COM itself provides the CoGetInterceptor
+ * function, which instantiates an ICallInterceptor. Note, however, that
+ * ICallInterceptor only works on a single interface; we need to be able to
+ * interpose QueryInterface calls so that we can instantiate a new
+ * ICallInterceptor for each new interface that is requested.
+ *
+ * We accomplish this by using COM aggregation, which means that the
+ * ICallInterceptor delegates its IUnknown implementation to its outer object
+ * (the mscom::Interceptor we implement and control).
+ *
+ * ACHTUNG! mscom::Interceptor uses FastMarshaler to disable COM garbage
+ * collection. If you use this class, you *must* call
+ * Interceptor::DisconnectRemotesForTarget when an object is no longer relevant.
+ * Otherwise, the object will never be released, causing a leak.
+ */
+class Interceptor final : public WeakReferenceSupport,
+ public IStdMarshalInfo,
+ public IMarshal,
+ public IInterceptor {
+ public:
+ static HRESULT Create(STAUniquePtr<IUnknown> aTarget, IInterceptorSink* aSink,
+ REFIID aInitialIid, void** aOutInterface);
+
+ /**
+ * Disconnect all remote clients for a given target.
+ * Because Interceptors disable COM garbage collection to improve
+ * performance, they never receive Release calls from remote clients. If
+ * the object can be shut down while clients still hold a reference, this
+ * function can be used to force COM to disconnect all remote connections
+ * (using CoDisconnectObject) and thus release the associated references to
+ * the Interceptor, its target and any objects associated with the
+ * HandlerProvider.
+ * Note that the specified target must be the same IUnknown pointer used to
+ * create the Interceptor. Where there is multiple inheritance, querying for
+ * IID_IUnknown and calling this function with that pointer alone will not
+ * disconnect remotes for all interfaces. If you expect that the same object
+ * may be fetched with different initial interfaces, you should call this
+ * function once for each possible IUnknown pointer.
+ * @return S_OK if there was an Interceptor for the given target,
+ * S_FALSE if there was not.
+ */
+ static HRESULT DisconnectRemotesForTarget(IUnknown* aTarget);
+
+ // IUnknown
+ STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override;
+ STDMETHODIMP_(ULONG) AddRef() override;
+ STDMETHODIMP_(ULONG) Release() override;
+
+ // IStdMarshalInfo
+ STDMETHODIMP GetClassForHandler(DWORD aDestContext, void* aDestContextPtr,
+ CLSID* aHandlerClsid) override;
+
+ // IMarshal
+ STDMETHODIMP GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ CLSID* pCid) override;
+ STDMETHODIMP GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ DWORD* pSize) override;
+ STDMETHODIMP MarshalInterface(IStream* pStm, REFIID riid, void* pv,
+ DWORD dwDestContext, void* pvDestContext,
+ DWORD mshlflags) override;
+ STDMETHODIMP UnmarshalInterface(IStream* pStm, REFIID riid,
+ void** ppv) override;
+ STDMETHODIMP ReleaseMarshalData(IStream* pStm) override;
+ STDMETHODIMP DisconnectObject(DWORD dwReserved) override;
+
+ // IInterceptor
+ STDMETHODIMP GetTargetForIID(
+ REFIID aIid, InterceptorTargetPtr<IUnknown>& aTarget) override;
+ STDMETHODIMP GetInterceptorForIID(REFIID aIid,
+ void** aOutInterceptor) override;
+
+ STDMETHODIMP GetEventSink(IInterceptorSink** aSink) override {
+ RefPtr<IInterceptorSink> sink = mEventSink;
+ sink.forget(aSink);
+ return mEventSink ? S_OK : S_FALSE;
+ }
+
+ private:
+ struct MapEntry {
+ MapEntry(REFIID aIid, IUnknown* aInterceptor, IUnknown* aTargetInterface)
+ : mIID(aIid),
+ mInterceptor(aInterceptor),
+ mTargetInterface(aTargetInterface) {}
+
+ IID mIID;
+ RefPtr<IUnknown> mInterceptor;
+ IUnknown* mTargetInterface;
+ };
+
+ private:
+ explicit Interceptor(IInterceptorSink* aSink);
+ ~Interceptor();
+ HRESULT GetInitialInterceptorForIID(detail::LiveSetAutoLock& aLiveSetLock,
+ REFIID aTargetIid,
+ STAUniquePtr<IUnknown> aTarget,
+ void** aOutInterface);
+ HRESULT GetInterceptorForIID(REFIID aIid, void** aOutInterceptor,
+ MutexAutoLock* aAlreadyLocked);
+ MapEntry* Lookup(REFIID aIid);
+ HRESULT QueryInterfaceTarget(REFIID aIid, void** aOutput,
+ TimeDuration* aOutDuration = nullptr);
+ HRESULT WeakRefQueryInterface(REFIID aIid, IUnknown** aOutInterface) override;
+ HRESULT CreateInterceptor(REFIID aIid, IUnknown* aOuter, IUnknown** aOutput);
+ REFIID MarshalAs(REFIID aIid) const;
+ HRESULT PublishTarget(detail::LiveSetAutoLock& aLiveSetLock,
+ RefPtr<IUnknown> aInterceptor, REFIID aTargetIid,
+ STAUniquePtr<IUnknown> aTarget);
+
+ private:
+ InterceptorTargetPtr<IUnknown> mTarget;
+ RefPtr<IInterceptorSink> mEventSink;
+ mozilla::Mutex mInterceptorMapMutex
+ MOZ_UNANNOTATED; // Guards mInterceptorMap
+ // Using a nsTArray since the # of interfaces is not going to be very high
+ nsTArray<MapEntry> mInterceptorMap;
+ mozilla::Mutex mStdMarshalMutex
+ MOZ_UNANNOTATED; // Guards mStdMarshalUnk and mStdMarshal
+ RefPtr<IUnknown> mStdMarshalUnk;
+ IMarshal* mStdMarshal; // WEAK
+ static MOZ_THREAD_LOCAL(bool) tlsCreatingStdMarshal;
+};
+
+template <typename InterfaceT>
+inline HRESULT CreateInterceptor(STAUniquePtr<InterfaceT> aTargetInterface,
+ IInterceptorSink* aEventSink,
+ InterfaceT** aOutInterface) {
+ if (!aTargetInterface || !aEventSink) {
+ return E_INVALIDARG;
+ }
+
+ REFIID iidTarget = __uuidof(InterfaceT);
+
+ STAUniquePtr<IUnknown> targetUnknown(aTargetInterface.release());
+ return Interceptor::Create(std::move(targetUnknown), aEventSink, iidTarget,
+ (void**)aOutInterface);
+}
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_Interceptor_h
diff --git a/ipc/mscom/InterceptorLog.cpp b/ipc/mscom/InterceptorLog.cpp
new file mode 100644
index 0000000000..0e127c41b6
--- /dev/null
+++ b/ipc/mscom/InterceptorLog.cpp
@@ -0,0 +1,527 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/mscom/InterceptorLog.h"
+
+#include <callobj.h>
+
+#include <utility>
+
+#include "MainThreadUtils.h"
+#include "mozilla/ClearOnShutdown.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/IntegerRange.h"
+#include "mozilla/Mutex.h"
+#include "mozilla/Services.h"
+#include "mozilla/StaticPtr.h"
+#include "mozilla/TimeStamp.h"
+#include "mozilla/Unused.h"
+#include "mozilla/mscom/Registration.h"
+#include "mozilla/mscom/Utils.h"
+#include "nsAppDirectoryServiceDefs.h"
+#include "nsDirectoryServiceDefs.h"
+#include "nsDirectoryServiceUtils.h"
+#include "nsIObserver.h"
+#include "nsIObserverService.h"
+#include "nsIOutputStream.h"
+#include "nsNetUtil.h"
+#include "nsPrintfCString.h"
+#include "nsReadableUtils.h"
+#include "nsTArray.h"
+#include "nsThreadUtils.h"
+#include "nsXPCOMPrivate.h"
+#include "nsXULAppAPI.h"
+#include "prenv.h"
+
+using mozilla::DebugOnly;
+using mozilla::Mutex;
+using mozilla::MutexAutoLock;
+using mozilla::NewNonOwningRunnableMethod;
+using mozilla::StaticAutoPtr;
+using mozilla::TimeDuration;
+using mozilla::TimeStamp;
+using mozilla::Unused;
+using mozilla::mscom::ArrayData;
+using mozilla::mscom::FindArrayData;
+using mozilla::mscom::IsValidGUID;
+using mozilla::services::GetObserverService;
+
+namespace {
+
+class ShutdownEvent final : public nsIObserver {
+ public:
+ NS_DECL_ISUPPORTS
+ NS_DECL_NSIOBSERVER
+
+ private:
+ ~ShutdownEvent() {}
+};
+
+NS_IMPL_ISUPPORTS(ShutdownEvent, nsIObserver)
+
+class Logger final {
+ public:
+ explicit Logger(const nsACString& aLeafBaseName);
+ bool IsValid() {
+ MutexAutoLock lock(mMutex);
+ return !!mThread;
+ }
+ void LogQI(HRESULT aResult, IUnknown* aTarget, REFIID aIid,
+ IUnknown* aInterface, const TimeDuration* aOverheadDuration,
+ const TimeDuration* aGeckoDuration);
+ void CaptureFrame(ICallFrame* aCallFrame, IUnknown* aTargetInterface,
+ nsACString& aCapturedFrame);
+ void LogEvent(const nsACString& aCapturedFrame,
+ const TimeDuration& aOverheadDuration,
+ const TimeDuration& aGeckoDuration);
+ nsresult Shutdown();
+
+ private:
+ void OpenFile();
+ void Flush();
+ void CloseFile();
+ void AssertRunningOnLoggerThread();
+ bool VariantToString(const VARIANT& aVariant, nsACString& aOut,
+ LONG aIndex = 0);
+ bool TryParamAsGuid(REFIID aIid, ICallFrame* aCallFrame,
+ const CALLFRAMEPARAMINFO& aParamInfo, nsACString& aLine);
+ static double GetElapsedTime();
+
+ nsCOMPtr<nsIFile> mLogFileName;
+ nsCOMPtr<nsIOutputStream> mLogFile; // Only accessed by mThread
+ Mutex mMutex MOZ_UNANNOTATED; // Guards mThread and mEntries
+ nsCOMPtr<nsIThread> mThread;
+ nsTArray<nsCString> mEntries;
+};
+
+Logger::Logger(const nsACString& aLeafBaseName)
+ : mMutex("mozilla::com::InterceptorLog::Logger") {
+ MOZ_ASSERT(NS_IsMainThread());
+ nsCOMPtr<nsIFile> logFileName;
+ GeckoProcessType procType = XRE_GetProcessType();
+ nsAutoCString leafName(aLeafBaseName);
+ nsresult rv;
+ if (procType == GeckoProcessType_Default) {
+ leafName.AppendLiteral("-Parent-");
+ rv = NS_GetSpecialDirectory(NS_OS_TEMP_DIR, getter_AddRefs(logFileName));
+ } else if (procType == GeckoProcessType_Content) {
+ leafName.AppendLiteral("-Content-");
+#if defined(MOZ_SANDBOX)
+ rv = NS_GetSpecialDirectory(NS_APP_CONTENT_PROCESS_TEMP_DIR,
+ getter_AddRefs(logFileName));
+#else
+ rv = NS_GetSpecialDirectory(NS_OS_TEMP_DIR, getter_AddRefs(logFileName));
+#endif // defined(MOZ_SANDBOX)
+ } else {
+ return;
+ }
+ if (NS_FAILED(rv)) {
+ return;
+ }
+ DWORD pid = GetCurrentProcessId();
+ leafName.AppendPrintf("%lu.log", pid);
+ // Using AppendNative here because Windows
+ rv = logFileName->AppendNative(leafName);
+ if (NS_FAILED(rv)) {
+ return;
+ }
+ mLogFileName.swap(logFileName);
+
+ nsCOMPtr<nsIObserverService> obsSvc = GetObserverService();
+ nsCOMPtr<nsIObserver> shutdownEvent = new ShutdownEvent();
+ rv = obsSvc->AddObserver(shutdownEvent, NS_XPCOM_SHUTDOWN_THREADS_OBSERVER_ID,
+ false);
+ if (NS_FAILED(rv)) {
+ return;
+ }
+
+ nsCOMPtr<nsIRunnable> openRunnable(
+ NewNonOwningRunnableMethod("Logger::OpenFile", this, &Logger::OpenFile));
+ rv = NS_NewNamedThread("COM Intcpt Log", getter_AddRefs(mThread),
+ openRunnable);
+ if (NS_FAILED(rv)) {
+ obsSvc->RemoveObserver(shutdownEvent,
+ NS_XPCOM_SHUTDOWN_THREADS_OBSERVER_ID);
+ }
+}
+
+void Logger::AssertRunningOnLoggerThread() {
+#if defined(DEBUG)
+ nsCOMPtr<nsIThread> curThread;
+ if (NS_FAILED(NS_GetCurrentThread(getter_AddRefs(curThread)))) {
+ return;
+ }
+ MutexAutoLock lock(mMutex);
+ MOZ_ASSERT(curThread == mThread);
+#endif
+}
+
+void Logger::OpenFile() {
+ AssertRunningOnLoggerThread();
+ MOZ_ASSERT(mLogFileName && !mLogFile);
+ NS_NewLocalFileOutputStream(getter_AddRefs(mLogFile), mLogFileName,
+ PR_WRONLY | PR_CREATE_FILE | PR_TRUNCATE,
+ PR_IRUSR | PR_IWUSR | PR_IRGRP);
+}
+
+void Logger::CloseFile() {
+ AssertRunningOnLoggerThread();
+ MOZ_ASSERT(mLogFile);
+ if (!mLogFile) {
+ return;
+ }
+ Flush();
+ mLogFile->Close();
+ mLogFile = nullptr;
+}
+
+nsresult Logger::Shutdown() {
+ MOZ_ASSERT(NS_IsMainThread());
+ nsresult rv = mThread->Dispatch(
+ NewNonOwningRunnableMethod("Logger::CloseFile", this, &Logger::CloseFile),
+ NS_DISPATCH_NORMAL);
+ NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "Dispatch failed");
+
+ rv = mThread->Shutdown();
+ NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "Shutdown failed");
+ (void)rv;
+ return NS_OK;
+}
+
+bool Logger::VariantToString(const VARIANT& aVariant, nsACString& aOut,
+ LONG aIndex) {
+ switch (aVariant.vt) {
+ case VT_DISPATCH: {
+ aOut.AppendPrintf("(IDispatch*) 0x%p", aVariant.pdispVal);
+ return true;
+ }
+ case VT_DISPATCH | VT_BYREF: {
+ aOut.AppendPrintf("(IDispatch*) 0x%p", (aVariant.ppdispVal)[aIndex]);
+ return true;
+ }
+ case VT_UNKNOWN: {
+ aOut.AppendPrintf("(IUnknown*) 0x%p", aVariant.punkVal);
+ return true;
+ }
+ case VT_UNKNOWN | VT_BYREF: {
+ aOut.AppendPrintf("(IUnknown*) 0x%p", (aVariant.ppunkVal)[aIndex]);
+ return true;
+ }
+ case VT_VARIANT | VT_BYREF: {
+ return VariantToString((aVariant.pvarVal)[aIndex], aOut);
+ }
+ case VT_I4 | VT_BYREF: {
+ aOut.AppendPrintf("%ld", aVariant.plVal[aIndex]);
+ return true;
+ }
+ case VT_UI4 | VT_BYREF: {
+ aOut.AppendPrintf("%lu", aVariant.pulVal[aIndex]);
+ return true;
+ }
+ case VT_I4: {
+ aOut.AppendPrintf("%ld", aVariant.lVal);
+ return true;
+ }
+ case VT_UI4: {
+ aOut.AppendPrintf("%lu", aVariant.ulVal);
+ return true;
+ }
+ case VT_EMPTY: {
+ aOut.AppendLiteral("(empty VARIANT)");
+ return true;
+ }
+ case VT_NULL: {
+ aOut.AppendLiteral("(null VARIANT)");
+ return true;
+ }
+ case VT_BSTR: {
+ aOut.AppendPrintf("\"%S\"", aVariant.bstrVal);
+ return true;
+ }
+ case VT_BSTR | VT_BYREF: {
+ aOut.AppendPrintf("\"%S\"", *aVariant.pbstrVal);
+ return true;
+ }
+ default: {
+ aOut.AppendPrintf("(VariantToString failed, VARTYPE == 0x%04hx)",
+ aVariant.vt);
+ return false;
+ }
+ }
+}
+
+/* static */
+double Logger::GetElapsedTime() {
+ TimeStamp ts = TimeStamp::Now();
+ TimeDuration duration = ts - TimeStamp::ProcessCreation();
+ return duration.ToMicroseconds();
+}
+
+void Logger::LogQI(HRESULT aResult, IUnknown* aTarget, REFIID aIid,
+ IUnknown* aInterface, const TimeDuration* aOverheadDuration,
+ const TimeDuration* aGeckoDuration) {
+ if (FAILED(aResult)) {
+ return;
+ }
+
+ double elapsed = GetElapsedTime();
+
+ nsAutoCString strOverheadDuration;
+ if (aOverheadDuration) {
+ strOverheadDuration.AppendPrintf("%.3f",
+ aOverheadDuration->ToMicroseconds());
+ } else {
+ strOverheadDuration.AppendLiteral("(none)");
+ }
+
+ nsAutoCString strGeckoDuration;
+ if (aGeckoDuration) {
+ strGeckoDuration.AppendPrintf("%.3f", aGeckoDuration->ToMicroseconds());
+ } else {
+ strGeckoDuration.AppendLiteral("(none)");
+ }
+
+ nsPrintfCString line("%.3f\t%s\t%s\t0x%p\tIUnknown::QueryInterface\t([in] ",
+ elapsed, strOverheadDuration.get(),
+ strGeckoDuration.get(), aTarget);
+
+ WCHAR buf[39] = {0};
+ if (StringFromGUID2(aIid, buf, mozilla::ArrayLength(buf))) {
+ line.AppendPrintf("%S", buf);
+ } else {
+ line.AppendLiteral("(IID Conversion Failed)");
+ }
+ line.AppendPrintf(", [out] 0x%p)\t0x%08lX\n", aInterface, aResult);
+
+ MutexAutoLock lock(mMutex);
+ mEntries.AppendElement(std::move(line));
+ mThread->Dispatch(
+ NewNonOwningRunnableMethod("Logger::Flush", this, &Logger::Flush),
+ NS_DISPATCH_NORMAL);
+}
+
+bool Logger::TryParamAsGuid(REFIID aIid, ICallFrame* aCallFrame,
+ const CALLFRAMEPARAMINFO& aParamInfo,
+ nsACString& aLine) {
+ if (aIid != IID_IServiceProvider) {
+ return false;
+ }
+
+ GUID** guid = reinterpret_cast<GUID**>(
+ static_cast<BYTE*>(aCallFrame->GetStackLocation()) +
+ aParamInfo.stackOffset);
+
+ if (!IsValidGUID(**guid)) {
+ return false;
+ }
+
+ WCHAR buf[39] = {0};
+ if (!StringFromGUID2(**guid, buf, mozilla::ArrayLength(buf))) {
+ return false;
+ }
+
+ aLine.AppendPrintf("%S", buf);
+ return true;
+}
+
+void Logger::CaptureFrame(ICallFrame* aCallFrame, IUnknown* aTargetInterface,
+ nsACString& aCapturedFrame) {
+ aCapturedFrame.Truncate();
+
+ // (1) Gather info about the call
+ CALLFRAMEINFO callInfo;
+ HRESULT hr = aCallFrame->GetInfo(&callInfo);
+ if (FAILED(hr)) {
+ return;
+ }
+
+ PWSTR interfaceName = nullptr;
+ PWSTR methodName = nullptr;
+ hr = aCallFrame->GetNames(&interfaceName, &methodName);
+ if (FAILED(hr)) {
+ return;
+ }
+
+ // (2) Serialize the call
+ nsPrintfCString line("0x%p\t%S::%S\t(", aTargetInterface, interfaceName,
+ methodName);
+
+ CoTaskMemFree(interfaceName);
+ interfaceName = nullptr;
+ CoTaskMemFree(methodName);
+ methodName = nullptr;
+
+ // Check for supplemental array data
+ const ArrayData* arrayData = FindArrayData(callInfo.iid, callInfo.iMethod);
+
+ for (ULONG paramIndex = 0; paramIndex < callInfo.cParams; ++paramIndex) {
+ CALLFRAMEPARAMINFO paramInfo;
+ hr = aCallFrame->GetParamInfo(paramIndex, &paramInfo);
+ if (SUCCEEDED(hr)) {
+ line.AppendLiteral("[");
+ if (paramInfo.fIn) {
+ line.AppendLiteral("in");
+ }
+ if (paramInfo.fOut) {
+ line.AppendLiteral("out");
+ }
+ line.AppendLiteral("] ");
+ }
+ VARIANT paramValue;
+ hr = aCallFrame->GetParam(paramIndex, &paramValue);
+ if (SUCCEEDED(hr)) {
+ if (arrayData && paramIndex == arrayData->mArrayParamIndex) {
+ VARIANT lengthParam;
+ hr = aCallFrame->GetParam(arrayData->mLengthParamIndex, &lengthParam);
+ if (SUCCEEDED(hr)) {
+ line.AppendLiteral("{ ");
+ StringJoinAppend(line, ", "_ns,
+ mozilla::IntegerRange<LONG>(0, *lengthParam.plVal),
+ [this, &paramValue](nsACString& line, const LONG i) {
+ VariantToString(paramValue, line, i);
+ });
+ line.AppendLiteral(" }");
+ } else {
+ line.AppendPrintf("(GetParam failed with HRESULT 0x%08lX)", hr);
+ }
+ } else {
+ VariantToString(paramValue, line);
+ }
+ } else if (hr != DISP_E_BADVARTYPE ||
+ !TryParamAsGuid(callInfo.iid, aCallFrame, paramInfo, line)) {
+ line.AppendPrintf("(GetParam failed with HRESULT 0x%08lX)", hr);
+ }
+ if (paramIndex < callInfo.cParams - 1) {
+ line.AppendLiteral(", ");
+ }
+ }
+ line.AppendLiteral(")\t");
+
+ HRESULT callResult = aCallFrame->GetReturnValue();
+ line.AppendPrintf("0x%08lX\n", callResult);
+
+ aCapturedFrame = std::move(line);
+}
+
+void Logger::LogEvent(const nsACString& aCapturedFrame,
+ const TimeDuration& aOverheadDuration,
+ const TimeDuration& aGeckoDuration) {
+ double elapsed = GetElapsedTime();
+
+ nsPrintfCString line("%.3f\t%.3f\t%.3f\t%s", elapsed,
+ aOverheadDuration.ToMicroseconds(),
+ aGeckoDuration.ToMicroseconds(),
+ PromiseFlatCString(aCapturedFrame).get());
+
+ MutexAutoLock lock(mMutex);
+ mEntries.AppendElement(line);
+ mThread->Dispatch(
+ NewNonOwningRunnableMethod("Logger::Flush", this, &Logger::Flush),
+ NS_DISPATCH_NORMAL);
+}
+
+void Logger::Flush() {
+ AssertRunningOnLoggerThread();
+ MOZ_ASSERT(mLogFile);
+ if (!mLogFile) {
+ return;
+ }
+ nsTArray<nsCString> linesToWrite;
+ { // Scope for lock
+ MutexAutoLock lock(mMutex);
+ linesToWrite = std::move(mEntries);
+ }
+
+ for (uint32_t i = 0, len = linesToWrite.Length(); i < len; ++i) {
+ uint32_t bytesWritten;
+ nsCString& line = linesToWrite[i];
+ nsresult rv = mLogFile->Write(line.get(), line.Length(), &bytesWritten);
+ Unused << NS_WARN_IF(NS_FAILED(rv));
+ }
+}
+
+StaticAutoPtr<Logger> sLogger;
+
+NS_IMETHODIMP
+ShutdownEvent::Observe(nsISupports* aSubject, const char* aTopic,
+ const char16_t* aData) {
+ if (strcmp(aTopic, NS_XPCOM_SHUTDOWN_THREADS_OBSERVER_ID)) {
+ MOZ_ASSERT(false);
+ return NS_ERROR_NOT_IMPLEMENTED;
+ }
+ MOZ_ASSERT(sLogger);
+ Unused << NS_WARN_IF(NS_FAILED(sLogger->Shutdown()));
+ nsCOMPtr<nsIObserver> kungFuDeathGrip(this);
+ nsCOMPtr<nsIObserverService> obsSvc = GetObserverService();
+ obsSvc->RemoveObserver(this, aTopic);
+ return NS_OK;
+}
+} // anonymous namespace
+
+static bool MaybeCreateLog(const char* aEnvVarName) {
+ MOZ_ASSERT(NS_IsMainThread());
+ MOZ_ASSERT(XRE_IsContentProcess() || XRE_IsParentProcess());
+ MOZ_ASSERT(!sLogger);
+ const char* leafBaseName = PR_GetEnv(aEnvVarName);
+ if (!leafBaseName) {
+ return false;
+ }
+ nsDependentCString strLeafBaseName(leafBaseName);
+ if (strLeafBaseName.IsEmpty()) {
+ return false;
+ }
+ sLogger = new Logger(strLeafBaseName);
+ if (!sLogger->IsValid()) {
+ sLogger = nullptr;
+ return false;
+ }
+ ClearOnShutdown(&sLogger);
+ return true;
+}
+
+namespace mozilla {
+namespace mscom {
+
+/* static */
+bool InterceptorLog::Init() {
+ static const bool isEnabled = MaybeCreateLog("MOZ_MSCOM_LOG_BASENAME");
+ return isEnabled;
+}
+
+/* static */
+void InterceptorLog::QI(HRESULT aResult, IUnknown* aTarget, REFIID aIid,
+ IUnknown* aInterface,
+ const TimeDuration* aOverheadDuration,
+ const TimeDuration* aGeckoDuration) {
+ if (!sLogger) {
+ return;
+ }
+ sLogger->LogQI(aResult, aTarget, aIid, aInterface, aOverheadDuration,
+ aGeckoDuration);
+}
+
+/* static */
+void InterceptorLog::CaptureFrame(ICallFrame* aCallFrame,
+ IUnknown* aTargetInterface,
+ nsACString& aCapturedFrame) {
+ if (!sLogger) {
+ return;
+ }
+ sLogger->CaptureFrame(aCallFrame, aTargetInterface, aCapturedFrame);
+}
+
+/* static */
+void InterceptorLog::Event(const nsACString& aCapturedFrame,
+ const TimeDuration& aOverheadDuration,
+ const TimeDuration& aGeckoDuration) {
+ if (!sLogger) {
+ return;
+ }
+ sLogger->LogEvent(aCapturedFrame, aOverheadDuration, aGeckoDuration);
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/InterceptorLog.h b/ipc/mscom/InterceptorLog.h
new file mode 100644
index 0000000000..893c2fcbb5
--- /dev/null
+++ b/ipc/mscom/InterceptorLog.h
@@ -0,0 +1,36 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_InterceptorLog_h
+#define mozilla_mscom_InterceptorLog_h
+
+#include "mozilla/TimeStamp.h"
+#include "nsString.h"
+
+struct ICallFrame;
+struct IUnknown;
+
+namespace mozilla {
+namespace mscom {
+
+class InterceptorLog {
+ public:
+ static bool Init();
+ static void QI(HRESULT aResult, IUnknown* aTarget, REFIID aIid,
+ IUnknown* aInterface,
+ const TimeDuration* aOverheadDuration = nullptr,
+ const TimeDuration* aGeckoDuration = nullptr);
+ static void CaptureFrame(ICallFrame* aCallFrame, IUnknown* aTarget,
+ nsACString& aCapturedFrame);
+ static void Event(const nsACString& aCapturedFrame,
+ const TimeDuration& aOverheadDuration,
+ const TimeDuration& aGeckoDuration);
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_InterceptorLog_h
diff --git a/ipc/mscom/MainThreadHandoff.cpp b/ipc/mscom/MainThreadHandoff.cpp
new file mode 100644
index 0000000000..544befd559
--- /dev/null
+++ b/ipc/mscom/MainThreadHandoff.cpp
@@ -0,0 +1,697 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#define INITGUID
+
+#include "mozilla/mscom/MainThreadHandoff.h"
+
+#include <utility>
+
+#include "mozilla/Assertions.h"
+#include "mozilla/Attributes.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/ThreadLocal.h"
+#include "mozilla/TimeStamp.h"
+#include "mozilla/Unused.h"
+#include "mozilla/mscom/AgileReference.h"
+#include "mozilla/mscom/InterceptorLog.h"
+#include "mozilla/mscom/Registration.h"
+#include "mozilla/mscom/Utils.h"
+#include "nsProxyRelease.h"
+#include "nsThreadUtils.h"
+
+using mozilla::DebugOnly;
+using mozilla::Unused;
+using mozilla::mscom::AgileReference;
+
+namespace {
+
+class MOZ_NON_TEMPORARY_CLASS InParamWalker : private ICallFrameWalker {
+ public:
+ InParamWalker() : mPreHandoff(true) {}
+
+ void SetHandoffDone() {
+ mPreHandoff = false;
+ mAgileRefsItr = mAgileRefs.begin();
+ }
+
+ HRESULT Walk(ICallFrame* aFrame) {
+ MOZ_ASSERT(aFrame);
+ if (!aFrame) {
+ return E_INVALIDARG;
+ }
+
+ return aFrame->WalkFrame(CALLFRAME_WALK_IN, this);
+ }
+
+ private:
+ // IUnknown
+ STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) override {
+ if (!aOutInterface) {
+ return E_INVALIDARG;
+ }
+ *aOutInterface = nullptr;
+
+ if (aIid == IID_IUnknown || aIid == IID_ICallFrameWalker) {
+ *aOutInterface = static_cast<ICallFrameWalker*>(this);
+ return S_OK;
+ }
+
+ return E_NOINTERFACE;
+ }
+
+ STDMETHODIMP_(ULONG) AddRef() override { return 2; }
+
+ STDMETHODIMP_(ULONG) Release() override { return 1; }
+
+ // ICallFrameWalker
+ STDMETHODIMP OnWalkInterface(REFIID aIid, PVOID* aInterface, BOOL aIn,
+ BOOL aOut) override {
+ MOZ_ASSERT(aIn);
+ if (!aIn) {
+ return E_UNEXPECTED;
+ }
+
+ IUnknown* origInterface = static_cast<IUnknown*>(*aInterface);
+ if (!origInterface) {
+ // Nothing to do
+ return S_OK;
+ }
+
+ if (mPreHandoff) {
+ mAgileRefs.AppendElement(AgileReference(aIid, origInterface));
+ return S_OK;
+ }
+
+ MOZ_ASSERT(mAgileRefsItr != mAgileRefs.end());
+ if (mAgileRefsItr == mAgileRefs.end()) {
+ return E_UNEXPECTED;
+ }
+
+ HRESULT hr = mAgileRefsItr->Resolve(aIid, aInterface);
+ MOZ_ASSERT(SUCCEEDED(hr));
+ if (SUCCEEDED(hr)) {
+ ++mAgileRefsItr;
+ }
+
+ return hr;
+ }
+
+ InParamWalker(const InParamWalker&) = delete;
+ InParamWalker(InParamWalker&&) = delete;
+ InParamWalker& operator=(const InParamWalker&) = delete;
+ InParamWalker& operator=(InParamWalker&&) = delete;
+
+ private:
+ bool mPreHandoff;
+ AutoTArray<AgileReference, 1> mAgileRefs;
+ nsTArray<AgileReference>::iterator mAgileRefsItr;
+};
+
+class HandoffRunnable : public mozilla::Runnable {
+ public:
+ explicit HandoffRunnable(ICallFrame* aCallFrame, IUnknown* aTargetInterface)
+ : Runnable("HandoffRunnable"),
+ mCallFrame(aCallFrame),
+ mTargetInterface(aTargetInterface),
+ mResult(E_UNEXPECTED) {
+ DebugOnly<HRESULT> hr = mInParamWalker.Walk(aCallFrame);
+ MOZ_ASSERT(SUCCEEDED(hr));
+ }
+
+ NS_IMETHOD Run() override {
+ mInParamWalker.SetHandoffDone();
+ // We declare hr a DebugOnly because if mInParamWalker.Walk() fails, then
+ // mCallFrame->Invoke will fail anyway.
+ DebugOnly<HRESULT> hr = mInParamWalker.Walk(mCallFrame);
+ MOZ_ASSERT(SUCCEEDED(hr));
+ mResult = mCallFrame->Invoke(mTargetInterface);
+ return NS_OK;
+ }
+
+ HRESULT GetResult() const { return mResult; }
+
+ private:
+ ICallFrame* mCallFrame;
+ InParamWalker mInParamWalker;
+ IUnknown* mTargetInterface;
+ HRESULT mResult;
+};
+
+class MOZ_RAII SavedCallFrame final {
+ public:
+ explicit SavedCallFrame(mozilla::NotNull<ICallFrame*> aFrame)
+ : mCallFrame(aFrame) {
+ static const bool sIsInit = tlsFrame.init();
+ MOZ_ASSERT(sIsInit);
+ MOZ_ASSERT(!tlsFrame.get());
+ tlsFrame.set(this);
+ Unused << sIsInit;
+ }
+
+ ~SavedCallFrame() {
+ MOZ_ASSERT(tlsFrame.get());
+ tlsFrame.set(nullptr);
+ }
+
+ HRESULT GetIidAndMethod(mozilla::NotNull<IID*> aIid,
+ mozilla::NotNull<ULONG*> aMethod) const {
+ return mCallFrame->GetIIDAndMethod(aIid, aMethod);
+ }
+
+ static const SavedCallFrame& Get() {
+ SavedCallFrame* saved = tlsFrame.get();
+ MOZ_ASSERT(saved);
+
+ return *saved;
+ }
+
+ SavedCallFrame(const SavedCallFrame&) = delete;
+ SavedCallFrame(SavedCallFrame&&) = delete;
+ SavedCallFrame& operator=(const SavedCallFrame&) = delete;
+ SavedCallFrame& operator=(SavedCallFrame&&) = delete;
+
+ private:
+ ICallFrame* mCallFrame;
+
+ private:
+ static MOZ_THREAD_LOCAL(SavedCallFrame*) tlsFrame;
+};
+
+MOZ_THREAD_LOCAL(SavedCallFrame*) SavedCallFrame::tlsFrame;
+
+class MOZ_RAII LogEvent final {
+ public:
+ LogEvent() : mCallStart(mozilla::TimeStamp::Now()) {}
+
+ ~LogEvent() {
+ if (mCapturedFrame.IsEmpty()) {
+ return;
+ }
+
+ mozilla::TimeStamp callEnd(mozilla::TimeStamp::Now());
+ mozilla::TimeDuration totalTime(callEnd - mCallStart);
+ mozilla::TimeDuration overhead(totalTime - mGeckoDuration -
+ mCaptureDuration);
+
+ mozilla::mscom::InterceptorLog::Event(mCapturedFrame, overhead,
+ mGeckoDuration);
+ }
+
+ void CaptureFrame(ICallFrame* aFrame, IUnknown* aTarget,
+ const mozilla::TimeDuration& aGeckoDuration) {
+ mozilla::TimeStamp captureStart(mozilla::TimeStamp::Now());
+
+ mozilla::mscom::InterceptorLog::CaptureFrame(aFrame, aTarget,
+ mCapturedFrame);
+ mGeckoDuration = aGeckoDuration;
+
+ mozilla::TimeStamp captureEnd(mozilla::TimeStamp::Now());
+
+ // Make sure that the time we spent in CaptureFrame isn't charged against
+ // overall overhead
+ mCaptureDuration = captureEnd - captureStart;
+ }
+
+ LogEvent(const LogEvent&) = delete;
+ LogEvent(LogEvent&&) = delete;
+ LogEvent& operator=(const LogEvent&) = delete;
+ LogEvent& operator=(LogEvent&&) = delete;
+
+ private:
+ mozilla::TimeStamp mCallStart;
+ mozilla::TimeDuration mGeckoDuration;
+ mozilla::TimeDuration mCaptureDuration;
+ nsAutoCString mCapturedFrame;
+};
+
+} // anonymous namespace
+
+namespace mozilla {
+namespace mscom {
+
+/* static */
+HRESULT MainThreadHandoff::Create(IHandlerProvider* aHandlerProvider,
+ IInterceptorSink** aOutput) {
+ RefPtr<MainThreadHandoff> handoff(new MainThreadHandoff(aHandlerProvider));
+ return handoff->QueryInterface(IID_IInterceptorSink, (void**)aOutput);
+}
+
+MainThreadHandoff::MainThreadHandoff(IHandlerProvider* aHandlerProvider)
+ : mRefCnt(0), mHandlerProvider(aHandlerProvider) {}
+
+MainThreadHandoff::~MainThreadHandoff() { MOZ_ASSERT(NS_IsMainThread()); }
+
+HRESULT
+MainThreadHandoff::QueryInterface(REFIID riid, void** ppv) {
+ IUnknown* punk = nullptr;
+ if (!ppv) {
+ return E_INVALIDARG;
+ }
+
+ if (riid == IID_IUnknown || riid == IID_ICallFrameEvents ||
+ riid == IID_IInterceptorSink || riid == IID_IMainThreadHandoff) {
+ punk = static_cast<IMainThreadHandoff*>(this);
+ } else if (riid == IID_ICallFrameWalker) {
+ punk = static_cast<ICallFrameWalker*>(this);
+ }
+
+ *ppv = punk;
+ if (!punk) {
+ return E_NOINTERFACE;
+ }
+
+ punk->AddRef();
+ return S_OK;
+}
+
+ULONG
+MainThreadHandoff::AddRef() {
+ return (ULONG)InterlockedIncrement((LONG*)&mRefCnt);
+}
+
+ULONG
+MainThreadHandoff::Release() {
+ ULONG newRefCnt = (ULONG)InterlockedDecrement((LONG*)&mRefCnt);
+ if (newRefCnt == 0) {
+ // It is possible for the last Release() call to happen off-main-thread.
+ // If so, we need to dispatch an event to delete ourselves.
+ if (NS_IsMainThread()) {
+ delete this;
+ } else {
+ // We need to delete this object on the main thread, but we aren't on the
+ // main thread right now, so we send a reference to ourselves to the main
+ // thread to be re-released there.
+ RefPtr<MainThreadHandoff> self = this;
+ NS_ReleaseOnMainThread("MainThreadHandoff", self.forget());
+ }
+ }
+ return newRefCnt;
+}
+
+HRESULT
+MainThreadHandoff::FixIServiceProvider(ICallFrame* aFrame) {
+ MOZ_ASSERT(aFrame);
+
+ CALLFRAMEPARAMINFO iidOutParamInfo;
+ HRESULT hr = aFrame->GetParamInfo(1, &iidOutParamInfo);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ VARIANT varIfaceOut;
+ hr = aFrame->GetParam(2, &varIfaceOut);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ MOZ_ASSERT(varIfaceOut.vt == (VT_UNKNOWN | VT_BYREF));
+ if (varIfaceOut.vt != (VT_UNKNOWN | VT_BYREF)) {
+ return DISP_E_BADVARTYPE;
+ }
+
+ IID** iidOutParam =
+ reinterpret_cast<IID**>(static_cast<BYTE*>(aFrame->GetStackLocation()) +
+ iidOutParamInfo.stackOffset);
+
+ return OnWalkInterface(**iidOutParam,
+ reinterpret_cast<void**>(varIfaceOut.ppunkVal), FALSE,
+ TRUE);
+}
+
+HRESULT
+MainThreadHandoff::OnCall(ICallFrame* aFrame) {
+ LogEvent logEvent;
+
+ // (1) Get info about the method call
+ HRESULT hr;
+ IID iid;
+ ULONG method;
+ hr = aFrame->GetIIDAndMethod(&iid, &method);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ RefPtr<IInterceptor> interceptor;
+ hr = mInterceptor->Resolve(IID_IInterceptor,
+ (void**)getter_AddRefs(interceptor));
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ InterceptorTargetPtr<IUnknown> targetInterface;
+ hr = interceptor->GetTargetForIID(iid, targetInterface);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ // (2) Execute the method call synchronously on the main thread
+ RefPtr<HandoffRunnable> handoffInfo(
+ new HandoffRunnable(aFrame, targetInterface.get()));
+ MainThreadInvoker invoker;
+ if (!invoker.Invoke(do_AddRef(handoffInfo))) {
+ MOZ_ASSERT(false);
+ return E_UNEXPECTED;
+ }
+ hr = handoffInfo->GetResult();
+ MOZ_ASSERT(SUCCEEDED(hr));
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ // (3) Capture *before* wrapping outputs so that the log will contain pointers
+ // to the true target interface, not the wrapped ones.
+ logEvent.CaptureFrame(aFrame, targetInterface.get(), invoker.GetDuration());
+
+ // (4) Scan the function call for outparams that contain interface pointers.
+ // Those will need to be wrapped with MainThreadHandoff so that they too will
+ // be exeuted on the main thread.
+
+ hr = aFrame->GetReturnValue();
+ if (FAILED(hr)) {
+ // If the call resulted in an error then there's not going to be anything
+ // that needs to be wrapped.
+ return S_OK;
+ }
+
+ if (iid == IID_IServiceProvider) {
+ // The only possible method index for IID_IServiceProvider is for
+ // QueryService at index 3; its other methods are inherited from IUnknown
+ // and are not processed here.
+ MOZ_ASSERT(method == 3);
+ // (5) If our interface is IServiceProvider, we need to manually ensure
+ // that the correct IID is provided for the interface outparam in
+ // IServiceProvider::QueryService.
+ hr = FixIServiceProvider(aFrame);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ } else if (const ArrayData* arrayData = FindArrayData(iid, method)) {
+ // (6) Unfortunately ICallFrame::WalkFrame does not correctly handle array
+ // outparams. Instead, we find out whether anybody has called
+ // mscom::RegisterArrayData to supply array parameter information and use it
+ // if available. This is a terrible hack, but it works for the short term.
+ // In the longer term we want to be able to use COM proxy/stub metadata to
+ // resolve array information for us.
+ hr = FixArrayElements(aFrame, *arrayData);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ } else {
+ SavedCallFrame savedFrame(WrapNotNull(aFrame));
+
+ // (7) Scan the outputs looking for any outparam interfaces that need
+ // wrapping. NB: WalkFrame does not correctly handle array outparams. It
+ // processes the first element of an array but not the remaining elements
+ // (if any).
+ hr = aFrame->WalkFrame(CALLFRAME_WALK_OUT, this);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ }
+
+ return S_OK;
+}
+
+static PVOID ResolveArrayPtr(VARIANT& aVariant) {
+ if (!(aVariant.vt & VT_BYREF)) {
+ return nullptr;
+ }
+ return aVariant.byref;
+}
+
+static PVOID* ResolveInterfacePtr(PVOID aArrayPtr, VARTYPE aVartype,
+ LONG aIndex) {
+ if (aVartype != (VT_VARIANT | VT_BYREF)) {
+ IUnknown** ifaceArray = reinterpret_cast<IUnknown**>(aArrayPtr);
+ return reinterpret_cast<PVOID*>(&ifaceArray[aIndex]);
+ }
+ VARIANT* variantArray = reinterpret_cast<VARIANT*>(aArrayPtr);
+ VARIANT& element = variantArray[aIndex];
+ return &element.byref;
+}
+
+HRESULT
+MainThreadHandoff::FixArrayElements(ICallFrame* aFrame,
+ const ArrayData& aArrayData) {
+ // Extract the array length
+ VARIANT paramVal;
+ VariantInit(&paramVal);
+ HRESULT hr = aFrame->GetParam(aArrayData.mLengthParamIndex, &paramVal);
+ MOZ_ASSERT(SUCCEEDED(hr) && (paramVal.vt == (VT_I4 | VT_BYREF) ||
+ paramVal.vt == (VT_UI4 | VT_BYREF)));
+ if (FAILED(hr) || (paramVal.vt != (VT_I4 | VT_BYREF) &&
+ paramVal.vt != (VT_UI4 | VT_BYREF))) {
+ return hr;
+ }
+
+ const LONG arrayLength = *(paramVal.plVal);
+ if (!arrayLength) {
+ // Nothing to do
+ return S_OK;
+ }
+
+ // Extract the array parameter
+ VariantInit(&paramVal);
+ PVOID arrayPtr = nullptr;
+ hr = aFrame->GetParam(aArrayData.mArrayParamIndex, &paramVal);
+ if (hr == DISP_E_BADVARTYPE) {
+ // ICallFrame::GetParam is not able to coerce the param into a VARIANT.
+ // That's ok, we can try to do it ourselves.
+ CALLFRAMEPARAMINFO paramInfo;
+ hr = aFrame->GetParamInfo(aArrayData.mArrayParamIndex, &paramInfo);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ PVOID stackBase = aFrame->GetStackLocation();
+ if (aArrayData.mFlag == ArrayData::Flag::eAllocatedByServer) {
+ // In order for the server to allocate the array's buffer and store it in
+ // an outparam, the parameter must be typed as Type***. Since the base
+ // of the array is Type*, we must dereference twice.
+ arrayPtr = **reinterpret_cast<PVOID**>(
+ reinterpret_cast<PBYTE>(stackBase) + paramInfo.stackOffset);
+ } else {
+ // We dereference because we need to obtain the value of a parameter
+ // from a stack offset. This pointer is the base of the array.
+ arrayPtr = *reinterpret_cast<PVOID*>(reinterpret_cast<PBYTE>(stackBase) +
+ paramInfo.stackOffset);
+ }
+ } else if (FAILED(hr)) {
+ return hr;
+ } else {
+ arrayPtr = ResolveArrayPtr(paramVal);
+ }
+
+ MOZ_ASSERT(arrayPtr);
+ if (!arrayPtr) {
+ return DISP_E_BADVARTYPE;
+ }
+
+ // We walk the elements of the array and invoke OnWalkInterface to wrap each
+ // one, just as ICallFrame::WalkFrame would do.
+ for (LONG index = 0; index < arrayLength; ++index) {
+ hr = OnWalkInterface(aArrayData.mArrayParamIid,
+ ResolveInterfacePtr(arrayPtr, paramVal.vt, index),
+ FALSE, TRUE);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ }
+ return S_OK;
+}
+
+HRESULT
+MainThreadHandoff::SetInterceptor(IWeakReference* aInterceptor) {
+ mInterceptor = aInterceptor;
+ return S_OK;
+}
+
+HRESULT
+MainThreadHandoff::GetHandler(NotNull<CLSID*> aHandlerClsid) {
+ if (!mHandlerProvider) {
+ return E_NOTIMPL;
+ }
+
+ return mHandlerProvider->GetHandler(aHandlerClsid);
+}
+
+HRESULT
+MainThreadHandoff::GetHandlerPayloadSize(NotNull<IInterceptor*> aInterceptor,
+ NotNull<DWORD*> aOutPayloadSize) {
+ if (!mHandlerProvider) {
+ return E_NOTIMPL;
+ }
+ return mHandlerProvider->GetHandlerPayloadSize(aInterceptor, aOutPayloadSize);
+}
+
+HRESULT
+MainThreadHandoff::WriteHandlerPayload(NotNull<IInterceptor*> aInterceptor,
+ NotNull<IStream*> aStream) {
+ if (!mHandlerProvider) {
+ return E_NOTIMPL;
+ }
+ return mHandlerProvider->WriteHandlerPayload(aInterceptor, aStream);
+}
+
+REFIID
+MainThreadHandoff::MarshalAs(REFIID aIid) {
+ if (!mHandlerProvider) {
+ return aIid;
+ }
+ return mHandlerProvider->MarshalAs(aIid);
+}
+
+HRESULT
+MainThreadHandoff::DisconnectHandlerRemotes() {
+ if (!mHandlerProvider) {
+ return E_NOTIMPL;
+ }
+
+ return mHandlerProvider->DisconnectHandlerRemotes();
+}
+
+HRESULT
+MainThreadHandoff::IsInterfaceMaybeSupported(REFIID aIid) {
+ if (!mHandlerProvider) {
+ return S_OK;
+ }
+ return mHandlerProvider->IsInterfaceMaybeSupported(aIid);
+}
+
+HRESULT
+MainThreadHandoff::OnWalkInterface(REFIID aIid, PVOID* aInterface,
+ BOOL aIsInParam, BOOL aIsOutParam) {
+ MOZ_ASSERT(aInterface && aIsOutParam);
+ if (!aInterface || !aIsOutParam) {
+ return E_UNEXPECTED;
+ }
+
+ // Adopt aInterface for the time being. We can't touch its refcount off
+ // the main thread, so we'll use STAUniquePtr so that we can safely
+ // Release() it if necessary.
+ STAUniquePtr<IUnknown> origInterface(static_cast<IUnknown*>(*aInterface));
+ *aInterface = nullptr;
+
+ if (!origInterface) {
+ // Nothing to wrap.
+ return S_OK;
+ }
+
+ // First make sure that aInterface isn't a proxy - we don't want to wrap
+ // those.
+ if (IsProxy(origInterface.get())) {
+ *aInterface = origInterface.release();
+ return S_OK;
+ }
+
+ RefPtr<IInterceptor> interceptor;
+ HRESULT hr = mInterceptor->Resolve(IID_IInterceptor,
+ (void**)getter_AddRefs(interceptor));
+ MOZ_ASSERT(SUCCEEDED(hr));
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ // Now make sure that origInterface isn't referring to the same IUnknown
+ // as an interface that we are already managing. We can determine this by
+ // querying (NOT casting!) both objects for IUnknown and then comparing the
+ // resulting pointers.
+ InterceptorTargetPtr<IUnknown> existingTarget;
+ hr = interceptor->GetTargetForIID(aIid, existingTarget);
+ if (SUCCEEDED(hr)) {
+ // We'll start by checking the raw pointers. If they are equal, then the
+ // objects are equal. OTOH, if they differ, we must compare their
+ // IUnknown pointers to know for sure.
+ bool areTargetsEqual = existingTarget.get() == origInterface.get();
+
+ if (!areTargetsEqual) {
+ // This check must be done on the main thread
+ auto checkFn = [&existingTarget, &origInterface,
+ &areTargetsEqual]() -> void {
+ RefPtr<IUnknown> unkExisting;
+ HRESULT hrExisting = existingTarget->QueryInterface(
+ IID_IUnknown, (void**)getter_AddRefs(unkExisting));
+ RefPtr<IUnknown> unkNew;
+ HRESULT hrNew = origInterface->QueryInterface(
+ IID_IUnknown, (void**)getter_AddRefs(unkNew));
+ areTargetsEqual =
+ SUCCEEDED(hrExisting) && SUCCEEDED(hrNew) && unkExisting == unkNew;
+ };
+
+ MainThreadInvoker invoker;
+ invoker.Invoke(NS_NewRunnableFunction(
+ "MainThreadHandoff::OnWalkInterface", checkFn));
+ }
+
+ if (areTargetsEqual) {
+ // The existing interface and the new interface both belong to the same
+ // target object. Let's just use the existing one.
+ void* intercepted = nullptr;
+ hr = interceptor->GetInterceptorForIID(aIid, &intercepted);
+ MOZ_ASSERT(SUCCEEDED(hr));
+ if (FAILED(hr)) {
+ return hr;
+ }
+ *aInterface = intercepted;
+ return S_OK;
+ }
+ }
+
+ IID effectiveIid = aIid;
+
+ RefPtr<IHandlerProvider> payload;
+ if (mHandlerProvider) {
+ if (aIid == IID_IUnknown) {
+ const SavedCallFrame& curFrame = SavedCallFrame::Get();
+
+ IID callIid;
+ ULONG callMethod;
+ hr = curFrame.GetIidAndMethod(WrapNotNull(&callIid),
+ WrapNotNull(&callMethod));
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ effectiveIid =
+ mHandlerProvider->GetEffectiveOutParamIid(callIid, callMethod);
+ }
+
+ hr = mHandlerProvider->NewInstance(
+ effectiveIid, ToInterceptorTargetPtr(origInterface),
+ WrapNotNull((IHandlerProvider**)getter_AddRefs(payload)));
+ MOZ_ASSERT(SUCCEEDED(hr));
+ if (FAILED(hr)) {
+ return hr;
+ }
+ }
+
+ // Now create a new MainThreadHandoff wrapper...
+ RefPtr<IInterceptorSink> handoff;
+ hr = MainThreadHandoff::Create(payload, getter_AddRefs(handoff));
+ MOZ_ASSERT(SUCCEEDED(hr));
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ REFIID interceptorIid =
+ payload ? payload->MarshalAs(effectiveIid) : effectiveIid;
+
+ RefPtr<IUnknown> wrapped;
+ hr = Interceptor::Create(std::move(origInterface), handoff, interceptorIid,
+ getter_AddRefs(wrapped));
+ MOZ_ASSERT(SUCCEEDED(hr));
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ // And replace the original interface pointer with the wrapped one.
+ wrapped.forget(reinterpret_cast<IUnknown**>(aInterface));
+
+ return S_OK;
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/MainThreadHandoff.h b/ipc/mscom/MainThreadHandoff.h
new file mode 100644
index 0000000000..e790cea4b4
--- /dev/null
+++ b/ipc/mscom/MainThreadHandoff.h
@@ -0,0 +1,105 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_MainThreadHandoff_h
+#define mozilla_mscom_MainThreadHandoff_h
+
+#include <utility>
+
+#include "mozilla/Assertions.h"
+#include "mozilla/Mutex.h"
+#include "mozilla/mscom/Interceptor.h"
+#include "mozilla/mscom/MainThreadInvoker.h"
+#include "mozilla/mscom/Utils.h"
+#include "nsTArray.h"
+
+namespace mozilla {
+namespace mscom {
+
+// {9a907000-7829-47f1-80eb-f67a26f47b34}
+DEFINE_GUID(IID_IMainThreadHandoff, 0x9a907000, 0x7829, 0x47f1, 0x80, 0xeb,
+ 0xf6, 0x7a, 0x26, 0xf4, 0x7b, 0x34);
+
+struct IMainThreadHandoff : public IInterceptorSink {
+ virtual STDMETHODIMP GetHandlerProvider(IHandlerProvider** aProvider) = 0;
+};
+
+struct ArrayData;
+
+class MainThreadHandoff final : public IMainThreadHandoff,
+ public ICallFrameWalker {
+ public:
+ static HRESULT Create(IHandlerProvider* aHandlerProvider,
+ IInterceptorSink** aOutput);
+
+ template <typename Interface>
+ static HRESULT WrapInterface(STAUniquePtr<Interface> aTargetInterface,
+ Interface** aOutInterface) {
+ return WrapInterface<Interface>(std::move(aTargetInterface), nullptr,
+ aOutInterface);
+ }
+
+ template <typename Interface>
+ static HRESULT WrapInterface(STAUniquePtr<Interface> aTargetInterface,
+ IHandlerProvider* aHandlerProvider,
+ Interface** aOutInterface) {
+ MOZ_ASSERT(!IsProxy(aTargetInterface.get()));
+ RefPtr<IInterceptorSink> handoff;
+ HRESULT hr =
+ MainThreadHandoff::Create(aHandlerProvider, getter_AddRefs(handoff));
+ if (FAILED(hr)) {
+ return hr;
+ }
+ return CreateInterceptor(std::move(aTargetInterface), handoff,
+ aOutInterface);
+ }
+
+ // IUnknown
+ STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override;
+ STDMETHODIMP_(ULONG) AddRef() override;
+ STDMETHODIMP_(ULONG) Release() override;
+
+ // ICallFrameEvents
+ STDMETHODIMP OnCall(ICallFrame* aFrame) override;
+
+ // IInterceptorSink
+ STDMETHODIMP SetInterceptor(IWeakReference* aInterceptor) override;
+ STDMETHODIMP GetHandler(NotNull<CLSID*> aHandlerClsid) override;
+ STDMETHODIMP GetHandlerPayloadSize(NotNull<IInterceptor*> aInterceptor,
+ NotNull<DWORD*> aOutPayloadSize) override;
+ STDMETHODIMP WriteHandlerPayload(NotNull<IInterceptor*> aInterceptor,
+ NotNull<IStream*> aStream) override;
+ STDMETHODIMP_(REFIID) MarshalAs(REFIID aIid) override;
+ STDMETHODIMP DisconnectHandlerRemotes() override;
+ STDMETHODIMP IsInterfaceMaybeSupported(REFIID aIid) override;
+
+ // IMainThreadHandoff
+ STDMETHODIMP GetHandlerProvider(IHandlerProvider** aProvider) override {
+ RefPtr<IHandlerProvider> provider = mHandlerProvider;
+ provider.forget(aProvider);
+ return mHandlerProvider ? S_OK : S_FALSE;
+ }
+
+ // ICallFrameWalker
+ STDMETHODIMP OnWalkInterface(REFIID aIid, PVOID* aInterface, BOOL aIsInParam,
+ BOOL aIsOutParam) override;
+
+ private:
+ explicit MainThreadHandoff(IHandlerProvider* aHandlerProvider);
+ ~MainThreadHandoff();
+ HRESULT FixArrayElements(ICallFrame* aFrame, const ArrayData& aArrayData);
+ HRESULT FixIServiceProvider(ICallFrame* aFrame);
+
+ private:
+ ULONG mRefCnt;
+ RefPtr<IWeakReference> mInterceptor;
+ RefPtr<IHandlerProvider> mHandlerProvider;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_MainThreadHandoff_h
diff --git a/ipc/mscom/MainThreadInvoker.cpp b/ipc/mscom/MainThreadInvoker.cpp
new file mode 100644
index 0000000000..8062800249
--- /dev/null
+++ b/ipc/mscom/MainThreadInvoker.cpp
@@ -0,0 +1,177 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/mscom/MainThreadInvoker.h"
+
+#include "MainThreadUtils.h"
+#include "mozilla/Assertions.h"
+#include "mozilla/BackgroundHangMonitor.h"
+#include "mozilla/ClearOnShutdown.h"
+#include "mozilla/ProfilerThreadState.h"
+#include "mozilla/SchedulerGroup.h"
+#include "mozilla/mscom/SpinEvent.h"
+#include "mozilla/RefPtr.h"
+#include "mozilla/Unused.h"
+#include "private/prpriv.h" // For PR_GetThreadID
+#include <winternl.h> // For NTSTATUS and NTAPI
+
+namespace {
+
+typedef NTSTATUS(NTAPI* NtTestAlertPtr)(VOID);
+
+/**
+ * SyncRunnable implements different code paths depending on whether or not
+ * we are running on a multiprocessor system. In the multiprocessor case, we
+ * leave the thread in a spin loop while waiting for the main thread to execute
+ * our runnable. Since spinning is pointless in the uniprocessor case, we block
+ * on an event that is set by the main thread once it has finished the runnable.
+ */
+class SyncRunnable : public mozilla::Runnable {
+ public:
+ explicit SyncRunnable(already_AddRefed<nsIRunnable> aRunnable)
+ : mozilla::Runnable("MainThreadInvoker"), mRunnable(aRunnable) {
+ static const bool gotStatics = InitStatics();
+ MOZ_ASSERT(gotStatics);
+ Unused << gotStatics;
+ }
+
+ ~SyncRunnable() = default;
+
+ NS_IMETHOD Run() override {
+ if (mHasRun) {
+ // The APC already ran, so we have nothing to do.
+ return NS_OK;
+ }
+
+ // Run the pending APC in the queue.
+ MOZ_ASSERT(sNtTestAlert);
+ sNtTestAlert();
+ return NS_OK;
+ }
+
+ // This is called by MainThreadInvoker::MainThreadAPC.
+ void APCRun() {
+ mHasRun = true;
+
+ TimeStamp runStart(TimeStamp::Now());
+ mRunnable->Run();
+ TimeStamp runEnd(TimeStamp::Now());
+
+ mDuration = runEnd - runStart;
+
+ mEvent.Signal();
+ }
+
+ bool WaitUntilComplete() {
+ return mEvent.Wait(mozilla::mscom::MainThreadInvoker::GetTargetThread());
+ }
+
+ const mozilla::TimeDuration& GetDuration() const { return mDuration; }
+
+ private:
+ bool mHasRun = false;
+ nsCOMPtr<nsIRunnable> mRunnable;
+ mozilla::mscom::SpinEvent mEvent;
+ mozilla::TimeDuration mDuration;
+
+ static NtTestAlertPtr sNtTestAlert;
+
+ static bool InitStatics() {
+ sNtTestAlert = reinterpret_cast<NtTestAlertPtr>(
+ ::GetProcAddress(::GetModuleHandleW(L"ntdll.dll"), "NtTestAlert"));
+ MOZ_ASSERT(sNtTestAlert);
+ return sNtTestAlert;
+ }
+};
+
+NtTestAlertPtr SyncRunnable::sNtTestAlert = nullptr;
+
+} // anonymous namespace
+
+namespace mozilla {
+namespace mscom {
+
+HANDLE MainThreadInvoker::sMainThread = nullptr;
+
+/* static */
+bool MainThreadInvoker::InitStatics() {
+ nsCOMPtr<nsIThread> mainThread;
+ nsresult rv = ::NS_GetMainThread(getter_AddRefs(mainThread));
+ if (NS_FAILED(rv)) {
+ return false;
+ }
+
+ PRThread* mainPrThread = nullptr;
+ rv = mainThread->GetPRThread(&mainPrThread);
+ if (NS_FAILED(rv)) {
+ return false;
+ }
+
+ PRUint32 tid = ::PR_GetThreadID(mainPrThread);
+ sMainThread = ::OpenThread(SYNCHRONIZE | THREAD_SET_CONTEXT, FALSE, tid);
+
+ return !!sMainThread;
+}
+
+MainThreadInvoker::MainThreadInvoker() {
+ static const bool gotStatics = InitStatics();
+ MOZ_ASSERT(gotStatics);
+ Unused << gotStatics;
+}
+
+bool MainThreadInvoker::Invoke(already_AddRefed<nsIRunnable>&& aRunnable) {
+ nsCOMPtr<nsIRunnable> runnable(std::move(aRunnable));
+ if (!runnable) {
+ return false;
+ }
+
+ if (NS_IsMainThread()) {
+ runnable->Run();
+ return true;
+ }
+
+ RefPtr<SyncRunnable> syncRunnable = new SyncRunnable(runnable.forget());
+
+ // The main thread could be either blocked on a condition variable waiting
+ // for a Gecko event, or it could be blocked waiting on a Windows HANDLE in
+ // IPC code (doing a sync message send). In the former case, we wake it by
+ // posting a Gecko runnable to the main thread. In the latter case, we wake
+ // it using an APC. However, the latter case doesn't happen very often now
+ // and APCs aren't otherwise run by the main thread. To ensure the
+ // SyncRunnable is cleaned up, we need both to run consistently.
+ // To do this, we:
+ // 1. Queue an APC which does the actual work.
+ // This ref gets released in MainThreadAPC when it runs.
+ SyncRunnable* syncRunnableRef = syncRunnable.get();
+ NS_ADDREF(syncRunnableRef);
+ if (!::QueueUserAPC(&MainThreadAPC, sMainThread,
+ reinterpret_cast<UINT_PTR>(syncRunnableRef))) {
+ return false;
+ }
+
+ // 2. Post a Gecko runnable (which always runs). If the APC hasn't run, the
+ // Gecko runnable runs it. Otherwise, it does nothing.
+ if (NS_FAILED(SchedulerGroup::Dispatch(TaskCategory::Other,
+ do_AddRef(syncRunnable)))) {
+ return false;
+ }
+
+ bool result = syncRunnable->WaitUntilComplete();
+ mDuration = syncRunnable->GetDuration();
+ return result;
+}
+
+/* static */ VOID CALLBACK MainThreadInvoker::MainThreadAPC(ULONG_PTR aParam) {
+ AUTO_PROFILER_THREAD_WAKE;
+ mozilla::BackgroundHangMonitor().NotifyActivity();
+ MOZ_ASSERT(NS_IsMainThread());
+ auto runnable = reinterpret_cast<SyncRunnable*>(aParam);
+ runnable->APCRun();
+ NS_RELEASE(runnable);
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/MainThreadInvoker.h b/ipc/mscom/MainThreadInvoker.h
new file mode 100644
index 0000000000..8b56f9bed0
--- /dev/null
+++ b/ipc/mscom/MainThreadInvoker.h
@@ -0,0 +1,56 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_MainThreadInvoker_h
+#define mozilla_mscom_MainThreadInvoker_h
+
+#include <windows.h>
+
+#include <utility>
+
+#include "mozilla/AlreadyAddRefed.h"
+#include "mozilla/StaticPtr.h"
+#include "mozilla/TimeStamp.h"
+#include "nsCOMPtr.h"
+#include "nsThreadUtils.h"
+
+class nsIRunnable;
+
+namespace mozilla {
+namespace mscom {
+
+class MainThreadInvoker {
+ public:
+ MainThreadInvoker();
+
+ bool Invoke(already_AddRefed<nsIRunnable>&& aRunnable);
+ const TimeDuration& GetDuration() const { return mDuration; }
+ static HANDLE GetTargetThread() { return sMainThread; }
+
+ private:
+ TimeDuration mDuration;
+
+ static bool InitStatics();
+ static VOID CALLBACK MainThreadAPC(ULONG_PTR aParam);
+
+ static HANDLE sMainThread;
+};
+
+template <typename Class, typename... Args>
+inline bool InvokeOnMainThread(const char* aName, Class* aObject,
+ void (Class::*aMethod)(Args...),
+ Args&&... aArgs) {
+ nsCOMPtr<nsIRunnable> runnable(NewNonOwningRunnableMethod<Args...>(
+ aName, aObject, aMethod, std::forward<Args>(aArgs)...));
+
+ MainThreadInvoker invoker;
+ return invoker.Invoke(runnable.forget());
+}
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_MainThreadInvoker_h
diff --git a/ipc/mscom/Objref.cpp b/ipc/mscom/Objref.cpp
new file mode 100644
index 0000000000..37108c71ac
--- /dev/null
+++ b/ipc/mscom/Objref.cpp
@@ -0,0 +1,411 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/mscom/Objref.h"
+
+#include "mozilla/Assertions.h"
+#include "mozilla/mscom/Utils.h"
+#include "mozilla/RefPtr.h"
+#include "mozilla/ScopeExit.h"
+#include "mozilla/UniquePtr.h"
+
+#include <guiddef.h>
+#include <objidl.h>
+#include <winnt.h>
+
+// {00000027-0000-0008-C000-000000000046}
+static const GUID CLSID_AggStdMarshal = {
+ 0x27, 0x0, 0x8, {0xC0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x46}};
+
+namespace {
+
+#pragma pack(push, 1)
+
+typedef uint64_t OID;
+typedef uint64_t OXID;
+typedef GUID IPID;
+
+struct STDOBJREF {
+ uint32_t mFlags;
+ uint32_t mPublicRefs;
+ OXID mOxid;
+ OID mOid;
+ IPID mIpid;
+};
+
+enum STDOBJREF_FLAGS { SORF_PING = 0, SORF_NOPING = 0x1000 };
+
+struct DUALSTRINGARRAY {
+ static size_t SizeFromNumEntries(const uint16_t aNumEntries) {
+ return sizeof(mNumEntries) + sizeof(mSecurityOffset) +
+ aNumEntries * sizeof(uint16_t);
+ }
+
+ size_t SizeOf() const { return SizeFromNumEntries(mNumEntries); }
+
+ uint16_t mNumEntries;
+ uint16_t mSecurityOffset;
+ uint16_t mStringArray[1]; // Length is mNumEntries
+};
+
+struct OBJREF_STANDARD {
+ static size_t SizeOfFixedLenHeader() { return sizeof(mStd); }
+
+ size_t SizeOf() const { return SizeOfFixedLenHeader() + mResAddr.SizeOf(); }
+
+ STDOBJREF mStd;
+ DUALSTRINGARRAY mResAddr;
+};
+
+struct OBJREF_HANDLER {
+ static size_t SizeOfFixedLenHeader() { return sizeof(mStd) + sizeof(mClsid); }
+
+ size_t SizeOf() const { return SizeOfFixedLenHeader() + mResAddr.SizeOf(); }
+
+ STDOBJREF mStd;
+ CLSID mClsid;
+ DUALSTRINGARRAY mResAddr;
+};
+
+struct OBJREF_CUSTOM {
+ static size_t SizeOfFixedLenHeader() {
+ return sizeof(mClsid) + sizeof(mCbExtension) + sizeof(mReserved);
+ }
+
+ CLSID mClsid;
+ uint32_t mCbExtension;
+ uint32_t mReserved;
+ uint8_t mPayload[1];
+};
+
+enum OBJREF_FLAGS {
+ OBJREF_TYPE_STANDARD = 0x00000001UL,
+ OBJREF_TYPE_HANDLER = 0x00000002UL,
+ OBJREF_TYPE_CUSTOM = 0x00000004UL,
+ OBJREF_TYPE_EXTENDED = 0x00000008UL,
+};
+
+struct OBJREF {
+ static size_t SizeOfFixedLenHeader(OBJREF_FLAGS aFlags) {
+ size_t size = sizeof(mSignature) + sizeof(mFlags) + sizeof(mIid);
+
+ switch (aFlags) {
+ case OBJREF_TYPE_STANDARD:
+ size += OBJREF_STANDARD::SizeOfFixedLenHeader();
+ break;
+ case OBJREF_TYPE_HANDLER:
+ size += OBJREF_HANDLER::SizeOfFixedLenHeader();
+ break;
+ case OBJREF_TYPE_CUSTOM:
+ size += OBJREF_CUSTOM::SizeOfFixedLenHeader();
+ break;
+ default:
+ MOZ_ASSERT_UNREACHABLE("Unsupported OBJREF type");
+ return 0;
+ }
+
+ return size;
+ }
+
+ size_t SizeOf() const {
+ size_t size = sizeof(mSignature) + sizeof(mFlags) + sizeof(mIid);
+
+ switch (mFlags) {
+ case OBJREF_TYPE_STANDARD:
+ size += mObjRefStd.SizeOf();
+ break;
+ case OBJREF_TYPE_HANDLER:
+ size += mObjRefHandler.SizeOf();
+ break;
+ default:
+ MOZ_ASSERT_UNREACHABLE("Unsupported OBJREF type");
+ return 0;
+ }
+
+ return size;
+ }
+
+ uint32_t mSignature;
+ uint32_t mFlags;
+ IID mIid;
+ union {
+ OBJREF_STANDARD mObjRefStd;
+ OBJREF_HANDLER mObjRefHandler;
+ OBJREF_CUSTOM mObjRefCustom;
+ // There are others but we're not supporting them here
+ };
+};
+
+enum OBJREF_SIGNATURES { OBJREF_SIGNATURE = 0x574F454DUL };
+
+#pragma pack(pop)
+
+struct ByteArrayDeleter {
+ void operator()(void* aPtr) { delete[] reinterpret_cast<uint8_t*>(aPtr); }
+};
+
+template <typename T>
+using VarStructUniquePtr = mozilla::UniquePtr<T, ByteArrayDeleter>;
+
+} // anonymous namespace
+
+namespace mozilla {
+namespace mscom {
+
+bool StripHandlerFromOBJREF(NotNull<IStream*> aStream, const uint64_t aStartPos,
+ const uint64_t aEndPos) {
+ // Ensure that the current stream position is set to the beginning
+ LARGE_INTEGER seekTo;
+ seekTo.QuadPart = aStartPos;
+
+ HRESULT hr = aStream->Seek(seekTo, STREAM_SEEK_SET, nullptr);
+ if (FAILED(hr)) {
+ return false;
+ }
+
+ ULONG bytesRead;
+
+ uint32_t signature;
+ hr = aStream->Read(&signature, sizeof(signature), &bytesRead);
+ if (FAILED(hr) || bytesRead != sizeof(signature) ||
+ signature != OBJREF_SIGNATURE) {
+ return false;
+ }
+
+ uint32_t type;
+ hr = aStream->Read(&type, sizeof(type), &bytesRead);
+ if (FAILED(hr) || bytesRead != sizeof(type)) {
+ return false;
+ }
+ if (type != OBJREF_TYPE_HANDLER) {
+ // If we're not a handler then just seek to the end of the OBJREF and return
+ // success; there is nothing left to do.
+ seekTo.QuadPart = aEndPos;
+ return SUCCEEDED(aStream->Seek(seekTo, STREAM_SEEK_SET, nullptr));
+ }
+
+ IID iid;
+ hr = aStream->Read(&iid, sizeof(iid), &bytesRead);
+ if (FAILED(hr) || bytesRead != sizeof(iid) || !IsValidGUID(iid)) {
+ return false;
+ }
+
+ // Seek past fixed-size STDOBJREF and CLSID
+ seekTo.QuadPart = sizeof(STDOBJREF) + sizeof(CLSID);
+ hr = aStream->Seek(seekTo, STREAM_SEEK_CUR, nullptr);
+ if (FAILED(hr)) {
+ return false;
+ }
+
+ uint16_t numEntries;
+ hr = aStream->Read(&numEntries, sizeof(numEntries), &bytesRead);
+ if (FAILED(hr) || bytesRead != sizeof(numEntries)) {
+ return false;
+ }
+
+ // We'll try to use a stack buffer if resAddrSize <= kMinDualStringArraySize
+ const uint32_t kMinDualStringArraySize = 12;
+ uint16_t staticResAddrBuf[kMinDualStringArraySize / sizeof(uint16_t)];
+
+ size_t resAddrSize = DUALSTRINGARRAY::SizeFromNumEntries(numEntries);
+
+ DUALSTRINGARRAY* resAddr;
+ VarStructUniquePtr<DUALSTRINGARRAY> dynamicResAddrBuf;
+
+ if (resAddrSize <= kMinDualStringArraySize) {
+ resAddr = reinterpret_cast<DUALSTRINGARRAY*>(staticResAddrBuf);
+ } else {
+ dynamicResAddrBuf.reset(
+ reinterpret_cast<DUALSTRINGARRAY*>(new uint8_t[resAddrSize]));
+ resAddr = dynamicResAddrBuf.get();
+ }
+
+ resAddr->mNumEntries = numEntries;
+
+ // Because we've already read numEntries
+ ULONG bytesToRead = resAddrSize - sizeof(numEntries);
+
+ hr = aStream->Read(&resAddr->mSecurityOffset, bytesToRead, &bytesRead);
+ if (FAILED(hr) || bytesRead != bytesToRead) {
+ return false;
+ }
+
+ // Signature doesn't change so we'll seek past that
+ seekTo.QuadPart = aStartPos + sizeof(signature);
+ hr = aStream->Seek(seekTo, STREAM_SEEK_SET, nullptr);
+ if (FAILED(hr)) {
+ return false;
+ }
+
+ ULONG bytesWritten;
+
+ uint32_t newType = OBJREF_TYPE_STANDARD;
+ hr = aStream->Write(&newType, sizeof(newType), &bytesWritten);
+ if (FAILED(hr) || bytesWritten != sizeof(newType)) {
+ return false;
+ }
+
+ // Skip past IID and STDOBJREF since those don't change
+ seekTo.QuadPart = sizeof(IID) + sizeof(STDOBJREF);
+ hr = aStream->Seek(seekTo, STREAM_SEEK_CUR, nullptr);
+ if (FAILED(hr)) {
+ return false;
+ }
+
+ hr = aStream->Write(resAddr, resAddrSize, &bytesWritten);
+ if (FAILED(hr) || bytesWritten != resAddrSize) {
+ return false;
+ }
+
+ // The difference between a OBJREF_STANDARD and an OBJREF_HANDLER is
+ // sizeof(CLSID), so we'll zero out the remaining bytes.
+ hr = aStream->Write(&CLSID_NULL, sizeof(CLSID), &bytesWritten);
+ if (FAILED(hr) || bytesWritten != sizeof(CLSID)) {
+ return false;
+ }
+
+ // Back up to the end of the tweaked OBJREF.
+ // There are now sizeof(CLSID) less bytes.
+ // Bug 1403180: Using -sizeof(CLSID) with a relative seek sometimes
+ // doesn't work on Windows 7.
+ // It succeeds, but doesn't seek the stream for some unknown reason.
+ // Use an absolute seek instead.
+ seekTo.QuadPart = aEndPos - sizeof(CLSID);
+ return SUCCEEDED(aStream->Seek(seekTo, STREAM_SEEK_SET, nullptr));
+}
+
+uint32_t GetOBJREFSize(NotNull<IStream*> aStream) {
+ // Make a clone so that we don't manipulate aStream's seek pointer
+ RefPtr<IStream> cloned;
+ HRESULT hr = aStream->Clone(getter_AddRefs(cloned));
+ if (FAILED(hr)) {
+ return 0;
+ }
+
+ uint32_t accumulatedSize = 0;
+
+ ULONG bytesRead;
+
+ uint32_t signature;
+ hr = cloned->Read(&signature, sizeof(signature), &bytesRead);
+ if (FAILED(hr) || bytesRead != sizeof(signature) ||
+ signature != OBJREF_SIGNATURE) {
+ return 0;
+ }
+
+ accumulatedSize += bytesRead;
+
+ uint32_t type;
+ hr = cloned->Read(&type, sizeof(type), &bytesRead);
+ if (FAILED(hr) || bytesRead != sizeof(type)) {
+ return 0;
+ }
+
+ accumulatedSize += bytesRead;
+
+ IID iid;
+ hr = cloned->Read(&iid, sizeof(iid), &bytesRead);
+ if (FAILED(hr) || bytesRead != sizeof(iid) || !IsValidGUID(iid)) {
+ return 0;
+ }
+
+ accumulatedSize += bytesRead;
+
+ LARGE_INTEGER seekTo;
+
+ if (type == OBJREF_TYPE_CUSTOM) {
+ CLSID clsid;
+ hr = cloned->Read(&clsid, sizeof(CLSID), &bytesRead);
+ if (FAILED(hr) || bytesRead != sizeof(CLSID)) {
+ return 0;
+ }
+
+ if (clsid != CLSID_StdMarshal && clsid != CLSID_AggStdMarshal) {
+ // We can only calulate the size if the payload is a standard OBJREF as
+ // identified by clsid being CLSID_StdMarshal or CLSID_AggStdMarshal.
+ // (CLSID_AggStdMarshal, the aggregated standard marshaler, is used when
+ // the handler marshals an interface.)
+ return 0;
+ }
+
+ accumulatedSize += bytesRead;
+
+ seekTo.QuadPart =
+ sizeof(OBJREF_CUSTOM::mCbExtension) + sizeof(OBJREF_CUSTOM::mReserved);
+ hr = cloned->Seek(seekTo, STREAM_SEEK_CUR, nullptr);
+ if (FAILED(hr)) {
+ return 0;
+ }
+
+ accumulatedSize += seekTo.LowPart;
+
+ uint32_t payloadLen = GetOBJREFSize(WrapNotNull(cloned.get()));
+ if (!payloadLen) {
+ return 0;
+ }
+
+ accumulatedSize += payloadLen;
+ return accumulatedSize;
+ }
+
+ switch (type) {
+ case OBJREF_TYPE_STANDARD:
+ seekTo.QuadPart = OBJREF_STANDARD::SizeOfFixedLenHeader();
+ break;
+ case OBJREF_TYPE_HANDLER:
+ seekTo.QuadPart = OBJREF_HANDLER::SizeOfFixedLenHeader();
+ break;
+ default:
+ return 0;
+ }
+
+ hr = cloned->Seek(seekTo, STREAM_SEEK_CUR, nullptr);
+ if (FAILED(hr)) {
+ return 0;
+ }
+
+ accumulatedSize += seekTo.LowPart;
+
+ uint16_t numEntries;
+ hr = cloned->Read(&numEntries, sizeof(numEntries), &bytesRead);
+ if (FAILED(hr) || bytesRead != sizeof(numEntries)) {
+ return 0;
+ }
+
+ accumulatedSize += DUALSTRINGARRAY::SizeFromNumEntries(numEntries);
+ return accumulatedSize;
+}
+
+bool SetIID(NotNull<IStream*> aStream, const uint64_t aStart, REFIID aNewIid) {
+ ULARGE_INTEGER initialStreamPos;
+
+ LARGE_INTEGER seekTo;
+ seekTo.QuadPart = 0LL;
+ HRESULT hr = aStream->Seek(seekTo, STREAM_SEEK_CUR, &initialStreamPos);
+ if (FAILED(hr)) {
+ return false;
+ }
+
+ auto resetStreamPos = MakeScopeExit([&]() {
+ seekTo.QuadPart = initialStreamPos.QuadPart;
+ hr = aStream->Seek(seekTo, STREAM_SEEK_SET, nullptr);
+ MOZ_DIAGNOSTIC_ASSERT(SUCCEEDED(hr));
+ });
+
+ seekTo.QuadPart =
+ aStart + sizeof(OBJREF::mSignature) + sizeof(OBJREF::mFlags);
+ hr = aStream->Seek(seekTo, STREAM_SEEK_SET, nullptr);
+ if (FAILED(hr)) {
+ return false;
+ }
+
+ ULONG bytesWritten;
+ hr = aStream->Write(&aNewIid, sizeof(IID), &bytesWritten);
+ return SUCCEEDED(hr) && bytesWritten == sizeof(IID);
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/Objref.h b/ipc/mscom/Objref.h
new file mode 100644
index 0000000000..1b19b68644
--- /dev/null
+++ b/ipc/mscom/Objref.h
@@ -0,0 +1,53 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_Objref_h
+#define mozilla_mscom_Objref_h
+
+#include "mozilla/NotNull.h"
+#include "mozilla/RefPtr.h"
+
+#include <guiddef.h>
+
+struct IStream;
+
+namespace mozilla {
+namespace mscom {
+
+/**
+ * Given a buffer containing a serialized proxy to an interface with a handler,
+ * this function strips out the handler and converts it to a standard one.
+ * @param aStream IStream containing a serialized proxy.
+ * There should be nothing else written to the stream past the
+ * current OBJREF.
+ * @param aStart Absolute position of the beginning of the OBJREF.
+ * @param aEnd Absolute position of the end of the OBJREF.
+ * @return true if the handler was successfully stripped, otherwise false.
+ */
+bool StripHandlerFromOBJREF(NotNull<IStream*> aStream, const uint64_t aStart,
+ const uint64_t aEnd);
+
+/**
+ * Given a buffer containing a serialized proxy to an interface, this function
+ * returns the length of the serialized data.
+ * @param aStream IStream containing a serialized proxy. The stream pointer
+ * must be positioned at the beginning of the OBJREF.
+ * @return The size of the serialized proxy, or 0 on error.
+ */
+uint32_t GetOBJREFSize(NotNull<IStream*> aStream);
+
+/**
+ * Overrides the IID in a serialized proxy with the specified IID.
+ * @param aStream Pointer to a stream containing a serialized proxy.
+ * @param aStart Offset to the beginning of the serialized proxy within aStream.
+ * @param aNewIid The replacement IID to apply to the serialized proxy.
+ */
+bool SetIID(NotNull<IStream*> aStream, const uint64_t aStart, REFIID aNewIid);
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_Objref_h
diff --git a/ipc/mscom/PassthruProxy.cpp b/ipc/mscom/PassthruProxy.cpp
new file mode 100644
index 0000000000..39bc7dba7e
--- /dev/null
+++ b/ipc/mscom/PassthruProxy.cpp
@@ -0,0 +1,393 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/mscom/PassthruProxy.h"
+#include "mozilla/mscom/ProxyStream.h"
+#include "VTableBuilder.h"
+
+// {96EF5801-CE6D-416E-A50A-0C2959AEAE1C}
+static const GUID CLSID_PassthruProxy = {
+ 0x96ef5801, 0xce6d, 0x416e, {0xa5, 0xa, 0xc, 0x29, 0x59, 0xae, 0xae, 0x1c}};
+
+namespace mozilla {
+namespace mscom {
+
+PassthruProxy::PassthruProxy()
+ : mRefCnt(0),
+ mWrappedIid(),
+ mVTableSize(0),
+ mVTable(nullptr),
+ mForgetPreservedStream(false) {}
+
+PassthruProxy::PassthruProxy(ProxyStream::Environment* aEnv, REFIID aIidToWrap,
+ uint32_t aVTableSize,
+ NotNull<IUnknown*> aObjToWrap)
+ : mRefCnt(0),
+ mWrappedIid(aIidToWrap),
+ mVTableSize(aVTableSize),
+ mVTable(nullptr),
+ mForgetPreservedStream(false) {
+ ProxyStream proxyStream(aIidToWrap, aObjToWrap, aEnv,
+ ProxyStreamFlags::ePreservable);
+ mPreservedStream = proxyStream.GetPreservedStream();
+ MOZ_ASSERT(mPreservedStream);
+}
+
+PassthruProxy::~PassthruProxy() {
+ if (mForgetPreservedStream) {
+ // We want to release the ref without clearing marshal data
+ IStream* stream = mPreservedStream.release();
+ stream->Release();
+ }
+
+ if (mVTable) {
+ DeleteNullVTable(mVTable);
+ }
+}
+
+HRESULT
+PassthruProxy::QueryProxyInterface(void** aOutInterface) {
+ // Even though we don't really provide the methods for the interface that
+ // we are proxying, we need to support it in QueryInterface. Instead we
+ // return an interface that, other than IUnknown, contains nullptr for all of
+ // its vtable entires. Obviously this interface is not intended to actually
+ // be called, it just has to be there.
+
+ if (!mVTable) {
+ MOZ_ASSERT(mVTableSize);
+ mVTable = BuildNullVTable(static_cast<IMarshal*>(this), mVTableSize);
+ MOZ_ASSERT(mVTable);
+ }
+
+ *aOutInterface = mVTable;
+ mVTable->AddRef();
+ return S_OK;
+}
+
+HRESULT
+PassthruProxy::QueryInterface(REFIID aIid, void** aOutInterface) {
+ if (!aOutInterface) {
+ return E_INVALIDARG;
+ }
+
+ *aOutInterface = nullptr;
+
+ if (aIid == IID_IUnknown || aIid == IID_IMarshal) {
+ RefPtr<IMarshal> ptr(this);
+ ptr.forget(aOutInterface);
+ return S_OK;
+ }
+
+ if (!IsInitialMarshal()) {
+ // We implement IClientSecurity so that IsProxy() recognizes us as such
+ if (aIid == IID_IClientSecurity) {
+ RefPtr<IClientSecurity> ptr(this);
+ ptr.forget(aOutInterface);
+ return S_OK;
+ }
+
+ if (aIid == mWrappedIid) {
+ return QueryProxyInterface(aOutInterface);
+ }
+ }
+
+ return E_NOINTERFACE;
+}
+
+ULONG
+PassthruProxy::AddRef() { return ++mRefCnt; }
+
+ULONG
+PassthruProxy::Release() {
+ ULONG result = --mRefCnt;
+ if (!result) {
+ delete this;
+ }
+
+ return result;
+}
+
+HRESULT
+PassthruProxy::GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ CLSID* pCid) {
+ if (!pCid) {
+ return E_INVALIDARG;
+ }
+
+ if (IsInitialMarshal()) {
+ // To properly use this class we need to be using TABLESTRONG marshaling
+ MOZ_ASSERT(mshlflags & MSHLFLAGS_TABLESTRONG);
+
+ // When we're marshaling for the first time, we identify ourselves as the
+ // class to use for unmarshaling.
+ *pCid = CLSID_PassthruProxy;
+ } else {
+ // Subsequent marshals use the standard marshaler.
+ *pCid = CLSID_StdMarshal;
+ }
+
+ return S_OK;
+}
+
+HRESULT
+PassthruProxy::GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ DWORD* pSize) {
+ STATSTG statstg;
+ HRESULT hr;
+
+ if (!IsInitialMarshal()) {
+ // If we are not the initial marshal then we are just copying mStream out
+ // to the marshal stream, so we just use mStream's size.
+ hr = mStream->Stat(&statstg, STATFLAG_NONAME);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ *pSize = statstg.cbSize.LowPart;
+
+ return hr;
+ }
+
+ // To properly use this class we need to be using TABLESTRONG marshaling
+ MOZ_ASSERT(mshlflags & MSHLFLAGS_TABLESTRONG);
+
+ if (!mPreservedStream) {
+ return E_POINTER;
+ }
+
+ hr = mPreservedStream->Stat(&statstg, STATFLAG_NONAME);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ *pSize = statstg.cbSize.LowPart + sizeof(mVTableSize) + sizeof(mWrappedIid);
+ return hr;
+}
+
+HRESULT
+PassthruProxy::MarshalInterface(IStream* pStm, REFIID riid, void* pv,
+ DWORD dwDestContext, void* pvDestContext,
+ DWORD mshlflags) {
+ MOZ_ASSERT(riid == mWrappedIid);
+ if (riid != mWrappedIid) {
+ return E_NOINTERFACE;
+ }
+
+ MOZ_ASSERT(pv == mVTable);
+ if (pv != mVTable) {
+ return E_INVALIDARG;
+ }
+
+ HRESULT hr;
+ RefPtr<IStream> cloned;
+
+ if (IsInitialMarshal()) {
+ // To properly use this class we need to be using TABLESTRONG marshaling
+ MOZ_ASSERT(mshlflags & MSHLFLAGS_TABLESTRONG);
+
+ if (!mPreservedStream) {
+ return E_POINTER;
+ }
+
+ // We write out the vtable size and the IID so that the wrapped proxy knows
+ // how to build its vtable on the content side.
+ ULONG bytesWritten;
+ hr = pStm->Write(&mVTableSize, sizeof(mVTableSize), &bytesWritten);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ if (bytesWritten != sizeof(mVTableSize)) {
+ return E_UNEXPECTED;
+ }
+
+ hr = pStm->Write(&mWrappedIid, sizeof(mWrappedIid), &bytesWritten);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ if (bytesWritten != sizeof(mWrappedIid)) {
+ return E_UNEXPECTED;
+ }
+
+ hr = mPreservedStream->Clone(getter_AddRefs(cloned));
+ } else {
+ hr = mStream->Clone(getter_AddRefs(cloned));
+ }
+
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ STATSTG statstg;
+ hr = cloned->Stat(&statstg, STATFLAG_NONAME);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ // Copy the proxy data
+ hr = cloned->CopyTo(pStm, statstg.cbSize, nullptr, nullptr);
+
+ if (SUCCEEDED(hr) && IsInitialMarshal() && mPreservedStream &&
+ (mshlflags & MSHLFLAGS_TABLESTRONG)) {
+ // If we have successfully copied mPreservedStream at least once for a
+ // MSHLFLAGS_TABLESTRONG marshal, then we want to forget our reference to
+ // it. This is because the COM runtime will manage it from here on out.
+ mForgetPreservedStream = true;
+ }
+
+ return hr;
+}
+
+HRESULT
+PassthruProxy::UnmarshalInterface(IStream* pStm, REFIID riid, void** ppv) {
+ // Read out the interface info that we copied during marshaling
+ ULONG bytesRead;
+ HRESULT hr = pStm->Read(&mVTableSize, sizeof(mVTableSize), &bytesRead);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ if (bytesRead != sizeof(mVTableSize)) {
+ return E_UNEXPECTED;
+ }
+
+ hr = pStm->Read(&mWrappedIid, sizeof(mWrappedIid), &bytesRead);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ if (bytesRead != sizeof(mWrappedIid)) {
+ return E_UNEXPECTED;
+ }
+
+ // Now we copy the proxy inside pStm into mStream
+ hr = CopySerializedProxy(pStm, getter_AddRefs(mStream));
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ return QueryInterface(riid, ppv);
+}
+
+HRESULT
+PassthruProxy::ReleaseMarshalData(IStream* pStm) {
+ if (!IsInitialMarshal()) {
+ return S_OK;
+ }
+
+ if (!pStm) {
+ return E_INVALIDARG;
+ }
+
+ if (mPreservedStream) {
+ // If we still have mPreservedStream, then simply clearing it will release
+ // its marshal data automagically.
+ mPreservedStream = nullptr;
+ return S_OK;
+ }
+
+ // Skip past the metadata that we wrote during initial marshaling.
+ LARGE_INTEGER seekTo;
+ seekTo.QuadPart = sizeof(mVTableSize) + sizeof(mWrappedIid);
+ HRESULT hr = pStm->Seek(seekTo, STREAM_SEEK_CUR, nullptr);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ // Now release the "inner" marshal data
+ return ::CoReleaseMarshalData(pStm);
+}
+
+HRESULT
+PassthruProxy::DisconnectObject(DWORD dwReserved) { return S_OK; }
+
+// The remainder of this code is just boilerplate COM stuff that provides the
+// association between CLSID_PassthruProxy and the PassthruProxy class itself.
+
+class PassthruProxyClassObject final : public IClassFactory {
+ public:
+ PassthruProxyClassObject();
+
+ // IUnknown
+ STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) override;
+ STDMETHODIMP_(ULONG) AddRef() override;
+ STDMETHODIMP_(ULONG) Release() override;
+
+ // IClassFactory
+ STDMETHODIMP CreateInstance(IUnknown* aOuter, REFIID aIid,
+ void** aOutObject) override;
+ STDMETHODIMP LockServer(BOOL aLock) override;
+
+ private:
+ ~PassthruProxyClassObject() = default;
+
+ Atomic<ULONG> mRefCnt;
+};
+
+PassthruProxyClassObject::PassthruProxyClassObject() : mRefCnt(0) {}
+
+HRESULT
+PassthruProxyClassObject::QueryInterface(REFIID aIid, void** aOutInterface) {
+ if (!aOutInterface) {
+ return E_INVALIDARG;
+ }
+
+ *aOutInterface = nullptr;
+
+ if (aIid == IID_IUnknown || aIid == IID_IClassFactory) {
+ RefPtr<IClassFactory> ptr(this);
+ ptr.forget(aOutInterface);
+ return S_OK;
+ }
+
+ return E_NOINTERFACE;
+}
+
+ULONG
+PassthruProxyClassObject::AddRef() { return ++mRefCnt; }
+
+ULONG
+PassthruProxyClassObject::Release() {
+ ULONG result = --mRefCnt;
+ if (!result) {
+ delete this;
+ }
+
+ return result;
+}
+
+HRESULT
+PassthruProxyClassObject::CreateInstance(IUnknown* aOuter, REFIID aIid,
+ void** aOutObject) {
+ // We don't expect to aggregate
+ MOZ_ASSERT(!aOuter);
+ if (aOuter) {
+ return E_INVALIDARG;
+ }
+
+ RefPtr<PassthruProxy> ptr(new PassthruProxy());
+ return ptr->QueryInterface(aIid, aOutObject);
+}
+
+HRESULT
+PassthruProxyClassObject::LockServer(BOOL aLock) {
+ // No-op since xul.dll is always in memory
+ return S_OK;
+}
+
+/* static */
+HRESULT PassthruProxy::Register() {
+ DWORD cookie;
+ RefPtr<IClassFactory> classObj(new PassthruProxyClassObject());
+ return ::CoRegisterClassObject(CLSID_PassthruProxy, classObj,
+ CLSCTX_INPROC_SERVER, REGCLS_MULTIPLEUSE,
+ &cookie);
+}
+
+} // namespace mscom
+} // namespace mozilla
+
+HRESULT
+RegisterPassthruProxy() { return mozilla::mscom::PassthruProxy::Register(); }
diff --git a/ipc/mscom/PassthruProxy.h b/ipc/mscom/PassthruProxy.h
new file mode 100644
index 0000000000..afdb5c648e
--- /dev/null
+++ b/ipc/mscom/PassthruProxy.h
@@ -0,0 +1,127 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_PassthruProxy_h
+#define mozilla_mscom_PassthruProxy_h
+
+#include "mozilla/Atomics.h"
+#include "mozilla/mscom/ProxyStream.h"
+#include "mozilla/mscom/Ptr.h"
+#include "mozilla/NotNull.h"
+#if defined(MOZ_SANDBOX)
+# include "mozilla/SandboxSettings.h"
+#endif // defined(MOZ_SANDBOX)
+
+#include <objbase.h>
+
+namespace mozilla {
+namespace mscom {
+namespace detail {
+
+template <typename Iface>
+struct VTableSizer;
+
+template <>
+struct VTableSizer<IDispatch> {
+ enum { Size = 7 };
+};
+
+} // namespace detail
+
+class PassthruProxy final : public IMarshal, public IClientSecurity {
+ public:
+ template <typename Iface>
+ static RefPtr<Iface> Wrap(NotNull<Iface*> aIn) {
+ static_assert(detail::VTableSizer<Iface>::Size >= 3, "VTable too small");
+
+#if defined(MOZ_SANDBOX)
+ if (mozilla::GetEffectiveContentSandboxLevel() < 3) {
+ // The sandbox isn't strong enough to be a problem; no wrapping required
+ return aIn.get();
+ }
+
+ typename detail::EnvironmentSelector<Iface>::Type env;
+
+ RefPtr<PassthruProxy> passthru(new PassthruProxy(
+ &env, __uuidof(Iface), detail::VTableSizer<Iface>::Size, aIn));
+
+ RefPtr<Iface> result;
+ if (FAILED(passthru->QueryProxyInterface(getter_AddRefs(result)))) {
+ return nullptr;
+ }
+
+ return result;
+#else
+ // No wrapping required
+ return aIn.get();
+#endif // defined(MOZ_SANDBOX)
+ }
+
+ static HRESULT Register();
+
+ PassthruProxy();
+
+ // IUnknown
+ STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override;
+ STDMETHODIMP_(ULONG) AddRef() override;
+ STDMETHODIMP_(ULONG) Release() override;
+
+ // IMarshal
+ STDMETHODIMP GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ CLSID* pCid) override;
+ STDMETHODIMP GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ DWORD* pSize) override;
+ STDMETHODIMP MarshalInterface(IStream* pStm, REFIID riid, void* pv,
+ DWORD dwDestContext, void* pvDestContext,
+ DWORD mshlflags) override;
+ STDMETHODIMP UnmarshalInterface(IStream* pStm, REFIID riid,
+ void** ppv) override;
+ STDMETHODIMP ReleaseMarshalData(IStream* pStm) override;
+ STDMETHODIMP DisconnectObject(DWORD dwReserved) override;
+
+ // IClientSecurity - we don't actually implement this interface, but its
+ // presence signals to mscom::IsProxy() that we are a proxy.
+ STDMETHODIMP QueryBlanket(IUnknown* aProxy, DWORD* aAuthnSvc,
+ DWORD* aAuthzSvc, OLECHAR** aSrvPrincName,
+ DWORD* aAuthnLevel, DWORD* aImpLevel,
+ void** aAuthInfo, DWORD* aCapabilities) override {
+ return E_NOTIMPL;
+ }
+
+ STDMETHODIMP SetBlanket(IUnknown* aProxy, DWORD aAuthnSvc, DWORD aAuthzSvc,
+ OLECHAR* aSrvPrincName, DWORD aAuthnLevel,
+ DWORD aImpLevel, void* aAuthInfo,
+ DWORD aCapabilities) override {
+ return E_NOTIMPL;
+ }
+
+ STDMETHODIMP CopyProxy(IUnknown* aProxy, IUnknown** aOutCopy) override {
+ return E_NOTIMPL;
+ }
+
+ private:
+ PassthruProxy(ProxyStream::Environment* aEnv, REFIID aIidToWrap,
+ uint32_t aVTableSize, NotNull<IUnknown*> aObjToWrap);
+ ~PassthruProxy();
+
+ bool IsInitialMarshal() const { return !mStream; }
+ HRESULT QueryProxyInterface(void** aOutInterface);
+
+ Atomic<ULONG> mRefCnt;
+ IID mWrappedIid;
+ PreservedStreamPtr mPreservedStream;
+ RefPtr<IStream> mStream;
+ uint32_t mVTableSize;
+ IUnknown* mVTable;
+ bool mForgetPreservedStream;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_PassthruProxy_h
diff --git a/ipc/mscom/ProcessRuntime.cpp b/ipc/mscom/ProcessRuntime.cpp
new file mode 100644
index 0000000000..2b4e613d75
--- /dev/null
+++ b/ipc/mscom/ProcessRuntime.cpp
@@ -0,0 +1,480 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/mscom/ProcessRuntime.h"
+
+#if defined(ACCESSIBILITY) && \
+ (defined(MOZILLA_INTERNAL_API) || defined(MOZ_HAS_MOZGLUE))
+# include "mozilla/mscom/ActCtxResource.h"
+#endif // defined(ACCESSIBILITY) && (defined(MOZILLA_INTERNAL_API) ||
+ // defined(MOZ_HAS_MOZGLUE))
+#include "mozilla/Assertions.h"
+#include "mozilla/DynamicallyLinkedFunctionPtr.h"
+#include "mozilla/mscom/COMWrappers.h"
+#include "mozilla/mscom/ProcessRuntimeShared.h"
+#include "mozilla/RefPtr.h"
+#include "mozilla/UniquePtr.h"
+#include "mozilla/Unused.h"
+#include "mozilla/Vector.h"
+#include "mozilla/WindowsProcessMitigations.h"
+#include "mozilla/WindowsVersion.h"
+
+#if defined(MOZILLA_INTERNAL_API)
+# include "mozilla/mscom/EnsureMTA.h"
+# if defined(MOZ_SANDBOX)
+# include "mozilla/sandboxTarget.h"
+# endif // defined(MOZ_SANDBOX)
+#endif // defined(MOZILLA_INTERNAL_API)
+
+#include <accctrl.h>
+#include <aclapi.h>
+#include <objbase.h>
+#include <objidl.h>
+
+// This API from oleaut32.dll is not declared in Windows SDK headers
+extern "C" void __cdecl SetOaNoCache(void);
+
+using namespace mozilla::mscom::detail;
+
+namespace mozilla {
+namespace mscom {
+
+#if defined(MOZILLA_INTERNAL_API)
+ProcessRuntime* ProcessRuntime::sInstance = nullptr;
+
+ProcessRuntime::ProcessRuntime() : ProcessRuntime(XRE_GetProcessType()) {}
+
+ProcessRuntime::ProcessRuntime(const GeckoProcessType aProcessType)
+ : ProcessRuntime(aProcessType == GeckoProcessType_Default
+ ? ProcessCategory::GeckoBrowserParent
+ : ProcessCategory::GeckoChild) {}
+#endif // defined(MOZILLA_INTERNAL_API)
+
+ProcessRuntime::ProcessRuntime(const ProcessCategory aProcessCategory)
+ : mInitResult(CO_E_NOTINITIALIZED), mProcessCategory(aProcessCategory) {
+#if defined(ACCESSIBILITY)
+# if defined(MOZILLA_INTERNAL_API)
+ // If we're inside XUL, and we're the parent process, then we trust that
+ // this has already been initialized for us prior to XUL being loaded.
+ // Only required in the child if the Resource ID has been passed down.
+ if (aProcessCategory != ProcessCategory::GeckoBrowserParent &&
+ ActCtxResource::GetAccessibilityResourceId()) {
+ mActCtxRgn.emplace(ActCtxResource::GetAccessibilityResource());
+ }
+# elif defined(MOZ_HAS_MOZGLUE)
+ // If we're here, then we're in mozglue and initializing this for the parent
+ // process.
+ MOZ_ASSERT(aProcessCategory == ProcessCategory::GeckoBrowserParent);
+ mActCtxRgn.emplace(ActCtxResource::GetAccessibilityResource());
+# endif
+#endif // defined(ACCESSIBILITY)
+
+#if defined(MOZILLA_INTERNAL_API)
+ MOZ_DIAGNOSTIC_ASSERT(!sInstance);
+ sInstance = this;
+
+ EnsureMTA();
+ /**
+ * From this point forward, all threads in this process are implicitly
+ * members of the multi-threaded apartment, with the following exceptions:
+ * 1. If any Win32 GUI APIs were called on the current thread prior to
+ * executing this constructor, then this thread has already been implicitly
+ * initialized as the process's main STA thread; or
+ * 2. A thread explicitly and successfully calls CoInitialize(Ex) to specify
+ * otherwise.
+ */
+
+ const bool isCurThreadImplicitMTA = IsCurrentThreadImplicitMTA();
+ // We only assert that the implicit MTA precondition holds when not running
+ // as the Gecko parent process.
+ MOZ_DIAGNOSTIC_ASSERT(aProcessCategory ==
+ ProcessCategory::GeckoBrowserParent ||
+ isCurThreadImplicitMTA);
+
+# if defined(MOZ_SANDBOX)
+ const bool isLockedDownChildProcess =
+ mProcessCategory == ProcessCategory::GeckoChild && IsWin32kLockedDown();
+ // If our process is running under Win32k lockdown, we cannot initialize
+ // COM with a single-threaded apartment. This is because STAs create a hidden
+ // window, which implicitly requires user32 and Win32k, which are blocked.
+ // Instead we start the multi-threaded apartment and conduct our process-wide
+ // COM initialization there.
+ if (isLockedDownChildProcess) {
+ // Make sure we're still running with the sandbox's privileged impersonation
+ // token.
+ HANDLE rawCurThreadImpToken;
+ if (!::OpenThreadToken(::GetCurrentThread(), TOKEN_DUPLICATE | TOKEN_QUERY,
+ FALSE, &rawCurThreadImpToken)) {
+ mInitResult = HRESULT_FROM_WIN32(::GetLastError());
+ return;
+ }
+ nsAutoHandle curThreadImpToken(rawCurThreadImpToken);
+
+ // Ensure that our current token is still an impersonation token (ie, we
+ // have not yet called RevertToSelf() on this thread).
+ DWORD len;
+ TOKEN_TYPE tokenType;
+ MOZ_RELEASE_ASSERT(
+ ::GetTokenInformation(rawCurThreadImpToken, TokenType, &tokenType,
+ sizeof(tokenType), &len) &&
+ len == sizeof(tokenType) && tokenType == TokenImpersonation);
+
+ // Ideally we want our current thread to be running implicitly inside the
+ // MTA, but if for some wacky reason we did not end up with that, we may
+ // compensate by completing initialization via EnsureMTA's persistent
+ // thread.
+ if (!isCurThreadImplicitMTA) {
+ InitUsingPersistentMTAThread(curThreadImpToken);
+ return;
+ }
+ }
+# endif // defined(MOZ_SANDBOX)
+#endif // defined(MOZILLA_INTERNAL_API)
+
+ mAptRegion.Init(GetDesiredApartmentType(mProcessCategory));
+
+ // It can happen that we are not the outermost COM initialization on this
+ // thread. In fact it should regularly be the case that the outermost
+ // initialization occurs from outside of XUL, before we show the skeleton UI,
+ // at which point we still need to run some things here from within XUL.
+ if (!mAptRegion.IsValidOutermost()) {
+ mInitResult = mAptRegion.GetHResult();
+#if defined(MOZILLA_INTERNAL_API)
+ MOZ_ASSERT(mProcessCategory == ProcessCategory::GeckoBrowserParent);
+ if (mProcessCategory != ProcessCategory::GeckoBrowserParent) {
+ // This is unexpected unless we're GeckoBrowserParent
+ return;
+ }
+
+ ProcessInitLock lock;
+
+ // Is another instance of ProcessRuntime responsible for the outer
+ // initialization?
+ const bool prevInit =
+ lock.GetInitState() == ProcessInitState::FullyInitialized;
+ MOZ_ASSERT(prevInit);
+ if (prevInit) {
+ PostInit();
+ }
+#endif // defined(MOZILLA_INTERNAL_API)
+ return;
+ }
+
+ InitInsideApartment();
+ if (FAILED(mInitResult)) {
+ return;
+ }
+
+#if defined(MOZILLA_INTERNAL_API)
+# if defined(MOZ_SANDBOX)
+ if (isLockedDownChildProcess) {
+ // In locked-down child processes, defer PostInit until priv drop
+ SandboxTarget::Instance()->RegisterSandboxStartCallback([self = this]() {
+ // Ensure that we're still live and the init was successful before
+ // calling PostInit()
+ if (self == sInstance && SUCCEEDED(self->mInitResult)) {
+ PostInit();
+ }
+ });
+ return;
+ }
+# endif // defined(MOZ_SANDBOX)
+
+ PostInit();
+#endif // defined(MOZILLA_INTERNAL_API)
+}
+
+#if defined(MOZILLA_INTERNAL_API)
+ProcessRuntime::~ProcessRuntime() {
+ MOZ_DIAGNOSTIC_ASSERT(sInstance == this);
+ sInstance = nullptr;
+}
+
+# if defined(MOZ_SANDBOX)
+void ProcessRuntime::InitUsingPersistentMTAThread(
+ const nsAutoHandle& aCurThreadToken) {
+ // Create an impersonation token based on the current thread's token
+ HANDLE rawMtaThreadImpToken = nullptr;
+ if (!::DuplicateToken(aCurThreadToken, SecurityImpersonation,
+ &rawMtaThreadImpToken)) {
+ mInitResult = HRESULT_FROM_WIN32(::GetLastError());
+ return;
+ }
+ nsAutoHandle mtaThreadImpToken(rawMtaThreadImpToken);
+
+ // Impersonate and initialize.
+ bool tokenSet = false;
+ EnsureMTA(
+ [this, rawMtaThreadImpToken, &tokenSet]() -> void {
+ if (!::SetThreadToken(nullptr, rawMtaThreadImpToken)) {
+ mInitResult = HRESULT_FROM_WIN32(::GetLastError());
+ return;
+ }
+
+ tokenSet = true;
+ InitInsideApartment();
+ },
+ EnsureMTA::Option::ForceDispatchToPersistentThread);
+
+ if (!tokenSet) {
+ return;
+ }
+
+ SandboxTarget::Instance()->RegisterSandboxStartCallback(
+ [self = this]() -> void {
+ EnsureMTA(
+ []() -> void {
+ // This is a security risk if it fails, so we release assert
+ MOZ_RELEASE_ASSERT(::RevertToSelf(),
+ "mscom::ProcessRuntime RevertToSelf failed");
+ },
+ EnsureMTA::Option::ForceDispatchToPersistentThread);
+
+ // Ensure that we're still live and the init was successful before
+ // calling PostInit()
+ if (self == sInstance && SUCCEEDED(self->mInitResult)) {
+ PostInit();
+ }
+ });
+}
+# endif // defined(MOZ_SANDBOX)
+#endif // defined(MOZILLA_INTERNAL_API)
+
+/* static */
+COINIT ProcessRuntime::GetDesiredApartmentType(
+ const ProcessRuntime::ProcessCategory aProcessCategory) {
+ switch (aProcessCategory) {
+ case ProcessCategory::GeckoBrowserParent:
+ return COINIT_APARTMENTTHREADED;
+ case ProcessCategory::GeckoChild:
+ if (!IsWin32kLockedDown()) {
+ // If Win32k is not locked down then we probably still need STA.
+ // We disable DDE since that is not usable from child processes.
+ return static_cast<COINIT>(COINIT_APARTMENTTHREADED |
+ COINIT_DISABLE_OLE1DDE);
+ }
+
+ [[fallthrough]];
+ default:
+ return COINIT_MULTITHREADED;
+ }
+}
+
+void ProcessRuntime::InitInsideApartment() {
+ ProcessInitLock lock;
+ const ProcessInitState prevInitState = lock.GetInitState();
+ if (prevInitState == ProcessInitState::FullyInitialized) {
+ // COM has already been initialized by a previous ProcessRuntime instance
+ mInitResult = S_OK;
+ return;
+ }
+
+ if (prevInitState < ProcessInitState::PartialSecurityInitialized) {
+ // We are required to initialize security prior to configuring global
+ // options.
+ mInitResult = InitializeSecurity(mProcessCategory);
+ MOZ_DIAGNOSTIC_ASSERT(SUCCEEDED(mInitResult));
+
+ // Even though this isn't great, we should try to proceed even when
+ // CoInitializeSecurity has previously been called: the additional settings
+ // we want to change are important enough that we don't want to skip them.
+ if (FAILED(mInitResult) && mInitResult != RPC_E_TOO_LATE) {
+ return;
+ }
+
+ lock.SetInitState(ProcessInitState::PartialSecurityInitialized);
+ }
+
+ if (prevInitState < ProcessInitState::PartialGlobalOptions) {
+ RefPtr<IGlobalOptions> globalOpts;
+ mInitResult = wrapped::CoCreateInstance(
+ CLSID_GlobalOptions, nullptr, CLSCTX_INPROC_SERVER, IID_IGlobalOptions,
+ getter_AddRefs(globalOpts));
+ MOZ_ASSERT(SUCCEEDED(mInitResult));
+ if (FAILED(mInitResult)) {
+ return;
+ }
+
+ // Disable COM's catch-all exception handler
+ mInitResult = globalOpts->Set(COMGLB_EXCEPTION_HANDLING,
+ COMGLB_EXCEPTION_DONOT_HANDLE_ANY);
+ MOZ_ASSERT(SUCCEEDED(mInitResult));
+ if (FAILED(mInitResult)) {
+ return;
+ }
+
+ lock.SetInitState(ProcessInitState::PartialGlobalOptions);
+ }
+
+ // Disable the BSTR cache (as it never invalidates, thus leaking memory)
+ // (This function is itself idempotent, so we do not concern ourselves with
+ // tracking whether or not we've already called it.)
+ ::SetOaNoCache();
+
+ lock.SetInitState(ProcessInitState::FullyInitialized);
+}
+
+#if defined(MOZILLA_INTERNAL_API)
+/**
+ * Guaranteed to run *after* the COM (and possible sandboxing) initialization
+ * has successfully completed and stabilized. This method MUST BE IDEMPOTENT!
+ */
+/* static */ void ProcessRuntime::PostInit() {
+ // Currently "roughed-in" but unused.
+}
+#endif // defined(MOZILLA_INTERNAL_API)
+
+/* static */
+DWORD
+ProcessRuntime::GetClientThreadId() {
+ DWORD callerTid;
+ HRESULT hr = ::CoGetCallerTID(&callerTid);
+ // Don't return callerTid unless the call succeeded and returned S_FALSE,
+ // indicating that the caller originates from a different process.
+ if (hr != S_FALSE) {
+ return 0;
+ }
+
+ return callerTid;
+}
+
+/* static */
+HRESULT
+ProcessRuntime::InitializeSecurity(const ProcessCategory aProcessCategory) {
+ HANDLE rawToken = nullptr;
+ BOOL ok = ::OpenProcessToken(::GetCurrentProcess(), TOKEN_QUERY, &rawToken);
+ if (!ok) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+ nsAutoHandle token(rawToken);
+
+ DWORD len = 0;
+ ok = ::GetTokenInformation(token, TokenUser, nullptr, len, &len);
+ DWORD win32Error = ::GetLastError();
+ if (!ok && win32Error != ERROR_INSUFFICIENT_BUFFER) {
+ return HRESULT_FROM_WIN32(win32Error);
+ }
+
+ auto tokenUserBuf = MakeUnique<BYTE[]>(len);
+ TOKEN_USER& tokenUser = *reinterpret_cast<TOKEN_USER*>(tokenUserBuf.get());
+ ok = ::GetTokenInformation(token, TokenUser, tokenUserBuf.get(), len, &len);
+ if (!ok) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+
+ len = 0;
+ ok = ::GetTokenInformation(token, TokenPrimaryGroup, nullptr, len, &len);
+ win32Error = ::GetLastError();
+ if (!ok && win32Error != ERROR_INSUFFICIENT_BUFFER) {
+ return HRESULT_FROM_WIN32(win32Error);
+ }
+
+ auto tokenPrimaryGroupBuf = MakeUnique<BYTE[]>(len);
+ TOKEN_PRIMARY_GROUP& tokenPrimaryGroup =
+ *reinterpret_cast<TOKEN_PRIMARY_GROUP*>(tokenPrimaryGroupBuf.get());
+ ok = ::GetTokenInformation(token, TokenPrimaryGroup,
+ tokenPrimaryGroupBuf.get(), len, &len);
+ if (!ok) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+
+ SECURITY_DESCRIPTOR sd;
+ if (!::InitializeSecurityDescriptor(&sd, SECURITY_DESCRIPTOR_REVISION)) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+
+ BYTE systemSid[SECURITY_MAX_SID_SIZE];
+ DWORD systemSidSize = sizeof(systemSid);
+ if (!::CreateWellKnownSid(WinLocalSystemSid, nullptr, systemSid,
+ &systemSidSize)) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+
+ BYTE adminSid[SECURITY_MAX_SID_SIZE];
+ DWORD adminSidSize = sizeof(adminSid);
+ if (!::CreateWellKnownSid(WinBuiltinAdministratorsSid, nullptr, adminSid,
+ &adminSidSize)) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+
+ const bool allowAppContainers =
+ aProcessCategory == ProcessCategory::GeckoBrowserParent &&
+ IsWin8OrLater();
+
+ BYTE appContainersSid[SECURITY_MAX_SID_SIZE];
+ DWORD appContainersSidSize = sizeof(appContainersSid);
+ if (allowAppContainers) {
+ if (!::CreateWellKnownSid(WinBuiltinAnyPackageSid, nullptr,
+ appContainersSid, &appContainersSidSize)) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+ }
+
+ // Grant access to SYSTEM, Administrators, the user, and when running as the
+ // browser process on Windows 8+, all app containers.
+ const size_t kMaxInlineEntries = 4;
+ mozilla::Vector<EXPLICIT_ACCESS_W, kMaxInlineEntries> entries;
+
+ Unused << entries.append(EXPLICIT_ACCESS_W{
+ COM_RIGHTS_EXECUTE,
+ GRANT_ACCESS,
+ NO_INHERITANCE,
+ {nullptr, NO_MULTIPLE_TRUSTEE, TRUSTEE_IS_SID, TRUSTEE_IS_USER,
+ reinterpret_cast<LPWSTR>(systemSid)}});
+
+ Unused << entries.append(EXPLICIT_ACCESS_W{
+ COM_RIGHTS_EXECUTE,
+ GRANT_ACCESS,
+ NO_INHERITANCE,
+ {nullptr, NO_MULTIPLE_TRUSTEE, TRUSTEE_IS_SID,
+ TRUSTEE_IS_WELL_KNOWN_GROUP, reinterpret_cast<LPWSTR>(adminSid)}});
+
+ Unused << entries.append(EXPLICIT_ACCESS_W{
+ COM_RIGHTS_EXECUTE,
+ GRANT_ACCESS,
+ NO_INHERITANCE,
+ {nullptr, NO_MULTIPLE_TRUSTEE, TRUSTEE_IS_SID, TRUSTEE_IS_USER,
+ reinterpret_cast<LPWSTR>(tokenUser.User.Sid)}});
+
+ if (allowAppContainers) {
+ Unused << entries.append(
+ EXPLICIT_ACCESS_W{COM_RIGHTS_EXECUTE,
+ GRANT_ACCESS,
+ NO_INHERITANCE,
+ {nullptr, NO_MULTIPLE_TRUSTEE, TRUSTEE_IS_SID,
+ TRUSTEE_IS_WELL_KNOWN_GROUP,
+ reinterpret_cast<LPWSTR>(appContainersSid)}});
+ }
+
+ PACL rawDacl = nullptr;
+ win32Error =
+ ::SetEntriesInAclW(entries.length(), entries.begin(), nullptr, &rawDacl);
+ if (win32Error != ERROR_SUCCESS) {
+ return HRESULT_FROM_WIN32(win32Error);
+ }
+
+ UniquePtr<ACL, LocalFreeDeleter> dacl(rawDacl);
+
+ if (!::SetSecurityDescriptorDacl(&sd, TRUE, dacl.get(), FALSE)) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+
+ if (!::SetSecurityDescriptorOwner(&sd, tokenUser.User.Sid, FALSE)) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+
+ if (!::SetSecurityDescriptorGroup(&sd, tokenPrimaryGroup.PrimaryGroup,
+ FALSE)) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+
+ return wrapped::CoInitializeSecurity(
+ &sd, -1, nullptr, nullptr, RPC_C_AUTHN_LEVEL_DEFAULT,
+ RPC_C_IMP_LEVEL_IDENTIFY, nullptr, EOAC_NONE, nullptr);
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/ProcessRuntime.h b/ipc/mscom/ProcessRuntime.h
new file mode 100644
index 0000000000..7b52a56b8a
--- /dev/null
+++ b/ipc/mscom/ProcessRuntime.h
@@ -0,0 +1,96 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_ProcessRuntime_h
+#define mozilla_mscom_ProcessRuntime_h
+
+#include "mozilla/Attributes.h"
+#if defined(ACCESSIBILITY)
+# include "mozilla/mscom/ActivationContext.h"
+# include "mozilla/Maybe.h"
+#endif // defined(ACCESSIBILITY)
+#include "mozilla/mscom/ApartmentRegion.h"
+#include "nsWindowsHelpers.h"
+#if defined(MOZILLA_INTERNAL_API)
+# include "nsXULAppAPI.h"
+#endif // defined(MOZILLA_INTERNAL_API)
+
+namespace mozilla {
+namespace mscom {
+
+class MOZ_NON_TEMPORARY_CLASS ProcessRuntime final {
+#if !defined(MOZILLA_INTERNAL_API)
+ public:
+#endif // defined(MOZILLA_INTERNAL_API)
+ enum class ProcessCategory {
+ GeckoBrowserParent,
+ // We give Launcher its own process category, but internally to this class
+ // it should be treated identically to GeckoBrowserParent.
+ Launcher = GeckoBrowserParent,
+ GeckoChild,
+ Service,
+ };
+
+ // This constructor is only public when compiled outside of XUL
+ explicit ProcessRuntime(const ProcessCategory aProcessCategory);
+
+ public:
+#if defined(MOZILLA_INTERNAL_API)
+ ProcessRuntime();
+ ~ProcessRuntime();
+#else
+ ~ProcessRuntime() = default;
+#endif // defined(MOZILLA_INTERNAL_API)
+
+ explicit operator bool() const { return SUCCEEDED(mInitResult); }
+ HRESULT GetHResult() const { return mInitResult; }
+
+ ProcessRuntime(const ProcessRuntime&) = delete;
+ ProcessRuntime(ProcessRuntime&&) = delete;
+ ProcessRuntime& operator=(const ProcessRuntime&) = delete;
+ ProcessRuntime& operator=(ProcessRuntime&&) = delete;
+
+ /**
+ * @return 0 if call is in-process or resolving the calling thread failed,
+ * otherwise contains the thread id of the calling thread.
+ */
+ static DWORD GetClientThreadId();
+
+ private:
+#if defined(MOZILLA_INTERNAL_API)
+ explicit ProcessRuntime(const GeckoProcessType aProcessType);
+# if defined(MOZ_SANDBOX)
+ void InitUsingPersistentMTAThread(const nsAutoHandle& aCurThreadToken);
+# endif // defined(MOZ_SANDBOX)
+#endif // defined(MOZILLA_INTERNAL_API)
+ void InitInsideApartment();
+
+#if defined(MOZILLA_INTERNAL_API)
+ static void PostInit();
+#endif // defined(MOZILLA_INTERNAL_API)
+ static HRESULT InitializeSecurity(const ProcessCategory aProcessCategory);
+ static COINIT GetDesiredApartmentType(const ProcessCategory aProcessCategory);
+
+ private:
+ HRESULT mInitResult;
+ const ProcessCategory mProcessCategory;
+#if defined(ACCESSIBILITY) && \
+ (defined(MOZILLA_INTERNAL_API) || defined(MOZ_HAS_MOZGLUE))
+ Maybe<ActivationContextRegion> mActCtxRgn;
+#endif // defined(ACCESSIBILITY) && (defined(MOZILLA_INTERNAL_API) ||
+ // defined(MOZ_HAS_MOZGLUE))
+ ApartmentRegion mAptRegion;
+
+ private:
+#if defined(MOZILLA_INTERNAL_API)
+ static ProcessRuntime* sInstance;
+#endif // defined(MOZILLA_INTERNAL_API)
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_ProcessRuntime_h
diff --git a/ipc/mscom/ProfilerMarkers.cpp b/ipc/mscom/ProfilerMarkers.cpp
new file mode 100644
index 0000000000..29d033723f
--- /dev/null
+++ b/ipc/mscom/ProfilerMarkers.cpp
@@ -0,0 +1,236 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ProfilerMarkers.h"
+
+#include "MainThreadUtils.h"
+#include "mozilla/Assertions.h"
+#include "mozilla/Atomics.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/mscom/Utils.h"
+#include "mozilla/ProfilerMarkers.h"
+#include "mozilla/Services.h"
+#include "nsCOMPtr.h"
+#include "nsIObserver.h"
+#include "nsIObserverService.h"
+#include "nsISupportsImpl.h"
+#include "nsString.h"
+#include "nsXULAppAPI.h"
+
+#include <objbase.h>
+#include <objidlbase.h>
+
+// {9DBE6B28-E5E7-4FDE-AF00-9404604E74DC}
+static const GUID GUID_MozProfilerMarkerExtension = {
+ 0x9dbe6b28, 0xe5e7, 0x4fde, {0xaf, 0x0, 0x94, 0x4, 0x60, 0x4e, 0x74, 0xdc}};
+
+namespace {
+
+class ProfilerMarkerChannelHook final : public IChannelHook {
+ ~ProfilerMarkerChannelHook() = default;
+
+ public:
+ ProfilerMarkerChannelHook() : mRefCnt(0) {}
+
+ // IUnknown
+ STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) override;
+ STDMETHODIMP_(ULONG) AddRef() override;
+ STDMETHODIMP_(ULONG) Release() override;
+
+ /**
+ * IChannelHook exposes six methods: The Client* methods are called when
+ * a client is sending an IPC request, whereas the Server* methods are called
+ * when a server is receiving an IPC request.
+ *
+ * For our purposes, we only care about the client-side methods. The COM
+ * runtime invokes the methods in the following order:
+ * 1. ClientGetSize, where the hook specifies the size of its payload;
+ * 2. ClientFillBuffer, where the hook fills the channel's buffer with its
+ * payload information. NOTE: This method is only called when ClientGetSize
+ * specifies a non-zero payload size. For our purposes, since we are not
+ * sending a payload, this method will never be called!
+ * 3. ClientNotify, when the response has been received from the server.
+ *
+ * Since we want to use these hooks to record the beginning and end of a COM
+ * IPC call, we use ClientGetSize for logging the start, and ClientNotify for
+ * logging the end.
+ *
+ * Finally, our implementation responds to any request matching our extension
+ * ID, however we only care about main thread COM calls.
+ */
+
+ // IChannelHook
+ STDMETHODIMP_(void)
+ ClientGetSize(REFGUID aExtensionId, REFIID aIid,
+ ULONG* aOutDataSize) override;
+
+ // No-op (see the large comment above)
+ STDMETHODIMP_(void)
+ ClientFillBuffer(REFGUID aExtensionId, REFIID aIid, ULONG* aDataSize,
+ void* aDataBuf) override {}
+
+ STDMETHODIMP_(void)
+ ClientNotify(REFGUID aExtensionId, REFIID aIid, ULONG aDataSize,
+ void* aDataBuffer, DWORD aDataRep, HRESULT aFault) override;
+
+ // We don't care about the server-side notifications, so leave as no-ops.
+ STDMETHODIMP_(void)
+ ServerNotify(REFGUID aExtensionId, REFIID aIid, ULONG aDataSize,
+ void* aDataBuf, DWORD aDataRep) override {}
+
+ STDMETHODIMP_(void)
+ ServerGetSize(REFGUID aExtensionId, REFIID aIid, HRESULT aFault,
+ ULONG* aOutDataSize) override {}
+
+ STDMETHODIMP_(void)
+ ServerFillBuffer(REFGUID aExtensionId, REFIID aIid, ULONG* aDataSize,
+ void* aDataBuf, HRESULT aFault) override {}
+
+ private:
+ void BuildMarkerName(REFIID aIid, nsACString& aOutMarkerName);
+
+ private:
+ mozilla::Atomic<ULONG> mRefCnt;
+};
+
+HRESULT ProfilerMarkerChannelHook::QueryInterface(REFIID aIid,
+ void** aOutInterface) {
+ if (aIid == IID_IChannelHook || aIid == IID_IUnknown) {
+ RefPtr<IChannelHook> ptr(this);
+ ptr.forget(aOutInterface);
+ return S_OK;
+ }
+
+ return E_NOINTERFACE;
+}
+
+ULONG ProfilerMarkerChannelHook::AddRef() { return ++mRefCnt; }
+
+ULONG ProfilerMarkerChannelHook::Release() {
+ ULONG result = --mRefCnt;
+ if (!result) {
+ delete this;
+ }
+
+ return result;
+}
+
+void ProfilerMarkerChannelHook::BuildMarkerName(REFIID aIid,
+ nsACString& aOutMarkerName) {
+ aOutMarkerName.AssignLiteral("ORPC Call for ");
+
+ nsAutoCString iidStr;
+ mozilla::mscom::DiagnosticNameForIID(aIid, iidStr);
+ aOutMarkerName.Append(iidStr);
+}
+
+void ProfilerMarkerChannelHook::ClientGetSize(REFGUID aExtensionId, REFIID aIid,
+ ULONG* aOutDataSize) {
+ if (aExtensionId == GUID_MozProfilerMarkerExtension) {
+ if (NS_IsMainThread()) {
+ nsAutoCString markerName;
+ BuildMarkerName(aIid, markerName);
+ PROFILER_MARKER(markerName, IPC, mozilla::MarkerTiming::IntervalStart(),
+ Tracing, "MSCOM");
+ }
+
+ if (aOutDataSize) {
+ // We don't add any payload data to the channel
+ *aOutDataSize = 0UL;
+ }
+ }
+}
+
+void ProfilerMarkerChannelHook::ClientNotify(REFGUID aExtensionId, REFIID aIid,
+ ULONG aDataSize, void* aDataBuffer,
+ DWORD aDataRep, HRESULT aFault) {
+ if (NS_IsMainThread() && aExtensionId == GUID_MozProfilerMarkerExtension) {
+ nsAutoCString markerName;
+ BuildMarkerName(aIid, markerName);
+ PROFILER_MARKER(markerName, IPC, mozilla::MarkerTiming::IntervalEnd(),
+ Tracing, "MSCOM");
+ }
+}
+
+} // anonymous namespace
+
+static void RegisterChannelHook() {
+ RefPtr<ProfilerMarkerChannelHook> hook(new ProfilerMarkerChannelHook());
+ mozilla::DebugOnly<HRESULT> hr =
+ ::CoRegisterChannelHook(GUID_MozProfilerMarkerExtension, hook);
+ MOZ_ASSERT(SUCCEEDED(hr));
+}
+
+namespace {
+
+class ProfilerStartupObserver final : public nsIObserver {
+ ~ProfilerStartupObserver() = default;
+
+ public:
+ NS_DECL_ISUPPORTS
+ NS_DECL_NSIOBSERVER
+};
+
+NS_IMPL_ISUPPORTS(ProfilerStartupObserver, nsIObserver)
+
+NS_IMETHODIMP ProfilerStartupObserver::Observe(nsISupports* aSubject,
+ const char* aTopic,
+ const char16_t* aData) {
+ if (strcmp(aTopic, "profiler-started")) {
+ return NS_OK;
+ }
+
+ RegisterChannelHook();
+
+ // Once we've set the channel hook, we don't care about this notification
+ // anymore; our channel hook will remain set for the lifetime of the process.
+ nsCOMPtr<nsIObserverService> obsServ(mozilla::services::GetObserverService());
+ MOZ_ASSERT(!!obsServ);
+ if (!obsServ) {
+ return NS_OK;
+ }
+
+ obsServ->RemoveObserver(this, "profiler-started");
+ return NS_OK;
+}
+
+} // anonymous namespace
+
+namespace mozilla {
+namespace mscom {
+
+void InitProfilerMarkers() {
+ if (!XRE_IsParentProcess()) {
+ return;
+ }
+
+ MOZ_ASSERT(NS_IsMainThread());
+ if (!NS_IsMainThread()) {
+ return;
+ }
+
+ if (profiler_is_active()) {
+ // If the profiler is already running, we'll immediately register our
+ // channel hook.
+ RegisterChannelHook();
+ return;
+ }
+
+ // The profiler is not running yet. To avoid unnecessary invocations of the
+ // channel hook, we won't bother with installing it until the profiler starts.
+ // Set up an observer to watch for this.
+ nsCOMPtr<nsIObserverService> obsServ(mozilla::services::GetObserverService());
+ MOZ_ASSERT(!!obsServ);
+ if (!obsServ) {
+ return;
+ }
+
+ nsCOMPtr<nsIObserver> obs(new ProfilerStartupObserver());
+ obsServ->AddObserver(obs, "profiler-started", false);
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/ProfilerMarkers.h b/ipc/mscom/ProfilerMarkers.h
new file mode 100644
index 0000000000..a21d5b375a
--- /dev/null
+++ b/ipc/mscom/ProfilerMarkers.h
@@ -0,0 +1,18 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_ProfilerMarkers_h
+#define mozilla_mscom_ProfilerMarkers_h
+
+namespace mozilla {
+namespace mscom {
+
+void InitProfilerMarkers();
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_ProfilerMarkers_h
diff --git a/ipc/mscom/ProxyStream.cpp b/ipc/mscom/ProxyStream.cpp
new file mode 100644
index 0000000000..5a3457e147
--- /dev/null
+++ b/ipc/mscom/ProxyStream.cpp
@@ -0,0 +1,411 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <utility>
+#if defined(ACCESSIBILITY)
+# include "HandlerData.h"
+# include "mozilla/a11y/Platform.h"
+# include "mozilla/mscom/ActivationContext.h"
+#endif // defined(ACCESSIBILITY)
+#include "mozilla/mscom/EnsureMTA.h"
+#include "mozilla/mscom/ProxyStream.h"
+#include "mozilla/mscom/Utils.h"
+#include "mozilla/ScopeExit.h"
+
+#include "mozilla/mscom/Objref.h"
+#include "nsExceptionHandler.h"
+#include "nsPrintfCString.h"
+#include "RegistrationAnnotator.h"
+
+#include <windows.h>
+#include <objbase.h>
+#include <shlwapi.h>
+
+namespace mozilla {
+namespace mscom {
+
+ProxyStream::ProxyStream()
+ : mGlobalLockedBuf(nullptr),
+ mHGlobal(nullptr),
+ mBufSize(0),
+ mPreserveStream(false) {}
+
+// GetBuffer() fails with this variant, but that's okay because we're just
+// reconstructing the stream from a buffer anyway.
+ProxyStream::ProxyStream(REFIID aIID, const BYTE* aInitBuf,
+ const int aInitBufSize, Environment* aEnv)
+ : mGlobalLockedBuf(nullptr),
+ mHGlobal(nullptr),
+ mBufSize(aInitBufSize),
+ mPreserveStream(false) {
+ CrashReporter::Annotation kCrashReportKey =
+ CrashReporter::Annotation::ProxyStreamUnmarshalStatus;
+
+ if (!aInitBufSize) {
+ CrashReporter::AnnotateCrashReport(kCrashReportKey, "!aInitBufSize"_ns);
+ // We marshaled a nullptr. Nothing else to do here.
+ return;
+ }
+
+ HRESULT createStreamResult =
+ CreateStream(aInitBuf, aInitBufSize, getter_AddRefs(mStream));
+ if (FAILED(createStreamResult)) {
+ nsPrintfCString hrAsStr("0x%08lX", createStreamResult);
+ CrashReporter::AnnotateCrashReport(kCrashReportKey, hrAsStr);
+ return;
+ }
+
+ // NB: We can't check for a null mStream until after we have checked for
+ // the zero aInitBufSize above. This is because InitStream will also fail
+ // in that case, even though marshaling a nullptr is allowable.
+ MOZ_ASSERT(mStream);
+ if (!mStream) {
+ CrashReporter::AnnotateCrashReport(kCrashReportKey, "!mStream"_ns);
+ return;
+ }
+
+#if defined(ACCESSIBILITY)
+ const uint32_t expectedStreamLen = GetOBJREFSize(WrapNotNull(mStream));
+ nsAutoCString strActCtx;
+ nsAutoString manifestPath;
+#endif // defined(ACCESSIBILITY)
+
+ HRESULT unmarshalResult = S_OK;
+
+ // We need to convert to an interface here otherwise we mess up const
+ // correctness with IPDL. We'll request an IUnknown and then QI the
+ // actual interface later.
+
+#if defined(ACCESSIBILITY)
+ auto marshalFn = [this, &strActCtx, &manifestPath, &unmarshalResult, &aIID,
+ aEnv]() -> void
+#else
+ auto marshalFn = [this, &unmarshalResult, &aIID, aEnv]() -> void
+#endif // defined(ACCESSIBILITY)
+ {
+ if (aEnv) {
+ bool pushOk = aEnv->Push();
+ MOZ_DIAGNOSTIC_ASSERT(pushOk);
+ if (!pushOk) {
+ return;
+ }
+ }
+
+ auto popEnv = MakeScopeExit([aEnv]() -> void {
+ if (!aEnv) {
+ return;
+ }
+
+#ifdef MOZ_DIAGNOSTIC_ASSERT_ENABLED
+ bool popOk =
+#endif
+ aEnv->Pop();
+ MOZ_DIAGNOSTIC_ASSERT(popOk);
+ });
+
+#if defined(ACCESSIBILITY)
+ auto curActCtx = ActivationContext::GetCurrent();
+ if (curActCtx.isOk()) {
+ strActCtx.AppendPrintf("0x%" PRIxPTR, curActCtx.unwrap());
+ } else {
+ strActCtx.AppendPrintf("HRESULT 0x%08lX", curActCtx.unwrapErr());
+ }
+
+ ActivationContext::GetCurrentManifestPath(manifestPath);
+#endif // defined(ACCESSIBILITY)
+
+ unmarshalResult = ::CoUnmarshalInterface(mStream, aIID,
+ getter_AddRefs(mUnmarshaledProxy));
+ MOZ_ASSERT(SUCCEEDED(unmarshalResult));
+ };
+
+ if (XRE_IsParentProcess()) {
+ // We'll marshal this stuff directly using the current thread, therefore its
+ // proxy will reside in the same apartment as the current thread.
+ marshalFn();
+ } else {
+ // When marshaling in child processes, we want to force the MTA.
+ EnsureMTA mta(marshalFn);
+ }
+
+ mStream = nullptr;
+
+ if (FAILED(unmarshalResult) || !mUnmarshaledProxy) {
+ nsPrintfCString hrAsStr("0x%08lX", unmarshalResult);
+ CrashReporter::AnnotateCrashReport(
+ CrashReporter::Annotation::CoUnmarshalInterfaceResult, hrAsStr);
+ AnnotateInterfaceRegistration(aIID);
+ if (!mUnmarshaledProxy) {
+ CrashReporter::AnnotateCrashReport(kCrashReportKey,
+ "!mUnmarshaledProxy"_ns);
+ }
+
+#if defined(ACCESSIBILITY)
+ AnnotateClassRegistration(CLSID_AccessibleHandler);
+ CrashReporter::AnnotateCrashReport(
+ CrashReporter::Annotation::UnmarshalActCtx, strActCtx);
+ CrashReporter::AnnotateCrashReport(
+ CrashReporter::Annotation::UnmarshalActCtxManifestPath,
+ NS_ConvertUTF16toUTF8(manifestPath));
+ CrashReporter::AnnotateCrashReport(
+ CrashReporter::Annotation::A11yHandlerRegistered,
+ a11y::IsHandlerRegistered() ? "true"_ns : "false"_ns);
+
+ CrashReporter::AnnotateCrashReport(
+ CrashReporter::Annotation::ExpectedStreamLen, expectedStreamLen);
+
+ CrashReporter::AnnotateCrashReport(
+ CrashReporter::Annotation::ActualStreamLen, aInitBufSize);
+#endif // defined(ACCESSIBILITY)
+ }
+}
+
+ProxyStream::ProxyStream(ProxyStream&& aOther)
+ : mGlobalLockedBuf(nullptr),
+ mHGlobal(nullptr),
+ mBufSize(0),
+ mPreserveStream(false) {
+ *this = std::move(aOther);
+}
+
+ProxyStream& ProxyStream::operator=(ProxyStream&& aOther) {
+ if (mHGlobal && mGlobalLockedBuf) {
+ DebugOnly<BOOL> result = ::GlobalUnlock(mHGlobal);
+ MOZ_ASSERT(!result && ::GetLastError() == NO_ERROR);
+ }
+
+ mStream = std::move(aOther.mStream);
+
+ mGlobalLockedBuf = aOther.mGlobalLockedBuf;
+ aOther.mGlobalLockedBuf = nullptr;
+
+ // ::GlobalFree() was called implicitly when mStream was replaced.
+ mHGlobal = aOther.mHGlobal;
+ aOther.mHGlobal = nullptr;
+
+ mBufSize = aOther.mBufSize;
+ aOther.mBufSize = 0;
+
+ mUnmarshaledProxy = std::move(aOther.mUnmarshaledProxy);
+
+ mPreserveStream = aOther.mPreserveStream;
+ return *this;
+}
+
+ProxyStream::~ProxyStream() {
+ if (mHGlobal && mGlobalLockedBuf) {
+ DebugOnly<BOOL> result = ::GlobalUnlock(mHGlobal);
+ MOZ_ASSERT(!result && ::GetLastError() == NO_ERROR);
+ // ::GlobalFree() is called implicitly when mStream is released
+ }
+
+ // If this assert triggers then we will be leaking a marshaled proxy!
+ // Call GetPreservedStream to obtain a preservable stream and then save it
+ // until the proxy is no longer needed.
+ MOZ_ASSERT(!mPreserveStream);
+}
+
+const BYTE* ProxyStream::GetBuffer(int& aReturnedBufSize) const {
+ aReturnedBufSize = 0;
+ if (!mStream) {
+ return nullptr;
+ }
+ if (!mGlobalLockedBuf) {
+ return nullptr;
+ }
+ aReturnedBufSize = mBufSize;
+ return mGlobalLockedBuf;
+}
+
+PreservedStreamPtr ProxyStream::GetPreservedStream() {
+ MOZ_ASSERT(mStream);
+ MOZ_ASSERT(mHGlobal);
+
+ if (!mStream || !mPreserveStream) {
+ return nullptr;
+ }
+
+ // Clone the stream so that the result has a distinct seek pointer.
+ RefPtr<IStream> cloned;
+ HRESULT hr = mStream->Clone(getter_AddRefs(cloned));
+ if (FAILED(hr)) {
+ return nullptr;
+ }
+
+ // Ensure the stream is rewound. We do this because CoReleaseMarshalData needs
+ // the stream to be pointing to the beginning of the marshal data.
+ LARGE_INTEGER pos;
+ pos.QuadPart = 0LL;
+ hr = cloned->Seek(pos, STREAM_SEEK_SET, nullptr);
+ if (FAILED(hr)) {
+ return nullptr;
+ }
+
+ mPreserveStream = false;
+ return ToPreservedStreamPtr(std::move(cloned));
+}
+
+bool ProxyStream::GetInterface(void** aOutInterface) {
+ // We should not have a locked buffer on this side
+ MOZ_ASSERT(!mGlobalLockedBuf);
+ MOZ_ASSERT(aOutInterface);
+
+ if (!aOutInterface) {
+ return false;
+ }
+
+ *aOutInterface = mUnmarshaledProxy.release();
+ return true;
+}
+
+ProxyStream::ProxyStream(REFIID aIID, IUnknown* aObject, Environment* aEnv,
+ ProxyStreamFlags aFlags)
+ : mGlobalLockedBuf(nullptr),
+ mHGlobal(nullptr),
+ mBufSize(0),
+ mPreserveStream(aFlags & ProxyStreamFlags::ePreservable) {
+ if (!aObject) {
+ return;
+ }
+
+ RefPtr<IStream> stream;
+ HGLOBAL hglobal = NULL;
+ int streamSize = 0;
+ DWORD mshlFlags = mPreserveStream ? MSHLFLAGS_TABLESTRONG : MSHLFLAGS_NORMAL;
+
+ HRESULT createStreamResult = S_OK;
+ HRESULT marshalResult = S_OK;
+ HRESULT statResult = S_OK;
+ HRESULT getHGlobalResult = S_OK;
+
+#if defined(ACCESSIBILITY)
+ nsAutoString manifestPath;
+ auto marshalFn = [&aIID, aObject, mshlFlags, &stream, &streamSize, &hglobal,
+ &createStreamResult, &marshalResult, &statResult,
+ &getHGlobalResult, aEnv, &manifestPath]() -> void {
+#else
+ auto marshalFn = [&aIID, aObject, mshlFlags, &stream, &streamSize, &hglobal,
+ &createStreamResult, &marshalResult, &statResult,
+ &getHGlobalResult, aEnv]() -> void {
+#endif // defined(ACCESSIBILITY)
+ if (aEnv) {
+ bool pushOk = aEnv->Push();
+ MOZ_DIAGNOSTIC_ASSERT(pushOk);
+ if (!pushOk) {
+ return;
+ }
+ }
+
+ auto popEnv = MakeScopeExit([aEnv]() -> void {
+ if (!aEnv) {
+ return;
+ }
+
+#ifdef MOZ_DIAGNOSTIC_ASSERT_ENABLED
+ bool popOk =
+#endif
+ aEnv->Pop();
+ MOZ_DIAGNOSTIC_ASSERT(popOk);
+ });
+
+ createStreamResult =
+ ::CreateStreamOnHGlobal(nullptr, TRUE, getter_AddRefs(stream));
+ if (FAILED(createStreamResult)) {
+ return;
+ }
+
+#if defined(ACCESSIBILITY)
+ ActivationContext::GetCurrentManifestPath(manifestPath);
+#endif // defined(ACCESSIBILITY)
+
+ marshalResult = ::CoMarshalInterface(stream, aIID, aObject, MSHCTX_LOCAL,
+ nullptr, mshlFlags);
+ MOZ_ASSERT(marshalResult != E_INVALIDARG);
+ if (FAILED(marshalResult)) {
+ return;
+ }
+
+ STATSTG statstg;
+ statResult = stream->Stat(&statstg, STATFLAG_NONAME);
+ if (SUCCEEDED(statResult)) {
+ streamSize = static_cast<int>(statstg.cbSize.LowPart);
+ } else {
+ return;
+ }
+
+ getHGlobalResult = ::GetHGlobalFromStream(stream, &hglobal);
+ MOZ_ASSERT(SUCCEEDED(getHGlobalResult));
+ };
+
+ if (XRE_IsParentProcess()) {
+ // We'll marshal this stuff directly using the current thread, therefore its
+ // stub will reside in the same apartment as the current thread.
+ marshalFn();
+ } else {
+ // When marshaling in child processes, we want to force the MTA.
+ EnsureMTA mta(marshalFn);
+ }
+
+ if (FAILED(createStreamResult)) {
+ nsPrintfCString hrAsStr("0x%08lX", createStreamResult);
+ CrashReporter::AnnotateCrashReport(
+ CrashReporter::Annotation::CreateStreamOnHGlobalFailure, hrAsStr);
+ }
+
+ if (FAILED(marshalResult)) {
+ AnnotateInterfaceRegistration(aIID);
+ nsPrintfCString hrAsStr("0x%08lX", marshalResult);
+ CrashReporter::AnnotateCrashReport(
+ CrashReporter::Annotation::CoMarshalInterfaceFailure, hrAsStr);
+#if defined(ACCESSIBILITY)
+ CrashReporter::AnnotateCrashReport(
+ CrashReporter::Annotation::MarshalActCtxManifestPath,
+ NS_ConvertUTF16toUTF8(manifestPath));
+#endif // defined(ACCESSIBILITY)
+ }
+
+ if (FAILED(statResult)) {
+ nsPrintfCString hrAsStr("0x%08lX", statResult);
+ CrashReporter::AnnotateCrashReport(CrashReporter::Annotation::StatFailure,
+ hrAsStr);
+ }
+
+ if (FAILED(getHGlobalResult)) {
+ nsPrintfCString hrAsStr("0x%08lX", getHGlobalResult);
+ CrashReporter::AnnotateCrashReport(
+ CrashReporter::Annotation::GetHGlobalFromStreamFailure, hrAsStr);
+ }
+
+ mStream = std::move(stream);
+
+ if (streamSize) {
+ CrashReporter::AnnotateCrashReport(
+ CrashReporter::Annotation::ProxyStreamSizeFrom, "IStream::Stat"_ns);
+ mBufSize = streamSize;
+ }
+
+ if (!hglobal) {
+ return;
+ }
+
+ mGlobalLockedBuf = reinterpret_cast<BYTE*>(::GlobalLock(hglobal));
+ mHGlobal = hglobal;
+
+ // If we couldn't get the stream size directly from mStream, we may use
+ // the size of the memory block allocated by the HGLOBAL, though it might
+ // be larger than the actual stream size.
+ if (!streamSize) {
+ CrashReporter::AnnotateCrashReport(
+ CrashReporter::Annotation::ProxyStreamSizeFrom, "GlobalSize"_ns);
+ mBufSize = static_cast<int>(::GlobalSize(hglobal));
+ }
+
+ CrashReporter::AnnotateCrashReport(CrashReporter::Annotation::ProxyStreamSize,
+ mBufSize);
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/ProxyStream.h b/ipc/mscom/ProxyStream.h
new file mode 100644
index 0000000000..92b55745ca
--- /dev/null
+++ b/ipc/mscom/ProxyStream.h
@@ -0,0 +1,88 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_ProxyStream_h
+#define mozilla_mscom_ProxyStream_h
+
+#include "ipc/IPCMessageUtils.h"
+
+#include "mozilla/mscom/Ptr.h"
+#include "mozilla/RefPtr.h"
+#include "mozilla/TypedEnumBits.h"
+#include "mozilla/UniquePtr.h"
+
+namespace mozilla {
+namespace mscom {
+
+enum class ProxyStreamFlags : uint32_t {
+ eDefault = 0,
+ // When ePreservable is set on a ProxyStream, its caller *must* call
+ // GetPreservableStream() before the ProxyStream is destroyed.
+ ePreservable = 1
+};
+
+MOZ_MAKE_ENUM_CLASS_BITWISE_OPERATORS(ProxyStreamFlags);
+
+class ProxyStream final {
+ public:
+ class MOZ_RAII Environment {
+ public:
+ virtual ~Environment() = default;
+ virtual bool Push() = 0;
+ virtual bool Pop() = 0;
+ };
+
+ class MOZ_RAII DefaultEnvironment : public Environment {
+ public:
+ bool Push() override { return true; }
+ bool Pop() override { return true; }
+ };
+
+ ProxyStream();
+ ProxyStream(REFIID aIID, IUnknown* aObject, Environment* aEnv,
+ ProxyStreamFlags aFlags = ProxyStreamFlags::eDefault);
+ ProxyStream(REFIID aIID, const BYTE* aInitBuf, const int aInitBufSize,
+ Environment* aEnv);
+
+ ~ProxyStream();
+
+ // Not copyable because this would mess up the COM marshaling.
+ ProxyStream(const ProxyStream& aOther) = delete;
+ ProxyStream& operator=(const ProxyStream& aOther) = delete;
+
+ ProxyStream(ProxyStream&& aOther);
+ ProxyStream& operator=(ProxyStream&& aOther);
+
+ inline bool IsValid() const { return !(mUnmarshaledProxy && mStream); }
+
+ bool GetInterface(void** aOutInterface);
+ const BYTE* GetBuffer(int& aReturnedBufSize) const;
+
+ PreservedStreamPtr GetPreservedStream();
+
+ bool operator==(const ProxyStream& aOther) const { return this == &aOther; }
+
+ private:
+ RefPtr<IStream> mStream;
+ BYTE* mGlobalLockedBuf;
+ HGLOBAL mHGlobal;
+ int mBufSize;
+ ProxyUniquePtr<IUnknown> mUnmarshaledProxy;
+ bool mPreserveStream;
+};
+
+namespace detail {
+
+template <typename Interface>
+struct EnvironmentSelector {
+ typedef ProxyStream::DefaultEnvironment Type;
+};
+
+} // namespace detail
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_ProxyStream_h
diff --git a/ipc/mscom/Ptr.h b/ipc/mscom/Ptr.h
new file mode 100644
index 0000000000..fd4bd9c91f
--- /dev/null
+++ b/ipc/mscom/Ptr.h
@@ -0,0 +1,306 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_Ptr_h
+#define mozilla_mscom_Ptr_h
+
+#include "mozilla/Assertions.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/mscom/EnsureMTA.h"
+#include "mozilla/SchedulerGroup.h"
+#include "mozilla/UniquePtr.h"
+#include "nsError.h"
+#include "nsThreadUtils.h"
+#include "nsXULAppAPI.h"
+
+#include <objidl.h>
+
+/**
+ * The glue code in mozilla::mscom often needs to pass around interface pointers
+ * belonging to a different apartment from the current one. We must not touch
+ * the reference counts of those objects on the wrong apartment. By using these
+ * UniquePtr specializations, we may ensure that the reference counts are always
+ * handled correctly.
+ */
+
+namespace mozilla {
+namespace mscom {
+
+namespace detail {
+
+template <typename T>
+struct MainThreadRelease {
+ void operator()(T* aPtr) {
+ if (!aPtr) {
+ return;
+ }
+ if (NS_IsMainThread()) {
+ aPtr->Release();
+ return;
+ }
+ DebugOnly<nsresult> rv = SchedulerGroup::Dispatch(
+ TaskCategory::Other,
+ NewNonOwningRunnableMethod("mscom::MainThreadRelease", aPtr,
+ &T::Release));
+ MOZ_ASSERT(NS_SUCCEEDED(rv));
+ }
+};
+
+template <typename T>
+struct MTADelete {
+ void operator()(T* aPtr) {
+ if (!aPtr) {
+ return;
+ }
+
+ EnsureMTA::AsyncOperation([aPtr]() -> void { delete aPtr; });
+ }
+};
+
+template <typename T>
+struct MTARelease {
+ void operator()(T* aPtr) {
+ if (!aPtr) {
+ return;
+ }
+
+ // Static analysis doesn't recognize that, even though aPtr escapes the
+ // current scope, we are in effect moving our strong ref into the lambda.
+ void* ptr = aPtr;
+ EnsureMTA::AsyncOperation(
+ [ptr]() -> void { reinterpret_cast<T*>(ptr)->Release(); });
+ }
+};
+
+template <typename T>
+struct MTAReleaseInChildProcess {
+ void operator()(T* aPtr) {
+ if (!aPtr) {
+ return;
+ }
+
+ if (XRE_IsParentProcess()) {
+ MOZ_ASSERT(NS_IsMainThread());
+ aPtr->Release();
+ return;
+ }
+
+ // Static analysis doesn't recognize that, even though aPtr escapes the
+ // current scope, we are in effect moving our strong ref into the lambda.
+ void* ptr = aPtr;
+ EnsureMTA::AsyncOperation(
+ [ptr]() -> void { reinterpret_cast<T*>(ptr)->Release(); });
+ }
+};
+
+struct InterceptorTargetDeleter {
+ void operator()(IUnknown* aPtr) {
+ // We intentionally do not touch the refcounts of interceptor targets!
+ }
+};
+
+struct PreservedStreamDeleter {
+ void operator()(IStream* aPtr) {
+ if (!aPtr) {
+ return;
+ }
+
+ // Static analysis doesn't recognize that, even though aPtr escapes the
+ // current scope, we are in effect moving our strong ref into the lambda.
+ void* ptr = aPtr;
+ auto cleanup = [ptr]() -> void {
+ DebugOnly<HRESULT> hr =
+ ::CoReleaseMarshalData(reinterpret_cast<LPSTREAM>(ptr));
+ MOZ_ASSERT(SUCCEEDED(hr));
+ reinterpret_cast<LPSTREAM>(ptr)->Release();
+ };
+
+ if (XRE_IsParentProcess()) {
+ MOZ_ASSERT(NS_IsMainThread());
+ cleanup();
+ return;
+ }
+
+ EnsureMTA::AsyncOperation(cleanup);
+ }
+};
+
+} // namespace detail
+
+template <typename T>
+using STAUniquePtr = mozilla::UniquePtr<T, detail::MainThreadRelease<T>>;
+
+template <typename T>
+using MTAUniquePtr = mozilla::UniquePtr<T, detail::MTARelease<T>>;
+
+template <typename T>
+using MTADeletePtr = mozilla::UniquePtr<T, detail::MTADelete<T>>;
+
+template <typename T>
+using ProxyUniquePtr =
+ mozilla::UniquePtr<T, detail::MTAReleaseInChildProcess<T>>;
+
+template <typename T>
+using InterceptorTargetPtr =
+ mozilla::UniquePtr<T, detail::InterceptorTargetDeleter>;
+
+using PreservedStreamPtr =
+ mozilla::UniquePtr<IStream, detail::PreservedStreamDeleter>;
+
+namespace detail {
+
+// We don't have direct access to UniquePtr's storage, so we use mPtrStorage
+// to receive the pointer and then set the target inside the destructor.
+template <typename T, typename Deleter>
+class UniquePtrGetterAddRefs {
+ public:
+ explicit UniquePtrGetterAddRefs(UniquePtr<T, Deleter>& aSmartPtr)
+ : mTargetSmartPtr(aSmartPtr), mPtrStorage(nullptr) {}
+
+ ~UniquePtrGetterAddRefs() { mTargetSmartPtr.reset(mPtrStorage); }
+
+ operator void**() { return reinterpret_cast<void**>(&mPtrStorage); }
+
+ operator T**() { return &mPtrStorage; }
+
+ T*& operator*() { return mPtrStorage; }
+
+ private:
+ UniquePtr<T, Deleter>& mTargetSmartPtr;
+ T* mPtrStorage;
+};
+
+} // namespace detail
+
+template <typename T>
+inline STAUniquePtr<T> ToSTAUniquePtr(RefPtr<T>&& aRefPtr) {
+ return STAUniquePtr<T>(aRefPtr.forget().take());
+}
+
+template <typename T>
+inline STAUniquePtr<T> ToSTAUniquePtr(const RefPtr<T>& aRefPtr) {
+ MOZ_ASSERT(NS_IsMainThread());
+ return STAUniquePtr<T>(do_AddRef(aRefPtr).take());
+}
+
+template <typename T>
+inline STAUniquePtr<T> ToSTAUniquePtr(T* aRawPtr) {
+ MOZ_ASSERT(NS_IsMainThread());
+ if (aRawPtr) {
+ aRawPtr->AddRef();
+ }
+ return STAUniquePtr<T>(aRawPtr);
+}
+
+template <typename T, typename U>
+inline STAUniquePtr<T> ToSTAUniquePtr(const InterceptorTargetPtr<U>& aTarget) {
+ MOZ_ASSERT(NS_IsMainThread());
+ RefPtr<T> newRef(static_cast<T*>(aTarget.get()));
+ return ToSTAUniquePtr(std::move(newRef));
+}
+
+template <typename T>
+inline MTAUniquePtr<T> ToMTAUniquePtr(RefPtr<T>&& aRefPtr) {
+ return MTAUniquePtr<T>(aRefPtr.forget().take());
+}
+
+template <typename T>
+inline MTAUniquePtr<T> ToMTAUniquePtr(const RefPtr<T>& aRefPtr) {
+ MOZ_ASSERT(IsCurrentThreadMTA());
+ return MTAUniquePtr<T>(do_AddRef(aRefPtr).take());
+}
+
+template <typename T>
+inline MTAUniquePtr<T> ToMTAUniquePtr(T* aRawPtr) {
+ MOZ_ASSERT(IsCurrentThreadMTA());
+ if (aRawPtr) {
+ aRawPtr->AddRef();
+ }
+ return MTAUniquePtr<T>(aRawPtr);
+}
+
+template <typename T>
+inline ProxyUniquePtr<T> ToProxyUniquePtr(RefPtr<T>&& aRefPtr) {
+ return ProxyUniquePtr<T>(aRefPtr.forget().take());
+}
+
+template <typename T>
+inline ProxyUniquePtr<T> ToProxyUniquePtr(const RefPtr<T>& aRefPtr) {
+ MOZ_ASSERT(IsProxy(aRefPtr));
+ MOZ_ASSERT((XRE_IsParentProcess() && NS_IsMainThread()) ||
+ (XRE_IsContentProcess() && IsCurrentThreadMTA()));
+
+ return ProxyUniquePtr<T>(do_AddRef(aRefPtr).take());
+}
+
+template <typename T>
+inline ProxyUniquePtr<T> ToProxyUniquePtr(T* aRawPtr) {
+ MOZ_ASSERT(IsProxy(aRawPtr));
+ MOZ_ASSERT((XRE_IsParentProcess() && NS_IsMainThread()) ||
+ (XRE_IsContentProcess() && IsCurrentThreadMTA()));
+
+ if (aRawPtr) {
+ aRawPtr->AddRef();
+ }
+ return ProxyUniquePtr<T>(aRawPtr);
+}
+
+template <typename T, typename Deleter>
+inline InterceptorTargetPtr<T> ToInterceptorTargetPtr(
+ const UniquePtr<T, Deleter>& aTargetPtr) {
+ return InterceptorTargetPtr<T>(aTargetPtr.get());
+}
+
+inline PreservedStreamPtr ToPreservedStreamPtr(RefPtr<IStream>&& aStream) {
+ return PreservedStreamPtr(aStream.forget().take());
+}
+
+inline PreservedStreamPtr ToPreservedStreamPtr(
+ already_AddRefed<IStream>& aStream) {
+ return PreservedStreamPtr(aStream.take());
+}
+
+template <typename T, typename Deleter>
+inline detail::UniquePtrGetterAddRefs<T, Deleter> getter_AddRefs(
+ UniquePtr<T, Deleter>& aSmartPtr) {
+ return detail::UniquePtrGetterAddRefs<T, Deleter>(aSmartPtr);
+}
+
+} // namespace mscom
+} // namespace mozilla
+
+// This block makes it possible for these smart pointers to be correctly
+// applied in NewRunnableMethod and friends
+namespace detail {
+
+template <typename T>
+struct SmartPointerStorageClass<mozilla::mscom::STAUniquePtr<T>> {
+ typedef StoreCopyPassByRRef<mozilla::mscom::STAUniquePtr<T>> Type;
+};
+
+template <typename T>
+struct SmartPointerStorageClass<mozilla::mscom::MTAUniquePtr<T>> {
+ typedef StoreCopyPassByRRef<mozilla::mscom::MTAUniquePtr<T>> Type;
+};
+
+template <typename T>
+struct SmartPointerStorageClass<mozilla::mscom::ProxyUniquePtr<T>> {
+ typedef StoreCopyPassByRRef<mozilla::mscom::ProxyUniquePtr<T>> Type;
+};
+
+template <typename T>
+struct SmartPointerStorageClass<mozilla::mscom::InterceptorTargetPtr<T>> {
+ typedef StoreCopyPassByRRef<mozilla::mscom::InterceptorTargetPtr<T>> Type;
+};
+
+template <>
+struct SmartPointerStorageClass<mozilla::mscom::PreservedStreamPtr> {
+ typedef StoreCopyPassByRRef<mozilla::mscom::PreservedStreamPtr> Type;
+};
+
+} // namespace detail
+
+#endif // mozilla_mscom_Ptr_h
diff --git a/ipc/mscom/Registration.cpp b/ipc/mscom/Registration.cpp
new file mode 100644
index 0000000000..e984ad0de4
--- /dev/null
+++ b/ipc/mscom/Registration.cpp
@@ -0,0 +1,534 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+// COM registration data structures are built with C code, so we need to
+// simulate that in our C++ code by defining CINTERFACE before including
+// anything else that could possibly pull in Windows header files.
+#define CINTERFACE
+
+#include "mozilla/mscom/Registration.h"
+
+#include <utility>
+
+#include "mozilla/ArrayUtils.h"
+#include "mozilla/Assertions.h"
+#include "mozilla/RefPtr.h"
+#include "mozilla/StaticPtr.h"
+#include "mozilla/Vector.h"
+#include "mozilla/mscom/ActivationContext.h"
+#include "mozilla/mscom/Utils.h"
+#include "nsWindowsHelpers.h"
+
+#if defined(MOZILLA_INTERNAL_API)
+# include "mozilla/ClearOnShutdown.h"
+# include "mozilla/mscom/EnsureMTA.h"
+HRESULT RegisterPassthruProxy();
+#else
+# include <stdlib.h>
+#endif // defined(MOZILLA_INTERNAL_API)
+
+#include <oaidl.h>
+#include <objidl.h>
+#include <rpcproxy.h>
+#include <shlwapi.h>
+
+#include <algorithm>
+
+/* This code MUST NOT use any non-inlined internal Mozilla APIs, as it will be
+ compiled into DLLs that COM may load into non-Mozilla processes! */
+
+extern "C" {
+
+// This function is defined in generated code for proxy DLLs but is not declared
+// in rpcproxy.h, so we need this declaration.
+void RPC_ENTRY GetProxyDllInfo(const ProxyFileInfo*** aInfo, const CLSID** aId);
+}
+
+namespace mozilla {
+namespace mscom {
+
+static bool GetContainingLibPath(wchar_t* aBuffer, size_t aBufferLen) {
+ HMODULE thisModule = reinterpret_cast<HMODULE>(GetContainingModuleHandle());
+ if (!thisModule) {
+ return false;
+ }
+
+ DWORD fileNameResult = GetModuleFileName(thisModule, aBuffer, aBufferLen);
+ if (!fileNameResult || (fileNameResult == aBufferLen &&
+ ::GetLastError() == ERROR_INSUFFICIENT_BUFFER)) {
+ return false;
+ }
+
+ return true;
+}
+
+static bool BuildLibPath(RegistrationFlags aFlags, wchar_t* aBuffer,
+ size_t aBufferLen, const wchar_t* aLeafName) {
+ if (aFlags == RegistrationFlags::eUseBinDirectory) {
+ if (!GetContainingLibPath(aBuffer, aBufferLen)) {
+ return false;
+ }
+
+ if (!PathRemoveFileSpec(aBuffer)) {
+ return false;
+ }
+ } else if (aFlags == RegistrationFlags::eUseSystemDirectory) {
+ UINT result = GetSystemDirectoryW(aBuffer, static_cast<UINT>(aBufferLen));
+ if (!result || result > aBufferLen) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+
+ if (!PathAppend(aBuffer, aLeafName)) {
+ return false;
+ }
+ return true;
+}
+
+static bool RegisterPSClsids(const ProxyFileInfo** aProxyInfo,
+ const CLSID* aProxyClsid) {
+ while (*aProxyInfo) {
+ const ProxyFileInfo& curInfo = **aProxyInfo;
+ for (unsigned short idx = 0, size = curInfo.TableSize; idx < size; ++idx) {
+ HRESULT hr = CoRegisterPSClsid(*(curInfo.pStubVtblList[idx]->header.piid),
+ *aProxyClsid);
+ if (FAILED(hr)) {
+ return false;
+ }
+ }
+ ++aProxyInfo;
+ }
+
+ return true;
+}
+
+#if !defined(MOZILLA_INTERNAL_API)
+using GetProxyDllInfoFnT = decltype(&GetProxyDllInfo);
+
+static GetProxyDllInfoFnT ResolveGetProxyDllInfo() {
+ HMODULE thisModule = reinterpret_cast<HMODULE>(GetContainingModuleHandle());
+ if (!thisModule) {
+ return nullptr;
+ }
+
+ return reinterpret_cast<GetProxyDllInfoFnT>(
+ GetProcAddress(thisModule, "GetProxyDllInfo"));
+}
+#endif // !defined(MOZILLA_INTERNAL_API)
+
+UniquePtr<RegisteredProxy> RegisterProxy() {
+#if !defined(MOZILLA_INTERNAL_API)
+ GetProxyDllInfoFnT GetProxyDllInfoFn = ResolveGetProxyDllInfo();
+ MOZ_ASSERT(!!GetProxyDllInfoFn);
+ if (!GetProxyDllInfoFn) {
+ return nullptr;
+ }
+#endif // !defined(MOZILLA_INTERNAL_API)
+
+ const ProxyFileInfo** proxyInfo = nullptr;
+ const CLSID* proxyClsid = nullptr;
+#if defined(MOZILLA_INTERNAL_API)
+ GetProxyDllInfo(&proxyInfo, &proxyClsid);
+#else
+ GetProxyDllInfoFn(&proxyInfo, &proxyClsid);
+#endif // defined(MOZILLA_INTERNAL_API)
+ if (!proxyInfo || !proxyClsid) {
+ return nullptr;
+ }
+
+ IUnknown* classObject = nullptr;
+ HRESULT hr =
+ DllGetClassObject(*proxyClsid, IID_IUnknown, (void**)&classObject);
+ if (FAILED(hr)) {
+ return nullptr;
+ }
+
+ DWORD regCookie;
+ hr = CoRegisterClassObject(*proxyClsid, classObject, CLSCTX_INPROC_SERVER,
+ REGCLS_MULTIPLEUSE, &regCookie);
+ if (FAILED(hr)) {
+ classObject->lpVtbl->Release(classObject);
+ return nullptr;
+ }
+
+ wchar_t modulePathBuf[MAX_PATH + 1] = {0};
+ if (!GetContainingLibPath(modulePathBuf, ArrayLength(modulePathBuf))) {
+ CoRevokeClassObject(regCookie);
+ classObject->lpVtbl->Release(classObject);
+ return nullptr;
+ }
+
+ ITypeLib* typeLib = nullptr;
+ hr = LoadTypeLibEx(modulePathBuf, REGKIND_NONE, &typeLib);
+ MOZ_ASSERT(SUCCEEDED(hr));
+ if (FAILED(hr)) {
+ CoRevokeClassObject(regCookie);
+ classObject->lpVtbl->Release(classObject);
+ return nullptr;
+ }
+
+#if defined(MOZILLA_INTERNAL_API)
+ hr = RegisterPassthruProxy();
+ MOZ_ASSERT(SUCCEEDED(hr));
+ if (FAILED(hr)) {
+ CoRevokeClassObject(regCookie);
+ classObject->lpVtbl->Release(classObject);
+ return nullptr;
+ }
+#endif // defined(MOZILLA_INTERNAL_API)
+
+ // RegisteredProxy takes ownership of classObject and typeLib references
+ auto result(MakeUnique<RegisteredProxy>(classObject, regCookie, typeLib));
+
+ if (!RegisterPSClsids(proxyInfo, proxyClsid)) {
+ return nullptr;
+ }
+
+ return result;
+}
+
+UniquePtr<RegisteredProxy> RegisterProxy(const wchar_t* aLeafName,
+ RegistrationFlags aFlags) {
+ wchar_t modulePathBuf[MAX_PATH + 1] = {0};
+ if (!BuildLibPath(aFlags, modulePathBuf, ArrayLength(modulePathBuf),
+ aLeafName)) {
+ return nullptr;
+ }
+
+ nsModuleHandle proxyDll(LoadLibrary(modulePathBuf));
+ if (!proxyDll.get()) {
+ return nullptr;
+ }
+
+ // Instantiate an activation context so that CoGetClassObject will use any
+ // COM metadata embedded in proxyDll's manifest to resolve CLSIDs.
+ ActivationContextRegion actCtxRgn(proxyDll.get());
+
+ auto GetProxyDllInfoFn = reinterpret_cast<decltype(&GetProxyDllInfo)>(
+ GetProcAddress(proxyDll, "GetProxyDllInfo"));
+ if (!GetProxyDllInfoFn) {
+ return nullptr;
+ }
+
+ const ProxyFileInfo** proxyInfo = nullptr;
+ const CLSID* proxyClsid = nullptr;
+ GetProxyDllInfoFn(&proxyInfo, &proxyClsid);
+ if (!proxyInfo || !proxyClsid) {
+ return nullptr;
+ }
+
+ // We call CoGetClassObject instead of DllGetClassObject because it forces
+ // the COM runtime to manage the lifetime of the DLL.
+ IUnknown* classObject = nullptr;
+ HRESULT hr = CoGetClassObject(*proxyClsid, CLSCTX_INPROC_SERVER, nullptr,
+ IID_IUnknown, (void**)&classObject);
+ if (FAILED(hr)) {
+ return nullptr;
+ }
+
+ DWORD regCookie;
+ hr = CoRegisterClassObject(*proxyClsid, classObject, CLSCTX_INPROC_SERVER,
+ REGCLS_MULTIPLEUSE, &regCookie);
+ if (FAILED(hr)) {
+ classObject->lpVtbl->Release(classObject);
+ return nullptr;
+ }
+
+ ITypeLib* typeLib = nullptr;
+ hr = LoadTypeLibEx(modulePathBuf, REGKIND_NONE, &typeLib);
+ MOZ_ASSERT(SUCCEEDED(hr));
+ if (FAILED(hr)) {
+ CoRevokeClassObject(regCookie);
+ classObject->lpVtbl->Release(classObject);
+ return nullptr;
+ }
+
+ // RegisteredProxy takes ownership of proxyDll, classObject, and typeLib
+ // references
+ auto result(MakeUnique<RegisteredProxy>(
+ reinterpret_cast<uintptr_t>(proxyDll.disown()), classObject, regCookie,
+ typeLib));
+
+ if (!RegisterPSClsids(proxyInfo, proxyClsid)) {
+ return nullptr;
+ }
+
+ return result;
+}
+
+UniquePtr<RegisteredProxy> RegisterTypelib(const wchar_t* aLeafName,
+ RegistrationFlags aFlags) {
+ wchar_t modulePathBuf[MAX_PATH + 1] = {0};
+ if (!BuildLibPath(aFlags, modulePathBuf, ArrayLength(modulePathBuf),
+ aLeafName)) {
+ return nullptr;
+ }
+
+ ITypeLib* typeLib = nullptr;
+ HRESULT hr = LoadTypeLibEx(modulePathBuf, REGKIND_NONE, &typeLib);
+ if (FAILED(hr)) {
+ return nullptr;
+ }
+
+ // RegisteredProxy takes ownership of typeLib reference
+ auto result(MakeUnique<RegisteredProxy>(typeLib));
+ return result;
+}
+
+RegisteredProxy::RegisteredProxy(uintptr_t aModule, IUnknown* aClassObject,
+ uint32_t aRegCookie, ITypeLib* aTypeLib)
+ : mModule(aModule),
+ mClassObject(aClassObject),
+ mRegCookie(aRegCookie),
+ mTypeLib(aTypeLib)
+#if defined(MOZILLA_INTERNAL_API)
+ ,
+ mIsRegisteredInMTA(IsCurrentThreadMTA())
+#endif // defined(MOZILLA_INTERNAL_API)
+{
+ MOZ_ASSERT(aClassObject);
+ MOZ_ASSERT(aTypeLib);
+ AddToRegistry(this);
+}
+
+RegisteredProxy::RegisteredProxy(IUnknown* aClassObject, uint32_t aRegCookie,
+ ITypeLib* aTypeLib)
+ : mModule(0),
+ mClassObject(aClassObject),
+ mRegCookie(aRegCookie),
+ mTypeLib(aTypeLib)
+#if defined(MOZILLA_INTERNAL_API)
+ ,
+ mIsRegisteredInMTA(IsCurrentThreadMTA())
+#endif // defined(MOZILLA_INTERNAL_API)
+{
+ MOZ_ASSERT(aClassObject);
+ MOZ_ASSERT(aTypeLib);
+ AddToRegistry(this);
+}
+
+// If we're initializing from a typelib, it doesn't matter which apartment we
+// run in, so mIsRegisteredInMTA may always be set to false in this case.
+RegisteredProxy::RegisteredProxy(ITypeLib* aTypeLib)
+ : mModule(0),
+ mClassObject(nullptr),
+ mRegCookie(0),
+ mTypeLib(aTypeLib)
+#if defined(MOZILLA_INTERNAL_API)
+ ,
+ mIsRegisteredInMTA(false)
+#endif // defined(MOZILLA_INTERNAL_API)
+{
+ MOZ_ASSERT(aTypeLib);
+ AddToRegistry(this);
+}
+
+void RegisteredProxy::Clear() {
+ if (mTypeLib) {
+ mTypeLib->lpVtbl->Release(mTypeLib);
+ mTypeLib = nullptr;
+ }
+ if (mClassObject) {
+ // NB: mClassObject and mRegCookie must be freed from inside the apartment
+ // which they were created in.
+ auto cleanupFn = [&]() -> void {
+ ::CoRevokeClassObject(mRegCookie);
+ mRegCookie = 0;
+ mClassObject->lpVtbl->Release(mClassObject);
+ mClassObject = nullptr;
+ };
+#if defined(MOZILLA_INTERNAL_API)
+ // This code only supports MTA when built internally
+ if (mIsRegisteredInMTA) {
+ EnsureMTA mta(cleanupFn);
+ } else {
+ cleanupFn();
+ }
+#else
+ cleanupFn();
+#endif // defined(MOZILLA_INTERNAL_API)
+ }
+ if (mModule) {
+ ::FreeLibrary(reinterpret_cast<HMODULE>(mModule));
+ mModule = 0;
+ }
+}
+
+RegisteredProxy::~RegisteredProxy() {
+ DeleteFromRegistry(this);
+ Clear();
+}
+
+RegisteredProxy::RegisteredProxy(RegisteredProxy&& aOther)
+ : mModule(0),
+ mClassObject(nullptr),
+ mRegCookie(0),
+ mTypeLib(nullptr)
+#if defined(MOZILLA_INTERNAL_API)
+ ,
+ mIsRegisteredInMTA(false)
+#endif // defined(MOZILLA_INTERNAL_API)
+{
+ *this = std::forward<RegisteredProxy>(aOther);
+ AddToRegistry(this);
+}
+
+RegisteredProxy& RegisteredProxy::operator=(RegisteredProxy&& aOther) {
+ Clear();
+
+ mModule = aOther.mModule;
+ aOther.mModule = 0;
+ mClassObject = aOther.mClassObject;
+ aOther.mClassObject = nullptr;
+ mRegCookie = aOther.mRegCookie;
+ aOther.mRegCookie = 0;
+ mTypeLib = aOther.mTypeLib;
+ aOther.mTypeLib = nullptr;
+
+#if defined(MOZILLA_INTERNAL_API)
+ mIsRegisteredInMTA = aOther.mIsRegisteredInMTA;
+#endif // defined(MOZILLA_INTERNAL_API)
+
+ return *this;
+}
+
+HRESULT
+RegisteredProxy::GetTypeInfoForGuid(REFGUID aGuid,
+ ITypeInfo** aOutTypeInfo) const {
+ if (!aOutTypeInfo) {
+ return E_INVALIDARG;
+ }
+ if (!mTypeLib) {
+ return E_UNEXPECTED;
+ }
+ return mTypeLib->lpVtbl->GetTypeInfoOfGuid(mTypeLib, aGuid, aOutTypeInfo);
+}
+
+static StaticAutoPtr<Vector<RegisteredProxy*>> sRegistry;
+
+namespace UseGetMutexForAccess {
+
+// This must not be accessed directly; use GetMutex() instead
+static CRITICAL_SECTION sMutex;
+
+} // namespace UseGetMutexForAccess
+
+static CRITICAL_SECTION* GetMutex() {
+ static CRITICAL_SECTION& mutex = []() -> CRITICAL_SECTION& {
+#if defined(RELEASE_OR_BETA)
+ DWORD flags = CRITICAL_SECTION_NO_DEBUG_INFO;
+#else
+ DWORD flags = 0;
+#endif
+ InitializeCriticalSectionEx(&UseGetMutexForAccess::sMutex, 4000, flags);
+#if !defined(MOZILLA_INTERNAL_API)
+ atexit([]() { DeleteCriticalSection(&UseGetMutexForAccess::sMutex); });
+#endif
+ return UseGetMutexForAccess::sMutex;
+ }();
+ return &mutex;
+}
+
+/* static */
+bool RegisteredProxy::Find(REFIID aIid, ITypeInfo** aTypeInfo) {
+ AutoCriticalSection lock(GetMutex());
+
+ if (!sRegistry) {
+ return false;
+ }
+
+ for (auto&& proxy : *sRegistry) {
+ if (SUCCEEDED(proxy->GetTypeInfoForGuid(aIid, aTypeInfo))) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+/* static */
+void RegisteredProxy::AddToRegistry(RegisteredProxy* aProxy) {
+ MOZ_ASSERT(aProxy);
+
+ AutoCriticalSection lock(GetMutex());
+
+ if (!sRegistry) {
+ sRegistry = new Vector<RegisteredProxy*>();
+
+#if !defined(MOZILLA_INTERNAL_API)
+ // sRegistry allocation is fallible outside of Mozilla processes
+ if (!sRegistry) {
+ return;
+ }
+#endif
+ }
+
+ MOZ_ALWAYS_TRUE(sRegistry->emplaceBack(aProxy));
+}
+
+/* static */
+void RegisteredProxy::DeleteFromRegistry(RegisteredProxy* aProxy) {
+ MOZ_ASSERT(aProxy);
+
+ AutoCriticalSection lock(GetMutex());
+
+ MOZ_ASSERT(sRegistry && !sRegistry->empty());
+
+ if (!sRegistry) {
+ return;
+ }
+
+ sRegistry->erase(std::remove(sRegistry->begin(), sRegistry->end(), aProxy),
+ sRegistry->end());
+
+ if (sRegistry->empty()) {
+ sRegistry = nullptr;
+ }
+}
+
+#if defined(MOZILLA_INTERNAL_API)
+
+static StaticAutoPtr<Vector<std::pair<const ArrayData*, size_t>>> sArrayData;
+
+void RegisterArrayData(const ArrayData* aArrayData, size_t aLength) {
+ AutoCriticalSection lock(GetMutex());
+
+ if (!sArrayData) {
+ sArrayData = new Vector<std::pair<const ArrayData*, size_t>>();
+ ClearOnShutdown(&sArrayData, ShutdownPhase::XPCOMShutdownThreads);
+ }
+
+ MOZ_ALWAYS_TRUE(sArrayData->emplaceBack(std::make_pair(aArrayData, aLength)));
+}
+
+const ArrayData* FindArrayData(REFIID aIid, ULONG aMethodIndex) {
+ AutoCriticalSection lock(GetMutex());
+
+ if (!sArrayData) {
+ return nullptr;
+ }
+
+ for (auto&& data : *sArrayData) {
+ for (size_t innerIdx = 0, innerLen = data.second; innerIdx < innerLen;
+ ++innerIdx) {
+ const ArrayData* array = data.first;
+ if (aMethodIndex == array[innerIdx].mMethodIndex &&
+ IsInterfaceEqualToOrInheritedFrom(aIid, array[innerIdx].mIid,
+ aMethodIndex)) {
+ return &array[innerIdx];
+ }
+ }
+ }
+
+ return nullptr;
+}
+
+#endif // defined(MOZILLA_INTERNAL_API)
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/Registration.h b/ipc/mscom/Registration.h
new file mode 100644
index 0000000000..3dfd5458cd
--- /dev/null
+++ b/ipc/mscom/Registration.h
@@ -0,0 +1,142 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_Registration_h
+#define mozilla_mscom_Registration_h
+
+#include "mozilla/RefPtr.h"
+#include "mozilla/UniquePtr.h"
+
+#include <objbase.h>
+
+struct ITypeInfo;
+struct ITypeLib;
+
+namespace mozilla {
+namespace mscom {
+
+/**
+ * Assumptions:
+ * (1) The DLL exports GetProxyDllInfo. This is not exported by default; it must
+ * be specified in the EXPORTS section of the DLL's module definition file.
+ */
+class RegisteredProxy {
+ public:
+ RegisteredProxy(uintptr_t aModule, IUnknown* aClassObject,
+ uint32_t aRegCookie, ITypeLib* aTypeLib);
+ RegisteredProxy(IUnknown* aClassObject, uint32_t aRegCookie,
+ ITypeLib* aTypeLib);
+ explicit RegisteredProxy(ITypeLib* aTypeLib);
+ RegisteredProxy(RegisteredProxy&& aOther);
+ RegisteredProxy& operator=(RegisteredProxy&& aOther);
+
+ ~RegisteredProxy();
+
+ HRESULT GetTypeInfoForGuid(REFGUID aGuid, ITypeInfo** aOutTypeInfo) const;
+
+ static bool Find(REFIID aIid, ITypeInfo** aOutTypeInfo);
+
+ private:
+ RegisteredProxy() = delete;
+ RegisteredProxy(RegisteredProxy&) = delete;
+ RegisteredProxy& operator=(RegisteredProxy&) = delete;
+
+ void Clear();
+
+ static void AddToRegistry(RegisteredProxy* aProxy);
+ static void DeleteFromRegistry(RegisteredProxy* aProxy);
+
+ private:
+ // Not using Windows types here: We shouldn't #include windows.h
+ // since it might pull in COM code which we want to do very carefully in
+ // Registration.cpp.
+ uintptr_t mModule;
+ IUnknown* mClassObject;
+ uint32_t mRegCookie;
+ ITypeLib* mTypeLib;
+#if defined(MOZILLA_INTERNAL_API)
+ bool mIsRegisteredInMTA;
+#endif // defined(MOZILLA_INTERNAL_API)
+};
+
+enum class RegistrationFlags { eUseBinDirectory, eUseSystemDirectory };
+
+// For our own DLL that we are currently executing in (ie, xul).
+// Assumes corresponding TLB is embedded in resources.
+UniquePtr<RegisteredProxy> RegisterProxy();
+
+// For DLL files. Assumes corresponding TLB is embedded in resources.
+UniquePtr<RegisteredProxy> RegisterProxy(
+ const wchar_t* aLeafName,
+ RegistrationFlags aFlags = RegistrationFlags::eUseBinDirectory);
+// For standalone TLB files.
+UniquePtr<RegisteredProxy> RegisterTypelib(
+ const wchar_t* aLeafName,
+ RegistrationFlags aFlags = RegistrationFlags::eUseBinDirectory);
+
+#if defined(MOZILLA_INTERNAL_API)
+
+/**
+ * The COM interceptor uses type library information to build its interface
+ * proxies. Unfortunately type libraries do not encode size_is and length_is
+ * annotations that have been specified in IDL. This structure allows us to
+ * explicitly declare such relationships so that the COM interceptor may
+ * be made aware of them.
+ */
+struct ArrayData {
+ enum class Flag {
+ eNone = 0,
+ eAllocatedByServer = 1 // This implies an extra level of indirection
+ };
+
+ ArrayData(REFIID aIid, ULONG aMethodIndex, ULONG aArrayParamIndex,
+ VARTYPE aArrayParamType, REFIID aArrayParamIid,
+ ULONG aLengthParamIndex, Flag aFlag = Flag::eNone)
+ : mIid(aIid),
+ mMethodIndex(aMethodIndex),
+ mArrayParamIndex(aArrayParamIndex),
+ mArrayParamType(aArrayParamType),
+ mArrayParamIid(aArrayParamIid),
+ mLengthParamIndex(aLengthParamIndex),
+ mFlag(aFlag) {}
+
+ ArrayData(const ArrayData& aOther) { *this = aOther; }
+
+ ArrayData& operator=(const ArrayData& aOther) {
+ mIid = aOther.mIid;
+ mMethodIndex = aOther.mMethodIndex;
+ mArrayParamIndex = aOther.mArrayParamIndex;
+ mArrayParamType = aOther.mArrayParamType;
+ mArrayParamIid = aOther.mArrayParamIid;
+ mLengthParamIndex = aOther.mLengthParamIndex;
+ mFlag = aOther.mFlag;
+ return *this;
+ }
+
+ IID mIid;
+ ULONG mMethodIndex;
+ ULONG mArrayParamIndex;
+ VARTYPE mArrayParamType;
+ IID mArrayParamIid;
+ ULONG mLengthParamIndex;
+ Flag mFlag;
+};
+
+void RegisterArrayData(const ArrayData* aArrayData, size_t aLength);
+
+template <size_t N>
+inline void RegisterArrayData(const ArrayData (&aData)[N]) {
+ RegisterArrayData(aData, N);
+}
+
+const ArrayData* FindArrayData(REFIID aIid, ULONG aMethodIndex);
+
+#endif // defined(MOZILLA_INTERNAL_API)
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_Registration_h
diff --git a/ipc/mscom/RegistrationAnnotator.cpp b/ipc/mscom/RegistrationAnnotator.cpp
new file mode 100644
index 0000000000..7c45920975
--- /dev/null
+++ b/ipc/mscom/RegistrationAnnotator.cpp
@@ -0,0 +1,385 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "RegistrationAnnotator.h"
+
+#include "mozilla/JSONStringWriteFuncs.h"
+#include "mozilla/mscom/Utils.h"
+#include "mozilla/NotNull.h"
+#include "nsExceptionHandler.h"
+#include "nsPrintfCString.h"
+#include "nsWindowsHelpers.h"
+#include "nsXULAppAPI.h"
+
+#include <oleauto.h>
+
+namespace mozilla {
+namespace mscom {
+
+static const char16_t kSoftwareClasses[] = u"SOFTWARE\\Classes";
+static const char16_t kInterface[] = u"\\Interface\\";
+static const char16_t kDefaultValue[] = u"";
+static const char16_t kThreadingModel[] = u"ThreadingModel";
+static const char16_t kBackslash[] = u"\\";
+static const char16_t kFlags[] = u"FLAGS";
+static const char16_t kProxyStubClsid32[] = u"\\ProxyStubClsid32";
+static const char16_t kClsid[] = u"\\CLSID\\";
+static const char16_t kInprocServer32[] = u"\\InprocServer32";
+static const char16_t kInprocHandler32[] = u"\\InprocHandler32";
+static const char16_t kTypeLib[] = u"\\TypeLib";
+static const char16_t kVersion[] = u"Version";
+static const char16_t kWin32[] = u"Win32";
+static const char16_t kWin64[] = u"Win64";
+
+static bool GetStringValue(HKEY aBaseKey, const nsAString& aStrSubKey,
+ const nsAString& aValueName, nsAString& aOutput) {
+ const nsString& flatSubKey = PromiseFlatString(aStrSubKey);
+ const nsString& flatValueName = PromiseFlatString(aValueName);
+ LPCWSTR valueName = aValueName.IsEmpty() ? nullptr : flatValueName.get();
+
+ DWORD type = 0;
+ DWORD numBytes = 0;
+ LONG result = RegGetValue(aBaseKey, flatSubKey.get(), valueName, RRF_RT_ANY,
+ &type, nullptr, &numBytes);
+ if (result != ERROR_SUCCESS || (type != REG_SZ && type != REG_EXPAND_SZ)) {
+ return false;
+ }
+
+ int numChars = (numBytes + 1) / sizeof(wchar_t);
+ aOutput.SetLength(numChars);
+
+ DWORD acceptFlag = type == REG_SZ ? RRF_RT_REG_SZ : RRF_RT_REG_EXPAND_SZ;
+
+ result = RegGetValue(aBaseKey, flatSubKey.get(), valueName, acceptFlag,
+ nullptr, aOutput.BeginWriting(), &numBytes);
+ if (result == ERROR_SUCCESS) {
+ // Truncate null terminator
+ aOutput.SetLength(((numBytes + 1) / sizeof(wchar_t)) - 1);
+ }
+
+ return result == ERROR_SUCCESS;
+}
+
+template <size_t N>
+inline static bool GetStringValue(HKEY aBaseKey, const nsAString& aStrSubKey,
+ const char16_t (&aValueName)[N],
+ nsAString& aOutput) {
+ return GetStringValue(aBaseKey, aStrSubKey, nsLiteralString(aValueName),
+ aOutput);
+}
+
+/**
+ * This function fails unless the entire string has been converted.
+ * (eg, the string "FLAGS" will convert to 0xF but we will return false)
+ */
+static bool ConvertLCID(const wchar_t* aStr, NotNull<unsigned long*> aOutLcid) {
+ wchar_t* endChar;
+ *aOutLcid = wcstoul(aStr, &endChar, 16);
+ return *endChar == 0;
+}
+
+static bool GetLoadedPath(nsAString& aPath) {
+ // These paths may be REG_EXPAND_SZ, so we expand any environment strings
+ DWORD bufCharLen =
+ ExpandEnvironmentStrings(PromiseFlatString(aPath).get(), nullptr, 0);
+
+ auto buf = MakeUnique<WCHAR[]>(bufCharLen);
+
+ if (!ExpandEnvironmentStrings(PromiseFlatString(aPath).get(), buf.get(),
+ bufCharLen)) {
+ return false;
+ }
+
+ // Use LoadLibrary so that the DLL is resolved using the loader's DLL search
+ // rules
+ nsModuleHandle mod(LoadLibrary(buf.get()));
+ if (!mod) {
+ return false;
+ }
+
+ WCHAR finalPath[MAX_PATH + 1] = {};
+ DWORD result = GetModuleFileNameW(mod, finalPath, ArrayLength(finalPath));
+ if (!result || (result == ArrayLength(finalPath) &&
+ GetLastError() == ERROR_INSUFFICIENT_BUFFER)) {
+ return false;
+ }
+
+ aPath = nsDependentString(finalPath, result);
+ return true;
+}
+
+static void AnnotateClsidRegistrationForHive(
+ JSONWriter& aJson, HKEY aHive, const nsAString& aClsid,
+ const JSONWriter::CollectionStyle aStyle) {
+ nsAutoString clsidSubkey;
+ clsidSubkey.AppendLiteral(kSoftwareClasses);
+ clsidSubkey.AppendLiteral(kClsid);
+ clsidSubkey.Append(aClsid);
+
+ nsAutoString className;
+ if (GetStringValue(aHive, clsidSubkey, kDefaultValue, className)) {
+ aJson.StringProperty("ClassName", NS_ConvertUTF16toUTF8(className));
+ }
+
+ nsAutoString inprocServerSubkey(clsidSubkey);
+ inprocServerSubkey.AppendLiteral(kInprocServer32);
+
+ nsAutoString pathToServerDll;
+ if (GetStringValue(aHive, inprocServerSubkey, kDefaultValue,
+ pathToServerDll)) {
+ aJson.StringProperty("Path", NS_ConvertUTF16toUTF8(pathToServerDll));
+ if (GetLoadedPath(pathToServerDll)) {
+ aJson.StringProperty("LoadedPath",
+ NS_ConvertUTF16toUTF8(pathToServerDll));
+ }
+ }
+
+ nsAutoString apartment;
+ if (GetStringValue(aHive, inprocServerSubkey, kThreadingModel, apartment)) {
+ aJson.StringProperty("ThreadingModel", NS_ConvertUTF16toUTF8(apartment));
+ }
+
+ nsAutoString inprocHandlerSubkey(clsidSubkey);
+ inprocHandlerSubkey.AppendLiteral(kInprocHandler32);
+ nsAutoString pathToHandlerDll;
+ if (GetStringValue(aHive, inprocHandlerSubkey, kDefaultValue,
+ pathToHandlerDll)) {
+ aJson.StringProperty("HandlerPath",
+ NS_ConvertUTF16toUTF8(pathToHandlerDll));
+ if (GetLoadedPath(pathToHandlerDll)) {
+ aJson.StringProperty("LoadedHandlerPath",
+ NS_ConvertUTF16toUTF8(pathToHandlerDll));
+ }
+ }
+
+ nsAutoString handlerApartment;
+ if (GetStringValue(aHive, inprocHandlerSubkey, kThreadingModel,
+ handlerApartment)) {
+ aJson.StringProperty("HandlerThreadingModel",
+ NS_ConvertUTF16toUTF8(handlerApartment));
+ }
+}
+
+static void CheckTlbPath(JSONWriter& aJson, const nsAString& aTypelibPath) {
+ const nsString& flatPath = PromiseFlatString(aTypelibPath);
+ DWORD bufCharLen = ExpandEnvironmentStrings(flatPath.get(), nullptr, 0);
+
+ auto buf = MakeUnique<WCHAR[]>(bufCharLen);
+
+ if (!ExpandEnvironmentStrings(flatPath.get(), buf.get(), bufCharLen)) {
+ return;
+ }
+
+ // See whether this tlb can actually be loaded
+ RefPtr<ITypeLib> typeLib;
+ HRESULT hr = LoadTypeLibEx(buf.get(), REGKIND_NONE, getter_AddRefs(typeLib));
+
+ nsPrintfCString loadResult("0x%08lX", hr);
+ aJson.StringProperty("LoadResult", loadResult);
+}
+
+template <size_t N>
+static void AnnotateTypelibPlatform(JSONWriter& aJson, HKEY aBaseKey,
+ const nsAString& aLcidSubkey,
+ const char16_t (&aPlatform)[N],
+ const JSONWriter::CollectionStyle aStyle) {
+ nsLiteralString platform(aPlatform);
+
+ nsAutoString fullSubkey(aLcidSubkey);
+ fullSubkey.AppendLiteral(kBackslash);
+ fullSubkey.Append(platform);
+
+ nsAutoString tlbPath;
+ if (GetStringValue(aBaseKey, fullSubkey, kDefaultValue, tlbPath)) {
+ aJson.StartObjectProperty(NS_ConvertUTF16toUTF8(platform), aStyle);
+ aJson.StringProperty("Path", NS_ConvertUTF16toUTF8(tlbPath));
+ CheckTlbPath(aJson, tlbPath);
+ aJson.EndObject();
+ }
+}
+
+static void AnnotateTypelibRegistrationForHive(
+ JSONWriter& aJson, HKEY aHive, const nsAString& aTypelibId,
+ const nsAString& aTypelibVersion,
+ const JSONWriter::CollectionStyle aStyle) {
+ nsAutoString typelibSubKey;
+ typelibSubKey.AppendLiteral(kSoftwareClasses);
+ typelibSubKey.AppendLiteral(kTypeLib);
+ typelibSubKey.AppendLiteral(kBackslash);
+ typelibSubKey.Append(aTypelibId);
+ typelibSubKey.AppendLiteral(kBackslash);
+ typelibSubKey.Append(aTypelibVersion);
+
+ nsAutoString typelibDesc;
+ if (GetStringValue(aHive, typelibSubKey, kDefaultValue, typelibDesc)) {
+ aJson.StringProperty("Description", NS_ConvertUTF16toUTF8(typelibDesc));
+ }
+
+ nsAutoString flagsSubKey(typelibSubKey);
+ flagsSubKey.AppendLiteral(kBackslash);
+ flagsSubKey.AppendLiteral(kFlags);
+
+ nsAutoString typelibFlags;
+ if (GetStringValue(aHive, flagsSubKey, kDefaultValue, typelibFlags)) {
+ aJson.StringProperty("Flags", NS_ConvertUTF16toUTF8(typelibFlags));
+ }
+
+ HKEY rawTypelibKey;
+ LONG result =
+ RegOpenKeyEx(aHive, typelibSubKey.get(), 0, KEY_READ, &rawTypelibKey);
+ if (result != ERROR_SUCCESS) {
+ return;
+ }
+ nsAutoRegKey typelibKey(rawTypelibKey);
+
+ const size_t kMaxLcidCharLen = 9;
+ WCHAR keyName[kMaxLcidCharLen];
+
+ for (DWORD index = 0; result == ERROR_SUCCESS; ++index) {
+ DWORD keyNameLength = ArrayLength(keyName);
+ result = RegEnumKeyEx(typelibKey, index, keyName, &keyNameLength, nullptr,
+ nullptr, nullptr, nullptr);
+
+ unsigned long lcid;
+ if (result == ERROR_SUCCESS && ConvertLCID(keyName, WrapNotNull(&lcid))) {
+ nsDependentString strLcid(keyName, keyNameLength);
+ aJson.StartObjectProperty(NS_ConvertUTF16toUTF8(strLcid), aStyle);
+ AnnotateTypelibPlatform(aJson, typelibKey, strLcid, kWin32, aStyle);
+#if defined(HAVE_64BIT_BUILD)
+ AnnotateTypelibPlatform(aJson, typelibKey, strLcid, kWin64, aStyle);
+#endif
+ aJson.EndObject();
+ }
+ }
+}
+
+static void AnnotateInterfaceRegistrationForHive(
+ JSONWriter& aJson, HKEY aHive, REFIID aIid,
+ const JSONWriter::CollectionStyle aStyle) {
+ nsAutoString interfaceSubKey;
+ interfaceSubKey.AppendLiteral(kSoftwareClasses);
+ interfaceSubKey.AppendLiteral(kInterface);
+ nsAutoString iid;
+ GUIDToString(aIid, iid);
+ interfaceSubKey.Append(iid);
+
+ nsAutoString interfaceName;
+ if (GetStringValue(aHive, interfaceSubKey, kDefaultValue, interfaceName)) {
+ aJson.StringProperty("InterfaceName", NS_ConvertUTF16toUTF8(interfaceName));
+ }
+
+ nsAutoString psSubKey(interfaceSubKey);
+ psSubKey.AppendLiteral(kProxyStubClsid32);
+
+ nsAutoString psClsid;
+ if (GetStringValue(aHive, psSubKey, kDefaultValue, psClsid)) {
+ aJson.StartObjectProperty("ProxyStub", aStyle);
+ aJson.StringProperty("CLSID", NS_ConvertUTF16toUTF8(psClsid));
+ AnnotateClsidRegistrationForHive(aJson, aHive, psClsid, aStyle);
+ aJson.EndObject();
+ }
+
+ nsAutoString typelibSubKey(interfaceSubKey);
+ typelibSubKey.AppendLiteral(kTypeLib);
+
+ nsAutoString typelibId;
+ bool haveTypelibId =
+ GetStringValue(aHive, typelibSubKey, kDefaultValue, typelibId);
+
+ nsAutoString typelibVersion;
+ bool haveTypelibVersion =
+ GetStringValue(aHive, typelibSubKey, kVersion, typelibVersion);
+
+ if (haveTypelibId || haveTypelibVersion) {
+ aJson.StartObjectProperty("TypeLib", aStyle);
+ }
+
+ if (haveTypelibId) {
+ aJson.StringProperty("ID", NS_ConvertUTF16toUTF8(typelibId));
+ }
+
+ if (haveTypelibVersion) {
+ aJson.StringProperty("Version", NS_ConvertUTF16toUTF8(typelibVersion));
+ }
+
+ if (haveTypelibId && haveTypelibVersion) {
+ AnnotateTypelibRegistrationForHive(aJson, aHive, typelibId, typelibVersion,
+ aStyle);
+ }
+
+ if (haveTypelibId || haveTypelibVersion) {
+ aJson.EndObject();
+ }
+}
+
+void AnnotateInterfaceRegistration(REFIID aIid) {
+#if defined(DEBUG)
+ const JSONWriter::CollectionStyle style = JSONWriter::MultiLineStyle;
+#else
+ const JSONWriter::CollectionStyle style = JSONWriter::SingleLineStyle;
+#endif
+
+ JSONStringWriteFunc<nsCString> jsonString;
+ JSONWriter json(jsonString);
+
+ json.Start(style);
+
+ json.StartObjectProperty("HKLM", style);
+ AnnotateInterfaceRegistrationForHive(json, HKEY_LOCAL_MACHINE, aIid, style);
+ json.EndObject();
+
+ json.StartObjectProperty("HKCU", style);
+ AnnotateInterfaceRegistrationForHive(json, HKEY_CURRENT_USER, aIid, style);
+ json.EndObject();
+
+ json.End();
+
+ CrashReporter::Annotation annotationKey;
+ if (XRE_IsParentProcess()) {
+ annotationKey = CrashReporter::Annotation::InterfaceRegistrationInfoParent;
+ } else {
+ annotationKey = CrashReporter::Annotation::InterfaceRegistrationInfoChild;
+ }
+ CrashReporter::AnnotateCrashReport(annotationKey, jsonString.StringCRef());
+}
+
+void AnnotateClassRegistration(REFCLSID aClsid) {
+#if defined(DEBUG)
+ const JSONWriter::CollectionStyle style = JSONWriter::MultiLineStyle;
+#else
+ const JSONWriter::CollectionStyle style = JSONWriter::SingleLineStyle;
+#endif
+
+ nsAutoString strClsid;
+ GUIDToString(aClsid, strClsid);
+
+ JSONStringWriteFunc<nsCString> jsonString;
+ JSONWriter json(jsonString);
+
+ json.Start(style);
+
+ json.StartObjectProperty("HKLM", style);
+ AnnotateClsidRegistrationForHive(json, HKEY_LOCAL_MACHINE, strClsid, style);
+ json.EndObject();
+
+ json.StartObjectProperty("HKCU", style);
+ AnnotateClsidRegistrationForHive(json, HKEY_CURRENT_USER, strClsid, style);
+ json.EndObject();
+
+ json.End();
+
+ CrashReporter::Annotation annotationKey;
+ if (XRE_IsParentProcess()) {
+ annotationKey = CrashReporter::Annotation::ClassRegistrationInfoParent;
+ } else {
+ annotationKey = CrashReporter::Annotation::ClassRegistrationInfoChild;
+ }
+
+ CrashReporter::AnnotateCrashReport(annotationKey, jsonString.StringCRef());
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/RegistrationAnnotator.h b/ipc/mscom/RegistrationAnnotator.h
new file mode 100644
index 0000000000..3065595045
--- /dev/null
+++ b/ipc/mscom/RegistrationAnnotator.h
@@ -0,0 +1,19 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_RegistrationAnnotator_h
+#define mozilla_mscom_RegistrationAnnotator_h
+
+namespace mozilla {
+namespace mscom {
+
+void AnnotateInterfaceRegistration(REFIID aIid);
+void AnnotateClassRegistration(REFCLSID aClsid);
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_RegistrationAnnotator_h
diff --git a/ipc/mscom/SpinEvent.cpp b/ipc/mscom/SpinEvent.cpp
new file mode 100644
index 0000000000..493a1646b7
--- /dev/null
+++ b/ipc/mscom/SpinEvent.cpp
@@ -0,0 +1,77 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/mscom/SpinEvent.h"
+
+#include "mozilla/ArrayUtils.h"
+#include "mozilla/Assertions.h"
+#include "mozilla/TimeStamp.h"
+#include "nsServiceManagerUtils.h"
+#include "nsString.h"
+#include "nsSystemInfo.h"
+
+namespace mozilla {
+namespace mscom {
+
+static const TimeDuration kMaxSpinTime = TimeDuration::FromMilliseconds(30);
+bool SpinEvent::sIsMulticore = false;
+
+/* static */
+bool SpinEvent::InitStatics() {
+ SYSTEM_INFO sysInfo;
+ ::GetSystemInfo(&sysInfo);
+ sIsMulticore = sysInfo.dwNumberOfProcessors > 1;
+ return true;
+}
+
+SpinEvent::SpinEvent() : mDone(false) {
+ static const bool gotStatics = InitStatics();
+ MOZ_ALWAYS_TRUE(gotStatics);
+
+ mDoneEvent.own(::CreateEventW(nullptr, FALSE, FALSE, nullptr));
+ MOZ_ASSERT(mDoneEvent);
+}
+
+bool SpinEvent::Wait(HANDLE aTargetThread) {
+ MOZ_ASSERT(aTargetThread);
+ if (!aTargetThread) {
+ return false;
+ }
+
+ if (sIsMulticore) {
+ // Bug 1311834: Spinning allows for faster response than waiting on an
+ // event, as events are constrained by the system's timer resolution.
+ // Bug 1429665: However, we only want to spin for a very short time. If
+ // we're waiting for a while, we don't want to be burning CPU for the
+ // entire time. At that point, a few extra ms isn't going to make much
+ // difference to perceived responsiveness.
+ TimeStamp start(TimeStamp::Now());
+ while (!mDone) {
+ TimeDuration elapsed(TimeStamp::Now() - start);
+ if (elapsed >= kMaxSpinTime) {
+ break;
+ }
+ YieldProcessor();
+ }
+ if (mDone) {
+ return true;
+ }
+ }
+
+ MOZ_ASSERT(mDoneEvent);
+ HANDLE handles[] = {mDoneEvent, aTargetThread};
+ DWORD waitResult = ::WaitForMultipleObjects(mozilla::ArrayLength(handles),
+ handles, FALSE, INFINITE);
+ return waitResult == WAIT_OBJECT_0;
+}
+
+void SpinEvent::Signal() {
+ ::SetEvent(mDoneEvent);
+ mDone = true;
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/SpinEvent.h b/ipc/mscom/SpinEvent.h
new file mode 100644
index 0000000000..461865a6bc
--- /dev/null
+++ b/ipc/mscom/SpinEvent.h
@@ -0,0 +1,40 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_SpinEvent_h
+#define mozilla_mscom_SpinEvent_h
+
+#include "mozilla/Atomics.h"
+#include "mozilla/Attributes.h"
+#include "nsWindowsHelpers.h"
+
+namespace mozilla {
+namespace mscom {
+
+class MOZ_NON_TEMPORARY_CLASS SpinEvent final {
+ public:
+ SpinEvent();
+ ~SpinEvent() = default;
+
+ bool Wait(HANDLE aTargetThread);
+ void Signal();
+
+ SpinEvent(const SpinEvent&) = delete;
+ SpinEvent(SpinEvent&&) = delete;
+ SpinEvent& operator=(SpinEvent&&) = delete;
+ SpinEvent& operator=(const SpinEvent&) = delete;
+
+ private:
+ Atomic<bool, ReleaseAcquire> mDone;
+ nsAutoHandle mDoneEvent;
+ static bool InitStatics();
+ static bool sIsMulticore;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_SpinEvent_h
diff --git a/ipc/mscom/StructStream.cpp b/ipc/mscom/StructStream.cpp
new file mode 100644
index 0000000000..4a24962a6f
--- /dev/null
+++ b/ipc/mscom/StructStream.cpp
@@ -0,0 +1,23 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <malloc.h>
+#include <rpc.h>
+
+/**
+ * These functions need to be defined in order for the types that use
+ * mozilla::mscom::StructToStream and mozilla::mscom::StructFromStream to work.
+ */
+extern "C" {
+
+void __RPC_FAR* __RPC_USER midl_user_allocate(size_t aNumBytes) {
+ const unsigned long kRpcReqdBufAlignment = 8;
+ return _aligned_malloc(aNumBytes, kRpcReqdBufAlignment);
+}
+
+void __RPC_USER midl_user_free(void* aBuffer) { _aligned_free(aBuffer); }
+
+} // extern "C"
diff --git a/ipc/mscom/StructStream.h b/ipc/mscom/StructStream.h
new file mode 100644
index 0000000000..45a32c7687
--- /dev/null
+++ b/ipc/mscom/StructStream.h
@@ -0,0 +1,239 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_StructStream_h
+#define mozilla_mscom_StructStream_h
+
+#include "mozilla/Attributes.h"
+#include "mozilla/UniquePtr.h"
+#include "nscore.h"
+
+#include <memory.h>
+#include <midles.h>
+#include <objidl.h>
+#include <rpc.h>
+
+/**
+ * This code is used for (de)serializing data structures that have been
+ * declared using midl, thus allowing us to use Microsoft RPC for marshaling
+ * data for our COM handlers that may run in other processes that are not ours.
+ */
+
+namespace mozilla {
+namespace mscom {
+
+namespace detail {
+
+typedef ULONG EncodedLenT;
+
+} // namespace detail
+
+class MOZ_NON_TEMPORARY_CLASS StructToStream {
+ public:
+ /**
+ * This constructor variant represents an empty/null struct to be serialized.
+ */
+ StructToStream()
+ : mStatus(RPC_S_OK), mHandle(nullptr), mBuffer(nullptr), mEncodedLen(0) {}
+
+ template <typename StructT>
+ StructToStream(StructT& aSrcStruct, void (*aEncodeFnPtr)(handle_t, StructT*))
+ : mStatus(RPC_X_INVALID_BUFFER),
+ mHandle(nullptr),
+ mBuffer(nullptr),
+ mEncodedLen(0) {
+ mStatus =
+ ::MesEncodeDynBufferHandleCreate(&mBuffer, &mEncodedLen, &mHandle);
+ if (mStatus != RPC_S_OK) {
+ return;
+ }
+
+ MOZ_SEH_TRY { aEncodeFnPtr(mHandle, &aSrcStruct); }
+#ifdef HAVE_SEH_EXCEPTIONS
+ MOZ_SEH_EXCEPT(EXCEPTION_EXECUTE_HANDLER) {
+ mStatus = ::RpcExceptionCode();
+ return;
+ }
+#endif
+
+ if (!mBuffer || !mEncodedLen) {
+ mStatus = RPC_X_NO_MEMORY;
+ return;
+ }
+ }
+
+ ~StructToStream() {
+ if (mHandle) {
+ ::MesHandleFree(mHandle);
+ }
+ if (mBuffer) {
+ // Bug 1440564: You'd think that MesHandleFree would free the buffer,
+ // since it was created by RPC, but it doesn't.
+ midl_user_free(mBuffer);
+ }
+ }
+
+ static unsigned long GetEmptySize() { return sizeof(detail::EncodedLenT); }
+
+ static HRESULT WriteEmpty(IStream* aDestStream) {
+ StructToStream emptyStruct;
+ return emptyStruct.Write(aDestStream);
+ }
+
+ explicit operator bool() const { return mStatus == RPC_S_OK; }
+
+ bool IsEmpty() const { return mStatus == RPC_S_OK && !mEncodedLen; }
+
+ unsigned long GetSize() const { return sizeof(mEncodedLen) + mEncodedLen; }
+
+ HRESULT Write(IStream* aDestStream) {
+ if (!aDestStream) {
+ return E_INVALIDARG;
+ }
+ if (mStatus != RPC_S_OK) {
+ return E_FAIL;
+ }
+
+ ULONG bytesWritten;
+ HRESULT hr =
+ aDestStream->Write(&mEncodedLen, sizeof(mEncodedLen), &bytesWritten);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ if (bytesWritten != sizeof(mEncodedLen)) {
+ return E_UNEXPECTED;
+ }
+
+ if (mBuffer && mEncodedLen) {
+ hr = aDestStream->Write(mBuffer, mEncodedLen, &bytesWritten);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ if (bytesWritten != mEncodedLen) {
+ return E_UNEXPECTED;
+ }
+ }
+
+ return hr;
+ }
+
+ StructToStream(const StructToStream&) = delete;
+ StructToStream(StructToStream&&) = delete;
+ StructToStream& operator=(const StructToStream&) = delete;
+ StructToStream& operator=(StructToStream&&) = delete;
+
+ private:
+ RPC_STATUS mStatus;
+ handle_t mHandle;
+ char* mBuffer;
+ detail::EncodedLenT mEncodedLen;
+};
+
+class MOZ_NON_TEMPORARY_CLASS StructFromStream {
+ struct AlignedFreeDeleter {
+ void operator()(void* aPtr) { ::_aligned_free(aPtr); }
+ };
+
+ static const detail::EncodedLenT kRpcReqdBufAlignment = 8;
+
+ public:
+ explicit StructFromStream(IStream* aStream)
+ : mStatus(RPC_X_INVALID_BUFFER), mHandle(nullptr) {
+ MOZ_ASSERT(aStream);
+
+ // Read the length of the encoded data first
+ detail::EncodedLenT encodedLen = 0;
+ ULONG bytesRead = 0;
+ HRESULT hr = aStream->Read(&encodedLen, sizeof(encodedLen), &bytesRead);
+ if (FAILED(hr)) {
+ return;
+ }
+
+ // NB: Some implementations of IStream return S_FALSE to indicate EOF,
+ // other implementations return S_OK and set the number of bytes read to 0.
+ // We must handle both.
+ if (hr == S_FALSE || !bytesRead) {
+ mStatus = RPC_S_OBJECT_NOT_FOUND;
+ return;
+ }
+
+ if (bytesRead != sizeof(encodedLen)) {
+ return;
+ }
+
+ if (!encodedLen) {
+ mStatus = RPC_S_OBJECT_NOT_FOUND;
+ return;
+ }
+
+ MOZ_ASSERT(encodedLen % kRpcReqdBufAlignment == 0);
+ if (encodedLen % kRpcReqdBufAlignment) {
+ return;
+ }
+
+ // This memory allocation is fallible
+ mEncodedBuffer.reset(static_cast<char*>(
+ ::_aligned_malloc(encodedLen, kRpcReqdBufAlignment)));
+ if (!mEncodedBuffer) {
+ return;
+ }
+
+ ULONG bytesReadFromStream = 0;
+ hr = aStream->Read(mEncodedBuffer.get(), encodedLen, &bytesReadFromStream);
+ if (FAILED(hr) || bytesReadFromStream != encodedLen) {
+ return;
+ }
+
+ mStatus = ::MesDecodeBufferHandleCreate(mEncodedBuffer.get(), encodedLen,
+ &mHandle);
+ }
+
+ ~StructFromStream() {
+ if (mHandle) {
+ ::MesHandleFree(mHandle);
+ }
+ }
+
+ explicit operator bool() const { return mStatus == RPC_S_OK || IsEmpty(); }
+
+ bool IsEmpty() const { return mStatus == RPC_S_OBJECT_NOT_FOUND; }
+
+ template <typename StructT>
+ bool Read(StructT* aDestStruct, void (*aDecodeFnPtr)(handle_t, StructT*)) {
+ if (!aDestStruct || !aDecodeFnPtr || mStatus != RPC_S_OK) {
+ return false;
+ }
+
+ // NB: Deserialization will fail with BSTRs unless the destination data
+ // is zeroed out!
+ ZeroMemory(aDestStruct, sizeof(StructT));
+
+ MOZ_SEH_TRY { aDecodeFnPtr(mHandle, aDestStruct); }
+#ifdef HAVE_SEH_EXCEPTIONS
+ MOZ_SEH_EXCEPT(EXCEPTION_EXECUTE_HANDLER) {
+ mStatus = ::RpcExceptionCode();
+ return false;
+ }
+#endif
+
+ return true;
+ }
+
+ StructFromStream(const StructFromStream&) = delete;
+ StructFromStream(StructFromStream&&) = delete;
+ StructFromStream& operator=(const StructFromStream&) = delete;
+ StructFromStream& operator=(StructFromStream&&) = delete;
+
+ private:
+ RPC_STATUS mStatus;
+ handle_t mHandle;
+ UniquePtr<char, AlignedFreeDeleter> mEncodedBuffer;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_StructStream_h
diff --git a/ipc/mscom/Utils.cpp b/ipc/mscom/Utils.cpp
new file mode 100644
index 0000000000..377bbeba5c
--- /dev/null
+++ b/ipc/mscom/Utils.cpp
@@ -0,0 +1,600 @@
+/* -*- Mode: C; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#if defined(MOZILLA_INTERNAL_API)
+# include "MainThreadUtils.h"
+# include "mozilla/dom/ContentChild.h"
+#endif
+
+#if defined(ACCESSIBILITY)
+# include "mozilla/mscom/Registration.h"
+# if defined(MOZILLA_INTERNAL_API)
+# include "nsTArray.h"
+# endif
+#endif
+
+#include "mozilla/ArrayUtils.h"
+#include "mozilla/mscom/COMWrappers.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/mscom/Objref.h"
+#include "mozilla/mscom/Utils.h"
+#include "mozilla/RefPtr.h"
+#include "mozilla/WindowsVersion.h"
+
+#include <objidl.h>
+#include <shlwapi.h>
+#include <winnt.h>
+
+#include <utility>
+
+#if defined(_MSC_VER)
+extern "C" IMAGE_DOS_HEADER __ImageBase;
+#endif
+
+namespace mozilla {
+namespace mscom {
+
+bool IsCOMInitializedOnCurrentThread() {
+ APTTYPE aptType;
+ APTTYPEQUALIFIER aptTypeQualifier;
+ HRESULT hr = wrapped::CoGetApartmentType(&aptType, &aptTypeQualifier);
+ return hr != CO_E_NOTINITIALIZED;
+}
+
+bool IsCurrentThreadMTA() {
+ APTTYPE aptType;
+ APTTYPEQUALIFIER aptTypeQualifier;
+ HRESULT hr = wrapped::CoGetApartmentType(&aptType, &aptTypeQualifier);
+ if (FAILED(hr)) {
+ return false;
+ }
+
+ return aptType == APTTYPE_MTA;
+}
+
+bool IsCurrentThreadExplicitMTA() {
+ APTTYPE aptType;
+ APTTYPEQUALIFIER aptTypeQualifier;
+ HRESULT hr = wrapped::CoGetApartmentType(&aptType, &aptTypeQualifier);
+ if (FAILED(hr)) {
+ return false;
+ }
+
+ return aptType == APTTYPE_MTA &&
+ aptTypeQualifier != APTTYPEQUALIFIER_IMPLICIT_MTA;
+}
+
+bool IsCurrentThreadImplicitMTA() {
+ APTTYPE aptType;
+ APTTYPEQUALIFIER aptTypeQualifier;
+ HRESULT hr = wrapped::CoGetApartmentType(&aptType, &aptTypeQualifier);
+ if (FAILED(hr)) {
+ return false;
+ }
+
+ return aptType == APTTYPE_MTA &&
+ aptTypeQualifier == APTTYPEQUALIFIER_IMPLICIT_MTA;
+}
+
+#if defined(MOZILLA_INTERNAL_API)
+bool IsCurrentThreadNonMainMTA() {
+ if (NS_IsMainThread()) {
+ return false;
+ }
+
+ return IsCurrentThreadMTA();
+}
+#endif // defined(MOZILLA_INTERNAL_API)
+
+bool IsProxy(IUnknown* aUnknown) {
+ if (!aUnknown) {
+ return false;
+ }
+
+ // Only proxies implement this interface, so if it is present then we must
+ // be dealing with a proxied object.
+ RefPtr<IClientSecurity> clientSecurity;
+ HRESULT hr = aUnknown->QueryInterface(IID_IClientSecurity,
+ (void**)getter_AddRefs(clientSecurity));
+ if (SUCCEEDED(hr) || hr == RPC_E_WRONG_THREAD) {
+ return true;
+ }
+ return false;
+}
+
+bool IsValidGUID(REFGUID aCheckGuid) {
+ // This function determines whether or not aCheckGuid conforms to RFC4122
+ // as it applies to Microsoft COM.
+
+ BYTE variant = aCheckGuid.Data4[0];
+ if (!(variant & 0x80)) {
+ // NCS Reserved
+ return false;
+ }
+ if ((variant & 0xE0) == 0xE0) {
+ // Reserved for future use
+ return false;
+ }
+ if ((variant & 0xC0) == 0xC0) {
+ // Microsoft Reserved.
+ return true;
+ }
+
+ BYTE version = HIBYTE(aCheckGuid.Data3) >> 4;
+ // Other versions are specified in RFC4122 but these are the two used by COM.
+ return version == 1 || version == 4;
+}
+
+uintptr_t GetContainingModuleHandle() {
+ HMODULE thisModule = nullptr;
+#if defined(_MSC_VER)
+ thisModule = reinterpret_cast<HMODULE>(&__ImageBase);
+#else
+ if (!GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
+ GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
+ reinterpret_cast<LPCTSTR>(&GetContainingModuleHandle),
+ &thisModule)) {
+ return 0;
+ }
+#endif
+ return reinterpret_cast<uintptr_t>(thisModule);
+}
+
+namespace detail {
+
+long BuildRegGuidPath(REFGUID aGuid, const GuidType aGuidType, wchar_t* aBuf,
+ const size_t aBufLen) {
+ constexpr wchar_t kClsid[] = L"CLSID\\";
+ constexpr wchar_t kAppid[] = L"AppID\\";
+ constexpr wchar_t kSubkeyBase[] = L"SOFTWARE\\Classes\\";
+
+ // We exclude null terminators in these length calculations because we include
+ // the stringified GUID's null terminator at the end. Since kClsid and kAppid
+ // have identical lengths, we just choose one to compute this length.
+ constexpr size_t kSubkeyBaseLen = mozilla::ArrayLength(kSubkeyBase) - 1;
+ constexpr size_t kSubkeyLen =
+ kSubkeyBaseLen + mozilla::ArrayLength(kClsid) - 1;
+ // Guid length as formatted for the registry (including curlies and dashes),
+ // but excluding null terminator.
+ constexpr size_t kGuidLen = kGuidRegFormatCharLenInclNul - 1;
+ constexpr size_t kExpectedPathLenInclNul = kSubkeyLen + kGuidLen + 1;
+
+ if (aBufLen < kExpectedPathLenInclNul) {
+ // Buffer is too short
+ return E_INVALIDARG;
+ }
+
+ if (wcscpy_s(aBuf, aBufLen, kSubkeyBase)) {
+ return E_INVALIDARG;
+ }
+
+ const wchar_t* strGuidType = aGuidType == GuidType::CLSID ? kClsid : kAppid;
+ if (wcscat_s(aBuf, aBufLen, strGuidType)) {
+ return E_INVALIDARG;
+ }
+
+ int guidConversionResult =
+ ::StringFromGUID2(aGuid, &aBuf[kSubkeyLen], aBufLen - kSubkeyLen);
+ if (!guidConversionResult) {
+ return E_INVALIDARG;
+ }
+
+ return S_OK;
+}
+
+} // namespace detail
+
+long CreateStream(const uint8_t* aInitBuf, const uint32_t aInitBufSize,
+ IStream** aOutStream) {
+ if (!aInitBufSize || !aOutStream) {
+ return E_INVALIDARG;
+ }
+
+ *aOutStream = nullptr;
+
+ HRESULT hr;
+ RefPtr<IStream> stream;
+
+ if (IsWin8OrLater()) {
+ // SHCreateMemStream is not safe for us to use until Windows 8. On older
+ // versions of Windows it is not thread-safe and it creates IStreams that do
+ // not support the full IStream API.
+
+ // If aInitBuf is null then initSize must be 0.
+ UINT initSize = aInitBuf ? aInitBufSize : 0;
+ stream = already_AddRefed<IStream>(::SHCreateMemStream(aInitBuf, initSize));
+ if (!stream) {
+ return E_OUTOFMEMORY;
+ }
+
+ if (!aInitBuf) {
+ // Now we'll set the required size
+ ULARGE_INTEGER newSize;
+ newSize.QuadPart = aInitBufSize;
+ hr = stream->SetSize(newSize);
+ if (FAILED(hr)) {
+ return hr;
+ }
+ }
+ } else {
+ HGLOBAL hglobal = ::GlobalAlloc(GMEM_MOVEABLE, aInitBufSize);
+ if (!hglobal) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+
+ // stream takes ownership of hglobal if this call is successful
+ hr = ::CreateStreamOnHGlobal(hglobal, TRUE, getter_AddRefs(stream));
+ if (FAILED(hr)) {
+ ::GlobalFree(hglobal);
+ return hr;
+ }
+
+ // The default stream size is derived from ::GlobalSize(hglobal), which due
+ // to rounding may be larger than aInitBufSize. We forcibly set the correct
+ // stream size here.
+ ULARGE_INTEGER streamSize;
+ streamSize.QuadPart = aInitBufSize;
+ hr = stream->SetSize(streamSize);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ if (aInitBuf) {
+ ULONG bytesWritten;
+ hr = stream->Write(aInitBuf, aInitBufSize, &bytesWritten);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ if (bytesWritten != aInitBufSize) {
+ return E_UNEXPECTED;
+ }
+ }
+ }
+
+ // Ensure that the stream is rewound
+ LARGE_INTEGER streamOffset;
+ streamOffset.QuadPart = 0LL;
+ hr = stream->Seek(streamOffset, STREAM_SEEK_SET, nullptr);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ stream.forget(aOutStream);
+ return S_OK;
+}
+
+long CopySerializedProxy(IStream* aInStream, IStream** aOutStream) {
+ if (!aInStream || !aOutStream) {
+ return E_INVALIDARG;
+ }
+
+ *aOutStream = nullptr;
+
+ uint32_t desiredStreamSize = GetOBJREFSize(WrapNotNull(aInStream));
+ if (!desiredStreamSize) {
+ return E_INVALIDARG;
+ }
+
+ RefPtr<IStream> stream;
+ HRESULT hr = CreateStream(nullptr, desiredStreamSize, getter_AddRefs(stream));
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ ULARGE_INTEGER numBytesToCopy;
+ numBytesToCopy.QuadPart = desiredStreamSize;
+ hr = aInStream->CopyTo(stream, numBytesToCopy, nullptr, nullptr);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ LARGE_INTEGER seekTo;
+ seekTo.QuadPart = 0LL;
+ hr = stream->Seek(seekTo, STREAM_SEEK_SET, nullptr);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ stream.forget(aOutStream);
+ return S_OK;
+}
+
+#if defined(MOZILLA_INTERNAL_API)
+
+void GUIDToString(REFGUID aGuid, nsAString& aOutString) {
+ // This buffer length is long enough to hold a GUID string that is formatted
+ // to include curly braces and dashes.
+ const int kBufLenWithNul = 39;
+ aOutString.SetLength(kBufLenWithNul);
+ int result = StringFromGUID2(aGuid, char16ptr_t(aOutString.BeginWriting()),
+ kBufLenWithNul);
+ MOZ_ASSERT(result);
+ if (result) {
+ // Truncate the terminator
+ aOutString.SetLength(result - 1);
+ }
+}
+
+// Undocumented IIDs that are relevant for diagnostic purposes
+static const IID IID_ISCMLocalActivator = {
+ 0x00000136,
+ 0x0000,
+ 0x0000,
+ {0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46}};
+static const IID IID_IRundown = {
+ 0x00000134,
+ 0x0000,
+ 0x0000,
+ {0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46}};
+static const IID IID_IRemUnknown = {
+ 0x00000131,
+ 0x0000,
+ 0x0000,
+ {0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46}};
+static const IID IID_IRemUnknown2 = {
+ 0x00000143,
+ 0x0000,
+ 0x0000,
+ {0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46}};
+
+struct IIDToLiteralMapEntry {
+ constexpr IIDToLiteralMapEntry(REFIID aIid, nsLiteralCString&& aStr)
+ : mIid(aIid), mStr(std::forward<nsLiteralCString>(aStr)) {}
+
+ REFIID mIid;
+ const nsLiteralCString mStr;
+};
+
+/**
+ * Given the name of an interface, the IID_ENTRY macro generates a pair
+ * containing a reference to the interface ID and a stringified version of
+ * the interface name.
+ *
+ * For example:
+ *
+ * {IID_ENTRY(IUnknown)}
+ * is expanded to:
+ * {IID_IUnknown, "IUnknown"_ns}
+ *
+ */
+// clang-format off
+# define IID_ENTRY_STRINGIFY(iface) #iface##_ns
+# define IID_ENTRY(iface) IID_##iface, IID_ENTRY_STRINGIFY(iface)
+// clang-format on
+
+// Mapping of selected IIDs to friendly, human readable descriptions for each
+// interface.
+static constexpr IIDToLiteralMapEntry sIidDiagStrs[] = {
+ {IID_ENTRY(IUnknown)},
+ {IID_IRemUnknown, "cross-apartment IUnknown"_ns},
+ {IID_IRundown, "cross-apartment object management"_ns},
+ {IID_ISCMLocalActivator, "out-of-process object instantiation"_ns},
+ {IID_IRemUnknown2, "cross-apartment IUnknown"_ns}};
+
+# undef IID_ENTRY
+# undef IID_ENTRY_STRINGIFY
+
+void DiagnosticNameForIID(REFIID aIid, nsACString& aOutString) {
+ // If the IID matches something in sIidDiagStrs, output its string.
+ for (const auto& curEntry : sIidDiagStrs) {
+ if (curEntry.mIid == aIid) {
+ aOutString.Assign(curEntry.mStr);
+ return;
+ }
+ }
+
+ // Otherwise just convert the IID to string form and output that.
+ nsAutoString strIid;
+ GUIDToString(aIid, strIid);
+
+ aOutString.AssignLiteral("IID ");
+ AppendUTF16toUTF8(strIid, aOutString);
+}
+
+#else
+
+void GUIDToString(REFGUID aGuid,
+ wchar_t (&aOutBuf)[kGuidRegFormatCharLenInclNul]) {
+ DebugOnly<int> result =
+ ::StringFromGUID2(aGuid, aOutBuf, ArrayLength(aOutBuf));
+ MOZ_ASSERT(result);
+}
+
+#endif // defined(MOZILLA_INTERNAL_API)
+
+#if defined(ACCESSIBILITY)
+
+static bool IsVtableIndexFromParentInterface(TYPEATTR* aTypeAttr,
+ unsigned long aVtableIndex) {
+ MOZ_ASSERT(aTypeAttr);
+
+ // This is the number of functions declared in this interface (excluding
+ // parent interfaces).
+ unsigned int numExclusiveFuncs = aTypeAttr->cFuncs;
+
+ // This is the number of vtable entries (which includes parent interfaces).
+ // TYPEATTR::cbSizeVft is the entire vtable size in bytes, so we need to
+ // divide in order to compute the number of entries.
+ unsigned int numVtblEntries = aTypeAttr->cbSizeVft / sizeof(void*);
+
+ // This is the index of the first entry in the vtable that belongs to this
+ // interface and not a parent.
+ unsigned int firstVtblIndex = numVtblEntries - numExclusiveFuncs;
+
+ // If aVtableIndex is less than firstVtblIndex, then we're asking for an
+ // index that may belong to a parent interface.
+ return aVtableIndex < firstVtblIndex;
+}
+
+bool IsVtableIndexFromParentInterface(REFIID aInterface,
+ unsigned long aVtableIndex) {
+ RefPtr<ITypeInfo> typeInfo;
+ if (!RegisteredProxy::Find(aInterface, getter_AddRefs(typeInfo))) {
+ return false;
+ }
+
+ TYPEATTR* typeAttr = nullptr;
+ HRESULT hr = typeInfo->GetTypeAttr(&typeAttr);
+ if (FAILED(hr)) {
+ return false;
+ }
+
+ bool result = IsVtableIndexFromParentInterface(typeAttr, aVtableIndex);
+
+ typeInfo->ReleaseTypeAttr(typeAttr);
+ return result;
+}
+
+# if defined(MOZILLA_INTERNAL_API)
+
+bool IsCallerExternalProcess() {
+ MOZ_ASSERT(XRE_IsContentProcess());
+
+ /**
+ * CoGetCallerTID() gives us the caller's thread ID when that thread resides
+ * in a single-threaded apartment. Since our chrome main thread does live
+ * inside an STA, we will therefore be able to check whether the caller TID
+ * equals our chrome main thread TID. This enables us to distinguish
+ * between our chrome thread vs other out-of-process callers. We check for
+ * S_FALSE to ensure that the caller is a different process from ours, which
+ * is the only scenario that we care about.
+ */
+ DWORD callerTid;
+ if (::CoGetCallerTID(&callerTid) != S_FALSE) {
+ return false;
+ }
+
+ // Now check whether the caller is our parent process main thread.
+ const DWORD parentMainTid =
+ dom::ContentChild::GetSingleton()->GetChromeMainThreadId();
+ return callerTid != parentMainTid;
+}
+
+bool IsInterfaceEqualToOrInheritedFrom(REFIID aInterface, REFIID aFrom,
+ unsigned long aVtableIndexHint) {
+ if (aInterface == aFrom) {
+ return true;
+ }
+
+ // We expect this array to be length 1 but that is not guaranteed by the API.
+ AutoTArray<RefPtr<ITypeInfo>, 1> typeInfos;
+
+ // Grab aInterface's ITypeInfo so that we may obtain information about its
+ // inheritance hierarchy.
+ RefPtr<ITypeInfo> typeInfo;
+ if (RegisteredProxy::Find(aInterface, getter_AddRefs(typeInfo))) {
+ typeInfos.AppendElement(std::move(typeInfo));
+ }
+
+ /**
+ * The main loop of this function searches the hierarchy of aInterface's
+ * parent interfaces, searching for aFrom.
+ */
+ while (!typeInfos.IsEmpty()) {
+ RefPtr<ITypeInfo> curTypeInfo(typeInfos.PopLastElement());
+
+ TYPEATTR* typeAttr = nullptr;
+ HRESULT hr = curTypeInfo->GetTypeAttr(&typeAttr);
+ if (FAILED(hr)) {
+ break;
+ }
+
+ bool isFromParentVtable =
+ IsVtableIndexFromParentInterface(typeAttr, aVtableIndexHint);
+ WORD numParentInterfaces = typeAttr->cImplTypes;
+
+ curTypeInfo->ReleaseTypeAttr(typeAttr);
+ typeAttr = nullptr;
+
+ if (!isFromParentVtable) {
+ // The vtable index cannot belong to this interface (otherwise the IIDs
+ // would already have matched and we would have returned true). Since we
+ // now also know that the vtable index cannot possibly be contained inside
+ // curTypeInfo's parent interface, there is no point searching any further
+ // up the hierarchy from here. OTOH we still should check any remaining
+ // entries that are still in the typeInfos array, so we continue.
+ continue;
+ }
+
+ for (WORD i = 0; i < numParentInterfaces; ++i) {
+ HREFTYPE refCookie;
+ hr = curTypeInfo->GetRefTypeOfImplType(i, &refCookie);
+ if (FAILED(hr)) {
+ continue;
+ }
+
+ RefPtr<ITypeInfo> nextTypeInfo;
+ hr = curTypeInfo->GetRefTypeInfo(refCookie, getter_AddRefs(nextTypeInfo));
+ if (FAILED(hr)) {
+ continue;
+ }
+
+ hr = nextTypeInfo->GetTypeAttr(&typeAttr);
+ if (FAILED(hr)) {
+ continue;
+ }
+
+ IID nextIid = typeAttr->guid;
+
+ nextTypeInfo->ReleaseTypeAttr(typeAttr);
+ typeAttr = nullptr;
+
+ if (nextIid == aFrom) {
+ return true;
+ }
+
+ typeInfos.AppendElement(std::move(nextTypeInfo));
+ }
+ }
+
+ return false;
+}
+
+# endif // defined(MOZILLA_INTERNAL_API)
+
+#endif // defined(ACCESSIBILITY)
+
+#if defined(MOZILLA_INTERNAL_API)
+bool IsClassThreadAwareInprocServer(REFCLSID aClsid) {
+ nsAutoString strClsid;
+ GUIDToString(aClsid, strClsid);
+
+ nsAutoString inprocServerSubkey(u"CLSID\\"_ns);
+ inprocServerSubkey.Append(strClsid);
+ inprocServerSubkey.Append(u"\\InprocServer32"_ns);
+
+ // Of the possible values, "Apartment" is the longest, so we'll make this
+ // buffer large enough to hold that one.
+ wchar_t threadingModelBuf[ArrayLength(L"Apartment")] = {};
+
+ DWORD numBytes = sizeof(threadingModelBuf);
+ LONG result = ::RegGetValueW(HKEY_CLASSES_ROOT, inprocServerSubkey.get(),
+ L"ThreadingModel", RRF_RT_REG_SZ, nullptr,
+ threadingModelBuf, &numBytes);
+ if (result != ERROR_SUCCESS) {
+ // This will also handle the case where the CLSID is not an inproc server.
+ return false;
+ }
+
+ DWORD numChars = numBytes / sizeof(wchar_t);
+ // numChars includes the null terminator
+ if (numChars <= 1) {
+ return false;
+ }
+
+ nsDependentString threadingModel(threadingModelBuf, numChars - 1);
+
+ // Ensure that the threading model is one of the known values that indicates
+ // that the class can operate natively (ie, no proxying) inside a MTA.
+ return threadingModel.LowerCaseEqualsLiteral("both") ||
+ threadingModel.LowerCaseEqualsLiteral("free") ||
+ threadingModel.LowerCaseEqualsLiteral("neutral");
+}
+#endif // defined(MOZILLA_INTERNAL_API)
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/Utils.h b/ipc/mscom/Utils.h
new file mode 100644
index 0000000000..c36efde3d3
--- /dev/null
+++ b/ipc/mscom/Utils.h
@@ -0,0 +1,168 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_Utils_h
+#define mozilla_mscom_Utils_h
+
+#if defined(MOZILLA_INTERNAL_API)
+# include "nsString.h"
+#endif // defined(MOZILLA_INTERNAL_API)
+
+#include "mozilla/Attributes.h"
+#include <guiddef.h>
+#include <stdint.h>
+
+struct IStream;
+struct IUnknown;
+
+namespace mozilla {
+namespace mscom {
+namespace detail {
+
+enum class GuidType {
+ CLSID,
+ AppID,
+};
+
+long BuildRegGuidPath(REFGUID aGuid, const GuidType aGuidType, wchar_t* aBuf,
+ const size_t aBufLen);
+
+} // namespace detail
+
+bool IsCOMInitializedOnCurrentThread();
+bool IsCurrentThreadMTA();
+bool IsCurrentThreadExplicitMTA();
+bool IsCurrentThreadImplicitMTA();
+#if defined(MOZILLA_INTERNAL_API)
+bool IsCurrentThreadNonMainMTA();
+#endif // defined(MOZILLA_INTERNAL_API)
+bool IsProxy(IUnknown* aUnknown);
+bool IsValidGUID(REFGUID aCheckGuid);
+uintptr_t GetContainingModuleHandle();
+
+template <size_t N>
+inline long BuildAppidPath(REFGUID aAppId, wchar_t (&aPath)[N]) {
+ return detail::BuildRegGuidPath(aAppId, detail::GuidType::AppID, aPath, N);
+}
+
+template <size_t N>
+inline long BuildClsidPath(REFCLSID aClsid, wchar_t (&aPath)[N]) {
+ return detail::BuildRegGuidPath(aClsid, detail::GuidType::CLSID, aPath, N);
+}
+
+/**
+ * Given a buffer, create a new IStream object.
+ * @param aBuf Buffer containing data to initialize the stream. This parameter
+ * may be nullptr, causing the stream to be created with aBufLen
+ * bytes of uninitialized data.
+ * @param aBufLen Length of data in aBuf, or desired stream size if aBuf is
+ * nullptr.
+ * @param aOutStream Outparam to receive the newly created stream.
+ * @return HRESULT error code.
+ */
+long CreateStream(const uint8_t* aBuf, const uint32_t aBufLen,
+ IStream** aOutStream);
+
+/**
+ * Creates a deep copy of a proxy contained in a stream.
+ * @param aInStream Stream containing the proxy to copy. Its seek pointer must
+ * be positioned to point at the beginning of the proxy data.
+ * @param aOutStream Outparam to receive the newly created stream.
+ * @return HRESULT error code.
+ */
+long CopySerializedProxy(IStream* aInStream, IStream** aOutStream);
+
+/**
+ * Length of a stringified GUID as formatted for the registry, i.e. including
+ * curly-braces and dashes.
+ */
+constexpr size_t kGuidRegFormatCharLenInclNul = 39;
+
+#if defined(MOZILLA_INTERNAL_API)
+/**
+ * Checks the registry to see if |aClsid| is a thread-aware in-process server.
+ *
+ * In DCOM, an in-process server is a server that is implemented inside a DLL
+ * that is loaded into the client's process for execution. If |aClsid| declares
+ * itself to be a local server (that is, a server that resides in another
+ * process), this function returns false.
+ *
+ * For the server to be thread-aware, its registry entry must declare a
+ * ThreadingModel that is one of "Free", "Both", or "Neutral". If the threading
+ * model is "Apartment" or some other, invalid value, the class is treated as
+ * being single-threaded.
+ *
+ * NB: This function cannot check CLSIDs that were registered via manifests,
+ * as unfortunately there is not a documented API available to query for those.
+ * This should not be an issue for most CLSIDs that Gecko is interested in, as
+ * we typically instantiate system CLSIDs which are available in the registry.
+ *
+ * @param aClsid The CLSID of the COM class to be checked.
+ * @return true if the class meets the above criteria, otherwise false.
+ */
+bool IsClassThreadAwareInprocServer(REFCLSID aClsid);
+
+void GUIDToString(REFGUID aGuid, nsAString& aOutString);
+
+/**
+ * Converts an IID to a human-readable string for the purposes of diagnostic
+ * tools such as the profiler. For some special cases, we output a friendly
+ * string that describes the purpose of the interface. If no such description
+ * exists, we simply fall back to outputting the IID as a string formatted by
+ * GUIDToString().
+ */
+void DiagnosticNameForIID(REFIID aIid, nsACString& aOutString);
+#else
+void GUIDToString(REFGUID aGuid,
+ wchar_t (&aOutBuf)[kGuidRegFormatCharLenInclNul]);
+#endif // defined(MOZILLA_INTERNAL_API)
+
+#if defined(ACCESSIBILITY)
+bool IsVtableIndexFromParentInterface(REFIID aInterface,
+ unsigned long aVtableIndex);
+
+# if defined(MOZILLA_INTERNAL_API)
+bool IsCallerExternalProcess();
+
+bool IsInterfaceEqualToOrInheritedFrom(REFIID aInterface, REFIID aFrom,
+ unsigned long aVtableIndexHint);
+# endif // defined(MOZILLA_INTERNAL_API)
+
+#endif // defined(ACCESSIBILITY)
+
+/**
+ * Execute cleanup code when going out of scope if a condition is met.
+ * This is useful when, for example, particular cleanup needs to be performed
+ * whenever a call returns a failure HRESULT.
+ * Both the condition and cleanup code are provided as functions (usually
+ * lambdas).
+ */
+template <typename CondFnT, typename ExeFnT>
+class MOZ_RAII ExecuteWhen final {
+ public:
+ ExecuteWhen(CondFnT& aCondFn, ExeFnT& aExeFn)
+ : mCondFn(aCondFn), mExeFn(aExeFn) {}
+
+ ~ExecuteWhen() {
+ if (mCondFn()) {
+ mExeFn();
+ }
+ }
+
+ ExecuteWhen(const ExecuteWhen&) = delete;
+ ExecuteWhen(ExecuteWhen&&) = delete;
+ ExecuteWhen& operator=(const ExecuteWhen&) = delete;
+ ExecuteWhen& operator=(ExecuteWhen&&) = delete;
+
+ private:
+ CondFnT& mCondFn;
+ ExeFnT& mExeFn;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_Utils_h
diff --git a/ipc/mscom/VTableBuilder.c b/ipc/mscom/VTableBuilder.c
new file mode 100644
index 0000000000..48550503c2
--- /dev/null
+++ b/ipc/mscom/VTableBuilder.c
@@ -0,0 +1,54 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "VTableBuilder.h"
+
+#include <stdlib.h>
+
+static HRESULT STDMETHODCALLTYPE QueryInterfaceThunk(IUnknown* aThis,
+ REFIID aIid,
+ void** aOutInterface) {
+ void** table = (void**)aThis;
+ IUnknown* real = (IUnknown*)table[1];
+ return real->lpVtbl->QueryInterface(real, aIid, aOutInterface);
+}
+
+static ULONG STDMETHODCALLTYPE AddRefThunk(IUnknown* aThis) {
+ void** table = (void**)aThis;
+ IUnknown* real = (IUnknown*)table[1];
+ return real->lpVtbl->AddRef(real);
+}
+
+static ULONG STDMETHODCALLTYPE ReleaseThunk(IUnknown* aThis) {
+ void** table = (void**)aThis;
+ IUnknown* real = (IUnknown*)table[1];
+ return real->lpVtbl->Release(real);
+}
+
+IUnknown* BuildNullVTable(IUnknown* aUnk, uint32_t aVtblSize) {
+ void** table;
+
+ if (aVtblSize < 3) {
+ return NULL;
+ }
+
+ // We need to allocate slots for two additional pointers: The |lpVtbl|
+ // pointer, as well as |aUnk| which is needed to get |this| right when
+ // something calls through one of the IUnknown thunks.
+ table = calloc(aVtblSize + 2, sizeof(void*));
+
+ table[0] = &table[2]; // |lpVtbl|, points to the first entry of the vtable
+ table[1] = aUnk; // |this|
+ // Now the actual vtable entries for IUnknown
+ table[2] = &QueryInterfaceThunk;
+ table[3] = &AddRefThunk;
+ table[4] = &ReleaseThunk;
+ // Remaining entries are NULL thanks to calloc zero-initializing everything.
+
+ return (IUnknown*)table;
+}
+
+void DeleteNullVTable(IUnknown* aUnk) { free(aUnk); }
diff --git a/ipc/mscom/VTableBuilder.h b/ipc/mscom/VTableBuilder.h
new file mode 100644
index 0000000000..765ec51b0b
--- /dev/null
+++ b/ipc/mscom/VTableBuilder.h
@@ -0,0 +1,37 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_VTableBuilder_h
+#define mozilla_mscom_VTableBuilder_h
+
+#include <stdint.h>
+#include <unknwn.h>
+
+#if defined(__cplusplus)
+extern "C" {
+#endif
+
+/**
+ * This function constructs an interface with |aVtblSize| entries, where the
+ * IUnknown methods are valid and redirect to |aUnk|. Other than the IUnknown
+ * methods, the resulting interface is not intended to actually be called; the
+ * remaining vtable entries are null pointers to enforce that.
+ *
+ * @param aUnk The IUnknown to which the null vtable will redirect its IUnknown
+ * methods.
+ * @param aVtblSize The total size of the vtable, including the IUnknown
+ * entries (thus this parameter must be >= 3).
+ * @return The resulting IUnknown, or nullptr on error.
+ */
+IUnknown* BuildNullVTable(IUnknown* aUnk, uint32_t aVtblSize);
+
+void DeleteNullVTable(IUnknown* aUnk);
+
+#if defined(__cplusplus)
+}
+#endif
+
+#endif // mozilla_mscom_VTableBuilder_h
diff --git a/ipc/mscom/WeakRef.cpp b/ipc/mscom/WeakRef.cpp
new file mode 100644
index 0000000000..b15c027153
--- /dev/null
+++ b/ipc/mscom/WeakRef.cpp
@@ -0,0 +1,225 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#define INITGUID
+#include "mozilla/mscom/WeakRef.h"
+
+#include "mozilla/DebugOnly.h"
+#include "mozilla/Mutex.h"
+#include "nsThreadUtils.h"
+#include "nsWindowsHelpers.h"
+#include "nsProxyRelease.h"
+
+static void InitializeCS(CRITICAL_SECTION& aCS) {
+ DWORD flags = 0;
+#if defined(RELEASE_OR_BETA)
+ flags |= CRITICAL_SECTION_NO_DEBUG_INFO;
+#endif
+ InitializeCriticalSectionEx(&aCS, 4000, flags);
+}
+
+namespace mozilla {
+namespace mscom {
+
+namespace detail {
+
+SharedRef::SharedRef(WeakReferenceSupport* aSupport) : mSupport(aSupport) {
+ ::InitializeCS(mCS);
+}
+
+SharedRef::~SharedRef() { ::DeleteCriticalSection(&mCS); }
+
+void SharedRef::Lock() { ::EnterCriticalSection(&mCS); }
+
+void SharedRef::Unlock() { ::LeaveCriticalSection(&mCS); }
+
+HRESULT
+SharedRef::ToStrongRef(IWeakReferenceSource** aOutStrongReference) {
+ RefPtr<IWeakReferenceSource> strongRef;
+
+ { // Scope for lock
+ AutoCriticalSection lock(&mCS);
+ if (!mSupport) {
+ return E_POINTER;
+ }
+ strongRef = mSupport;
+ }
+
+ strongRef.forget(aOutStrongReference);
+ return S_OK;
+}
+
+HRESULT
+SharedRef::Resolve(REFIID aIid, void** aOutStrongReference) {
+ RefPtr<WeakReferenceSupport> strongRef;
+
+ { // Scope for lock
+ AutoCriticalSection lock(&mCS);
+ if (!mSupport) {
+ return E_POINTER;
+ }
+ strongRef = mSupport;
+ }
+
+ return strongRef->QueryInterface(aIid, aOutStrongReference);
+}
+
+void SharedRef::Clear() {
+ AutoCriticalSection lock(&mCS);
+ mSupport = nullptr;
+}
+
+} // namespace detail
+
+typedef mozilla::detail::BaseAutoLock<detail::SharedRef&> SharedRefAutoLock;
+typedef mozilla::detail::BaseAutoUnlock<detail::SharedRef&> SharedRefAutoUnlock;
+
+WeakReferenceSupport::WeakReferenceSupport(Flags aFlags)
+ : mRefCnt(0), mFlags(aFlags) {
+ mSharedRef = new detail::SharedRef(this);
+}
+
+HRESULT
+WeakReferenceSupport::QueryInterface(REFIID riid, void** ppv) {
+ RefPtr<IUnknown> punk;
+ if (!ppv) {
+ return E_INVALIDARG;
+ }
+ *ppv = nullptr;
+
+ // Raise the refcount for stabilization purposes during aggregation
+ StabilizeRefCount stabilize(*this);
+
+ if (riid == IID_IUnknown || riid == IID_IWeakReferenceSource) {
+ punk = static_cast<IUnknown*>(this);
+ } else {
+ HRESULT hr = WeakRefQueryInterface(riid, getter_AddRefs(punk));
+ if (FAILED(hr)) {
+ return hr;
+ }
+ }
+
+ if (!punk) {
+ return E_NOINTERFACE;
+ }
+
+ punk.forget(ppv);
+ return S_OK;
+}
+
+WeakReferenceSupport::StabilizeRefCount::StabilizeRefCount(
+ WeakReferenceSupport& aObject)
+ : mObject(aObject) {
+ SharedRefAutoLock lock(*mObject.mSharedRef);
+ ++mObject.mRefCnt;
+}
+
+WeakReferenceSupport::StabilizeRefCount::~StabilizeRefCount() {
+ // We directly access these fields instead of calling Release() because we
+ // want to adjust the ref count without the other side effects (such as
+ // deleting this if the count drops back to zero, which may happen during
+ // an initial QI during object creation).
+ SharedRefAutoLock lock(*mObject.mSharedRef);
+ --mObject.mRefCnt;
+}
+
+ULONG
+WeakReferenceSupport::AddRef() {
+ SharedRefAutoLock lock(*mSharedRef);
+ ULONG result = ++mRefCnt;
+ NS_LOG_ADDREF(this, result, "mscom::WeakReferenceSupport", sizeof(*this));
+ return result;
+}
+
+ULONG
+WeakReferenceSupport::Release() {
+ ULONG newRefCnt;
+ { // Scope for lock
+ SharedRefAutoLock lock(*mSharedRef);
+ newRefCnt = --mRefCnt;
+ if (newRefCnt == 0) {
+ mSharedRef->Clear();
+ }
+ }
+ NS_LOG_RELEASE(this, newRefCnt, "mscom::WeakReferenceSupport");
+ if (newRefCnt == 0) {
+ if (mFlags != Flags::eDestroyOnMainThread || NS_IsMainThread()) {
+ delete this;
+ } else {
+ // We need to delete this object on the main thread, but we aren't on the
+ // main thread right now, so we send a reference to ourselves to the main
+ // thread to be re-released there.
+ RefPtr<WeakReferenceSupport> self = this;
+ NS_ReleaseOnMainThread("WeakReferenceSupport", self.forget());
+ }
+ }
+ return newRefCnt;
+}
+
+HRESULT
+WeakReferenceSupport::GetWeakReference(IWeakReference** aOutWeakRef) {
+ if (!aOutWeakRef) {
+ return E_INVALIDARG;
+ }
+
+ RefPtr<WeakRef> weakRef = MakeAndAddRef<WeakRef>(mSharedRef);
+ return weakRef->QueryInterface(IID_IWeakReference, (void**)aOutWeakRef);
+}
+
+WeakRef::WeakRef(RefPtr<detail::SharedRef>& aSharedRef)
+ : mRefCnt(0), mSharedRef(aSharedRef) {
+ MOZ_ASSERT(aSharedRef);
+}
+
+HRESULT
+WeakRef::QueryInterface(REFIID riid, void** ppv) {
+ IUnknown* punk = nullptr;
+ if (!ppv) {
+ return E_INVALIDARG;
+ }
+
+ if (riid == IID_IUnknown || riid == IID_IWeakReference) {
+ punk = static_cast<IUnknown*>(this);
+ }
+
+ *ppv = punk;
+ if (!punk) {
+ return E_NOINTERFACE;
+ }
+
+ punk->AddRef();
+ return S_OK;
+}
+
+ULONG
+WeakRef::AddRef() {
+ ULONG result = ++mRefCnt;
+ NS_LOG_ADDREF(this, result, "mscom::WeakRef", sizeof(*this));
+ return result;
+}
+
+ULONG
+WeakRef::Release() {
+ ULONG newRefCnt = --mRefCnt;
+ NS_LOG_RELEASE(this, newRefCnt, "mscom::WeakRef");
+ if (newRefCnt == 0) {
+ delete this;
+ }
+ return newRefCnt;
+}
+
+HRESULT
+WeakRef::ToStrongRef(IWeakReferenceSource** aOutStrongReference) {
+ return mSharedRef->ToStrongRef(aOutStrongReference);
+}
+
+HRESULT
+WeakRef::Resolve(REFIID aIid, void** aOutStrongReference) {
+ return mSharedRef->Resolve(aIid, aOutStrongReference);
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/WeakRef.h b/ipc/mscom/WeakRef.h
new file mode 100644
index 0000000000..ba87eecb70
--- /dev/null
+++ b/ipc/mscom/WeakRef.h
@@ -0,0 +1,142 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_WeakRef_h
+#define mozilla_mscom_WeakRef_h
+
+#include <guiddef.h>
+#include <unknwn.h>
+
+#include "mozilla/Assertions.h"
+#include "mozilla/Atomics.h"
+#include "mozilla/RefPtr.h"
+#include "nsISupportsImpl.h"
+
+/**
+ * Thread-safe weak references for COM that works pre-Windows 8 and do not
+ * require WinRT.
+ */
+
+namespace mozilla {
+namespace mscom {
+
+struct IWeakReferenceSource;
+class WeakReferenceSupport;
+
+namespace detail {
+
+class SharedRef final {
+ public:
+ explicit SharedRef(WeakReferenceSupport* aSupport);
+ void Lock();
+ void Unlock();
+
+ HRESULT ToStrongRef(IWeakReferenceSource** aOutStringReference);
+ HRESULT Resolve(REFIID aIid, void** aOutStrongReference);
+ void Clear();
+
+ NS_INLINE_DECL_THREADSAFE_REFCOUNTING(SharedRef)
+
+ SharedRef(const SharedRef&) = delete;
+ SharedRef(SharedRef&&) = delete;
+ SharedRef& operator=(const SharedRef&) = delete;
+ SharedRef& operator=(SharedRef&&) = delete;
+
+ private:
+ ~SharedRef();
+
+ private:
+ CRITICAL_SECTION mCS;
+ WeakReferenceSupport* mSupport;
+};
+
+} // namespace detail
+
+// {F841AEFA-064C-49A4-B73D-EBD14A90F012}
+DEFINE_GUID(IID_IWeakReference, 0xf841aefa, 0x64c, 0x49a4, 0xb7, 0x3d, 0xeb,
+ 0xd1, 0x4a, 0x90, 0xf0, 0x12);
+
+struct IWeakReference : public IUnknown {
+ virtual STDMETHODIMP ToStrongRef(
+ IWeakReferenceSource** aOutStrongReference) = 0;
+ virtual STDMETHODIMP Resolve(REFIID aIid, void** aOutStrongReference) = 0;
+};
+
+// {87611F0C-9BBB-4F78-9D43-CAC5AD432CA1}
+DEFINE_GUID(IID_IWeakReferenceSource, 0x87611f0c, 0x9bbb, 0x4f78, 0x9d, 0x43,
+ 0xca, 0xc5, 0xad, 0x43, 0x2c, 0xa1);
+
+struct IWeakReferenceSource : public IUnknown {
+ virtual STDMETHODIMP GetWeakReference(IWeakReference** aOutWeakRef) = 0;
+};
+
+class WeakRef;
+
+class WeakReferenceSupport : public IWeakReferenceSource {
+ public:
+ enum class Flags { eNone = 0, eDestroyOnMainThread = 1 };
+
+ // IUnknown
+ STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override;
+ STDMETHODIMP_(ULONG) AddRef() override;
+ STDMETHODIMP_(ULONG) Release() override;
+
+ // IWeakReferenceSource
+ STDMETHODIMP GetWeakReference(IWeakReference** aOutWeakRef) override;
+
+ protected:
+ explicit WeakReferenceSupport(Flags aFlags);
+ virtual ~WeakReferenceSupport() = default;
+
+ virtual HRESULT WeakRefQueryInterface(REFIID aIid,
+ IUnknown** aOutInterface) = 0;
+
+ class MOZ_RAII StabilizeRefCount final {
+ public:
+ explicit StabilizeRefCount(WeakReferenceSupport& aObject);
+ ~StabilizeRefCount();
+
+ StabilizeRefCount(const StabilizeRefCount&) = delete;
+ StabilizeRefCount(StabilizeRefCount&&) = delete;
+ StabilizeRefCount& operator=(const StabilizeRefCount&) = delete;
+ StabilizeRefCount& operator=(StabilizeRefCount&&) = delete;
+
+ private:
+ WeakReferenceSupport& mObject;
+ };
+
+ friend class StabilizeRefCount;
+
+ private:
+ RefPtr<detail::SharedRef> mSharedRef;
+ ULONG mRefCnt;
+ Flags mFlags;
+};
+
+class WeakRef final : public IWeakReference {
+ public:
+ // IUnknown
+ STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override;
+ STDMETHODIMP_(ULONG) AddRef() override;
+ STDMETHODIMP_(ULONG) Release() override;
+
+ // IWeakReference
+ STDMETHODIMP ToStrongRef(IWeakReferenceSource** aOutStrongReference) override;
+ STDMETHODIMP Resolve(REFIID aIid, void** aOutStrongReference) override;
+
+ explicit WeakRef(RefPtr<detail::SharedRef>& aSharedRef);
+
+ private:
+ ~WeakRef() = default;
+
+ Atomic<ULONG> mRefCnt;
+ RefPtr<detail::SharedRef> mSharedRef;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_WeakRef_h
diff --git a/ipc/mscom/moz.build b/ipc/mscom/moz.build
new file mode 100644
index 0000000000..d615bd82d3
--- /dev/null
+++ b/ipc/mscom/moz.build
@@ -0,0 +1,97 @@
+# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*-
+# vim: set filetype=python:
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+EXPORTS.mozilla.mscom += [
+ "Aggregation.h",
+ "AgileReference.h",
+ "ApartmentRegion.h",
+ "AsyncInvoker.h",
+ "COMPtrHolder.h",
+ "COMWrappers.h",
+ "EnsureMTA.h",
+ "Objref.h",
+ "PassthruProxy.h",
+ "ProcessRuntime.h",
+ "ProfilerMarkers.h",
+ "ProxyStream.h",
+ "Ptr.h",
+ "Utils.h",
+]
+
+DIRS += [
+ "mozglue",
+]
+
+SOURCES += [
+ "VTableBuilder.c",
+]
+
+UNIFIED_SOURCES += [
+ "AgileReference.cpp",
+ "COMWrappers.cpp",
+ "EnsureMTA.cpp",
+ "Objref.cpp",
+ "PassthruProxy.cpp",
+ "ProcessRuntime.cpp",
+ "ProfilerMarkers.cpp",
+ "ProxyStream.cpp",
+ "RegistrationAnnotator.cpp",
+ "Utils.cpp",
+]
+
+if CONFIG["ACCESSIBILITY"]:
+ DIRS += [
+ "oop",
+ ]
+
+ EXPORTS.mozilla.mscom += [
+ "ActivationContext.h",
+ "DispatchForwarder.h",
+ "FastMarshaler.h",
+ "IHandlerProvider.h",
+ "Interceptor.h",
+ "InterceptorLog.h",
+ "MainThreadHandoff.h",
+ "MainThreadInvoker.h",
+ "Registration.h",
+ "SpinEvent.h",
+ "StructStream.h",
+ "WeakRef.h",
+ ]
+
+ SOURCES += [
+ "Interceptor.cpp",
+ "MainThreadHandoff.cpp",
+ "Registration.cpp",
+ "SpinEvent.cpp",
+ "WeakRef.cpp",
+ ]
+
+ UNIFIED_SOURCES += [
+ "ActivationContext.cpp",
+ "DispatchForwarder.cpp",
+ "FastMarshaler.cpp",
+ "InterceptorLog.cpp",
+ "MainThreadInvoker.cpp",
+ "StructStream.cpp",
+ ]
+
+LOCAL_INCLUDES += [
+ "/xpcom/base",
+ "/xpcom/build",
+]
+
+DEFINES["MOZ_MSCOM_REMARSHAL_NO_HANDLER"] = True
+
+include("/ipc/chromium/chromium-config.mozbuild")
+
+FINAL_LIBRARY = "xul"
+
+with Files("**"):
+ BUG_COMPONENT = ("Core", "IPC: MSCOM")
+ SCHEDULES.exclusive = ["windows"]
+
+REQUIRES_UNIFIED_BUILD = True
diff --git a/ipc/mscom/mozglue/ActCtxResource.cpp b/ipc/mscom/mozglue/ActCtxResource.cpp
new file mode 100644
index 0000000000..778e36c4bc
--- /dev/null
+++ b/ipc/mscom/mozglue/ActCtxResource.cpp
@@ -0,0 +1,239 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ActCtxResource.h"
+
+#include <string>
+
+#include "mozilla/GetKnownFolderPath.h"
+#include "mozilla/WindowsVersion.h"
+#include "nsWindowsHelpers.h"
+
+namespace mozilla {
+namespace mscom {
+
+#if !defined(HAVE_64BIT_BUILD)
+
+static bool ReadCOMRegDefaultString(const std::wstring& aRegPath,
+ std::wstring& aOutBuf) {
+ aOutBuf.clear();
+
+ std::wstring fullyQualifiedRegPath;
+ fullyQualifiedRegPath.append(L"SOFTWARE\\Classes\\");
+ fullyQualifiedRegPath.append(aRegPath);
+
+ // Get the required size and type of the registry value.
+ // We expect either REG_SZ or REG_EXPAND_SZ.
+ DWORD type;
+ DWORD bufLen = 0;
+ LONG result =
+ ::RegGetValueW(HKEY_LOCAL_MACHINE, fullyQualifiedRegPath.c_str(), nullptr,
+ RRF_RT_ANY, &type, nullptr, &bufLen);
+ if (result != ERROR_SUCCESS || (type != REG_SZ && type != REG_EXPAND_SZ)) {
+ return false;
+ }
+
+ // Now obtain the value
+ DWORD flags = type == REG_SZ ? RRF_RT_REG_SZ : RRF_RT_REG_EXPAND_SZ;
+
+ aOutBuf.resize((bufLen + 1) / sizeof(char16_t));
+
+ result = ::RegGetValueW(HKEY_LOCAL_MACHINE, fullyQualifiedRegPath.c_str(),
+ nullptr, flags, nullptr, &aOutBuf[0], &bufLen);
+ if (result != ERROR_SUCCESS) {
+ aOutBuf.clear();
+ return false;
+ }
+
+ // Truncate terminator
+ aOutBuf.resize((bufLen + 1) / sizeof(char16_t) - 1);
+ return true;
+}
+
+static bool IsSystemOleAcc(HANDLE aFile) {
+ if (aFile == INVALID_HANDLE_VALUE) {
+ return false;
+ }
+
+ BY_HANDLE_FILE_INFORMATION info = {};
+ if (!::GetFileInformationByHandle(aFile, &info)) {
+ return false;
+ }
+
+ // Use FOLDERID_SystemX86 so that Windows doesn't give us a redirected
+ // system32 if we're a 32-bit process running on a 64-bit OS. This is
+ // necessary because the values that we are reading from the registry
+ // are not redirected; they reference SysWOW64 directly.
+ auto systemPath = GetKnownFolderPath(FOLDERID_SystemX86);
+ if (!systemPath) {
+ return false;
+ }
+
+ std::wstring oleAccPath(systemPath.get());
+ oleAccPath.append(L"\\oleacc.dll");
+
+ nsAutoHandle oleAcc(
+ ::CreateFileW(oleAccPath.c_str(), GENERIC_READ,
+ FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE,
+ nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr));
+
+ if (oleAcc.get() == INVALID_HANDLE_VALUE) {
+ return false;
+ }
+
+ BY_HANDLE_FILE_INFORMATION oleAccInfo = {};
+ if (!::GetFileInformationByHandle(oleAcc, &oleAccInfo)) {
+ return false;
+ }
+
+ return info.dwVolumeSerialNumber == oleAccInfo.dwVolumeSerialNumber &&
+ info.nFileIndexLow == oleAccInfo.nFileIndexLow &&
+ info.nFileIndexHigh == oleAccInfo.nFileIndexHigh;
+}
+
+static bool IsTypelibPreferred() {
+ // If IAccessible's Proxy/Stub CLSID is kUniversalMarshalerClsid, then any
+ // external a11y clients are expecting to use a typelib.
+ const wchar_t kUniversalMarshalerClsid[] =
+ L"{00020424-0000-0000-C000-000000000046}";
+
+ const wchar_t kIAccessiblePSClsidPath[] =
+ L"Interface\\{618736E0-3C3D-11CF-810C-00AA00389B71}"
+ L"\\ProxyStubClsid32";
+
+ std::wstring psClsid;
+ if (!ReadCOMRegDefaultString(kIAccessiblePSClsidPath, psClsid)) {
+ return false;
+ }
+
+ if (psClsid.size() !=
+ sizeof(kUniversalMarshalerClsid) / sizeof(kUniversalMarshalerClsid)[0] -
+ 1) {
+ return false;
+ }
+
+ int index = 0;
+ while (kUniversalMarshalerClsid[index]) {
+ if (toupper(psClsid[index]) != kUniversalMarshalerClsid[index]) {
+ return false;
+ }
+ index++;
+ }
+ return true;
+}
+
+static bool IsIAccessibleTypelibRegistered() {
+ // The system default IAccessible typelib is always registered with version
+ // 1.1, under the neutral locale (LCID 0).
+ const wchar_t kIAccessibleTypelibRegPath[] =
+ L"TypeLib\\{1EA4DBF0-3C3B-11CF-810C-00AA00389B71}\\1.1\\0\\win32";
+
+ std::wstring typelibPath;
+ if (!ReadCOMRegDefaultString(kIAccessibleTypelibRegPath, typelibPath)) {
+ return false;
+ }
+
+ nsAutoHandle libTestFile(
+ ::CreateFileW(typelibPath.c_str(), GENERIC_READ,
+ FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE,
+ nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr));
+
+ return IsSystemOleAcc(libTestFile);
+}
+
+static bool IsIAccessiblePSRegistered() {
+ const wchar_t kIAccessiblePSRegPath[] =
+ L"CLSID\\{03022430-ABC4-11D0-BDE2-00AA001A1953}\\InProcServer32";
+
+ std::wstring proxyStubPath;
+ if (!ReadCOMRegDefaultString(kIAccessiblePSRegPath, proxyStubPath)) {
+ return false;
+ }
+
+ nsAutoHandle libTestFile(
+ ::CreateFileW(proxyStubPath.c_str(), GENERIC_READ,
+ FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE,
+ nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr));
+
+ return IsSystemOleAcc(libTestFile);
+}
+
+static bool UseIAccessibleProxyStub() {
+ // If a typelib is preferred then external clients are expecting to use
+ // typelib marshaling, so we should use that whenever available.
+ if (IsTypelibPreferred() && IsIAccessibleTypelibRegistered()) {
+ return false;
+ }
+
+ // Otherwise we try the proxy/stub
+ if (IsIAccessiblePSRegistered()) {
+ return true;
+ }
+
+ return false;
+}
+
+#endif // !defined(HAVE_64BIT_BUILD)
+
+#if defined(_MSC_VER)
+extern "C" IMAGE_DOS_HEADER __ImageBase;
+#endif
+
+static HMODULE GetContainingModuleHandle() {
+ HMODULE thisModule = nullptr;
+#if defined(_MSC_VER)
+ thisModule = reinterpret_cast<HMODULE>(&__ImageBase);
+#else
+ if (!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
+ GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
+ reinterpret_cast<LPCWSTR>(&GetContainingModuleHandle),
+ &thisModule)) {
+ return 0;
+ }
+#endif
+ return thisModule;
+}
+
+static uint16_t sActCtxResourceId = 0;
+
+/* static */
+void ActCtxResource::SetAccessibilityResourceId(uint16_t aResourceId) {
+ sActCtxResourceId = aResourceId;
+}
+
+/* static */
+uint16_t ActCtxResource::GetAccessibilityResourceId() {
+ return sActCtxResourceId;
+}
+
+static void EnsureAccessibilityResourceId() {
+ if (!sActCtxResourceId) {
+#if defined(HAVE_64BIT_BUILD)
+ // The manifest for 64-bit Windows is embedded with resource ID 64.
+ sActCtxResourceId = 64;
+#else
+ // The manifest for 32-bit Windows is embedded with resource ID 32.
+ // Beginning with Windows 10 Creators Update, 32-bit builds always use the
+ // 64-bit manifest. Older builds of Windows may or may not require the
+ // 64-bit manifest: UseIAccessibleProxyStub() determines the course of
+ // action.
+ if (mozilla::IsWin10CreatorsUpdateOrLater() || UseIAccessibleProxyStub()) {
+ sActCtxResourceId = 64;
+ } else {
+ sActCtxResourceId = 32;
+ }
+#endif // defined(HAVE_64BIT_BUILD)
+ }
+}
+
+ActCtxResource ActCtxResource::GetAccessibilityResource() {
+ ActCtxResource result = {};
+ result.mModule = GetContainingModuleHandle();
+ EnsureAccessibilityResourceId();
+ result.mId = GetAccessibilityResourceId();
+ return result;
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/mozglue/ActCtxResource.h b/ipc/mscom/mozglue/ActCtxResource.h
new file mode 100644
index 0000000000..4e9c16c8e0
--- /dev/null
+++ b/ipc/mscom/mozglue/ActCtxResource.h
@@ -0,0 +1,40 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef ACT_CTX_RESOURCE_H
+#define ACT_CTX_RESOURCE_H
+
+#include <stdint.h>
+#include <windows.h>
+#include "mozilla/Types.h"
+
+namespace mozilla {
+namespace mscom {
+
+struct ActCtxResource {
+ uint16_t mId;
+ HMODULE mModule;
+
+ /**
+ * Set the resource ID used by GetAccessibilityResource. This is so that
+ * sandboxed child processes can use a value passed down from the parent.
+ */
+ static MFBT_API void SetAccessibilityResourceId(uint16_t aResourceId);
+
+ /**
+ * Get the resource ID used by GetAccessibilityResource.
+ */
+ static MFBT_API uint16_t GetAccessibilityResourceId();
+
+ /**
+ * @return ActCtxResource of a11y manifest resource to be passed to
+ * mscom::ActivationContext
+ */
+ static MFBT_API ActCtxResource GetAccessibilityResource();
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif
diff --git a/ipc/mscom/mozglue/ProcessRuntimeShared.cpp b/ipc/mscom/mozglue/ProcessRuntimeShared.cpp
new file mode 100644
index 0000000000..ee4f6b221f
--- /dev/null
+++ b/ipc/mscom/mozglue/ProcessRuntimeShared.cpp
@@ -0,0 +1,31 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/mscom/ProcessRuntimeShared.h"
+
+#include "mozilla/glue/WinUtils.h"
+
+// We allow multiple ProcessRuntime instances to exist simultaneously (even
+// on separate threads), but only one should be doing the process-wide
+// initialization. These variables provide that mutual exclusion.
+static mozilla::glue::Win32SRWLock gLock;
+static mozilla::mscom::detail::ProcessInitState gProcessInitState =
+ mozilla::mscom::detail::ProcessInitState::Uninitialized;
+
+namespace mozilla {
+namespace mscom {
+namespace detail {
+
+MFBT_API ProcessInitState& BeginProcessRuntimeInit() {
+ gLock.LockExclusive();
+ return gProcessInitState;
+}
+
+MFBT_API void EndProcessRuntimeInit() { gLock.UnlockExclusive(); }
+
+} // namespace detail
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/mozglue/ProcessRuntimeShared.h b/ipc/mscom/mozglue/ProcessRuntimeShared.h
new file mode 100644
index 0000000000..d5a58a7b92
--- /dev/null
+++ b/ipc/mscom/mozglue/ProcessRuntimeShared.h
@@ -0,0 +1,55 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_ProcessRuntimeShared_h
+#define mozilla_mscom_ProcessRuntimeShared_h
+
+#include "mozilla/Assertions.h"
+#include "mozilla/Attributes.h"
+#include "mozilla/Types.h"
+
+namespace mozilla {
+namespace mscom {
+namespace detail {
+
+enum class ProcessInitState : uint32_t {
+ Uninitialized = 0,
+ PartialSecurityInitialized,
+ PartialGlobalOptions,
+ FullyInitialized,
+};
+
+MFBT_API ProcessInitState& BeginProcessRuntimeInit();
+MFBT_API void EndProcessRuntimeInit();
+
+} // namespace detail
+
+class MOZ_RAII ProcessInitLock final {
+ public:
+ ProcessInitLock() : mInitState(detail::BeginProcessRuntimeInit()) {}
+
+ ~ProcessInitLock() { detail::EndProcessRuntimeInit(); }
+
+ detail::ProcessInitState GetInitState() const { return mInitState; }
+
+ void SetInitState(const detail::ProcessInitState aNewState) {
+ MOZ_DIAGNOSTIC_ASSERT(aNewState > mInitState);
+ mInitState = aNewState;
+ }
+
+ ProcessInitLock(const ProcessInitLock&) = delete;
+ ProcessInitLock(ProcessInitLock&&) = delete;
+ ProcessInitLock operator=(const ProcessInitLock&) = delete;
+ ProcessInitLock operator=(ProcessInitLock&&) = delete;
+
+ private:
+ detail::ProcessInitState& mInitState;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_ProcessRuntimeShared_h
diff --git a/ipc/mscom/mozglue/moz.build b/ipc/mscom/mozglue/moz.build
new file mode 100644
index 0000000000..635b5264cf
--- /dev/null
+++ b/ipc/mscom/mozglue/moz.build
@@ -0,0 +1,19 @@
+# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*-
+# vim: set filetype=python:
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+FINAL_LIBRARY = "mozglue"
+
+EXPORTS.mozilla.mscom += [
+ "ActCtxResource.h",
+ "ProcessRuntimeShared.h",
+]
+
+UNIFIED_SOURCES += [
+ "ActCtxResource.cpp",
+ "ProcessRuntimeShared.cpp",
+]
+
+REQUIRES_UNIFIED_BUILD = True
diff --git a/ipc/mscom/oop/Factory.h b/ipc/mscom/oop/Factory.h
new file mode 100644
index 0000000000..e95f1d2499
--- /dev/null
+++ b/ipc/mscom/oop/Factory.h
@@ -0,0 +1,142 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_Factory_h
+#define mozilla_mscom_Factory_h
+
+#if defined(MOZILLA_INTERNAL_API)
+# error This code is NOT for internal Gecko use!
+#endif // defined(MOZILLA_INTERNAL_API)
+
+#include <objbase.h>
+#include <unknwn.h>
+
+#include <utility>
+
+#include "Module.h"
+#include "mozilla/Attributes.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/RefPtr.h"
+#include "mozilla/StaticPtr.h"
+
+/* WARNING! The code in this file may be loaded into the address spaces of other
+ processes! It MUST NOT link against xul.dll or other Gecko binaries! Only
+ inline code may be included! */
+
+namespace mozilla {
+namespace mscom {
+
+template <typename T>
+class MOZ_NONHEAP_CLASS Factory : public IClassFactory {
+ template <typename... Args>
+ HRESULT DoCreate(Args&&... args) {
+ MOZ_DIAGNOSTIC_ASSERT(false, "This should not be executed");
+ return E_NOTIMPL;
+ }
+
+ template <typename... Args>
+ HRESULT DoCreate(HRESULT (*aFnPtr)(IUnknown*, REFIID, void**),
+ Args&&... args) {
+ return aFnPtr(std::forward<Args>(args)...);
+ }
+
+ public:
+ // IUnknown
+ STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) override {
+ if (!aOutInterface) {
+ return E_INVALIDARG;
+ }
+
+ if (aIid == IID_IUnknown || aIid == IID_IClassFactory) {
+ RefPtr<IClassFactory> punk(this);
+ punk.forget(aOutInterface);
+ return S_OK;
+ }
+
+ *aOutInterface = nullptr;
+
+ return E_NOINTERFACE;
+ }
+
+ STDMETHODIMP_(ULONG) AddRef() override {
+ Module::Lock();
+ return 2;
+ }
+
+ STDMETHODIMP_(ULONG) Release() override {
+ Module::Unlock();
+ return 1;
+ }
+
+ // IClassFactory
+ STDMETHODIMP CreateInstance(IUnknown* aOuter, REFIID aIid,
+ void** aOutInterface) override {
+ return DoCreate(&T::Create, aOuter, aIid, aOutInterface);
+ }
+
+ STDMETHODIMP LockServer(BOOL aLock) override {
+ if (aLock) {
+ Module::Lock();
+ } else {
+ Module::Unlock();
+ }
+ return S_OK;
+ }
+};
+
+template <typename T>
+class MOZ_NONHEAP_CLASS SingletonFactory : public Factory<T> {
+ public:
+ STDMETHODIMP CreateInstance(IUnknown* aOuter, REFIID aIid,
+ void** aOutInterface) override {
+ if (aOuter || !aOutInterface) {
+ return E_INVALIDARG;
+ }
+
+ RefPtr<T> obj(sInstance);
+ if (!obj) {
+ obj = GetOrCreateSingleton();
+ }
+
+ return obj->QueryInterface(aIid, aOutInterface);
+ }
+
+ RefPtr<T> GetOrCreateSingleton() {
+ if (!sInstance) {
+ RefPtr<T> object;
+ if (FAILED(T::Create(getter_AddRefs(object)))) {
+ return nullptr;
+ }
+
+ sInstance = object.forget();
+ }
+
+ return sInstance;
+ }
+
+ RefPtr<T> GetSingleton() { return sInstance; }
+
+ void ClearSingleton() {
+ if (!sInstance) {
+ return;
+ }
+
+ DebugOnly<HRESULT> hr = ::CoDisconnectObject(sInstance.get(), 0);
+ MOZ_ASSERT(SUCCEEDED(hr));
+ sInstance = nullptr;
+ }
+
+ private:
+ static StaticRefPtr<T> sInstance;
+};
+
+template <typename T>
+StaticRefPtr<T> SingletonFactory<T>::sInstance;
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_Factory_h
diff --git a/ipc/mscom/oop/Handler.cpp b/ipc/mscom/oop/Handler.cpp
new file mode 100644
index 0000000000..1c40e5300f
--- /dev/null
+++ b/ipc/mscom/oop/Handler.cpp
@@ -0,0 +1,281 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "Handler.h"
+#include "Module.h"
+
+#include "mozilla/ArrayUtils.h"
+#include "mozilla/Assertions.h"
+#include "mozilla/mscom/Objref.h"
+#include "nsWindowsHelpers.h"
+
+#include <objbase.h>
+#include <shlwapi.h>
+#include <string.h>
+
+/* WARNING! The code in this file may be loaded into the address spaces of other
+ processes! It MUST NOT link against xul.dll or other Gecko binaries! Only
+ inline code may be included! */
+
+namespace mozilla {
+namespace mscom {
+
+Handler::Handler(IUnknown* aOuter, HRESULT* aResult)
+ : mRefCnt(0), mOuter(aOuter), mUnmarshal(nullptr), mHasPayload(false) {
+ MOZ_ASSERT(aResult);
+
+ if (!aOuter) {
+ *aResult = E_INVALIDARG;
+ return;
+ }
+
+ StabilizedRefCount<ULONG> stabilizer(mRefCnt);
+
+ *aResult =
+ ::CoGetStdMarshalEx(aOuter, SMEXF_HANDLER, getter_AddRefs(mInnerUnk));
+ if (FAILED(*aResult)) {
+ return;
+ }
+
+ *aResult = mInnerUnk->QueryInterface(IID_IMarshal, (void**)&mUnmarshal);
+ if (FAILED(*aResult)) {
+ return;
+ }
+
+ // mUnmarshal is a weak ref
+ mUnmarshal->Release();
+}
+
+HRESULT
+Handler::InternalQueryInterface(REFIID riid, void** ppv) {
+ if (!ppv) {
+ return E_INVALIDARG;
+ }
+
+ if (riid == IID_IUnknown) {
+ RefPtr<IUnknown> punk(static_cast<IUnknown*>(&mInternalUnknown));
+ punk.forget(ppv);
+ return S_OK;
+ }
+
+ if (riid == IID_IMarshal) {
+ RefPtr<IMarshal> ptr(this);
+ ptr.forget(ppv);
+ return S_OK;
+ }
+
+ // Try the handler implementation
+ HRESULT hr = QueryHandlerInterface(mInnerUnk, riid, ppv);
+ if (hr == S_FALSE) {
+ // The handler knows this interface is not available, so don't bother
+ // asking the proxy.
+ return E_NOINTERFACE;
+ }
+ if (hr != E_NOINTERFACE) {
+ return hr;
+ }
+
+ // Now forward to the marshaler's inner
+ return mInnerUnk->QueryInterface(riid, ppv);
+}
+
+ULONG
+Handler::InternalAddRef() {
+ if (!mRefCnt) {
+ Module::Lock();
+ }
+ return ++mRefCnt;
+}
+
+ULONG
+Handler::InternalRelease() {
+ ULONG newRefCnt = --mRefCnt;
+ if (newRefCnt == 0) {
+ delete this;
+ Module::Unlock();
+ }
+ return newRefCnt;
+}
+
+HRESULT
+Handler::GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags, CLSID* pCid) {
+ return mUnmarshal->GetUnmarshalClass(MarshalAs(riid), pv, dwDestContext,
+ pvDestContext, mshlflags, pCid);
+}
+
+HRESULT
+Handler::GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags, DWORD* pSize) {
+ if (!pSize) {
+ return E_INVALIDARG;
+ }
+
+ *pSize = 0;
+
+ RefPtr<IUnknown> unkToMarshal;
+ HRESULT hr;
+
+ REFIID marshalAs = MarshalAs(riid);
+ if (marshalAs == riid) {
+ unkToMarshal = static_cast<IUnknown*>(pv);
+ } else {
+ hr = mInnerUnk->QueryInterface(marshalAs, getter_AddRefs(unkToMarshal));
+ if (FAILED(hr)) {
+ return hr;
+ }
+ }
+
+ // We do not necessarily want to use the pv that COM is giving us; we may want
+ // to marshal a different proxy that is more appropriate to what we're
+ // wrapping...
+ hr = mUnmarshal->GetMarshalSizeMax(marshalAs, unkToMarshal.get(),
+ dwDestContext, pvDestContext, mshlflags,
+ pSize);
+
+#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+ return hr;
+#else
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ if (!HasPayload()) {
+ return S_OK;
+ }
+
+ DWORD payloadSize = 0;
+ hr = GetHandlerPayloadSize(marshalAs, &payloadSize);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ *pSize += payloadSize;
+ return S_OK;
+#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+}
+
+HRESULT
+Handler::GetMarshalInterface(REFIID aMarshalAsIid, NotNull<IUnknown*> aProxy,
+ NotNull<IID*> aOutIid,
+ NotNull<IUnknown**> aOutUnk) {
+ *aOutIid = aMarshalAsIid;
+ return aProxy->QueryInterface(
+ aMarshalAsIid,
+ reinterpret_cast<void**>(static_cast<IUnknown**>(aOutUnk)));
+}
+
+HRESULT
+Handler::MarshalInterface(IStream* pStm, REFIID riid, void* pv,
+ DWORD dwDestContext, void* pvDestContext,
+ DWORD mshlflags) {
+ // We do not necessarily want to use the pv that COM is giving us; we may want
+ // to marshal a different proxy that is more appropriate to what we're
+ // wrapping...
+ RefPtr<IUnknown> unkToMarshal;
+ HRESULT hr;
+
+#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+ LARGE_INTEGER seekTo;
+ seekTo.QuadPart = 0;
+
+ ULARGE_INTEGER objrefPos;
+
+ // Save the current position as it points to the location where the OBJREF
+ // will be written.
+ hr = pStm->Seek(seekTo, STREAM_SEEK_CUR, &objrefPos);
+ if (FAILED(hr)) {
+ return hr;
+ }
+#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+
+ REFIID marshalAs = MarshalAs(riid);
+ IID marshalOutAs;
+
+ hr = GetMarshalInterface(
+ marshalAs, WrapNotNull<IUnknown*>(mInnerUnk), WrapNotNull(&marshalOutAs),
+ WrapNotNull<IUnknown**>(getter_AddRefs(unkToMarshal)));
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ hr = mUnmarshal->MarshalInterface(pStm, marshalAs, unkToMarshal.get(),
+ dwDestContext, pvDestContext, mshlflags);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+ // Obtain the current stream position which is the end of the OBJREF
+ ULARGE_INTEGER endPos;
+ hr = pStm->Seek(seekTo, STREAM_SEEK_CUR, &endPos);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ // Now strip out the handler.
+ if (!StripHandlerFromOBJREF(WrapNotNull(pStm), objrefPos.QuadPart,
+ endPos.QuadPart)) {
+ return E_FAIL;
+ }
+
+ // Fix the IID
+ if (!SetIID(WrapNotNull(pStm), objrefPos.QuadPart, marshalOutAs)) {
+ return E_FAIL;
+ }
+
+ return S_OK;
+#else
+ if (!HasPayload()) {
+ return S_OK;
+ }
+
+ // Unfortunately when COM re-marshals a proxy that prevouisly had a payload,
+ // we must re-serialize it.
+ return WriteHandlerPayload(pStm, marshalAs);
+#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER)
+}
+
+HRESULT
+Handler::UnmarshalInterface(IStream* pStm, REFIID riid, void** ppv) {
+ REFIID unmarshalAs = MarshalAs(riid);
+ HRESULT hr = mUnmarshal->UnmarshalInterface(pStm, unmarshalAs, ppv);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ // This method may be called on the same object multiple times (as new
+ // interfaces are queried off the proxy). Not all interfaces will necessarily
+ // refresh the payload, so we set mHasPayload using OR to reflect that fact.
+ // (Otherwise mHasPayload could be cleared and the handler would think that
+ // it doesn't have a payload even though it actually does).
+ mHasPayload |= (ReadHandlerPayload(pStm, unmarshalAs) == S_OK);
+
+ return hr;
+}
+
+HRESULT
+Handler::ReleaseMarshalData(IStream* pStm) {
+ return mUnmarshal->ReleaseMarshalData(pStm);
+}
+
+HRESULT
+Handler::DisconnectObject(DWORD dwReserved) {
+ return mUnmarshal->DisconnectObject(dwReserved);
+}
+
+HRESULT
+Handler::Unregister(REFCLSID aClsid) { return Module::Deregister(aClsid); }
+
+HRESULT
+Handler::Register(REFCLSID aClsid, const bool aMsixContainer) {
+ return Module::Register(aClsid, Module::ThreadingModel::DedicatedUiThreadOnly,
+ Module::ClassType::InprocHandler, nullptr,
+ aMsixContainer);
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/oop/Handler.h b/ipc/mscom/oop/Handler.h
new file mode 100644
index 0000000000..b22c925c22
--- /dev/null
+++ b/ipc/mscom/oop/Handler.h
@@ -0,0 +1,134 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_Handler_h
+#define mozilla_mscom_Handler_h
+
+#if defined(MOZILLA_INTERNAL_API)
+# error This code is NOT for internal Gecko use!
+#endif // defined(MOZILLA_INTERNAL_API)
+
+#include <objidl.h>
+
+#include "mozilla/mscom/Aggregation.h"
+#include "mozilla/NotNull.h"
+#include "mozilla/RefPtr.h"
+
+/* WARNING! The code in this file may be loaded into the address spaces of other
+ processes! It MUST NOT link against xul.dll or other Gecko binaries! Only
+ inline code may be included! */
+
+namespace mozilla {
+namespace mscom {
+
+class Handler : public IMarshal {
+ public:
+ // IMarshal
+ STDMETHODIMP GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ CLSID* pCid) override;
+ STDMETHODIMP GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext,
+ void* pvDestContext, DWORD mshlflags,
+ DWORD* pSize) override;
+ STDMETHODIMP MarshalInterface(IStream* pStm, REFIID riid, void* pv,
+ DWORD dwDestContext, void* pvDestContext,
+ DWORD mshlflags) override;
+ STDMETHODIMP UnmarshalInterface(IStream* pStm, REFIID riid,
+ void** ppv) override;
+ STDMETHODIMP ReleaseMarshalData(IStream* pStm) override;
+ STDMETHODIMP DisconnectObject(DWORD dwReserved) override;
+
+ /**
+ * This method allows the handler to return its own interfaces that override
+ * those interfaces that are exposed by the underlying COM proxy.
+ * @param aProxyUnknown is the IUnknown of the underlying COM proxy. This is
+ * provided to give the handler implementation an
+ * opportunity to acquire interfaces to the underlying
+ * remote object, if needed.
+ * @param aIid Interface requested, similar to IUnknown::QueryInterface
+ * @param aOutInterface Outparam for the resulting interface to return to the
+ * client.
+ * @return The usual HRESULT codes similarly to IUnknown::QueryInterface.
+ * If E_NOINTERFACE is returned, the proxy will be queried.
+ * If the handler is certain that this interface is not available,
+ * it can return S_FALSE to avoid querying the proxy. This will be
+ * translated to E_NOINTERFACE before it is returned to the client.
+ */
+ virtual HRESULT QueryHandlerInterface(IUnknown* aProxyUnknown, REFIID aIid,
+ void** aOutInterface) = 0;
+ /**
+ * Called when the implementer should deserialize data in aStream.
+ * @return S_OK on success;
+ * S_FALSE if the deserialization was successful but there was no
+ * data; HRESULT error code otherwise.
+ */
+ virtual HRESULT ReadHandlerPayload(IStream* aStream, REFIID aIid) {
+ return S_FALSE;
+ }
+
+ /**
+ * Unfortunately when COM marshals a proxy, it doesn't implicitly marshal
+ * the payload that was originally sent with the proxy. We must implement
+ * that code in the handler in order to make this happen.
+ */
+
+ /**
+ * This function allows the implementer to substitute a different interface
+ * for marshaling than the one that COM is intending to marshal. For example,
+ * the implementer might want to marshal a proxy for an interface that is
+ * derived from the requested interface.
+ *
+ * The default implementation is the identity function.
+ */
+ virtual REFIID MarshalAs(REFIID aRequestedIid) { return aRequestedIid; }
+
+ virtual HRESULT GetMarshalInterface(REFIID aMarshalAsIid,
+ NotNull<IUnknown*> aProxy,
+ NotNull<IID*> aOutIid,
+ NotNull<IUnknown**> aOutUnk);
+
+ /**
+ * Called when the implementer must provide the size of the payload.
+ */
+ virtual HRESULT GetHandlerPayloadSize(REFIID aIid, DWORD* aOutPayloadSize) {
+ if (!aOutPayloadSize) {
+ return E_INVALIDARG;
+ }
+ *aOutPayloadSize = 0;
+ return S_OK;
+ }
+
+ /**
+ * Called when the implementer should serialize the payload data into aStream.
+ */
+ virtual HRESULT WriteHandlerPayload(IStream* aStream, REFIID aIid) {
+ return S_OK;
+ }
+
+ IUnknown* GetProxy() const { return mInnerUnk; }
+
+ static HRESULT Register(REFCLSID aClsid, const bool aMsixContainer = false);
+ static HRESULT Unregister(REFCLSID aClsid);
+
+ protected:
+ Handler(IUnknown* aOuter, HRESULT* aResult);
+ virtual ~Handler() {}
+ bool HasPayload() const { return mHasPayload; }
+ IUnknown* GetOuter() const { return mOuter; }
+
+ private:
+ ULONG mRefCnt;
+ IUnknown* mOuter;
+ RefPtr<IUnknown> mInnerUnk;
+ IMarshal* mUnmarshal; // WEAK
+ bool mHasPayload;
+ DECLARE_AGGREGATABLE(Handler);
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_Handler_h
diff --git a/ipc/mscom/oop/Module.cpp b/ipc/mscom/oop/Module.cpp
new file mode 100644
index 0000000000..88ebf91531
--- /dev/null
+++ b/ipc/mscom/oop/Module.cpp
@@ -0,0 +1,292 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "Module.h"
+
+#include <stdlib.h>
+
+#include <ktmw32.h>
+#include <memory.h>
+#include <rpc.h>
+
+#include "mozilla/ArrayUtils.h"
+#include "mozilla/Assertions.h"
+#include "mozilla/mscom/Utils.h"
+#include "mozilla/Range.h"
+#include "nsWindowsHelpers.h"
+
+template <size_t N>
+static const mozilla::Range<const wchar_t> LiteralToRange(
+ const wchar_t (&aArg)[N]) {
+ return mozilla::Range(aArg, N);
+}
+
+namespace mozilla {
+namespace mscom {
+
+ULONG Module::sRefCount = 0;
+
+static const wchar_t* SubkeyNameFromClassType(
+ const Module::ClassType aClassType) {
+ switch (aClassType) {
+ case Module::ClassType::InprocServer:
+ return L"InprocServer32";
+ case Module::ClassType::InprocHandler:
+ return L"InprocHandler32";
+ default:
+ MOZ_CRASH("Unknown ClassType");
+ return nullptr;
+ }
+}
+
+static const Range<const wchar_t> ThreadingModelAsString(
+ const Module::ThreadingModel aThreadingModel) {
+ switch (aThreadingModel) {
+ case Module::ThreadingModel::DedicatedUiThreadOnly:
+ return LiteralToRange(L"Apartment");
+ case Module::ThreadingModel::MultiThreadedApartmentOnly:
+ return LiteralToRange(L"Free");
+ case Module::ThreadingModel::DedicatedUiThreadXorMultiThreadedApartment:
+ return LiteralToRange(L"Both");
+ case Module::ThreadingModel::AllThreadsAllApartments:
+ return LiteralToRange(L"Neutral");
+ default:
+ MOZ_CRASH("Unknown ThreadingModel");
+ return Range<const wchar_t>();
+ }
+}
+
+/* static */
+HRESULT Module::Register(const CLSID* const* aClsids, const size_t aNumClsids,
+ const ThreadingModel aThreadingModel,
+ const ClassType aClassType, const GUID* const aAppId,
+ const bool aMsixContainer) {
+ MOZ_ASSERT(aClsids && aNumClsids);
+ if (!aClsids || !aNumClsids) {
+ return E_INVALIDARG;
+ }
+ MOZ_ASSERT(!aAppId || !aMsixContainer,
+ "aAppId isn't valid in an MSIX container");
+
+ const wchar_t* inprocName = SubkeyNameFromClassType(aClassType);
+
+ const Range<const wchar_t> threadingModelStr =
+ ThreadingModelAsString(aThreadingModel);
+ const DWORD threadingModelStrLenBytesInclNul =
+ threadingModelStr.length() * sizeof(wchar_t);
+
+ wchar_t strAppId[kGuidRegFormatCharLenInclNul] = {};
+ if (aAppId) {
+ GUIDToString(*aAppId, strAppId);
+ }
+
+ // Obtain the full path to this DLL
+ HMODULE thisModule;
+ if (!::GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
+ GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
+ reinterpret_cast<LPCWSTR>(&Module::CanUnload),
+ &thisModule)) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+
+ wchar_t absThisModulePath[MAX_PATH + 1] = {};
+ DWORD actualPathLenCharsExclNul = ::GetModuleFileNameW(
+ thisModule, absThisModulePath, ArrayLength(absThisModulePath));
+ if (!actualPathLenCharsExclNul ||
+ actualPathLenCharsExclNul == ArrayLength(absThisModulePath)) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+ const DWORD actualPathLenBytesInclNul =
+ (actualPathLenCharsExclNul + 1) * sizeof(wchar_t);
+
+ nsAutoHandle txn;
+ // RegCreateKeyTransacted doesn't work in MSIX containers.
+ if (!aMsixContainer) {
+ // Use the name of this DLL as the name of the transaction
+ wchar_t txnName[_MAX_FNAME] = {};
+ if (_wsplitpath_s(absThisModulePath, nullptr, 0, nullptr, 0, txnName,
+ ArrayLength(txnName), nullptr, 0)) {
+ return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND);
+ }
+
+ // Manipulate the registry using a transaction so that any failures are
+ // rolled back.
+ txn.own(::CreateTransaction(nullptr, nullptr, TRANSACTION_DO_NOT_PROMOTE, 0,
+ 0, 0, txnName));
+ if (txn.get() == INVALID_HANDLE_VALUE) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+ }
+
+ HRESULT hr;
+ LSTATUS status;
+
+ // A single DLL may serve multiple components. For each CLSID, we register
+ // this DLL as its server and, when an AppId is specified, set up a reference
+ // from the CLSID to the specified AppId.
+ for (size_t idx = 0; idx < aNumClsids; ++idx) {
+ if (!aClsids[idx]) {
+ return E_INVALIDARG;
+ }
+
+ wchar_t clsidKeyPath[256];
+ hr = BuildClsidPath(*aClsids[idx], clsidKeyPath);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ // Create the CLSID key
+ HKEY rawClsidKey;
+ // Subtle: If aMsixContainer is true, as well as calling a different
+ // function, we also use HKEY_CURRENT_USER. When false, we use
+ // HKEY_LOCAL_MACHINE.
+ if (aMsixContainer) {
+ status = ::RegCreateKeyExW(HKEY_CURRENT_USER, clsidKeyPath, 0, nullptr,
+ REG_OPTION_NON_VOLATILE, KEY_ALL_ACCESS,
+ nullptr, &rawClsidKey, nullptr);
+ } else {
+ status = ::RegCreateKeyTransactedW(
+ HKEY_LOCAL_MACHINE, clsidKeyPath, 0, nullptr, REG_OPTION_NON_VOLATILE,
+ KEY_ALL_ACCESS, nullptr, &rawClsidKey, nullptr, txn, nullptr);
+ }
+ if (status != ERROR_SUCCESS) {
+ return HRESULT_FROM_WIN32(status);
+ }
+ nsAutoRegKey clsidKey(rawClsidKey);
+
+ if (aAppId) {
+ // This value associates the registered CLSID with the specified AppID
+ status = ::RegSetValueExW(clsidKey, L"AppID", 0, REG_SZ,
+ reinterpret_cast<const BYTE*>(strAppId),
+ ArrayLength(strAppId) * sizeof(wchar_t));
+ if (status != ERROR_SUCCESS) {
+ return HRESULT_FROM_WIN32(status);
+ }
+ }
+
+ HKEY rawInprocKey;
+ if (aMsixContainer) {
+ status = ::RegCreateKeyExW(clsidKey, inprocName, 0, nullptr,
+ REG_OPTION_NON_VOLATILE, KEY_ALL_ACCESS,
+ nullptr, &rawInprocKey, nullptr);
+ } else {
+ status = ::RegCreateKeyTransactedW(
+ clsidKey, inprocName, 0, nullptr, REG_OPTION_NON_VOLATILE,
+ KEY_ALL_ACCESS, nullptr, &rawInprocKey, nullptr, txn, nullptr);
+ }
+ if (status != ERROR_SUCCESS) {
+ return HRESULT_FROM_WIN32(status);
+ }
+ nsAutoRegKey inprocKey(rawInprocKey);
+
+ // Set the component's path to this DLL
+ status = ::RegSetValueExW(inprocKey, nullptr, 0, REG_EXPAND_SZ,
+ reinterpret_cast<const BYTE*>(absThisModulePath),
+ actualPathLenBytesInclNul);
+ if (status != ERROR_SUCCESS) {
+ return HRESULT_FROM_WIN32(status);
+ }
+
+ status = ::RegSetValueExW(
+ inprocKey, L"ThreadingModel", 0, REG_SZ,
+ reinterpret_cast<const BYTE*>(threadingModelStr.begin().get()),
+ threadingModelStrLenBytesInclNul);
+ if (status != ERROR_SUCCESS) {
+ return HRESULT_FROM_WIN32(status);
+ }
+ }
+
+ if (aAppId) {
+ // When specified, we must also create a key for the AppID.
+ wchar_t appidKeyPath[256];
+ hr = BuildAppidPath(*aAppId, appidKeyPath);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ HKEY rawAppidKey;
+ status = ::RegCreateKeyTransactedW(
+ HKEY_LOCAL_MACHINE, appidKeyPath, 0, nullptr, REG_OPTION_NON_VOLATILE,
+ KEY_ALL_ACCESS, nullptr, &rawAppidKey, nullptr, txn, nullptr);
+ if (status != ERROR_SUCCESS) {
+ return HRESULT_FROM_WIN32(status);
+ }
+ nsAutoRegKey appidKey(rawAppidKey);
+
+ // Setting DllSurrogate to a null or empty string indicates to Windows that
+ // we want to use the default surrogate (i.e. dllhost.exe) to load our DLL.
+ status =
+ ::RegSetValueExW(appidKey, L"DllSurrogate", 0, REG_SZ,
+ reinterpret_cast<const BYTE*>(L""), sizeof(wchar_t));
+ if (status != ERROR_SUCCESS) {
+ return HRESULT_FROM_WIN32(status);
+ }
+ }
+
+ if (!aMsixContainer) {
+ if (!::CommitTransaction(txn)) {
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+ }
+
+ return S_OK;
+}
+
+/**
+ * Unfortunately the registry transaction APIs are not as well-developed for
+ * deleting things as they are for creating them. We just use RegDeleteTree
+ * for the implementation of this method.
+ */
+HRESULT Module::Deregister(const CLSID* const* aClsids, const size_t aNumClsids,
+ const GUID* const aAppId) {
+ MOZ_ASSERT(aClsids && aNumClsids);
+ if (!aClsids || !aNumClsids) {
+ return E_INVALIDARG;
+ }
+
+ HRESULT hr;
+ LSTATUS status;
+
+ // Delete the key for each CLSID. This will also delete any references to
+ // the AppId.
+ for (size_t idx = 0; idx < aNumClsids; ++idx) {
+ if (!aClsids[idx]) {
+ return E_INVALIDARG;
+ }
+
+ wchar_t clsidKeyPath[256];
+ hr = BuildClsidPath(*aClsids[idx], clsidKeyPath);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ status = ::RegDeleteTreeW(HKEY_LOCAL_MACHINE, clsidKeyPath);
+ // We allow the deletion to succeed if the key was already gone
+ if (status != ERROR_SUCCESS && status != ERROR_FILE_NOT_FOUND) {
+ return HRESULT_FROM_WIN32(status);
+ }
+ }
+
+ // Now delete the AppID key, if desired.
+ if (aAppId) {
+ wchar_t appidKeyPath[256];
+ hr = BuildAppidPath(*aAppId, appidKeyPath);
+ if (FAILED(hr)) {
+ return hr;
+ }
+
+ status = ::RegDeleteTreeW(HKEY_LOCAL_MACHINE, appidKeyPath);
+ // We allow the deletion to succeed if the key was already gone
+ if (status != ERROR_SUCCESS && status != ERROR_FILE_NOT_FOUND) {
+ return HRESULT_FROM_WIN32(status);
+ }
+ }
+
+ return S_OK;
+}
+
+} // namespace mscom
+} // namespace mozilla
diff --git a/ipc/mscom/oop/Module.h b/ipc/mscom/oop/Module.h
new file mode 100644
index 0000000000..1fa4b31c49
--- /dev/null
+++ b/ipc/mscom/oop/Module.h
@@ -0,0 +1,98 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_Module_h
+#define mozilla_mscom_Module_h
+
+#if defined(MOZILLA_INTERNAL_API)
+# error This code is NOT for internal Gecko use!
+#endif // defined(MOZILLA_INTERNAL_API)
+
+#include <objbase.h>
+
+/* WARNING! The code in this file may be loaded into the address spaces of other
+ processes! It MUST NOT link against xul.dll or other Gecko binaries! Only
+ inline code may be included! */
+
+namespace mozilla {
+namespace mscom {
+
+class Module {
+ public:
+ static HRESULT CanUnload() { return sRefCount == 0 ? S_OK : S_FALSE; }
+
+ static void Lock() { ++sRefCount; }
+ static void Unlock() { --sRefCount; }
+
+ enum class ThreadingModel {
+ DedicatedUiThreadOnly,
+ MultiThreadedApartmentOnly,
+ DedicatedUiThreadXorMultiThreadedApartment,
+ AllThreadsAllApartments,
+ };
+
+ enum class ClassType {
+ InprocServer,
+ InprocHandler,
+ };
+
+ /**
+ * In the Register functions, the aMsixContainer parameter specifies whether
+ * this registration is being performed inside an MSIX container. If true,
+ * the CLSID is registered in HKCU and a registry transaction is not used, as
+ * registry transactions don't work in an MSIX container. If false (the
+ * default), the CLSID is registered in HKLM and a registry transaction is
+ * used so that any failures roll back the entire operation. Specifying aAppId
+ * is only valid when aMsixContainer is false.
+ */
+ static HRESULT Register(REFCLSID aClsid, const ThreadingModel aThreadingModel,
+ const ClassType aClassType = ClassType::InprocServer,
+ const GUID* const aAppId = nullptr,
+ const bool aMsixContainer = false) {
+ const CLSID* clsidArray[] = {&aClsid};
+ return Register(clsidArray, aThreadingModel, aClassType, aAppId,
+ aMsixContainer);
+ }
+
+ template <size_t N>
+ static HRESULT Register(const CLSID* (&aClsids)[N],
+ const ThreadingModel aThreadingModel,
+ const ClassType aClassType = ClassType::InprocServer,
+ const GUID* const aAppId = nullptr,
+ const bool aMsixContainer = false) {
+ return Register(aClsids, N, aThreadingModel, aClassType, aAppId,
+ aMsixContainer);
+ }
+
+ static HRESULT Deregister(REFCLSID aClsid,
+ const GUID* const aAppId = nullptr) {
+ const CLSID* clsidArray[] = {&aClsid};
+ return Deregister(clsidArray, aAppId);
+ }
+
+ template <size_t N>
+ static HRESULT Deregister(const CLSID* (&aClsids)[N],
+ const GUID* const aAppId = nullptr) {
+ return Deregister(aClsids, N, aAppId);
+ }
+
+ private:
+ static HRESULT Register(const CLSID* const* aClsids, const size_t aNumClsids,
+ const ThreadingModel aThreadingModel,
+ const ClassType aClassType, const GUID* const aAppId,
+ const bool aMsixContainer);
+
+ static HRESULT Deregister(const CLSID* const* aClsids,
+ const size_t aNumClsids, const GUID* const aAppId);
+
+ private:
+ static ULONG sRefCount;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_Module_h
diff --git a/ipc/mscom/oop/moz.build b/ipc/mscom/oop/moz.build
new file mode 100644
index 0000000000..5ead655e6f
--- /dev/null
+++ b/ipc/mscom/oop/moz.build
@@ -0,0 +1,43 @@
+# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*-
+# vim: set filetype=python:
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+Library("mscom_oop")
+
+SOURCES += [
+ "../ActivationContext.cpp",
+ "../COMWrappers.cpp",
+ "../Objref.cpp",
+ "../Registration.cpp",
+ "../StructStream.cpp",
+ "../Utils.cpp",
+]
+
+UNIFIED_SOURCES += [
+ "Handler.cpp",
+ "Module.cpp",
+]
+
+OS_LIBS += [
+ "ktmw32",
+ "ole32",
+ "oleaut32",
+ "shlwapi",
+]
+
+LIBRARY_DEFINES["UNICODE"] = True
+LIBRARY_DEFINES["_UNICODE"] = True
+LIBRARY_DEFINES["MOZ_NO_MOZALLOC"] = True
+LIBRARY_DEFINES["MOZ_MSCOM_REMARSHAL_NO_HANDLER"] = True
+
+DisableStlWrapping()
+NO_EXPAND_LIBS = True
+FORCE_STATIC_LIB = True
+
+# This DLL may be loaded into other processes, so we need static libs for
+# Windows 7 and Windows 8.
+USE_STATIC_LIBS = True
+
+REQUIRES_UNIFIED_BUILD = True