diff options
Diffstat (limited to 'ipc/mscom')
62 files changed, 12072 insertions, 0 deletions
diff --git a/ipc/mscom/ActivationContext.cpp b/ipc/mscom/ActivationContext.cpp new file mode 100644 index 0000000000..c5ee9d1419 --- /dev/null +++ b/ipc/mscom/ActivationContext.cpp @@ -0,0 +1,223 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/mscom/ActivationContext.h" + +#include "mozilla/Assertions.h" +#include "mozilla/DebugOnly.h" +#include "mozilla/mscom/Utils.h" + +namespace mozilla { +namespace mscom { + +ActivationContext::ActivationContext(ActCtxResource aResource) + : ActivationContext(aResource.mModule, aResource.mId) {} + +ActivationContext::ActivationContext(HMODULE aLoadFromModule, WORD aResourceId) + : mActCtx(INVALID_HANDLE_VALUE) { + ACTCTXW actCtx = {sizeof(actCtx)}; + actCtx.dwFlags = ACTCTX_FLAG_RESOURCE_NAME_VALID | ACTCTX_FLAG_HMODULE_VALID; + actCtx.lpResourceName = MAKEINTRESOURCEW(aResourceId); + actCtx.hModule = aLoadFromModule; + + Init(actCtx); +} + +void ActivationContext::Init(ACTCTXW& aActCtx) { + MOZ_ASSERT(mActCtx == INVALID_HANDLE_VALUE); + mActCtx = ::CreateActCtxW(&aActCtx); + MOZ_ASSERT(mActCtx != INVALID_HANDLE_VALUE); +} + +void ActivationContext::AddRef() { + if (mActCtx == INVALID_HANDLE_VALUE) { + return; + } + ::AddRefActCtx(mActCtx); +} + +ActivationContext::ActivationContext(ActivationContext&& aOther) + : mActCtx(aOther.mActCtx) { + aOther.mActCtx = INVALID_HANDLE_VALUE; +} + +ActivationContext& ActivationContext::operator=(ActivationContext&& aOther) { + Release(); + + mActCtx = aOther.mActCtx; + aOther.mActCtx = INVALID_HANDLE_VALUE; + return *this; +} + +ActivationContext::ActivationContext(const ActivationContext& aOther) + : mActCtx(aOther.mActCtx) { + AddRef(); +} + +ActivationContext& ActivationContext::operator=( + const ActivationContext& aOther) { + Release(); + mActCtx = aOther.mActCtx; + AddRef(); + return *this; +} + +void ActivationContext::Release() { + if (mActCtx == INVALID_HANDLE_VALUE) { + return; + } + ::ReleaseActCtx(mActCtx); + mActCtx = INVALID_HANDLE_VALUE; +} + +ActivationContext::~ActivationContext() { Release(); } + +#if defined(MOZILLA_INTERNAL_API) + +/* static */ Result<uintptr_t, HRESULT> ActivationContext::GetCurrent() { + HANDLE actCtx; + if (!::GetCurrentActCtx(&actCtx)) { + return Result<uintptr_t, HRESULT>(HRESULT_FROM_WIN32(::GetLastError())); + } + + return reinterpret_cast<uintptr_t>(actCtx); +} + +/* static */ +HRESULT ActivationContext::GetCurrentManifestPath(nsAString& aOutManifestPath) { + aOutManifestPath.Truncate(); + + SIZE_T bytesNeeded; + BOOL ok = ::QueryActCtxW(QUERY_ACTCTX_FLAG_USE_ACTIVE_ACTCTX, nullptr, + nullptr, ActivationContextDetailedInformation, + nullptr, 0, &bytesNeeded); + if (!ok) { + DWORD err = ::GetLastError(); + if (err != ERROR_INSUFFICIENT_BUFFER) { + return HRESULT_FROM_WIN32(err); + } + } + + auto ctxBuf = MakeUnique<BYTE[]>(bytesNeeded); + + ok = ::QueryActCtxW(QUERY_ACTCTX_FLAG_USE_ACTIVE_ACTCTX, nullptr, nullptr, + ActivationContextDetailedInformation, ctxBuf.get(), + bytesNeeded, nullptr); + if (!ok) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + + auto ctxInfo = + reinterpret_cast<ACTIVATION_CONTEXT_DETAILED_INFORMATION*>(ctxBuf.get()); + + // assemblyIndex is 1-based, and we want the last index, so we can just copy + // ctxInfo->ulAssemblyCount directly. + DWORD assemblyIndex = ctxInfo->ulAssemblyCount; + ok = ::QueryActCtxW( + QUERY_ACTCTX_FLAG_USE_ACTIVE_ACTCTX, nullptr, &assemblyIndex, + AssemblyDetailedInformationInActivationContext, nullptr, 0, &bytesNeeded); + if (!ok) { + DWORD err = ::GetLastError(); + if (err != ERROR_INSUFFICIENT_BUFFER) { + return HRESULT_FROM_WIN32(err); + } + } + + auto assemblyBuf = MakeUnique<BYTE[]>(bytesNeeded); + + ok = ::QueryActCtxW(QUERY_ACTCTX_FLAG_USE_ACTIVE_ACTCTX, nullptr, + &assemblyIndex, + AssemblyDetailedInformationInActivationContext, + assemblyBuf.get(), bytesNeeded, &bytesNeeded); + if (!ok) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + + auto assemblyInfo = + reinterpret_cast<ACTIVATION_CONTEXT_ASSEMBLY_DETAILED_INFORMATION*>( + assemblyBuf.get()); + aOutManifestPath = nsDependentString( + assemblyInfo->lpAssemblyManifestPath, + (assemblyInfo->ulManifestPathLength + 1) / sizeof(wchar_t)); + + return S_OK; +} + +#endif // defined(MOZILLA_INTERNAL_API) + +ActivationContextRegion::ActivationContextRegion() : mActCookie(0) {} + +ActivationContextRegion::ActivationContextRegion( + const ActivationContext& aActCtx) + : mActCtx(aActCtx), mActCookie(0) { + Activate(); +} + +ActivationContextRegion& ActivationContextRegion::operator=( + const ActivationContext& aActCtx) { + Deactivate(); + mActCtx = aActCtx; + Activate(); + return *this; +} + +ActivationContextRegion::ActivationContextRegion(ActivationContext&& aActCtx) + : mActCtx(std::move(aActCtx)), mActCookie(0) { + Activate(); +} + +ActivationContextRegion& ActivationContextRegion::operator=( + ActivationContext&& aActCtx) { + Deactivate(); + mActCtx = std::move(aActCtx); + Activate(); + return *this; +} + +ActivationContextRegion::ActivationContextRegion(ActivationContextRegion&& aRgn) + : mActCtx(std::move(aRgn.mActCtx)), mActCookie(aRgn.mActCookie) { + aRgn.mActCookie = 0; +} + +ActivationContextRegion& ActivationContextRegion::operator=( + ActivationContextRegion&& aRgn) { + Deactivate(); + mActCtx = std::move(aRgn.mActCtx); + mActCookie = aRgn.mActCookie; + aRgn.mActCookie = 0; + return *this; +} + +void ActivationContextRegion::Activate() { + if (mActCtx.mActCtx == INVALID_HANDLE_VALUE) { + return; + } + +#ifdef MOZ_DIAGNOSTIC_ASSERT_ENABLED + BOOL activated = +#endif + ::ActivateActCtx(mActCtx.mActCtx, &mActCookie); + MOZ_DIAGNOSTIC_ASSERT(activated); +} + +bool ActivationContextRegion::Deactivate() { + if (!mActCookie) { + return true; + } + + BOOL deactivated = ::DeactivateActCtx(0, mActCookie); + MOZ_DIAGNOSTIC_ASSERT(deactivated); + if (deactivated) { + mActCookie = 0; + } + + return !!deactivated; +} + +ActivationContextRegion::~ActivationContextRegion() { Deactivate(); } + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/ActivationContext.h b/ipc/mscom/ActivationContext.h new file mode 100644 index 0000000000..3d0144528d --- /dev/null +++ b/ipc/mscom/ActivationContext.h @@ -0,0 +1,101 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_ActivationContext_h +#define mozilla_mscom_ActivationContext_h + +#include <utility> + +#include "mozilla/Attributes.h" +#include "mozilla/mscom/ActCtxResource.h" + +#if defined(MOZILLA_INTERNAL_API) +# include "mozilla/ResultVariant.h" +# include "nsString.h" +#endif // defined(MOZILLA_INTERNAL_API) + +#include <windows.h> + +namespace mozilla { +namespace mscom { + +class ActivationContext final { + public: + // This is the default resource ID that the Windows dynamic linker searches + // for when seeking a manifest while loading a DLL. + static constexpr WORD kDllManifestDefaultResourceId = 2; + + ActivationContext() : mActCtx(INVALID_HANDLE_VALUE) {} + + explicit ActivationContext(ActCtxResource aResource); + explicit ActivationContext(HMODULE aLoadFromModule, + WORD aResourceId = kDllManifestDefaultResourceId); + + ActivationContext(ActivationContext&& aOther); + ActivationContext& operator=(ActivationContext&& aOther); + + ActivationContext(const ActivationContext& aOther); + ActivationContext& operator=(const ActivationContext& aOther); + + ~ActivationContext(); + + explicit operator bool() const { return mActCtx != INVALID_HANDLE_VALUE; } + +#if defined(MOZILLA_INTERNAL_API) + static Result<uintptr_t, HRESULT> GetCurrent(); + static HRESULT GetCurrentManifestPath(nsAString& aOutManifestPath); +#endif // defined(MOZILLA_INTERNAL_API) + + private: + void Init(ACTCTXW& aActCtx); + void AddRef(); + void Release(); + + private: + HANDLE mActCtx; + + friend class ActivationContextRegion; +}; + +class MOZ_NON_TEMPORARY_CLASS ActivationContextRegion final { + public: + template <typename... Args> + explicit ActivationContextRegion(Args&&... aArgs) + : mActCtx(std::forward<Args>(aArgs)...), mActCookie(0) { + Activate(); + } + + ActivationContextRegion(); + + explicit ActivationContextRegion(const ActivationContext& aActCtx); + ActivationContextRegion& operator=(const ActivationContext& aActCtx); + + explicit ActivationContextRegion(ActivationContext&& aActCtx); + ActivationContextRegion& operator=(ActivationContext&& aActCtx); + + ActivationContextRegion(ActivationContextRegion&& aRgn); + ActivationContextRegion& operator=(ActivationContextRegion&& aRgn); + + ~ActivationContextRegion(); + + explicit operator bool() const { return !!mActCookie; } + + ActivationContextRegion(const ActivationContextRegion&) = delete; + ActivationContextRegion& operator=(const ActivationContextRegion&) = delete; + + bool Deactivate(); + + private: + void Activate(); + + ActivationContext mActCtx; + ULONG_PTR mActCookie; +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_ActivationContext_h diff --git a/ipc/mscom/Aggregation.h b/ipc/mscom/Aggregation.h new file mode 100644 index 0000000000..ca171e39cc --- /dev/null +++ b/ipc/mscom/Aggregation.h @@ -0,0 +1,83 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_Aggregation_h +#define mozilla_mscom_Aggregation_h + +#include "mozilla/Attributes.h" + +#include <stddef.h> +#include <unknwn.h> + +namespace mozilla { +namespace mscom { + +/** + * This is used for stabilizing a COM object's reference count during + * construction when that object aggregates other objects. Since the aggregated + * object(s) may AddRef() or Release(), we need to artifically boost the + * refcount to prevent premature destruction. Note that we increment/decrement + * instead of AddRef()/Release() in this class because we want to adjust the + * refcount without causing any other side effects (like object destruction). + */ +template <typename RefCntT> +class MOZ_RAII StabilizedRefCount { + public: + explicit StabilizedRefCount(RefCntT& aRefCnt) : mRefCnt(aRefCnt) { + ++aRefCnt; + } + + ~StabilizedRefCount() { --mRefCnt; } + + StabilizedRefCount(const StabilizedRefCount&) = delete; + StabilizedRefCount(StabilizedRefCount&&) = delete; + StabilizedRefCount& operator=(const StabilizedRefCount&) = delete; + StabilizedRefCount& operator=(StabilizedRefCount&&) = delete; + + private: + RefCntT& mRefCnt; +}; + +namespace detail { + +template <typename T> +class InternalUnknown : public IUnknown { + public: + STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) override { + return This()->InternalQueryInterface(aIid, aOutInterface); + } + + STDMETHODIMP_(ULONG) AddRef() override { return This()->InternalAddRef(); } + + STDMETHODIMP_(ULONG) Release() override { return This()->InternalRelease(); } + + private: + T* This() { + return reinterpret_cast<T*>(reinterpret_cast<char*>(this) - + offsetof(T, mInternalUnknown)); + } +}; + +} // namespace detail +} // namespace mscom +} // namespace mozilla + +#define DECLARE_AGGREGATABLE(Type) \ + public: \ + STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override { \ + return mOuter->QueryInterface(riid, ppv); \ + } \ + STDMETHODIMP_(ULONG) AddRef() override { return mOuter->AddRef(); } \ + STDMETHODIMP_(ULONG) Release() override { return mOuter->Release(); } \ + \ + protected: \ + STDMETHODIMP InternalQueryInterface(REFIID riid, void** ppv); \ + STDMETHODIMP_(ULONG) InternalAddRef(); \ + STDMETHODIMP_(ULONG) InternalRelease(); \ + friend class mozilla::mscom::detail::InternalUnknown<Type>; \ + mozilla::mscom::detail::InternalUnknown<Type> mInternalUnknown + +#endif // mozilla_mscom_Aggregation_h diff --git a/ipc/mscom/AgileReference.cpp b/ipc/mscom/AgileReference.cpp new file mode 100644 index 0000000000..7a1b94822d --- /dev/null +++ b/ipc/mscom/AgileReference.cpp @@ -0,0 +1,223 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/mscom/AgileReference.h" + +#include <utility> + +#include "mozilla/Assertions.h" +#include "mozilla/DebugOnly.h" +#include "mozilla/DynamicallyLinkedFunctionPtr.h" +#include "mozilla/mscom/Utils.h" + +#if defined(MOZILLA_INTERNAL_API) +# include "nsDebug.h" +# include "nsPrintfCString.h" +#endif // defined(MOZILLA_INTERNAL_API) + +#if NTDDI_VERSION < NTDDI_WINBLUE + +// Declarations from Windows SDK specific to Windows 8.1 + +enum AgileReferenceOptions { + AGILEREFERENCE_DEFAULT = 0, + AGILEREFERENCE_DELAYEDMARSHAL = 1, +}; + +HRESULT WINAPI RoGetAgileReference(AgileReferenceOptions options, REFIID riid, + IUnknown* pUnk, + IAgileReference** ppAgileReference); + +#endif // NTDDI_VERSION < NTDDI_WINBLUE + +namespace mozilla { +namespace mscom { +namespace detail { + +GlobalInterfaceTableCookie::GlobalInterfaceTableCookie(IUnknown* aObject, + REFIID aIid, + HRESULT& aOutHResult) + : mCookie(0) { + IGlobalInterfaceTable* git = ObtainGit(); + MOZ_ASSERT(git); + if (!git) { + aOutHResult = E_POINTER; + return; + } + + aOutHResult = git->RegisterInterfaceInGlobal(aObject, aIid, &mCookie); + MOZ_ASSERT(SUCCEEDED(aOutHResult)); +} + +GlobalInterfaceTableCookie::~GlobalInterfaceTableCookie() { + IGlobalInterfaceTable* git = ObtainGit(); + MOZ_ASSERT(git); + if (!git) { + return; + } + + DebugOnly<HRESULT> hr = git->RevokeInterfaceFromGlobal(mCookie); +#if defined(MOZILLA_INTERNAL_API) + NS_WARNING_ASSERTION( + SUCCEEDED(hr), + nsPrintfCString("IGlobalInterfaceTable::RevokeInterfaceFromGlobal failed " + "with HRESULT 0x%08lX", + ((HRESULT)hr)) + .get()); +#else + MOZ_ASSERT(SUCCEEDED(hr)); +#endif // defined(MOZILLA_INTERNAL_API) + mCookie = 0; +} + +HRESULT GlobalInterfaceTableCookie::GetInterface(REFIID aIid, + void** aOutInterface) const { + IGlobalInterfaceTable* git = ObtainGit(); + MOZ_ASSERT(git); + if (!git) { + return E_UNEXPECTED; + } + + MOZ_ASSERT(IsValid()); + return git->GetInterfaceFromGlobal(mCookie, aIid, aOutInterface); +} + +/* static */ +IGlobalInterfaceTable* GlobalInterfaceTableCookie::ObtainGit() { + // Internally to COM, the Global Interface Table is a singleton, therefore we + // don't worry about holding onto this reference indefinitely. + static IGlobalInterfaceTable* sGit = []() -> IGlobalInterfaceTable* { + IGlobalInterfaceTable* result = nullptr; + DebugOnly<HRESULT> hr = ::CoCreateInstance( + CLSID_StdGlobalInterfaceTable, nullptr, CLSCTX_INPROC_SERVER, + IID_IGlobalInterfaceTable, reinterpret_cast<void**>(&result)); + MOZ_ASSERT(SUCCEEDED(hr)); + return result; + }(); + + return sGit; +} + +} // namespace detail + +AgileReference::AgileReference() : mIid(), mHResult(E_NOINTERFACE) {} + +AgileReference::AgileReference(REFIID aIid, IUnknown* aObject) + : mIid(aIid), mHResult(E_UNEXPECTED) { + AssignInternal(aObject); +} + +AgileReference::AgileReference(AgileReference&& aOther) + : mIid(aOther.mIid), + mAgileRef(std::move(aOther.mAgileRef)), + mGitCookie(std::move(aOther.mGitCookie)), + mHResult(aOther.mHResult) { + aOther.mHResult = CO_E_RELEASED; +} + +void AgileReference::Assign(REFIID aIid, IUnknown* aObject) { + Clear(); + mIid = aIid; + AssignInternal(aObject); +} + +void AgileReference::AssignInternal(IUnknown* aObject) { + // We expect mIid to already be set + DebugOnly<IID> zeroIid = {}; + MOZ_ASSERT(mIid != zeroIid); + + /* + * There are two possible techniques for creating agile references. Starting + * with Windows 8.1, we may use the RoGetAgileReference API, which is faster. + * If that API is not available, we fall back to using the Global Interface + * Table. + */ + static const StaticDynamicallyLinkedFunctionPtr< + decltype(&::RoGetAgileReference)> + pRoGetAgileReference(L"ole32.dll", "RoGetAgileReference"); + + MOZ_ASSERT(aObject); + + if (pRoGetAgileReference && + SUCCEEDED(mHResult = + pRoGetAgileReference(AGILEREFERENCE_DEFAULT, mIid, aObject, + getter_AddRefs(mAgileRef)))) { + return; + } + + mGitCookie = new detail::GlobalInterfaceTableCookie(aObject, mIid, mHResult); + MOZ_ASSERT(mGitCookie->IsValid()); +} + +AgileReference::~AgileReference() { Clear(); } + +void AgileReference::Clear() { + mIid = {}; + mAgileRef = nullptr; + mGitCookie = nullptr; + mHResult = E_NOINTERFACE; +} + +AgileReference& AgileReference::operator=(const AgileReference& aOther) { + Clear(); + mIid = aOther.mIid; + mAgileRef = aOther.mAgileRef; + mGitCookie = aOther.mGitCookie; + mHResult = aOther.mHResult; + return *this; +} + +AgileReference& AgileReference::operator=(AgileReference&& aOther) { + Clear(); + mIid = aOther.mIid; + mAgileRef = std::move(aOther.mAgileRef); + mGitCookie = std::move(aOther.mGitCookie); + mHResult = aOther.mHResult; + aOther.mHResult = CO_E_RELEASED; + return *this; +} + +HRESULT +AgileReference::Resolve(REFIID aIid, void** aOutInterface) const { + MOZ_ASSERT(aOutInterface); + // This check is exclusive-OR; we should have one or the other, but not both + MOZ_ASSERT((mAgileRef || mGitCookie) && !(mAgileRef && mGitCookie)); + MOZ_ASSERT(IsCOMInitializedOnCurrentThread()); + + if (!aOutInterface) { + return E_INVALIDARG; + } + + *aOutInterface = nullptr; + + if (mAgileRef) { + // IAgileReference lets you directly resolve the interface you want... + return mAgileRef->Resolve(aIid, aOutInterface); + } + + if (!mGitCookie) { + return E_UNEXPECTED; + } + + RefPtr<IUnknown> originalInterface; + HRESULT hr = + mGitCookie->GetInterface(mIid, getter_AddRefs(originalInterface)); + if (FAILED(hr)) { + return hr; + } + + if (aIid == mIid) { + originalInterface.forget(aOutInterface); + return S_OK; + } + + // ...Whereas the GIT requires us to obtain the same interface that we + // requested and then QI for the desired interface afterward. + return originalInterface->QueryInterface(aIid, aOutInterface); +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/AgileReference.h b/ipc/mscom/AgileReference.h new file mode 100644 index 0000000000..d39e444494 --- /dev/null +++ b/ipc/mscom/AgileReference.h @@ -0,0 +1,143 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_AgileReference_h +#define mozilla_mscom_AgileReference_h + +#include "mozilla/Attributes.h" +#include "mozilla/RefPtr.h" +#include "nsISupportsImpl.h" + +#include <objidl.h> + +namespace mozilla { +namespace mscom { +namespace detail { + +class MOZ_HEAP_CLASS GlobalInterfaceTableCookie final { + public: + GlobalInterfaceTableCookie(IUnknown* aObject, REFIID aIid, + HRESULT& aOutHResult); + + bool IsValid() const { return !!mCookie; } + HRESULT GetInterface(REFIID aIid, void** aOutInterface) const; + + NS_INLINE_DECL_THREADSAFE_REFCOUNTING(GlobalInterfaceTableCookie) + + GlobalInterfaceTableCookie(const GlobalInterfaceTableCookie&) = delete; + GlobalInterfaceTableCookie(GlobalInterfaceTableCookie&&) = delete; + + GlobalInterfaceTableCookie& operator=(const GlobalInterfaceTableCookie&) = + delete; + GlobalInterfaceTableCookie& operator=(GlobalInterfaceTableCookie&&) = delete; + + private: + ~GlobalInterfaceTableCookie(); + + private: + DWORD mCookie; + + private: + static IGlobalInterfaceTable* ObtainGit(); +}; + +} // namespace detail + +/** + * This class encapsulates an "agile reference." These are references that + * allow you to pass COM interfaces between apartments. When you have an + * interface that you would like to pass between apartments, you wrap that + * interface in an AgileReference and pass the agile reference instead. Then + * you unwrap the interface by calling AgileReference::Resolve. + * + * Sample usage: + * + * // In the multithreaded apartment, foo is an IFoo* + * auto myAgileRef = MakeUnique<AgileReference>(IID_IFoo, foo); + * + * // myAgileRef is passed to our main thread, which runs in a single-threaded + * // apartment: + * + * RefPtr<IFoo> foo; + * HRESULT hr = myAgileRef->Resolve(IID_IFoo, getter_AddRefs(foo)); + * // Now foo may be called from the main thread + */ +class AgileReference final { + public: + AgileReference(); + + template <typename InterfaceT> + explicit AgileReference(RefPtr<InterfaceT>& aObject) + : AgileReference(__uuidof(InterfaceT), aObject) {} + + AgileReference(REFIID aIid, IUnknown* aObject); + + AgileReference(const AgileReference& aOther) = default; + AgileReference(AgileReference&& aOther); + + ~AgileReference(); + + explicit operator bool() const { + return mAgileRef || (mGitCookie && mGitCookie->IsValid()); + } + + HRESULT GetHResult() const { return mHResult; } + + template <typename T> + void Assign(const RefPtr<T>& aOther) { + Assign(__uuidof(T), aOther); + } + + template <typename T> + AgileReference& operator=(const RefPtr<T>& aOther) { + Assign(aOther); + return *this; + } + + HRESULT Resolve(REFIID aIid, void** aOutInterface) const; + + AgileReference& operator=(const AgileReference& aOther); + AgileReference& operator=(AgileReference&& aOther); + + AgileReference& operator=(decltype(nullptr)) { + Clear(); + return *this; + } + + void Clear(); + + private: + void Assign(REFIID aIid, IUnknown* aObject); + void AssignInternal(IUnknown* aObject); + + private: + IID mIid; + RefPtr<IAgileReference> mAgileRef; + RefPtr<detail::GlobalInterfaceTableCookie> mGitCookie; + HRESULT mHResult; +}; + +} // namespace mscom +} // namespace mozilla + +template <typename T> +RefPtr<T>::RefPtr(const mozilla::mscom::AgileReference& aAgileRef) + : mRawPtr(nullptr) { + (*this) = aAgileRef; +} + +template <typename T> +RefPtr<T>& RefPtr<T>::operator=( + const mozilla::mscom::AgileReference& aAgileRef) { + void* newRawPtr; + if (FAILED(aAgileRef.Resolve(__uuidof(T), &newRawPtr))) { + newRawPtr = nullptr; + } + assign_assuming_AddRef(static_cast<T*>(newRawPtr)); + return *this; +} + +#endif // mozilla_mscom_AgileReference_h diff --git a/ipc/mscom/ApartmentRegion.h b/ipc/mscom/ApartmentRegion.h new file mode 100644 index 0000000000..41915710df --- /dev/null +++ b/ipc/mscom/ApartmentRegion.h @@ -0,0 +1,93 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_ApartmentRegion_h +#define mozilla_mscom_ApartmentRegion_h + +#include "mozilla/Assertions.h" +#include "mozilla/Attributes.h" +#include "mozilla/mscom/COMWrappers.h" + +namespace mozilla { +namespace mscom { + +class MOZ_NON_TEMPORARY_CLASS ApartmentRegion final { + public: + /** + * This constructor is to be used when we want to instantiate the object but + * we do not yet know which type of apartment we want. Call Init() to + * complete initialization. + */ + constexpr ApartmentRegion() : mInitResult(CO_E_NOTINITIALIZED) {} + + explicit ApartmentRegion(COINIT aAptType) + : mInitResult(wrapped::CoInitializeEx(nullptr, aAptType)) { + // If this fires then we're probably mixing apartments on the same thread + MOZ_ASSERT(IsValid()); + } + + ~ApartmentRegion() { + if (IsValid()) { + wrapped::CoUninitialize(); + } + } + + explicit operator bool() const { return IsValid(); } + + bool IsValidOutermost() const { return mInitResult == S_OK; } + + bool IsValid() const { return SUCCEEDED(mInitResult); } + + bool Init(COINIT aAptType) { + MOZ_ASSERT(mInitResult == CO_E_NOTINITIALIZED); + mInitResult = wrapped::CoInitializeEx(nullptr, aAptType); + MOZ_ASSERT(IsValid()); + return IsValid(); + } + + HRESULT + GetHResult() const { return mInitResult; } + + private: + ApartmentRegion(const ApartmentRegion&) = delete; + ApartmentRegion& operator=(const ApartmentRegion&) = delete; + ApartmentRegion(ApartmentRegion&&) = delete; + ApartmentRegion& operator=(ApartmentRegion&&) = delete; + + HRESULT mInitResult; +}; + +template <COINIT T> +class MOZ_NON_TEMPORARY_CLASS ApartmentRegionT final { + public: + ApartmentRegionT() : mAptRgn(T) {} + + ~ApartmentRegionT() = default; + + explicit operator bool() const { return mAptRgn.IsValid(); } + + bool IsValidOutermost() const { return mAptRgn.IsValidOutermost(); } + + bool IsValid() const { return mAptRgn.IsValid(); } + + HRESULT GetHResult() const { return mAptRgn.GetHResult(); } + + private: + ApartmentRegionT(const ApartmentRegionT&) = delete; + ApartmentRegionT& operator=(const ApartmentRegionT&) = delete; + ApartmentRegionT(ApartmentRegionT&&) = delete; + ApartmentRegionT& operator=(ApartmentRegionT&&) = delete; + + ApartmentRegion mAptRgn; +}; + +typedef ApartmentRegionT<COINIT_APARTMENTTHREADED> STARegion; +typedef ApartmentRegionT<COINIT_MULTITHREADED> MTARegion; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_ApartmentRegion_h diff --git a/ipc/mscom/AsyncInvoker.h b/ipc/mscom/AsyncInvoker.h new file mode 100644 index 0000000000..ce25a4224b --- /dev/null +++ b/ipc/mscom/AsyncInvoker.h @@ -0,0 +1,465 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_AsyncInvoker_h +#define mozilla_mscom_AsyncInvoker_h + +#include <objidl.h> +#include <windows.h> + +#include <utility> + +#include "mozilla/Assertions.h" +#include "mozilla/Attributes.h" +#include "mozilla/DebugOnly.h" +#include "mozilla/Maybe.h" +#include "mozilla/Mutex.h" +#include "mozilla/mscom/Aggregation.h" +#include "mozilla/mscom/Utils.h" +#include "nsISerialEventTarget.h" +#include "nsISupportsImpl.h" +#include "nsThreadUtils.h" + +namespace mozilla { +namespace mscom { +namespace detail { + +template <typename AsyncInterface> +class ForgettableAsyncCall : public ISynchronize { + public: + explicit ForgettableAsyncCall(ICallFactory* aCallFactory) + : mRefCnt(0), mAsyncCall(nullptr) { + StabilizedRefCount<Atomic<ULONG>> stabilizer(mRefCnt); + + HRESULT hr = + aCallFactory->CreateCall(__uuidof(AsyncInterface), this, IID_IUnknown, + getter_AddRefs(mInnerUnk)); + if (FAILED(hr)) { + return; + } + + hr = mInnerUnk->QueryInterface(__uuidof(AsyncInterface), + reinterpret_cast<void**>(&mAsyncCall)); + if (SUCCEEDED(hr)) { + // Don't hang onto a ref. Because mAsyncCall is aggregated, its refcount + // is this->mRefCnt, so we'd create a cycle! + mAsyncCall->Release(); + } + } + + AsyncInterface* GetInterface() const { return mAsyncCall; } + + // IUnknown + STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) final { + if (aIid == IID_ISynchronize || aIid == IID_IUnknown) { + RefPtr<ISynchronize> ptr(this); + ptr.forget(aOutInterface); + return S_OK; + } + + return mInnerUnk->QueryInterface(aIid, aOutInterface); + } + + STDMETHODIMP_(ULONG) AddRef() final { + ULONG result = ++mRefCnt; + NS_LOG_ADDREF(this, result, "ForgettableAsyncCall", sizeof(*this)); + return result; + } + + STDMETHODIMP_(ULONG) Release() final { + ULONG result = --mRefCnt; + NS_LOG_RELEASE(this, result, "ForgettableAsyncCall"); + if (!result) { + delete this; + } + return result; + } + + // ISynchronize + STDMETHODIMP Wait(DWORD aFlags, DWORD aTimeoutMilliseconds) override { + return E_NOTIMPL; + } + + STDMETHODIMP Signal() override { + // Even though this function is a no-op, we must return S_OK as opposed to + // E_NOTIMPL or else COM will consider the async call to have failed. + return S_OK; + } + + STDMETHODIMP Reset() override { + // Even though this function is a no-op, we must return S_OK as opposed to + // E_NOTIMPL or else COM will consider the async call to have failed. + return S_OK; + } + + protected: + virtual ~ForgettableAsyncCall() = default; + + private: + Atomic<ULONG> mRefCnt; + RefPtr<IUnknown> mInnerUnk; + AsyncInterface* mAsyncCall; // weak reference +}; + +template <typename AsyncInterface> +class WaitableAsyncCall : public ForgettableAsyncCall<AsyncInterface> { + public: + explicit WaitableAsyncCall(ICallFactory* aCallFactory) + : ForgettableAsyncCall<AsyncInterface>(aCallFactory), + mEvent(::CreateEventW(nullptr, FALSE, FALSE, nullptr)) {} + + STDMETHODIMP Wait(DWORD aFlags, DWORD aTimeoutMilliseconds) override { + const DWORD waitStart = + aTimeoutMilliseconds == INFINITE ? 0 : ::GetTickCount(); + DWORD flags = aFlags; + if (XRE_IsContentProcess() && NS_IsMainThread()) { + flags |= COWAIT_ALERTABLE; + } + + HRESULT hr; + DWORD signaledIdx; + + DWORD elapsed = 0; + + while (true) { + if (aTimeoutMilliseconds != INFINITE) { + elapsed = ::GetTickCount() - waitStart; + } + if (elapsed >= aTimeoutMilliseconds) { + return RPC_S_CALLPENDING; + } + + ::SetLastError(ERROR_SUCCESS); + + hr = ::CoWaitForMultipleHandles(flags, aTimeoutMilliseconds - elapsed, 1, + &mEvent, &signaledIdx); + if (hr == RPC_S_CALLPENDING || FAILED(hr)) { + return hr; + } + + if (hr == S_OK && signaledIdx == 0) { + return hr; + } + } + } + + STDMETHODIMP Signal() override { + if (!::SetEvent(mEvent)) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + return S_OK; + } + + protected: + ~WaitableAsyncCall() { + if (mEvent) { + ::CloseHandle(mEvent); + } + } + + private: + HANDLE mEvent; +}; + +template <typename AsyncInterface> +class EventDrivenAsyncCall : public ForgettableAsyncCall<AsyncInterface> { + public: + explicit EventDrivenAsyncCall(ICallFactory* aCallFactory) + : ForgettableAsyncCall<AsyncInterface>(aCallFactory) {} + + bool HasCompletionRunnable() const { return !!mCompletionRunnable; } + + void ClearCompletionRunnable() { mCompletionRunnable = nullptr; } + + void SetCompletionRunnable(already_AddRefed<nsIRunnable> aRunnable) { + nsCOMPtr<nsIRunnable> innerRunnable(aRunnable); + MOZ_ASSERT(!!innerRunnable); + if (!innerRunnable) { + return; + } + + // We need to retain a ref to ourselves to outlive the AsyncInvoker + // such that our callback can execute. + RefPtr<EventDrivenAsyncCall<AsyncInterface>> self(this); + + mCompletionRunnable = NS_NewRunnableFunction( + "EventDrivenAsyncCall outer completion Runnable", + [innerRunnable = std::move(innerRunnable), self = std::move(self)]() { + innerRunnable->Run(); + }); + } + + void SetEventTarget(nsISerialEventTarget* aTarget) { mEventTarget = aTarget; } + + STDMETHODIMP Signal() override { + MOZ_ASSERT(!!mCompletionRunnable); + if (!mCompletionRunnable) { + return S_OK; + } + + nsCOMPtr<nsISerialEventTarget> eventTarget(mEventTarget.forget()); + if (!eventTarget) { + eventTarget = GetMainThreadSerialEventTarget(); + } + + DebugOnly<nsresult> rv = + eventTarget->Dispatch(mCompletionRunnable.forget(), NS_DISPATCH_NORMAL); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + return S_OK; + } + + private: + nsCOMPtr<nsIRunnable> mCompletionRunnable; + nsCOMPtr<nsISerialEventTarget> mEventTarget; +}; + +template <typename AsyncInterface> +class FireAndForgetInvoker { + protected: + void OnBeginInvoke() {} + void OnSyncInvoke(HRESULT aHr) {} + void OnAsyncInvokeFailed() {} + + typedef ForgettableAsyncCall<AsyncInterface> AsyncCallType; + + RefPtr<ForgettableAsyncCall<AsyncInterface>> mAsyncCall; +}; + +template <typename AsyncInterface> +class WaitableInvoker { + public: + HRESULT Wait(DWORD aTimeout = INFINITE) const { + if (!mAsyncCall) { + // Nothing to wait for + return S_OK; + } + + return mAsyncCall->Wait(0, aTimeout); + } + + protected: + void OnBeginInvoke() {} + void OnSyncInvoke(HRESULT aHr) {} + void OnAsyncInvokeFailed() {} + + typedef WaitableAsyncCall<AsyncInterface> AsyncCallType; + + RefPtr<WaitableAsyncCall<AsyncInterface>> mAsyncCall; +}; + +template <typename AsyncInterface> +class EventDrivenInvoker { + public: + void SetCompletionRunnable(already_AddRefed<nsIRunnable> aRunnable) { + if (mAsyncCall) { + mAsyncCall->SetCompletionRunnable(std::move(aRunnable)); + return; + } + + mCompletionRunnable = aRunnable; + } + + void SetAsyncEventTarget(nsISerialEventTarget* aTarget) { + if (mAsyncCall) { + mAsyncCall->SetEventTarget(aTarget); + } + } + + protected: + void OnBeginInvoke() { + MOZ_RELEASE_ASSERT( + mCompletionRunnable || + (mAsyncCall && mAsyncCall->HasCompletionRunnable()), + "You should have called SetCompletionRunnable before invoking!"); + } + + void OnSyncInvoke(HRESULT aHr) { + nsCOMPtr<nsIRunnable> completionRunnable(mCompletionRunnable.forget()); + if (FAILED(aHr)) { + return; + } + + completionRunnable->Run(); + } + + void OnAsyncInvokeFailed() { + MOZ_ASSERT(!!mAsyncCall); + mAsyncCall->ClearCompletionRunnable(); + } + + typedef EventDrivenAsyncCall<AsyncInterface> AsyncCallType; + + RefPtr<EventDrivenAsyncCall<AsyncInterface>> mAsyncCall; + nsCOMPtr<nsIRunnable> mCompletionRunnable; +}; + +} // namespace detail + +/** + * This class is intended for "fire-and-forget" asynchronous invocations of COM + * interfaces. This requires that an interface be annotated with the + * |async_uuid| attribute in midl. We also require that there be no outparams + * in the desired asynchronous interface (otherwise that would break the + * desired "fire-and-forget" semantics). + * + * For example, let us suppose we have some IDL as such: + * [object, uuid(...), async_uuid(...)] + * interface IFoo : IUnknown + * { + * HRESULT Bar([in] long baz); + * } + * + * Then, given an IFoo, we may construct an AsyncInvoker<IFoo, AsyncIFoo>: + * + * IFoo* foo = ...; + * AsyncInvoker<IFoo, AsyncIFoo> myInvoker(foo); + * HRESULT hr = myInvoker.Invoke(&IFoo::Bar, &AsyncIFoo::Begin_Bar, 7); + * + * Alternatively you may use the ASYNC_INVOKER_FOR and ASYNC_INVOKE macros, + * which automatically deduce the name of the asynchronous interface from the + * name of the synchronous interface: + * + * ASYNC_INVOKER_FOR(IFoo) myInvoker(foo); + * HRESULT hr = ASYNC_INVOKE(myInvoker, Bar, 7); + * + * This class may also be used when a synchronous COM call must be made that + * might reenter the content process. In this case, use the WaitableAsyncInvoker + * variant, or the WAITABLE_ASYNC_INVOKER_FOR macro: + * + * WAITABLE_ASYNC_INVOKER_FOR(Ifoo) myInvoker(foo); + * HRESULT hr = ASYNC_INVOKE(myInvoker, Bar, 7); + * if (SUCCEEDED(hr)) { + * myInvoker.Wait(); // <-- Wait for the COM call to complete. + * } + * + * In general you should avoid using the waitable version, but in some corner + * cases it is absolutely necessary in order to preserve correctness while + * avoiding deadlock. + * + * Finally, it is also possible to have the async invoker enqueue a runnable + * to the main thread when the async operation completes: + * + * EVENT_DRIVEN_ASYNC_INVOKER_FOR(Ifoo) myInvoker(foo); + * // myRunnable will be invoked on the main thread once the async operation + * // has completed. Note that we set this *before* we do the ASYNC_INVOKE! + * myInvoker.SetCompletionRunnable(myRunnable.forget()); + * HRESULT hr = ASYNC_INVOKE(myInvoker, Bar, 7); + * // ... + */ +template <typename SyncInterface, typename AsyncInterface, + template <typename Iface> class WaitPolicy = + detail::FireAndForgetInvoker> +class MOZ_RAII AsyncInvoker final : public WaitPolicy<AsyncInterface> { + using Base = WaitPolicy<AsyncInterface>; + + public: + typedef SyncInterface SyncInterfaceT; + typedef AsyncInterface AsyncInterfaceT; + + /** + * @param aSyncObj The COM object on which to invoke the asynchronous event. + * If this object is not a proxy to the synchronous variant + * of AsyncInterface, then it will be invoked synchronously + * instead (because it is an in-process virtual method call). + * @param aIsProxy An optional hint as to whether or not aSyncObj is a proxy. + * If not specified, AsyncInvoker will automatically detect + * whether aSyncObj is a proxy, however there may be a + * performance penalty associated with that. + */ + explicit AsyncInvoker(SyncInterface* aSyncObj, + const Maybe<bool>& aIsProxy = Nothing()) { + MOZ_ASSERT(aSyncObj); + + RefPtr<ICallFactory> callFactory; + if ((aIsProxy.isSome() && !aIsProxy.value()) || + FAILED(aSyncObj->QueryInterface(IID_ICallFactory, + getter_AddRefs(callFactory)))) { + mSyncObj = aSyncObj; + return; + } + + this->mAsyncCall = new typename Base::AsyncCallType(callFactory); + } + + /** + * @brief Invoke a method on the object. Member function pointers are provided + * for both the sychronous and asynchronous variants of the interface. + * If this invoker's encapsulated COM object is a proxy, then Invoke + * will call the asynchronous member function. Otherwise the + * synchronous version must be used, as the invocation will simply be a + * virtual function call that executes in-process. + * @param aSyncMethod Pointer to the method that we would like to invoke on + * the synchronous interface. + * @param aAsyncMethod Pointer to the method that we would like to invoke on + * the asynchronous interface. + */ + template <typename SyncMethod, typename AsyncMethod, typename... Args> + HRESULT Invoke(SyncMethod aSyncMethod, AsyncMethod aAsyncMethod, + Args&&... aArgs) { + this->OnBeginInvoke(); + if (mSyncObj) { + HRESULT hr = (mSyncObj->*aSyncMethod)(std::forward<Args>(aArgs)...); + this->OnSyncInvoke(hr); + return hr; + } + + MOZ_ASSERT(this->mAsyncCall); + if (!this->mAsyncCall) { + this->OnAsyncInvokeFailed(); + return E_POINTER; + } + + AsyncInterface* asyncInterface = this->mAsyncCall->GetInterface(); + MOZ_ASSERT(asyncInterface); + if (!asyncInterface) { + this->OnAsyncInvokeFailed(); + return E_POINTER; + } + + HRESULT hr = (asyncInterface->*aAsyncMethod)(std::forward<Args>(aArgs)...); + if (FAILED(hr)) { + this->OnAsyncInvokeFailed(); + } + + return hr; + } + + AsyncInvoker(const AsyncInvoker& aOther) = delete; + AsyncInvoker(AsyncInvoker&& aOther) = delete; + AsyncInvoker& operator=(const AsyncInvoker& aOther) = delete; + AsyncInvoker& operator=(AsyncInvoker&& aOther) = delete; + + private: + RefPtr<SyncInterface> mSyncObj; +}; + +template <typename SyncInterface, typename AsyncInterface> +using WaitableAsyncInvoker = + AsyncInvoker<SyncInterface, AsyncInterface, detail::WaitableInvoker>; + +template <typename SyncInterface, typename AsyncInterface> +using EventDrivenAsyncInvoker = + AsyncInvoker<SyncInterface, AsyncInterface, detail::EventDrivenInvoker>; + +} // namespace mscom +} // namespace mozilla + +#define ASYNC_INVOKER_FOR(SyncIface) \ + mozilla::mscom::AsyncInvoker<SyncIface, Async##SyncIface> + +#define WAITABLE_ASYNC_INVOKER_FOR(SyncIface) \ + mozilla::mscom::WaitableAsyncInvoker<SyncIface, Async##SyncIface> + +#define EVENT_DRIVEN_ASYNC_INVOKER_FOR(SyncIface) \ + mozilla::mscom::EventDrivenAsyncInvoker<SyncIface, Async##SyncIface> + +#define ASYNC_INVOKE(InvokerObj, SyncMethodName, ...) \ + InvokerObj.Invoke( \ + &decltype(InvokerObj)::SyncInterfaceT::SyncMethodName, \ + &decltype(InvokerObj)::AsyncInterfaceT::Begin_##SyncMethodName, \ + ##__VA_ARGS__) + +#endif // mozilla_mscom_AsyncInvoker_h diff --git a/ipc/mscom/COMPtrHolder.h b/ipc/mscom/COMPtrHolder.h new file mode 100644 index 0000000000..e99698b8e7 --- /dev/null +++ b/ipc/mscom/COMPtrHolder.h @@ -0,0 +1,201 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_COMPtrHolder_h +#define mozilla_mscom_COMPtrHolder_h + +#include <utility> + +#include "mozilla/Assertions.h" +#include "mozilla/Attributes.h" +#include "mozilla/DebugOnly.h" +#include "mozilla/mscom/ProxyStream.h" +#include "mozilla/mscom/Ptr.h" +#if defined(MOZ_SANDBOX) +# include "mozilla/SandboxSettings.h" +#endif // defined(MOZ_SANDBOX) +#include "nsExceptionHandler.h" + +namespace mozilla { +namespace mscom { + +template <typename Interface, const IID& _IID> +class COMPtrHolder { + public: + typedef ProxyUniquePtr<Interface> COMPtrType; + typedef COMPtrHolder<Interface, _IID> ThisType; + typedef typename detail::EnvironmentSelector<Interface>::Type EnvType; + + COMPtrHolder() {} + + MOZ_IMPLICIT COMPtrHolder(decltype(nullptr)) {} + + explicit COMPtrHolder(COMPtrType&& aPtr) + : mPtr(std::forward<COMPtrType>(aPtr)) {} + + COMPtrHolder(COMPtrType&& aPtr, const ActivationContext& aActCtx) + : mPtr(std::forward<COMPtrType>(aPtr)), mActCtx(aActCtx) {} + + Interface* Get() const { return mPtr.get(); } + + [[nodiscard]] Interface* Release() { return mPtr.release(); } + + void Set(COMPtrType&& aPtr) { mPtr = std::forward<COMPtrType>(aPtr); } + + void SetActCtx(const ActivationContext& aActCtx) { mActCtx = aActCtx; } + +#if defined(MOZ_SANDBOX) + // This method is const because we need to call it during IPC write, where + // we are passed as a const argument. At higher sandboxing levels we need to + // save this artifact from the serialization process for later deletion. + void PreserveStream(PreservedStreamPtr aPtr) const { + MOZ_ASSERT(!mMarshaledStream); + mMarshaledStream = std::move(aPtr); + } + + PreservedStreamPtr GetPreservedStream() { + return std::move(mMarshaledStream); + } +#endif // defined(MOZ_SANDBOX) + + COMPtrHolder(const COMPtrHolder& aOther) = delete; + + COMPtrHolder(COMPtrHolder&& aOther) + : mPtr(std::move(aOther.mPtr)) +#if defined(MOZ_SANDBOX) + , + mMarshaledStream(std::move(aOther.mMarshaledStream)) +#endif // defined(MOZ_SANDBOX) + { + } + + // COMPtrHolder is eventually added as a member of a struct that is declared + // in IPDL. The generated C++ code for that IPDL struct includes copy + // constructors and assignment operators that assume that all members are + // copyable. I don't think that those copy constructors and operator= are + // actually used by any generated code, but they are made available. Since no + // move semantics are available, this terrible hack makes COMPtrHolder build + // when used as a member of an IPDL struct. + ThisType& operator=(const ThisType& aOther) { + Set(std::move(aOther.mPtr)); + +#if defined(MOZ_SANDBOX) + mMarshaledStream = std::move(aOther.mMarshaledStream); +#endif // defined(MOZ_SANDBOX) + + return *this; + } + + ThisType& operator=(ThisType&& aOther) { + Set(std::move(aOther.mPtr)); + +#if defined(MOZ_SANDBOX) + mMarshaledStream = std::move(aOther.mMarshaledStream); +#endif // defined(MOZ_SANDBOX) + + return *this; + } + + bool operator==(const ThisType& aOther) const { return mPtr == aOther.mPtr; } + + bool IsNull() const { return !mPtr; } + + private: + // This is mutable to facilitate the above operator= hack + mutable COMPtrType mPtr; + ActivationContext mActCtx; + +#if defined(MOZ_SANDBOX) + // This is mutable so that we may optionally store a reference to a marshaled + // stream to be cleaned up later via PreserveStream(). + mutable PreservedStreamPtr mMarshaledStream; +#endif // defined(MOZ_SANDBOX) +}; + +} // namespace mscom +} // namespace mozilla + +namespace IPC { + +template <typename Interface, const IID& _IID> +struct ParamTraits<mozilla::mscom::COMPtrHolder<Interface, _IID>> { + typedef mozilla::mscom::COMPtrHolder<Interface, _IID> paramType; + + static void Write(MessageWriter* aWriter, const paramType& aParam) { +#if defined(MOZ_SANDBOX) + static const bool sIsStreamPreservationNeeded = + XRE_IsParentProcess() && + mozilla::GetEffectiveContentSandboxLevel() >= 3; +#else + const bool sIsStreamPreservationNeeded = false; +#endif // defined(MOZ_SANDBOX) + + typename paramType::EnvType env; + + mozilla::mscom::ProxyStreamFlags flags = + sIsStreamPreservationNeeded + ? mozilla::mscom::ProxyStreamFlags::ePreservable + : mozilla::mscom::ProxyStreamFlags::eDefault; + + mozilla::mscom::ProxyStream proxyStream(_IID, aParam.Get(), &env, flags); + int bufLen; + const BYTE* buf = proxyStream.GetBuffer(bufLen); + MOZ_ASSERT(buf || !bufLen); + aWriter->WriteInt(bufLen); + if (bufLen) { + aWriter->WriteBytes(reinterpret_cast<const char*>(buf), bufLen); + } + +#if defined(MOZ_SANDBOX) + if (sIsStreamPreservationNeeded) { + /** + * When we're sending a ProxyStream from parent to content and the + * content sandboxing level is >= 3, content is unable to communicate + * its releasing of its reference to the proxied object. We preserve the + * marshaled proxy data here and later manually release it on content's + * behalf. + */ + aParam.PreserveStream(proxyStream.GetPreservedStream()); + } +#endif // defined(MOZ_SANDBOX) + } + + static bool Read(MessageReader* aReader, paramType* aResult) { + int length; + if (!aReader->ReadLength(&length)) { + return false; + } + + mozilla::UniquePtr<BYTE[]> buf; + if (length) { + buf = mozilla::MakeUnique<BYTE[]>(length); + if (!aReader->ReadBytesInto(buf.get(), length)) { + return false; + } + } + + typename paramType::EnvType env; + + mozilla::mscom::ProxyStream proxyStream(_IID, buf.get(), length, &env); + if (!proxyStream.IsValid()) { + CrashReporter::AnnotateCrashReport( + CrashReporter::Annotation::ProxyStreamValid, "false"_ns); + return false; + } + + typename paramType::COMPtrType ptr; + if (!proxyStream.GetInterface(mozilla::mscom::getter_AddRefs(ptr))) { + return false; + } + + aResult->Set(std::move(ptr)); + return true; + } +}; + +} // namespace IPC + +#endif // mozilla_mscom_COMPtrHolder_h diff --git a/ipc/mscom/COMWrappers.cpp b/ipc/mscom/COMWrappers.cpp new file mode 100644 index 0000000000..1bc48a4927 --- /dev/null +++ b/ipc/mscom/COMWrappers.cpp @@ -0,0 +1,101 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/mscom/COMWrappers.h" + +#include <objbase.h> + +#include "mozilla/Assertions.h" +#include "mozilla/DynamicallyLinkedFunctionPtr.h" + +namespace mozilla::mscom::wrapped { + +HRESULT CoInitializeEx(LPVOID pvReserved, DWORD dwCoInit) { + static const StaticDynamicallyLinkedFunctionPtr<decltype(&::CoInitializeEx)> + pCoInitializeEx(L"combase.dll", "CoInitializeEx"); + if (!pCoInitializeEx) { + return ::CoInitializeEx(pvReserved, dwCoInit); + } + + return pCoInitializeEx(pvReserved, dwCoInit); +} + +void CoUninitialize() { + static const StaticDynamicallyLinkedFunctionPtr<decltype(&::CoUninitialize)> + pCoUninitialize(L"combase.dll", "CoUninitialize"); + if (!pCoUninitialize) { + return ::CoUninitialize(); + } + + return pCoUninitialize(); +} + +HRESULT CoIncrementMTAUsage(CO_MTA_USAGE_COOKIE* pCookie) { + static const StaticDynamicallyLinkedFunctionPtr< + decltype(&::CoIncrementMTAUsage)> + pCoIncrementMTAUsage(L"combase.dll", "CoIncrementMTAUsage"); + // This API is only available beginning with Windows 8. + if (!pCoIncrementMTAUsage) { + return E_NOTIMPL; + } + + HRESULT hr = pCoIncrementMTAUsage(pCookie); + MOZ_ASSERT(SUCCEEDED(hr)); + return hr; +} + +HRESULT CoGetApartmentType(APTTYPE* pAptType, APTTYPEQUALIFIER* pAptQualifier) { + static const StaticDynamicallyLinkedFunctionPtr< + decltype(&::CoGetApartmentType)> + pCoGetApartmentType(L"combase.dll", "CoGetApartmentType"); + if (!pCoGetApartmentType) { + return ::CoGetApartmentType(pAptType, pAptQualifier); + } + + return pCoGetApartmentType(pAptType, pAptQualifier); +} + +HRESULT CoInitializeSecurity(PSECURITY_DESCRIPTOR pSecDesc, LONG cAuthSvc, + SOLE_AUTHENTICATION_SERVICE* asAuthSvc, + void* pReserved1, DWORD dwAuthnLevel, + DWORD dwImpLevel, void* pAuthList, + DWORD dwCapabilities, void* pReserved3) { + static const StaticDynamicallyLinkedFunctionPtr< + decltype(&::CoInitializeSecurity)> + pCoInitializeSecurity(L"combase.dll", "CoInitializeSecurity"); + if (!pCoInitializeSecurity) { + return ::CoInitializeSecurity(pSecDesc, cAuthSvc, asAuthSvc, pReserved1, + dwAuthnLevel, dwImpLevel, pAuthList, + dwCapabilities, pReserved3); + } + + return pCoInitializeSecurity(pSecDesc, cAuthSvc, asAuthSvc, pReserved1, + dwAuthnLevel, dwImpLevel, pAuthList, + dwCapabilities, pReserved3); +} + +HRESULT CoCreateInstance(REFCLSID rclsid, LPUNKNOWN pUnkOuter, + DWORD dwClsContext, REFIID riid, LPVOID* ppv) { + static const StaticDynamicallyLinkedFunctionPtr<decltype(&::CoCreateInstance)> + pCoCreateInstance(L"combase.dll", "CoCreateInstance"); + if (!pCoCreateInstance) { + return ::CoCreateInstance(rclsid, pUnkOuter, dwClsContext, riid, ppv); + } + + return pCoCreateInstance(rclsid, pUnkOuter, dwClsContext, riid, ppv); +} + +HRESULT CoCreateGuid(GUID* pguid) { + static const StaticDynamicallyLinkedFunctionPtr<decltype(&::CoCreateGuid)> + pCoCreateGuid(L"combase.dll", "CoCreateGuid"); + if (!pCoCreateGuid) { + return ::CoCreateGuid(pguid); + } + + return pCoCreateGuid(pguid); +} + +} // namespace mozilla::mscom::wrapped diff --git a/ipc/mscom/COMWrappers.h b/ipc/mscom/COMWrappers.h new file mode 100644 index 0000000000..38bef10749 --- /dev/null +++ b/ipc/mscom/COMWrappers.h @@ -0,0 +1,44 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_COMWrappers_h +#define mozilla_mscom_COMWrappers_h + +#include <objbase.h> + +#if (NTDDI_VERSION < NTDDI_WIN8) +// Win8+ API that we use very carefully +DECLARE_HANDLE(CO_MTA_USAGE_COOKIE); +HRESULT WINAPI CoIncrementMTAUsage(CO_MTA_USAGE_COOKIE* pCookie); +#endif // (NTDDI_VERSION < NTDDI_WIN8) + +// A set of wrapped COM functions, so that we can dynamically link to the +// functions in combase.dll on win8+. This prevents ole32.dll and many other +// DLLs loading, which are not required when we have win32k locked down. +namespace mozilla::mscom::wrapped { + +HRESULT CoInitializeEx(LPVOID pvReserved, DWORD dwCoInit); + +void CoUninitialize(); + +HRESULT CoIncrementMTAUsage(CO_MTA_USAGE_COOKIE* pCookie); + +HRESULT CoGetApartmentType(APTTYPE* pAptType, APTTYPEQUALIFIER* pAptQualifier); + +HRESULT CoInitializeSecurity(PSECURITY_DESCRIPTOR pSecDesc, LONG cAuthSvc, + SOLE_AUTHENTICATION_SERVICE* asAuthSvc, + void* pReserved1, DWORD dwAuthnLevel, + DWORD dwImpLevel, void* pAuthList, + DWORD dwCapabilities, void* pReserved3); + +HRESULT CoCreateInstance(REFCLSID rclsid, LPUNKNOWN pUnkOuter, + DWORD dwClsContext, REFIID riid, LPVOID* ppv); + +HRESULT CoCreateGuid(GUID* pguid); + +} // namespace mozilla::mscom::wrapped + +#endif // mozilla_mscom_COMWrappers_h diff --git a/ipc/mscom/DispatchForwarder.cpp b/ipc/mscom/DispatchForwarder.cpp new file mode 100644 index 0000000000..e55b1316d0 --- /dev/null +++ b/ipc/mscom/DispatchForwarder.cpp @@ -0,0 +1,156 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/mscom/DispatchForwarder.h" + +#include <oleauto.h> + +#include <utility> + +#include "mozilla/mscom/MainThreadInvoker.h" + +namespace mozilla { +namespace mscom { + +/* static */ +HRESULT DispatchForwarder::Create(IInterceptor* aInterceptor, + STAUniquePtr<IDispatch>& aTarget, + IUnknown** aOutput) { + MOZ_ASSERT(aInterceptor && aOutput); + if (!aOutput) { + return E_INVALIDARG; + } + *aOutput = nullptr; + if (!aInterceptor) { + return E_INVALIDARG; + } + DispatchForwarder* forwarder = new DispatchForwarder(aInterceptor, aTarget); + HRESULT hr = forwarder->QueryInterface(IID_IDispatch, (void**)aOutput); + forwarder->Release(); + return hr; +} + +DispatchForwarder::DispatchForwarder(IInterceptor* aInterceptor, + STAUniquePtr<IDispatch>& aTarget) + : mRefCnt(1), mInterceptor(aInterceptor), mTarget(std::move(aTarget)) {} + +DispatchForwarder::~DispatchForwarder() {} + +HRESULT +DispatchForwarder::QueryInterface(REFIID riid, void** ppv) { + if (!ppv) { + return E_INVALIDARG; + } + + // Since this class implements a tearoff, any interfaces that are not + // IDispatch must be routed to the original object's QueryInterface. + // This is especially important for IUnknown since COM uses that interface + // to determine object identity. + if (riid != IID_IDispatch) { + return mInterceptor->QueryInterface(riid, ppv); + } + + IUnknown* punk = static_cast<IDispatch*>(this); + *ppv = punk; + if (!punk) { + return E_NOINTERFACE; + } + + punk->AddRef(); + return S_OK; +} + +ULONG +DispatchForwarder::AddRef() { + return (ULONG)InterlockedIncrement((LONG*)&mRefCnt); +} + +ULONG +DispatchForwarder::Release() { + ULONG newRefCnt = (ULONG)InterlockedDecrement((LONG*)&mRefCnt); + if (newRefCnt == 0) { + delete this; + } + return newRefCnt; +} + +HRESULT +DispatchForwarder::GetTypeInfoCount(UINT* pctinfo) { + if (!pctinfo) { + return E_INVALIDARG; + } + *pctinfo = 1; + return S_OK; +} + +HRESULT +DispatchForwarder::GetTypeInfo(UINT iTInfo, LCID lcid, ITypeInfo** ppTInfo) { + // ITypeInfo as implemented by COM is apartment-neutral, so we don't need + // to wrap it (yay!) + if (mTypeInfo) { + RefPtr<ITypeInfo> copy(mTypeInfo); + copy.forget(ppTInfo); + return S_OK; + } + HRESULT hr = E_UNEXPECTED; + auto fn = [&]() -> void { hr = mTarget->GetTypeInfo(iTInfo, lcid, ppTInfo); }; + MainThreadInvoker invoker; + if (!invoker.Invoke( + NS_NewRunnableFunction("DispatchForwarder::GetTypeInfo", fn))) { + return E_UNEXPECTED; + } + if (FAILED(hr)) { + return hr; + } + mTypeInfo = *ppTInfo; + return hr; +} + +HRESULT +DispatchForwarder::GetIDsOfNames(REFIID riid, LPOLESTR* rgszNames, UINT cNames, + LCID lcid, DISPID* rgDispId) { + HRESULT hr = E_UNEXPECTED; + auto fn = [&]() -> void { + hr = mTarget->GetIDsOfNames(riid, rgszNames, cNames, lcid, rgDispId); + }; + MainThreadInvoker invoker; + if (!invoker.Invoke( + NS_NewRunnableFunction("DispatchForwarder::GetIDsOfNames", fn))) { + return E_UNEXPECTED; + } + return hr; +} + +HRESULT +DispatchForwarder::Invoke(DISPID dispIdMember, REFIID riid, LCID lcid, + WORD wFlags, DISPPARAMS* pDispParams, + VARIANT* pVarResult, EXCEPINFO* pExcepInfo, + UINT* puArgErr) { + HRESULT hr; + if (!mInterface) { + if (!mTypeInfo) { + return E_UNEXPECTED; + } + TYPEATTR* typeAttr = nullptr; + hr = mTypeInfo->GetTypeAttr(&typeAttr); + if (FAILED(hr)) { + return hr; + } + hr = mInterceptor->QueryInterface(typeAttr->guid, + (void**)getter_AddRefs(mInterface)); + mTypeInfo->ReleaseTypeAttr(typeAttr); + if (FAILED(hr)) { + return hr; + } + } + // We don't invoke IDispatch on the target, but rather on the interceptor! + hr = ::DispInvoke(mInterface.get(), mTypeInfo, dispIdMember, wFlags, + pDispParams, pVarResult, pExcepInfo, puArgErr); + return hr; +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/DispatchForwarder.h b/ipc/mscom/DispatchForwarder.h new file mode 100644 index 0000000000..55d3fbb4f2 --- /dev/null +++ b/ipc/mscom/DispatchForwarder.h @@ -0,0 +1,79 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_DispatchForwarder_h +#define mozilla_mscom_DispatchForwarder_h + +#include <oaidl.h> + +#include "mozilla/mscom/Interceptor.h" +#include "mozilla/mscom/Ptr.h" + +namespace mozilla { +namespace mscom { + +class DispatchForwarder final : public IDispatch { + public: + static HRESULT Create(IInterceptor* aInterceptor, + STAUniquePtr<IDispatch>& aTarget, IUnknown** aOutput); + + // IUnknown + STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override; + STDMETHODIMP_(ULONG) AddRef() override; + STDMETHODIMP_(ULONG) Release() override; + + // IDispatch + STDMETHODIMP GetTypeInfoCount( + /* [out] */ __RPC__out UINT* pctinfo) override; + + STDMETHODIMP GetTypeInfo( + /* [in] */ UINT iTInfo, + /* [in] */ LCID lcid, + /* [out] */ __RPC__deref_out_opt ITypeInfo** ppTInfo) override; + + STDMETHODIMP GetIDsOfNames( + /* [in] */ __RPC__in REFIID riid, + /* [size_is][in] */ __RPC__in_ecount_full(cNames) LPOLESTR* rgszNames, + /* [range][in] */ __RPC__in_range(0, 16384) UINT cNames, + /* [in] */ LCID lcid, + /* [size_is][out] */ __RPC__out_ecount_full(cNames) DISPID* rgDispId) + override; + + STDMETHODIMP Invoke( + /* [annotation][in] */ + _In_ DISPID dispIdMember, + /* [annotation][in] */ + _In_ REFIID riid, + /* [annotation][in] */ + _In_ LCID lcid, + /* [annotation][in] */ + _In_ WORD wFlags, + /* [annotation][out][in] */ + _In_ DISPPARAMS* pDispParams, + /* [annotation][out] */ + _Out_opt_ VARIANT* pVarResult, + /* [annotation][out] */ + _Out_opt_ EXCEPINFO* pExcepInfo, + /* [annotation][out] */ + _Out_opt_ UINT* puArgErr) override; + + private: + DispatchForwarder(IInterceptor* aInterceptor, + STAUniquePtr<IDispatch>& aTarget); + ~DispatchForwarder(); + + private: + ULONG mRefCnt; + RefPtr<IInterceptor> mInterceptor; + STAUniquePtr<IDispatch> mTarget; + RefPtr<ITypeInfo> mTypeInfo; + RefPtr<IUnknown> mInterface; +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_DispatchForwarder_h diff --git a/ipc/mscom/EnsureMTA.cpp b/ipc/mscom/EnsureMTA.cpp new file mode 100644 index 0000000000..2258dd9dcd --- /dev/null +++ b/ipc/mscom/EnsureMTA.cpp @@ -0,0 +1,256 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/mscom/EnsureMTA.h" + +#include "mozilla/Assertions.h" +#include "mozilla/ClearOnShutdown.h" +#include "mozilla/DebugOnly.h" +#include "mozilla/mscom/COMWrappers.h" +#include "mozilla/mscom/Utils.h" +#include "mozilla/SchedulerGroup.h" +#include "mozilla/StaticLocalPtr.h" +#include "nsThreadManager.h" +#include "nsThreadUtils.h" + +#include "private/pprthred.h" + +namespace { + +class EnterMTARunnable : public mozilla::Runnable { + public: + EnterMTARunnable() : mozilla::Runnable("EnterMTARunnable") {} + NS_IMETHOD Run() override { + mozilla::DebugOnly<HRESULT> hr = + mozilla::mscom::wrapped::CoInitializeEx(nullptr, COINIT_MULTITHREADED); + MOZ_ASSERT(SUCCEEDED(hr)); + return NS_OK; + } +}; + +class BackgroundMTAData { + public: + BackgroundMTAData() { + nsCOMPtr<nsIRunnable> runnable = new EnterMTARunnable(); + mozilla::DebugOnly<nsresult> rv = NS_NewNamedThread( + "COM MTA", getter_AddRefs(mThread), runnable.forget()); + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "NS_NewNamedThread failed"); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + } + + ~BackgroundMTAData() { + if (mThread) { + mThread->Dispatch( + NS_NewRunnableFunction("BackgroundMTAData::~BackgroundMTAData", + &mozilla::mscom::wrapped::CoUninitialize), + NS_DISPATCH_NORMAL); + mThread->Shutdown(); + } + } + + nsCOMPtr<nsIThread> GetThread() const { return mThread; } + + private: + nsCOMPtr<nsIThread> mThread; +}; + +} // anonymous namespace + +namespace mozilla { +namespace mscom { + +EnsureMTA::EnsureMTA() { + MOZ_ASSERT(NS_IsMainThread()); + + // It is possible that we're running so early that we might need to start + // the thread manager ourselves. We do this here to guarantee that we have + // the ability to start the persistent MTA thread at any moment beyond this + // point. + nsresult rv = nsThreadManager::get().Init(); + // We intentionally don't check rv unless we need it when + // CoIncremementMTAUsage is unavailable. + + // Calling this function initializes the MTA without needing to explicitly + // create a thread and call CoInitializeEx to do it. + // We don't retain the cookie because once we've incremented the MTA, we + // leave it that way for the lifetime of the process. + CO_MTA_USAGE_COOKIE mtaCookie = nullptr; + HRESULT hr = wrapped::CoIncrementMTAUsage(&mtaCookie); + if (SUCCEEDED(hr)) { + if (NS_SUCCEEDED(rv)) { + // Start the persistent MTA thread (mostly) asynchronously. + Unused << GetPersistentMTAThread(); + } + + return; + } + + // In the fallback case, we simply initialize our persistent MTA thread. + + // Make sure thread manager init succeeded before trying to initialize the + // persistent MTA thread. + MOZ_DIAGNOSTIC_ASSERT(NS_SUCCEEDED(rv)); + if (NS_FAILED(rv)) { + return; + } + + // Before proceeding any further, pump a runnable through the persistent MTA + // thread to ensure that it is up and running and has finished initializing + // the multi-threaded apartment. + nsCOMPtr<nsIRunnable> runnable(NS_NewRunnableFunction( + "EnsureMTA::EnsureMTA()", + []() { MOZ_RELEASE_ASSERT(IsCurrentThreadExplicitMTA()); })); + SyncDispatchToPersistentThread(runnable); +} + +/* static */ +RefPtr<EnsureMTA::CreateInstanceAgileRefPromise> +EnsureMTA::CreateInstanceInternal(REFCLSID aClsid, REFIID aIid) { + MOZ_ASSERT(IsCurrentThreadExplicitMTA()); + + RefPtr<IUnknown> iface; + HRESULT hr = wrapped::CoCreateInstance(aClsid, nullptr, CLSCTX_INPROC_SERVER, + aIid, getter_AddRefs(iface)); + if (FAILED(hr)) { + return CreateInstanceAgileRefPromise::CreateAndReject(hr, __func__); + } + + // We need to use the two argument constructor for AgileReference because our + // RefPtr is not parameterized on the specific interface being requested. + AgileReference agileRef(aIid, iface); + if (!agileRef) { + return CreateInstanceAgileRefPromise::CreateAndReject(agileRef.GetHResult(), + __func__); + } + + return CreateInstanceAgileRefPromise::CreateAndResolve(std::move(agileRef), + __func__); +} + +/* static */ +RefPtr<EnsureMTA::CreateInstanceAgileRefPromise> EnsureMTA::CreateInstance( + REFCLSID aClsid, REFIID aIid) { + MOZ_ASSERT(IsCOMInitializedOnCurrentThread()); + + const bool isClassOk = IsClassThreadAwareInprocServer(aClsid); + MOZ_ASSERT(isClassOk, + "mozilla::mscom::EnsureMTA::CreateInstance is not " + "safe/performant/necessary to use with this CLSID. This CLSID " + "either does not support creation from within a multithreaded " + "apartment, or it is not an in-process server."); + if (!isClassOk) { + return CreateInstanceAgileRefPromise::CreateAndReject(CO_E_NOT_SUPPORTED, + __func__); + } + + if (IsCurrentThreadExplicitMTA()) { + // It's safe to immediately call CreateInstanceInternal + return CreateInstanceInternal(aClsid, aIid); + } + + // aClsid and aIid are references. Make local copies that we can put into the + // lambda in case the sources of aClsid or aIid are not static data + CLSID localClsid = aClsid; + IID localIid = aIid; + + auto invoker = [localClsid, + localIid]() -> RefPtr<CreateInstanceAgileRefPromise> { + return CreateInstanceInternal(localClsid, localIid); + }; + + nsCOMPtr<nsIThread> mtaThread(GetPersistentMTAThread()); + + return InvokeAsync(mtaThread->SerialEventTarget(), __func__, + std::move(invoker)); +} + +/* static */ +nsCOMPtr<nsIThread> EnsureMTA::GetPersistentMTAThread() { + static StaticLocalAutoPtr<BackgroundMTAData> sMTAData( + []() -> BackgroundMTAData* { + BackgroundMTAData* bgData = new BackgroundMTAData(); + + auto setClearOnShutdown = [ptr = &sMTAData]() -> void { + ClearOnShutdown(ptr, ShutdownPhase::XPCOMShutdownThreads); + }; + + if (NS_IsMainThread()) { + setClearOnShutdown(); + return bgData; + } + + SchedulerGroup::Dispatch( + TaskCategory::Other, + NS_NewRunnableFunction("mscom::EnsureMTA::GetPersistentMTAThread", + std::move(setClearOnShutdown))); + + return bgData; + }()); + + MOZ_ASSERT(sMTAData); + + return sMTAData->GetThread(); +} + +/* static */ +void EnsureMTA::SyncDispatchToPersistentThread(nsIRunnable* aRunnable) { + nsCOMPtr<nsIThread> thread(GetPersistentMTAThread()); + MOZ_ASSERT(thread); + if (!thread) { + return; + } + + // Note that, due to APC dispatch, we might reenter this function while we + // wait on this event. We therefore need a unique event object for each + // entry into this function. If perf becomes an issue then we will want to + // maintain an array of events where the Nth event is unique to the Nth + // reentry. + nsAutoHandle event(::CreateEventW(nullptr, FALSE, FALSE, nullptr)); + if (!event) { + return; + } + + HANDLE eventHandle = event.get(); + auto eventSetter = [&aRunnable, eventHandle]() -> void { + aRunnable->Run(); + ::SetEvent(eventHandle); + }; + + nsresult rv = thread->Dispatch( + NS_NewRunnableFunction("mscom::EnsureMTA::SyncDispatchToPersistentThread", + std::move(eventSetter)), + NS_DISPATCH_NORMAL); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + if (NS_FAILED(rv)) { + return; + } + +#if defined(ACCESSIBILITY) + const BOOL alertable = XRE_IsContentProcess() && NS_IsMainThread(); +#else + const BOOL alertable = FALSE; +#endif // defined(ACCESSIBILITY) + + AUTO_PROFILER_THREAD_SLEEP; + DWORD waitResult; + while ((waitResult = ::WaitForSingleObjectEx(event, INFINITE, alertable)) == + WAIT_IO_COMPLETION) { + } + MOZ_ASSERT(waitResult == WAIT_OBJECT_0); +} + +/** + * While this function currently appears to be redundant, it may become more + * sophisticated in the future. For example, we could optionally dispatch to an + * MTA context if we wanted to utilize the MTA thread pool. + */ +/* static */ +void EnsureMTA::SyncDispatch(nsCOMPtr<nsIRunnable>&& aRunnable, Option aOpt) { + SyncDispatchToPersistentThread(aRunnable); +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/EnsureMTA.h b/ipc/mscom/EnsureMTA.h new file mode 100644 index 0000000000..5e410d4daa --- /dev/null +++ b/ipc/mscom/EnsureMTA.h @@ -0,0 +1,191 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_EnsureMTA_h +#define mozilla_mscom_EnsureMTA_h + +#include "MainThreadUtils.h" +#include "mozilla/Attributes.h" +#include "mozilla/DebugOnly.h" +#include "mozilla/MozPromise.h" +#include "mozilla/Unused.h" +#include "mozilla/mscom/AgileReference.h" +#include "mozilla/mscom/Utils.h" +#include "mozilla/RefPtr.h" +#include "nsCOMPtr.h" +#include "nsIThread.h" +#include "nsThreadUtils.h" +#include "nsWindowsHelpers.h" + +#include <windows.h> + +namespace mozilla { +namespace mscom { +namespace detail { + +// Forward declarations +template <typename T> +struct MTADelete; + +template <typename T> +struct MTARelease; + +template <typename T> +struct MTAReleaseInChildProcess; + +struct PreservedStreamDeleter; + +} // namespace detail + +class ProcessRuntime; + +// This class is OK to use as a temporary on the stack. +class MOZ_STACK_CLASS EnsureMTA final { + public: + enum class Option { + Default, + // Forcibly dispatch to the thread returned by GetPersistentMTAThread(), + // even if the current thread is already inside a MTA. + ForceDispatchToPersistentThread, + }; + + /** + * Synchronously run |aClosure| on a thread living in the COM multithreaded + * apartment. If the current thread lives inside the COM MTA, then it runs + * |aClosure| immediately unless |aOpt| == + * Option::ForceDispatchToPersistentThread. + */ + template <typename FuncT> + explicit EnsureMTA(FuncT&& aClosure, Option aOpt = Option::Default) { + if (aOpt != Option::ForceDispatchToPersistentThread && + IsCurrentThreadMTA()) { + // We're already on the MTA, we can run aClosure directly + aClosure(); + return; + } + + // In this case we need to run aClosure on a background thread in the MTA + nsCOMPtr<nsIRunnable> runnable( + NS_NewRunnableFunction("EnsureMTA::EnsureMTA", std::move(aClosure))); + SyncDispatch(std::move(runnable), aOpt); + } + + using CreateInstanceAgileRefPromise = + MozPromise<AgileReference, HRESULT, false>; + + /** + * *** A MSCOM PEER SHOULD REVIEW ALL NEW USES OF THIS API! *** + * + * Asynchronously instantiate a new COM object from a MTA thread, unless the + * current thread is already living inside the multithreaded apartment, in + * which case the object is immediately instantiated. + * + * This function only supports the most common configurations for creating + * a new object, so it only supports in-process servers. Furthermore, this + * function does not support aggregation (ie. the |pUnkOuter| parameter to + * CoCreateInstance). + * + * Given that attempting to instantiate an Apartment-threaded COM object + * inside the MTA results in a *loss* of performance, we assert when that + * situation arises. + * + * The resulting promise, once resolved, provides an AgileReference that may + * be passed between any COM-initialized thread in the current process. + * + * *** A MSCOM PEER SHOULD REVIEW ALL NEW USES OF THIS API! *** + * + * WARNING: + * Some COM objects do not support creation in the multithreaded apartment, + * in which case this function is not available as an option. In this case, + * the promise will always be rejected. In debug builds we will assert. + * + * *** A MSCOM PEER SHOULD REVIEW ALL NEW USES OF THIS API! *** + * + * WARNING: + * Any in-process COM objects whose interfaces accept HWNDs are probably + * *not* safe to instantiate in the multithreaded apartment! Even if this + * function succeeds when creating such an object, you *MUST NOT* do so, as + * these failures might not become apparent until your code is running out in + * the wild on the release channel! + * + * *** A MSCOM PEER SHOULD REVIEW ALL NEW USES OF THIS API! *** + * + * WARNING: + * When you obtain an interface from the AgileReference, it may or may not be + * a proxy to the real object. This depends entirely on the implementation of + * the underlying class and the multithreading capabilities that the class + * declares to the COM runtime. If the interface is proxied, it might be + * expensive to invoke methods on that interface! *Always* test the + * performance of your method calls when calling interfaces that are resolved + * via this function! + * + * *** A MSCOM PEER SHOULD REVIEW ALL NEW USES OF THIS API! *** + * + * (Despite this myriad of warnings, it is still *much* safer to use this + * function to asynchronously create COM objects than it is to roll your own!) + * + * *** A MSCOM PEER SHOULD REVIEW ALL NEW USES OF THIS API! *** + */ + static RefPtr<CreateInstanceAgileRefPromise> CreateInstance(REFCLSID aClsid, + REFIID aIid); + + private: + static RefPtr<CreateInstanceAgileRefPromise> CreateInstanceInternal( + REFCLSID aClsid, REFIID aIid); + + static nsCOMPtr<nsIThread> GetPersistentMTAThread(); + + static void SyncDispatch(nsCOMPtr<nsIRunnable>&& aRunnable, Option aOpt); + static void SyncDispatchToPersistentThread(nsIRunnable* aRunnable); + + // The following function is private in order to force any consumers to be + // declared as friends of EnsureMTA. The intention is to prevent + // AsyncOperation from becoming some kind of free-for-all mechanism for + // asynchronously executing work on a background thread. + template <typename FuncT> + static void AsyncOperation(FuncT&& aClosure) { + if (IsCurrentThreadMTA()) { + aClosure(); + return; + } + + nsCOMPtr<nsIThread> thread(GetPersistentMTAThread()); + MOZ_ASSERT(thread); + if (!thread) { + return; + } + + DebugOnly<nsresult> rv = thread->Dispatch( + NS_NewRunnableFunction("mscom::EnsureMTA::AsyncOperation", + std::move(aClosure)), + NS_DISPATCH_NORMAL); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + } + + /** + * This constructor just ensures that the MTA is up and running. This should + * only be called by ProcessRuntime. + */ + EnsureMTA(); + + friend class mozilla::mscom::ProcessRuntime; + + template <typename T> + friend struct mozilla::mscom::detail::MTADelete; + + template <typename T> + friend struct mozilla::mscom::detail::MTARelease; + + template <typename T> + friend struct mozilla::mscom::detail::MTAReleaseInChildProcess; + + friend struct mozilla::mscom::detail::PreservedStreamDeleter; +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_EnsureMTA_h diff --git a/ipc/mscom/FastMarshaler.cpp b/ipc/mscom/FastMarshaler.cpp new file mode 100644 index 0000000000..4ae62394a2 --- /dev/null +++ b/ipc/mscom/FastMarshaler.cpp @@ -0,0 +1,164 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/mscom/FastMarshaler.h" + +#include "mozilla/mscom/Utils.h" + +#include <objbase.h> + +namespace mozilla { +namespace mscom { + +HRESULT +FastMarshaler::Create(IUnknown* aOuter, IUnknown** aOutMarshalerUnk) { + MOZ_ASSERT(XRE_IsContentProcess()); + + if (!aOuter || !aOutMarshalerUnk) { + return E_INVALIDARG; + } + + *aOutMarshalerUnk = nullptr; + + HRESULT hr; + RefPtr<FastMarshaler> fm(new FastMarshaler(aOuter, &hr)); + if (FAILED(hr)) { + return hr; + } + + return fm->InternalQueryInterface(IID_IUnknown, (void**)aOutMarshalerUnk); +} + +FastMarshaler::FastMarshaler(IUnknown* aOuter, HRESULT* aResult) + : mRefCnt(0), mOuter(aOuter), mStdMarshalWeak(nullptr) { + *aResult = + ::CoGetStdMarshalEx(aOuter, SMEXF_SERVER, getter_AddRefs(mStdMarshalUnk)); + if (FAILED(*aResult)) { + return; + } + + *aResult = + mStdMarshalUnk->QueryInterface(IID_IMarshal, (void**)&mStdMarshalWeak); + if (FAILED(*aResult)) { + return; + } + + // mStdMarshalWeak is weak + mStdMarshalWeak->Release(); +} + +HRESULT +FastMarshaler::InternalQueryInterface(REFIID riid, void** ppv) { + if (!ppv) { + return E_INVALIDARG; + } + + if (riid == IID_IUnknown) { + RefPtr<IUnknown> punk(static_cast<IUnknown*>(&mInternalUnknown)); + punk.forget(ppv); + return S_OK; + } + + if (riid == IID_IMarshal) { + RefPtr<IMarshal> ptr(this); + ptr.forget(ppv); + return S_OK; + } + + return mStdMarshalUnk->QueryInterface(riid, ppv); +} + +ULONG +FastMarshaler::InternalAddRef() { return ++mRefCnt; } + +ULONG +FastMarshaler::InternalRelease() { + ULONG result = --mRefCnt; + if (!result) { + delete this; + } + + return result; +} + +DWORD +FastMarshaler::GetMarshalFlags(DWORD aDestContext, DWORD aMshlFlags) { + // Only worry about local contexts. + if (aDestContext != MSHCTX_LOCAL) { + return aMshlFlags; + } + + return aMshlFlags | MSHLFLAGS_NOPING; +} + +HRESULT +FastMarshaler::GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + CLSID* pCid) { + if (!mStdMarshalWeak) { + return E_POINTER; + } + + return mStdMarshalWeak->GetUnmarshalClass( + riid, pv, dwDestContext, pvDestContext, + GetMarshalFlags(dwDestContext, mshlflags), pCid); +} + +HRESULT +FastMarshaler::GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + DWORD* pSize) { + if (!mStdMarshalWeak) { + return E_POINTER; + } + + return mStdMarshalWeak->GetMarshalSizeMax( + riid, pv, dwDestContext, pvDestContext, + GetMarshalFlags(dwDestContext, mshlflags), pSize); +} + +HRESULT +FastMarshaler::MarshalInterface(IStream* pStm, REFIID riid, void* pv, + DWORD dwDestContext, void* pvDestContext, + DWORD mshlflags) { + if (!mStdMarshalWeak) { + return E_POINTER; + } + + return mStdMarshalWeak->MarshalInterface( + pStm, riid, pv, dwDestContext, pvDestContext, + GetMarshalFlags(dwDestContext, mshlflags)); +} + +HRESULT +FastMarshaler::UnmarshalInterface(IStream* pStm, REFIID riid, void** ppv) { + if (!mStdMarshalWeak) { + return E_POINTER; + } + + return mStdMarshalWeak->UnmarshalInterface(pStm, riid, ppv); +} + +HRESULT +FastMarshaler::ReleaseMarshalData(IStream* pStm) { + if (!mStdMarshalWeak) { + return E_POINTER; + } + + return mStdMarshalWeak->ReleaseMarshalData(pStm); +} + +HRESULT +FastMarshaler::DisconnectObject(DWORD dwReserved) { + if (!mStdMarshalWeak) { + return E_POINTER; + } + + return mStdMarshalWeak->DisconnectObject(dwReserved); +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/FastMarshaler.h b/ipc/mscom/FastMarshaler.h new file mode 100644 index 0000000000..e1f3e88801 --- /dev/null +++ b/ipc/mscom/FastMarshaler.h @@ -0,0 +1,66 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_FastMarshaler_h +#define mozilla_mscom_FastMarshaler_h + +#include "mozilla/Atomics.h" +#include "mozilla/mscom/Aggregation.h" +#include "mozilla/RefPtr.h" + +#include <objidl.h> + +namespace mozilla { +namespace mscom { + +/** + * COM ping functionality is enabled by default and is designed to free strong + * references held by defunct client processes. However, this incurs a + * significant performance penalty in a11y code due to large numbers of remote + * objects being created and destroyed within a short period of time. Thus, we + * turn off pings to improve performance. + * ACHTUNG! When COM pings are disabled, Release calls from remote clients are + * never sent to the server! If you use this marshaler, you *must* explicitly + * disconnect clients using CoDisconnectObject when the object is no longer + * relevant. Otherwise, references to the object will never be released, causing + * a leak. + */ +class FastMarshaler final : public IMarshal { + public: + static HRESULT Create(IUnknown* aOuter, IUnknown** aOutMarshalerUnk); + + // IMarshal + STDMETHODIMP GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + CLSID* pCid) override; + STDMETHODIMP GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + DWORD* pSize) override; + STDMETHODIMP MarshalInterface(IStream* pStm, REFIID riid, void* pv, + DWORD dwDestContext, void* pvDestContext, + DWORD mshlflags) override; + STDMETHODIMP UnmarshalInterface(IStream* pStm, REFIID riid, + void** ppv) override; + STDMETHODIMP ReleaseMarshalData(IStream* pStm) override; + STDMETHODIMP DisconnectObject(DWORD dwReserved) override; + + private: + FastMarshaler(IUnknown* aOuter, HRESULT* aResult); + ~FastMarshaler() = default; + + static DWORD GetMarshalFlags(DWORD aDestContext, DWORD aMshlFlags); + + Atomic<ULONG> mRefCnt; + IUnknown* mOuter; + RefPtr<IUnknown> mStdMarshalUnk; + IMarshal* mStdMarshalWeak; + DECLARE_AGGREGATABLE(FastMarshaler); +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_FastMarshaler_h diff --git a/ipc/mscom/IHandlerProvider.h b/ipc/mscom/IHandlerProvider.h new file mode 100644 index 0000000000..76963fb9f6 --- /dev/null +++ b/ipc/mscom/IHandlerProvider.h @@ -0,0 +1,51 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_IHandlerProvider_h +#define mozilla_mscom_IHandlerProvider_h + +#include "mozilla/NotNull.h" +#include "mozilla/mscom/Ptr.h" + +#include <objidl.h> + +namespace mozilla { +namespace mscom { + +struct IInterceptor; + +struct HandlerProvider { + virtual STDMETHODIMP GetHandler(NotNull<CLSID*> aHandlerClsid) = 0; + virtual STDMETHODIMP GetHandlerPayloadSize( + NotNull<IInterceptor*> aInterceptor, NotNull<DWORD*> aOutPayloadSize) = 0; + virtual STDMETHODIMP WriteHandlerPayload(NotNull<IInterceptor*> aInterceptor, + NotNull<IStream*> aStream) = 0; + virtual STDMETHODIMP_(REFIID) MarshalAs(REFIID aIid) = 0; + virtual STDMETHODIMP DisconnectHandlerRemotes() = 0; + + /** + * Determine whether this interface might be supported by objects using + * this HandlerProvider. + * This is used to avoid unnecessary cross-thread QueryInterface calls for + * interfaces known to be unsupported. + * Return S_OK if the interface might be supported, E_NOINTERFACE if it + * definitely isn't supported. + */ + virtual STDMETHODIMP IsInterfaceMaybeSupported(REFIID aIid) { return S_OK; } +}; + +struct IHandlerProvider : public IUnknown, public HandlerProvider { + virtual STDMETHODIMP_(REFIID) + GetEffectiveOutParamIid(REFIID aCallIid, ULONG aCallMethod) = 0; + virtual STDMETHODIMP NewInstance( + REFIID aIid, InterceptorTargetPtr<IUnknown> aTarget, + NotNull<IHandlerProvider**> aOutNewPayload) = 0; +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_IHandlerProvider_h diff --git a/ipc/mscom/Interceptor.cpp b/ipc/mscom/Interceptor.cpp new file mode 100644 index 0000000000..0e223bba26 --- /dev/null +++ b/ipc/mscom/Interceptor.cpp @@ -0,0 +1,860 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#define INITGUID + +#include "mozilla/mscom/Interceptor.h" + +#include <utility> + +#include "MainThreadUtils.h" +#include "mozilla/Assertions.h" +#include "mozilla/ThreadLocal.h" +#include "mozilla/Unused.h" +#include "mozilla/mscom/DispatchForwarder.h" +#include "mozilla/mscom/FastMarshaler.h" +#include "mozilla/mscom/InterceptorLog.h" +#include "mozilla/mscom/MainThreadInvoker.h" +#include "mozilla/mscom/Objref.h" +#include "mozilla/mscom/Registration.h" +#include "mozilla/mscom/Utils.h" +#include "nsDirectoryServiceDefs.h" +#include "nsDirectoryServiceUtils.h" +#include "nsPrintfCString.h" +#include "nsRefPtrHashtable.h" +#include "nsThreadUtils.h" +#include "nsXULAppAPI.h" + +#define ENSURE_HR_SUCCEEDED(hr) \ + MOZ_ASSERT(SUCCEEDED((HRESULT)hr)); \ + if (FAILED((HRESULT)hr)) { \ + return hr; \ + } + +namespace mozilla { +namespace mscom { +namespace detail { + +class MOZ_CAPABILITY("mutex") LiveSet final { + public: + LiveSet() : mMutex("mozilla::mscom::LiveSet::mMutex") {} + + void Lock() MOZ_CAPABILITY_ACQUIRE(mMutex) { mMutex.Lock(); } + + void Unlock() MOZ_CAPABILITY_RELEASE(mMutex) { mMutex.Unlock(); } + + void Put(IUnknown* aKey, already_AddRefed<IWeakReference> aValue) { + mMutex.AssertCurrentThreadOwns(); + mLiveSet.InsertOrUpdate(aKey, RefPtr<IWeakReference>{std::move(aValue)}); + } + + RefPtr<IWeakReference> Get(IUnknown* aKey) { + mMutex.AssertCurrentThreadOwns(); + RefPtr<IWeakReference> result; + mLiveSet.Get(aKey, getter_AddRefs(result)); + return result; + } + + void Remove(IUnknown* aKey) { + mMutex.AssertCurrentThreadOwns(); + mLiveSet.Remove(aKey); + } + + private: + Mutex mMutex MOZ_UNANNOTATED; + nsRefPtrHashtable<nsPtrHashKey<IUnknown>, IWeakReference> mLiveSet; +}; + +/** + * We don't use the normal XPCOM BaseAutoLock because we need the ability + * to explicitly Unlock. + */ +class MOZ_RAII MOZ_SCOPED_CAPABILITY LiveSetAutoLock final { + public: + explicit LiveSetAutoLock(LiveSet& aLiveSet) MOZ_CAPABILITY_ACQUIRE(aLiveSet) + : mLiveSet(&aLiveSet) { + aLiveSet.Lock(); + } + + ~LiveSetAutoLock() MOZ_CAPABILITY_RELEASE() { + if (mLiveSet) { + mLiveSet->Unlock(); + } + } + + void Unlock() MOZ_CAPABILITY_RELEASE() { + MOZ_ASSERT(mLiveSet); + if (mLiveSet) { + mLiveSet->Unlock(); + mLiveSet = nullptr; + } + } + + LiveSetAutoLock(const LiveSetAutoLock& aOther) = delete; + LiveSetAutoLock(LiveSetAutoLock&& aOther) = delete; + LiveSetAutoLock& operator=(const LiveSetAutoLock& aOther) = delete; + LiveSetAutoLock& operator=(LiveSetAutoLock&& aOther) = delete; + + private: + LiveSet* mLiveSet; +}; + +class MOZ_RAII ReentrySentinel final { + public: + explicit ReentrySentinel(Interceptor* aCurrent) : mCurInterceptor(aCurrent) { + static const bool kHasTls = tlsSentinelStackTop.init(); + MOZ_RELEASE_ASSERT(kHasTls); + + mPrevSentinel = tlsSentinelStackTop.get(); + tlsSentinelStackTop.set(this); + } + + ~ReentrySentinel() { tlsSentinelStackTop.set(mPrevSentinel); } + + bool IsOutermost() const { + return !(mPrevSentinel && mPrevSentinel->IsMarshaling(mCurInterceptor)); + } + + ReentrySentinel(const ReentrySentinel&) = delete; + ReentrySentinel(ReentrySentinel&&) = delete; + ReentrySentinel& operator=(const ReentrySentinel&) = delete; + ReentrySentinel& operator=(ReentrySentinel&&) = delete; + + private: + bool IsMarshaling(Interceptor* aTopInterceptor) const { + return aTopInterceptor == mCurInterceptor || + (mPrevSentinel && mPrevSentinel->IsMarshaling(aTopInterceptor)); + } + + private: + Interceptor* mCurInterceptor; + ReentrySentinel* mPrevSentinel; + + static MOZ_THREAD_LOCAL(ReentrySentinel*) tlsSentinelStackTop; +}; + +MOZ_THREAD_LOCAL(ReentrySentinel*) ReentrySentinel::tlsSentinelStackTop; + +class MOZ_RAII LoggedQIResult final { + public: + explicit LoggedQIResult(REFIID aIid) + : mIid(aIid), + mHr(E_UNEXPECTED), + mTarget(nullptr), + mInterceptor(nullptr), + mBegin(TimeStamp::Now()) {} + + ~LoggedQIResult() { + if (!mTarget) { + return; + } + + TimeStamp end(TimeStamp::Now()); + TimeDuration total(end - mBegin); + TimeDuration overhead(total - mNonOverheadDuration); + + InterceptorLog::QI(mHr, mTarget, mIid, mInterceptor, &overhead, + &mNonOverheadDuration); + } + + void Log(IUnknown* aTarget, IUnknown* aInterceptor) { + mTarget = aTarget; + mInterceptor = aInterceptor; + } + + void operator=(HRESULT aHr) { mHr = aHr; } + + operator HRESULT() { return mHr; } + + operator TimeDuration*() { return &mNonOverheadDuration; } + + LoggedQIResult(const LoggedQIResult&) = delete; + LoggedQIResult(LoggedQIResult&&) = delete; + LoggedQIResult& operator=(const LoggedQIResult&) = delete; + LoggedQIResult& operator=(LoggedQIResult&&) = delete; + + private: + REFIID mIid; + HRESULT mHr; + IUnknown* mTarget; + IUnknown* mInterceptor; + TimeDuration mNonOverheadDuration; + TimeStamp mBegin; +}; + +} // namespace detail + +static detail::LiveSet& GetLiveSet() { + static detail::LiveSet sLiveSet; + return sLiveSet; +} + +MOZ_THREAD_LOCAL(bool) Interceptor::tlsCreatingStdMarshal; + +/* static */ +HRESULT Interceptor::Create(STAUniquePtr<IUnknown> aTarget, + IInterceptorSink* aSink, REFIID aInitialIid, + void** aOutInterface) { + MOZ_ASSERT(aOutInterface && aTarget && aSink); + if (!aOutInterface) { + return E_INVALIDARG; + } + + detail::LiveSetAutoLock lock(GetLiveSet()); + + RefPtr<IWeakReference> existingWeak(GetLiveSet().Get(aTarget.get())); + if (existingWeak) { + RefPtr<IWeakReferenceSource> existingStrong; + if (SUCCEEDED(existingWeak->ToStrongRef(getter_AddRefs(existingStrong)))) { + // QI on existingStrong may touch other threads. Since we now hold a + // strong ref on the interceptor, we may now release the lock. + lock.Unlock(); + return existingStrong->QueryInterface(aInitialIid, aOutInterface); + } + } + + *aOutInterface = nullptr; + + if (!aTarget || !aSink) { + return E_INVALIDARG; + } + + RefPtr<Interceptor> intcpt(new Interceptor(aSink)); + return intcpt->GetInitialInterceptorForIID(lock, aInitialIid, + std::move(aTarget), aOutInterface); +} + +Interceptor::Interceptor(IInterceptorSink* aSink) + : WeakReferenceSupport(WeakReferenceSupport::Flags::eDestroyOnMainThread), + mEventSink(aSink), + mInterceptorMapMutex("mozilla::mscom::Interceptor::mInterceptorMapMutex"), + mStdMarshalMutex("mozilla::mscom::Interceptor::mStdMarshalMutex"), + mStdMarshal(nullptr) { + static const bool kHasTls = tlsCreatingStdMarshal.init(); + MOZ_ASSERT(kHasTls); + Unused << kHasTls; + + MOZ_ASSERT(aSink); + RefPtr<IWeakReference> weakRef; + if (SUCCEEDED(GetWeakReference(getter_AddRefs(weakRef)))) { + aSink->SetInterceptor(weakRef); + } +} + +Interceptor::~Interceptor() { + { // Scope for lock + detail::LiveSetAutoLock lock(GetLiveSet()); + GetLiveSet().Remove(mTarget.get()); + } + + // This needs to run on the main thread because it releases target interface + // reference counts which may not be thread-safe. + MOZ_ASSERT(NS_IsMainThread()); + for (uint32_t index = 0, len = mInterceptorMap.Length(); index < len; + ++index) { + MapEntry& entry = mInterceptorMap[index]; + entry.mInterceptor = nullptr; + entry.mTargetInterface->Release(); + } +} + +HRESULT +Interceptor::GetClassForHandler(DWORD aDestContext, void* aDestContextPtr, + CLSID* aHandlerClsid) { + if (aDestContextPtr || !aHandlerClsid || + aDestContext == MSHCTX_DIFFERENTMACHINE) { + return E_INVALIDARG; + } + + MOZ_ASSERT(mEventSink); + return mEventSink->GetHandler(WrapNotNull(aHandlerClsid)); +} + +REFIID +Interceptor::MarshalAs(REFIID aIid) const { +#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) + return IsCallerExternalProcess() ? aIid : mEventSink->MarshalAs(aIid); +#else + return mEventSink->MarshalAs(aIid); +#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) +} + +HRESULT +Interceptor::GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + CLSID* pCid) { + return mStdMarshal->GetUnmarshalClass(MarshalAs(riid), pv, dwDestContext, + pvDestContext, mshlflags, pCid); +} + +HRESULT +Interceptor::GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + DWORD* pSize) { + detail::ReentrySentinel sentinel(this); + + HRESULT hr = mStdMarshal->GetMarshalSizeMax( + MarshalAs(riid), pv, dwDestContext, pvDestContext, mshlflags, pSize); + if (FAILED(hr) || !sentinel.IsOutermost()) { + return hr; + } + +#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) + if (XRE_IsContentProcess() && IsCallerExternalProcess()) { + // The caller isn't our chrome process, so we do not provide a handler + // payload. Even though we're only getting the size here, calculating the + // payload size might actually require building the payload. + return hr; + } +#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) + + DWORD payloadSize = 0; + hr = mEventSink->GetHandlerPayloadSize(WrapNotNull(this), + WrapNotNull(&payloadSize)); + if (hr == E_NOTIMPL) { + return S_OK; + } + + if (SUCCEEDED(hr)) { + *pSize += payloadSize; + } + return hr; +} + +HRESULT +Interceptor::MarshalInterface(IStream* pStm, REFIID riid, void* pv, + DWORD dwDestContext, void* pvDestContext, + DWORD mshlflags) { + detail::ReentrySentinel sentinel(this); + + HRESULT hr; + +#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) + // Save the current stream position + LARGE_INTEGER seekTo; + seekTo.QuadPart = 0; + + ULARGE_INTEGER objrefPos; + + hr = pStm->Seek(seekTo, STREAM_SEEK_CUR, &objrefPos); + if (FAILED(hr)) { + return hr; + } + +#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) + + hr = mStdMarshal->MarshalInterface(pStm, MarshalAs(riid), pv, dwDestContext, + pvDestContext, mshlflags); + if (FAILED(hr) || !sentinel.IsOutermost()) { + return hr; + } + +#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) + if (XRE_IsContentProcess() && IsCallerExternalProcess()) { + // The caller isn't our chrome process, so do not provide a handler. + + // First, save the current position that marks the current end of the + // OBJREF in the stream. + ULARGE_INTEGER endPos; + hr = pStm->Seek(seekTo, STREAM_SEEK_CUR, &endPos); + if (FAILED(hr)) { + return hr; + } + + // Now strip out the handler. + if (!StripHandlerFromOBJREF(WrapNotNull(pStm), objrefPos.QuadPart, + endPos.QuadPart)) { + return E_FAIL; + } + + return S_OK; + } +#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) + + hr = mEventSink->WriteHandlerPayload(WrapNotNull(this), WrapNotNull(pStm)); + if (hr == E_NOTIMPL) { + return S_OK; + } + + return hr; +} + +HRESULT +Interceptor::UnmarshalInterface(IStream* pStm, REFIID riid, void** ppv) { + return mStdMarshal->UnmarshalInterface(pStm, riid, ppv); +} + +HRESULT +Interceptor::ReleaseMarshalData(IStream* pStm) { + return mStdMarshal->ReleaseMarshalData(pStm); +} + +HRESULT +Interceptor::DisconnectObject(DWORD dwReserved) { + mEventSink->DisconnectHandlerRemotes(); + return mStdMarshal->DisconnectObject(dwReserved); +} + +Interceptor::MapEntry* Interceptor::Lookup(REFIID aIid) { + mInterceptorMapMutex.AssertCurrentThreadOwns(); + + for (uint32_t index = 0, len = mInterceptorMap.Length(); index < len; + ++index) { + if (mInterceptorMap[index].mIID == aIid) { + return &mInterceptorMap[index]; + } + } + return nullptr; +} + +HRESULT +Interceptor::GetTargetForIID(REFIID aIid, + InterceptorTargetPtr<IUnknown>& aTarget) { + MutexAutoLock lock(mInterceptorMapMutex); + MapEntry* entry = Lookup(aIid); + if (entry) { + aTarget.reset(entry->mTargetInterface); + return S_OK; + } + + return E_NOINTERFACE; +} + +// CoGetInterceptor requires type metadata to be able to generate its emulated +// vtable. If no registered metadata is available, CoGetInterceptor returns +// kFileNotFound. +static const HRESULT kFileNotFound = HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + +HRESULT +Interceptor::CreateInterceptor(REFIID aIid, IUnknown* aOuter, + IUnknown** aOutput) { + // In order to aggregate, we *must* request IID_IUnknown as the initial + // interface for the interceptor, as that IUnknown is non-delegating. + // This is a fundamental rule for creating aggregated objects in COM. + HRESULT hr = ::CoGetInterceptor(aIid, aOuter, IID_IUnknown, (void**)aOutput); + if (hr != kFileNotFound) { + return hr; + } + + // In the case that CoGetInterceptor returns kFileNotFound, we can try to + // explicitly load typelib data from our runtime registration facility and + // pass that into CoGetInterceptorFromTypeInfo. + + RefPtr<ITypeInfo> typeInfo; + bool found = RegisteredProxy::Find(aIid, getter_AddRefs(typeInfo)); + // If this assert fires then we have omitted registering the typelib for a + // required interface. To fix this, review our calls to mscom::RegisterProxy + // and mscom::RegisterTypelib, and add the additional typelib as necessary. + MOZ_ASSERT(found); + if (!found) { + return kFileNotFound; + } + + hr = ::CoGetInterceptorFromTypeInfo(aIid, aOuter, typeInfo, IID_IUnknown, + (void**)aOutput); + // If this assert fires then the interceptor doesn't like something about + // the format of the typelib. One thing in particular that it doesn't like + // is complex types that contain unions. + MOZ_ASSERT(SUCCEEDED(hr)); + return hr; +} + +HRESULT +Interceptor::PublishTarget(detail::LiveSetAutoLock& aLiveSetLock, + RefPtr<IUnknown> aInterceptor, REFIID aTargetIid, + STAUniquePtr<IUnknown> aTarget) + MOZ_NO_THREAD_SAFETY_ANALYSIS { + // Suppress thread safety analysis as this conditionally releases locks. + RefPtr<IWeakReference> weakRef; + HRESULT hr = GetWeakReference(getter_AddRefs(weakRef)); + if (FAILED(hr)) { + return hr; + } + + // mTarget is a weak reference to aTarget. This is safe because we transfer + // ownership of aTarget into mInterceptorMap which remains live for the + // lifetime of this Interceptor. + mTarget = ToInterceptorTargetPtr(aTarget); + GetLiveSet().Put(mTarget.get(), weakRef.forget()); + + // Now we transfer aTarget's ownership into mInterceptorMap. + mInterceptorMap.AppendElement( + MapEntry(aTargetIid, aInterceptor, aTarget.release())); + + // Release the live set lock because subsequent operations may post work to + // the main thread, creating potential for deadlocks. + aLiveSetLock.Unlock(); + return S_OK; +} + +HRESULT +Interceptor::GetInitialInterceptorForIID( + detail::LiveSetAutoLock& aLiveSetLock, REFIID aTargetIid, + STAUniquePtr<IUnknown> aTarget, + void** aOutInterceptor) MOZ_NO_THREAD_SAFETY_ANALYSIS { + // Suppress thread safety analysis as this conditionally releases locks. + MOZ_ASSERT(aOutInterceptor); + MOZ_ASSERT(aTargetIid != IID_IMarshal); + MOZ_ASSERT(!IsProxy(aTarget.get())); + + HRESULT hr = E_UNEXPECTED; + + auto hasFailed = [&hr]() -> bool { return FAILED(hr); }; + + MOZ_PUSH_IGNORE_THREAD_SAFETY // Avoid the lambda upsetting analysis. + auto cleanup = [&aLiveSetLock]() -> void { aLiveSetLock.Unlock(); }; + MOZ_POP_THREAD_SAFETY + + ExecuteWhen<decltype(hasFailed), decltype(cleanup)> onFail(hasFailed, + cleanup); + + if (aTargetIid == IID_IUnknown) { + // We must lock mInterceptorMapMutex so that nothing can race with us once + // we have been published to the live set. + MutexAutoLock lock(mInterceptorMapMutex); + + hr = PublishTarget(aLiveSetLock, nullptr, aTargetIid, std::move(aTarget)); + ENSURE_HR_SUCCEEDED(hr); + + hr = QueryInterface(aTargetIid, aOutInterceptor); + ENSURE_HR_SUCCEEDED(hr); + return hr; + } + + // Raise the refcount for stabilization purposes during aggregation + WeakReferenceSupport::StabilizeRefCount stabilizer(*this); + + RefPtr<IUnknown> unkInterceptor; + hr = CreateInterceptor(aTargetIid, static_cast<WeakReferenceSupport*>(this), + getter_AddRefs(unkInterceptor)); + ENSURE_HR_SUCCEEDED(hr); + + RefPtr<ICallInterceptor> interceptor; + hr = unkInterceptor->QueryInterface(IID_ICallInterceptor, + getter_AddRefs(interceptor)); + ENSURE_HR_SUCCEEDED(hr); + + hr = interceptor->RegisterSink(mEventSink); + ENSURE_HR_SUCCEEDED(hr); + + // We must lock mInterceptorMapMutex so that nothing can race with us once we + // have been published to the live set. + MutexAutoLock lock(mInterceptorMapMutex); + + hr = PublishTarget(aLiveSetLock, unkInterceptor, aTargetIid, + std::move(aTarget)); + ENSURE_HR_SUCCEEDED(hr); + + if (MarshalAs(aTargetIid) == aTargetIid) { + hr = unkInterceptor->QueryInterface(aTargetIid, aOutInterceptor); + ENSURE_HR_SUCCEEDED(hr); + return hr; + } + + hr = GetInterceptorForIID(aTargetIid, aOutInterceptor, &lock); + ENSURE_HR_SUCCEEDED(hr); + return hr; +} + +HRESULT +Interceptor::GetInterceptorForIID(REFIID aIid, void** aOutInterceptor) { + return GetInterceptorForIID(aIid, aOutInterceptor, nullptr); +} + +/** + * This method contains the core guts of the handling of QueryInterface calls + * that are delegated to us from the ICallInterceptor. + * + * @param aIid ID of the desired interface + * @param aOutInterceptor The resulting emulated vtable that corresponds to + * the interface specified by aIid. + * @param aAlreadyLocked Proof of an existing lock on |mInterceptorMapMutex|, + * if present. + */ +HRESULT +Interceptor::GetInterceptorForIID(REFIID aIid, void** aOutInterceptor, + MutexAutoLock* aAlreadyLocked) { + detail::LoggedQIResult result(aIid); + + if (!aOutInterceptor) { + return E_INVALIDARG; + } + + if (aIid == IID_IUnknown) { + // Special case: When we see IUnknown, we just provide a reference to this + RefPtr<IInterceptor> intcpt(this); + intcpt.forget(aOutInterceptor); + return S_OK; + } + + REFIID interceptorIid = MarshalAs(aIid); + + RefPtr<IUnknown> unkInterceptor; + IUnknown* interfaceForQILog = nullptr; + + // (1) Check to see if we already have an existing interceptor for + // interceptorIid. + auto doLookup = [&]() -> void { + MapEntry* entry = Lookup(interceptorIid); + if (entry) { + unkInterceptor = entry->mInterceptor; + interfaceForQILog = entry->mTargetInterface; + } + }; + + if (aAlreadyLocked) { + doLookup(); + } else { + MutexAutoLock lock(mInterceptorMapMutex); + doLookup(); + } + + // (1a) A COM interceptor already exists for this interface, so all we need + // to do is run a QI on it. + if (unkInterceptor) { + // Technically we didn't actually execute a QI on the target interface, but + // for logging purposes we would like to record the fact that this interface + // was requested. + result.Log(mTarget.get(), interfaceForQILog); + result = unkInterceptor->QueryInterface(interceptorIid, aOutInterceptor); + ENSURE_HR_SUCCEEDED(result); + return result; + } + + // (2) Obtain a new target interface. + + // (2a) First, make sure that the target interface is available + // NB: We *MUST* query the correct interface! ICallEvents::Invoke casts its + // pvReceiver argument directly to the required interface! DO NOT assume + // that COM will use QI or upcast/downcast! + HRESULT hr; + + STAUniquePtr<IUnknown> targetInterface; + IUnknown* rawTargetInterface = nullptr; + hr = + QueryInterfaceTarget(interceptorIid, (void**)&rawTargetInterface, result); + targetInterface.reset(rawTargetInterface); + result = hr; + result.Log(mTarget.get(), targetInterface.get()); + MOZ_ASSERT(SUCCEEDED(hr) || hr == E_NOINTERFACE); + if (hr == E_NOINTERFACE) { + return hr; + } + ENSURE_HR_SUCCEEDED(hr); + + // We *really* shouldn't be adding interceptors to proxies + MOZ_ASSERT(aIid != IID_IMarshal); + + // (3) Create a new COM interceptor to that interface that delegates its + // IUnknown to |this|. + + // Raise the refcount for stabilization purposes during aggregation + WeakReferenceSupport::StabilizeRefCount stabilizer(*this); + + hr = CreateInterceptor(interceptorIid, + static_cast<WeakReferenceSupport*>(this), + getter_AddRefs(unkInterceptor)); + ENSURE_HR_SUCCEEDED(hr); + + // (4) Obtain the interceptor's ICallInterceptor interface and register our + // event sink. + RefPtr<ICallInterceptor> interceptor; + hr = unkInterceptor->QueryInterface(IID_ICallInterceptor, + (void**)getter_AddRefs(interceptor)); + ENSURE_HR_SUCCEEDED(hr); + + hr = interceptor->RegisterSink(mEventSink); + ENSURE_HR_SUCCEEDED(hr); + + // (5) Now that we have this new COM interceptor, insert it into the map. + auto doInsertion = [&]() -> void { + // We might have raced with another thread, so first check that we don't + // already have an entry for this + MapEntry* entry = Lookup(interceptorIid); + if (entry && entry->mInterceptor) { + // Bug 1433046: Because of aggregation, the QI for |interceptor| + // AddRefed |this|, not |unkInterceptor|. Thus, releasing |unkInterceptor| + // will destroy the object. Before we do that, we must first release + // |interceptor|. Otherwise, |interceptor| would be invalidated when + // |unkInterceptor| is destroyed. + interceptor = nullptr; + unkInterceptor = entry->mInterceptor; + } else { + // MapEntry has a RefPtr to unkInterceptor, OTOH we must not touch the + // refcount for the target interface because we are just moving it into + // the map and its refcounting might not be thread-safe. + IUnknown* rawTargetInterface = targetInterface.release(); + mInterceptorMap.AppendElement( + MapEntry(interceptorIid, unkInterceptor, rawTargetInterface)); + } + }; + + if (aAlreadyLocked) { + doInsertion(); + } else { + MutexAutoLock lock(mInterceptorMapMutex); + doInsertion(); + } + + hr = unkInterceptor->QueryInterface(interceptorIid, aOutInterceptor); + ENSURE_HR_SUCCEEDED(hr); + return hr; +} + +HRESULT +Interceptor::QueryInterfaceTarget(REFIID aIid, void** aOutput, + TimeDuration* aOutDuration) { + // NB: This QI needs to run on the main thread because the target object + // is probably Gecko code that is not thread-safe. Note that this main + // thread invocation is *synchronous*. + if (!NS_IsMainThread() && tlsCreatingStdMarshal.get()) { + mStdMarshalMutex.AssertCurrentThreadOwns(); + // COM queries for special interfaces such as IFastRundown when creating a + // marshaler. We don't want these being dispatched to the main thread, + // since this would cause a deadlock on mStdMarshalMutex if the main + // thread is also querying for IMarshal. If we do need to respond to these + // special interfaces, this should be done before this point; e.g. in + // Interceptor::QueryInterface like we do for INoMarshal. + return E_NOINTERFACE; + } + + if (mEventSink->IsInterfaceMaybeSupported(aIid) == E_NOINTERFACE) { + return E_NOINTERFACE; + } + + MainThreadInvoker invoker; + HRESULT hr; + auto runOnMainThread = [&]() -> void { + MOZ_ASSERT(NS_IsMainThread()); + hr = mTarget->QueryInterface(aIid, aOutput); + }; + if (!invoker.Invoke(NS_NewRunnableFunction("Interceptor::QueryInterface", + runOnMainThread))) { + return E_FAIL; + } + if (aOutDuration) { + *aOutDuration = invoker.GetDuration(); + } + return hr; +} + +HRESULT +Interceptor::QueryInterface(REFIID riid, void** ppv) { + if (riid == IID_INoMarshal) { + // This entire library is designed around marshaling, so there's no point + // propagating this QI request all over the place! + return E_NOINTERFACE; + } + + return WeakReferenceSupport::QueryInterface(riid, ppv); +} + +HRESULT +Interceptor::WeakRefQueryInterface(REFIID aIid, IUnknown** aOutInterface) { + if (aIid == IID_IStdMarshalInfo) { + detail::ReentrySentinel sentinel(this); + + if (!sentinel.IsOutermost()) { + return E_NOINTERFACE; + } + + // Do not indicate that this interface is available unless we actually + // support it. We'll check that by looking for a successful call to + // IInterceptorSink::GetHandler() + CLSID dummy; + if (FAILED(mEventSink->GetHandler(WrapNotNull(&dummy)))) { + return E_NOINTERFACE; + } + + RefPtr<IStdMarshalInfo> std(this); + std.forget(aOutInterface); + return S_OK; + } + + if (aIid == IID_IMarshal) { + MutexAutoLock lock(mStdMarshalMutex); + + HRESULT hr; + + if (!mStdMarshalUnk) { + MOZ_ASSERT(!tlsCreatingStdMarshal.get()); + tlsCreatingStdMarshal.set(true); + if (XRE_IsContentProcess()) { + hr = FastMarshaler::Create(static_cast<IWeakReferenceSource*>(this), + getter_AddRefs(mStdMarshalUnk)); + } else { + hr = ::CoGetStdMarshalEx(static_cast<IWeakReferenceSource*>(this), + SMEXF_SERVER, getter_AddRefs(mStdMarshalUnk)); + } + tlsCreatingStdMarshal.set(false); + + ENSURE_HR_SUCCEEDED(hr); + } + + if (!mStdMarshal) { + hr = mStdMarshalUnk->QueryInterface(IID_IMarshal, (void**)&mStdMarshal); + ENSURE_HR_SUCCEEDED(hr); + + // mStdMarshal is weak, so drop its refcount + mStdMarshal->Release(); + } + + RefPtr<IMarshal> marshal(this); + marshal.forget(aOutInterface); + return S_OK; + } + + if (aIid == IID_IInterceptor) { + RefPtr<IInterceptor> intcpt(this); + intcpt.forget(aOutInterface); + return S_OK; + } + + if (aIid == IID_IDispatch) { + STAUniquePtr<IDispatch> disp; + IDispatch* rawDisp = nullptr; + HRESULT hr = QueryInterfaceTarget(aIid, (void**)&rawDisp); + ENSURE_HR_SUCCEEDED(hr); + + disp.reset(rawDisp); + return DispatchForwarder::Create(this, disp, aOutInterface); + } + + return GetInterceptorForIID(aIid, (void**)aOutInterface, nullptr); +} + +ULONG +Interceptor::AddRef() { return WeakReferenceSupport::AddRef(); } + +ULONG +Interceptor::Release() { return WeakReferenceSupport::Release(); } + +/* static */ +HRESULT Interceptor::DisconnectRemotesForTarget(IUnknown* aTarget) { + MOZ_ASSERT(aTarget); + + detail::LiveSetAutoLock lock(GetLiveSet()); + + // It is not an error if the interceptor doesn't exist, so we return + // S_FALSE instead of an error in that case. + RefPtr<IWeakReference> existingWeak(GetLiveSet().Get(aTarget)); + if (!existingWeak) { + return S_FALSE; + } + + RefPtr<IWeakReferenceSource> existingStrong; + if (FAILED(existingWeak->ToStrongRef(getter_AddRefs(existingStrong)))) { + return S_FALSE; + } + // Since we now hold a strong ref on the interceptor, we may now release the + // lock. + lock.Unlock(); + + return ::CoDisconnectObject(existingStrong, 0); +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/Interceptor.h b/ipc/mscom/Interceptor.h new file mode 100644 index 0000000000..8ab578092b --- /dev/null +++ b/ipc/mscom/Interceptor.h @@ -0,0 +1,199 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_Interceptor_h +#define mozilla_mscom_Interceptor_h + +#include <callobj.h> +#include <objidl.h> + +#include <utility> + +#include "mozilla/Mutex.h" +#include "mozilla/RefPtr.h" +#include "mozilla/mscom/IHandlerProvider.h" +#include "mozilla/mscom/Ptr.h" +#include "mozilla/mscom/WeakRef.h" +#include "nsTArray.h" + +namespace mozilla { +namespace mscom { +namespace detail { + +class LiveSetAutoLock; + +} // namespace detail + +// {8831EB53-A937-42BC-9921-B3E1121FDF86} +DEFINE_GUID(IID_IInterceptorSink, 0x8831eb53, 0xa937, 0x42bc, 0x99, 0x21, 0xb3, + 0xe1, 0x12, 0x1f, 0xdf, 0x86); + +struct IInterceptorSink : public ICallFrameEvents, public HandlerProvider { + virtual STDMETHODIMP SetInterceptor(IWeakReference* aInterceptor) = 0; +}; + +// {3710799B-ECA2-4165-B9B0-3FA1E4A9B230} +DEFINE_GUID(IID_IInterceptor, 0x3710799b, 0xeca2, 0x4165, 0xb9, 0xb0, 0x3f, + 0xa1, 0xe4, 0xa9, 0xb2, 0x30); + +struct IInterceptor : public IUnknown { + virtual STDMETHODIMP GetTargetForIID( + REFIID aIid, InterceptorTargetPtr<IUnknown>& aTarget) = 0; + virtual STDMETHODIMP GetInterceptorForIID(REFIID aIid, + void** aOutInterceptor) = 0; + virtual STDMETHODIMP GetEventSink(IInterceptorSink** aSink) = 0; +}; + +/** + * The COM interceptor is the core functionality in mscom that allows us to + * redirect method calls to different threads. It emulates the vtable of a + * target interface. When a call is made on this emulated vtable, the call is + * packaged up into an instance of the ICallFrame interface which may be passed + * to other contexts for execution. + * + * In order to accomplish this, COM itself provides the CoGetInterceptor + * function, which instantiates an ICallInterceptor. Note, however, that + * ICallInterceptor only works on a single interface; we need to be able to + * interpose QueryInterface calls so that we can instantiate a new + * ICallInterceptor for each new interface that is requested. + * + * We accomplish this by using COM aggregation, which means that the + * ICallInterceptor delegates its IUnknown implementation to its outer object + * (the mscom::Interceptor we implement and control). + * + * ACHTUNG! mscom::Interceptor uses FastMarshaler to disable COM garbage + * collection. If you use this class, you *must* call + * Interceptor::DisconnectRemotesForTarget when an object is no longer relevant. + * Otherwise, the object will never be released, causing a leak. + */ +class Interceptor final : public WeakReferenceSupport, + public IStdMarshalInfo, + public IMarshal, + public IInterceptor { + public: + static HRESULT Create(STAUniquePtr<IUnknown> aTarget, IInterceptorSink* aSink, + REFIID aInitialIid, void** aOutInterface); + + /** + * Disconnect all remote clients for a given target. + * Because Interceptors disable COM garbage collection to improve + * performance, they never receive Release calls from remote clients. If + * the object can be shut down while clients still hold a reference, this + * function can be used to force COM to disconnect all remote connections + * (using CoDisconnectObject) and thus release the associated references to + * the Interceptor, its target and any objects associated with the + * HandlerProvider. + * Note that the specified target must be the same IUnknown pointer used to + * create the Interceptor. Where there is multiple inheritance, querying for + * IID_IUnknown and calling this function with that pointer alone will not + * disconnect remotes for all interfaces. If you expect that the same object + * may be fetched with different initial interfaces, you should call this + * function once for each possible IUnknown pointer. + * @return S_OK if there was an Interceptor for the given target, + * S_FALSE if there was not. + */ + static HRESULT DisconnectRemotesForTarget(IUnknown* aTarget); + + // IUnknown + STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override; + STDMETHODIMP_(ULONG) AddRef() override; + STDMETHODIMP_(ULONG) Release() override; + + // IStdMarshalInfo + STDMETHODIMP GetClassForHandler(DWORD aDestContext, void* aDestContextPtr, + CLSID* aHandlerClsid) override; + + // IMarshal + STDMETHODIMP GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + CLSID* pCid) override; + STDMETHODIMP GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + DWORD* pSize) override; + STDMETHODIMP MarshalInterface(IStream* pStm, REFIID riid, void* pv, + DWORD dwDestContext, void* pvDestContext, + DWORD mshlflags) override; + STDMETHODIMP UnmarshalInterface(IStream* pStm, REFIID riid, + void** ppv) override; + STDMETHODIMP ReleaseMarshalData(IStream* pStm) override; + STDMETHODIMP DisconnectObject(DWORD dwReserved) override; + + // IInterceptor + STDMETHODIMP GetTargetForIID( + REFIID aIid, InterceptorTargetPtr<IUnknown>& aTarget) override; + STDMETHODIMP GetInterceptorForIID(REFIID aIid, + void** aOutInterceptor) override; + + STDMETHODIMP GetEventSink(IInterceptorSink** aSink) override { + RefPtr<IInterceptorSink> sink = mEventSink; + sink.forget(aSink); + return mEventSink ? S_OK : S_FALSE; + } + + private: + struct MapEntry { + MapEntry(REFIID aIid, IUnknown* aInterceptor, IUnknown* aTargetInterface) + : mIID(aIid), + mInterceptor(aInterceptor), + mTargetInterface(aTargetInterface) {} + + IID mIID; + RefPtr<IUnknown> mInterceptor; + IUnknown* mTargetInterface; + }; + + private: + explicit Interceptor(IInterceptorSink* aSink); + ~Interceptor(); + HRESULT GetInitialInterceptorForIID(detail::LiveSetAutoLock& aLiveSetLock, + REFIID aTargetIid, + STAUniquePtr<IUnknown> aTarget, + void** aOutInterface); + HRESULT GetInterceptorForIID(REFIID aIid, void** aOutInterceptor, + MutexAutoLock* aAlreadyLocked); + MapEntry* Lookup(REFIID aIid); + HRESULT QueryInterfaceTarget(REFIID aIid, void** aOutput, + TimeDuration* aOutDuration = nullptr); + HRESULT WeakRefQueryInterface(REFIID aIid, IUnknown** aOutInterface) override; + HRESULT CreateInterceptor(REFIID aIid, IUnknown* aOuter, IUnknown** aOutput); + REFIID MarshalAs(REFIID aIid) const; + HRESULT PublishTarget(detail::LiveSetAutoLock& aLiveSetLock, + RefPtr<IUnknown> aInterceptor, REFIID aTargetIid, + STAUniquePtr<IUnknown> aTarget); + + private: + InterceptorTargetPtr<IUnknown> mTarget; + RefPtr<IInterceptorSink> mEventSink; + mozilla::Mutex mInterceptorMapMutex + MOZ_UNANNOTATED; // Guards mInterceptorMap + // Using a nsTArray since the # of interfaces is not going to be very high + nsTArray<MapEntry> mInterceptorMap; + mozilla::Mutex mStdMarshalMutex + MOZ_UNANNOTATED; // Guards mStdMarshalUnk and mStdMarshal + RefPtr<IUnknown> mStdMarshalUnk; + IMarshal* mStdMarshal; // WEAK + static MOZ_THREAD_LOCAL(bool) tlsCreatingStdMarshal; +}; + +template <typename InterfaceT> +inline HRESULT CreateInterceptor(STAUniquePtr<InterfaceT> aTargetInterface, + IInterceptorSink* aEventSink, + InterfaceT** aOutInterface) { + if (!aTargetInterface || !aEventSink) { + return E_INVALIDARG; + } + + REFIID iidTarget = __uuidof(InterfaceT); + + STAUniquePtr<IUnknown> targetUnknown(aTargetInterface.release()); + return Interceptor::Create(std::move(targetUnknown), aEventSink, iidTarget, + (void**)aOutInterface); +} + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_Interceptor_h diff --git a/ipc/mscom/InterceptorLog.cpp b/ipc/mscom/InterceptorLog.cpp new file mode 100644 index 0000000000..0e127c41b6 --- /dev/null +++ b/ipc/mscom/InterceptorLog.cpp @@ -0,0 +1,527 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/mscom/InterceptorLog.h" + +#include <callobj.h> + +#include <utility> + +#include "MainThreadUtils.h" +#include "mozilla/ClearOnShutdown.h" +#include "mozilla/DebugOnly.h" +#include "mozilla/IntegerRange.h" +#include "mozilla/Mutex.h" +#include "mozilla/Services.h" +#include "mozilla/StaticPtr.h" +#include "mozilla/TimeStamp.h" +#include "mozilla/Unused.h" +#include "mozilla/mscom/Registration.h" +#include "mozilla/mscom/Utils.h" +#include "nsAppDirectoryServiceDefs.h" +#include "nsDirectoryServiceDefs.h" +#include "nsDirectoryServiceUtils.h" +#include "nsIObserver.h" +#include "nsIObserverService.h" +#include "nsIOutputStream.h" +#include "nsNetUtil.h" +#include "nsPrintfCString.h" +#include "nsReadableUtils.h" +#include "nsTArray.h" +#include "nsThreadUtils.h" +#include "nsXPCOMPrivate.h" +#include "nsXULAppAPI.h" +#include "prenv.h" + +using mozilla::DebugOnly; +using mozilla::Mutex; +using mozilla::MutexAutoLock; +using mozilla::NewNonOwningRunnableMethod; +using mozilla::StaticAutoPtr; +using mozilla::TimeDuration; +using mozilla::TimeStamp; +using mozilla::Unused; +using mozilla::mscom::ArrayData; +using mozilla::mscom::FindArrayData; +using mozilla::mscom::IsValidGUID; +using mozilla::services::GetObserverService; + +namespace { + +class ShutdownEvent final : public nsIObserver { + public: + NS_DECL_ISUPPORTS + NS_DECL_NSIOBSERVER + + private: + ~ShutdownEvent() {} +}; + +NS_IMPL_ISUPPORTS(ShutdownEvent, nsIObserver) + +class Logger final { + public: + explicit Logger(const nsACString& aLeafBaseName); + bool IsValid() { + MutexAutoLock lock(mMutex); + return !!mThread; + } + void LogQI(HRESULT aResult, IUnknown* aTarget, REFIID aIid, + IUnknown* aInterface, const TimeDuration* aOverheadDuration, + const TimeDuration* aGeckoDuration); + void CaptureFrame(ICallFrame* aCallFrame, IUnknown* aTargetInterface, + nsACString& aCapturedFrame); + void LogEvent(const nsACString& aCapturedFrame, + const TimeDuration& aOverheadDuration, + const TimeDuration& aGeckoDuration); + nsresult Shutdown(); + + private: + void OpenFile(); + void Flush(); + void CloseFile(); + void AssertRunningOnLoggerThread(); + bool VariantToString(const VARIANT& aVariant, nsACString& aOut, + LONG aIndex = 0); + bool TryParamAsGuid(REFIID aIid, ICallFrame* aCallFrame, + const CALLFRAMEPARAMINFO& aParamInfo, nsACString& aLine); + static double GetElapsedTime(); + + nsCOMPtr<nsIFile> mLogFileName; + nsCOMPtr<nsIOutputStream> mLogFile; // Only accessed by mThread + Mutex mMutex MOZ_UNANNOTATED; // Guards mThread and mEntries + nsCOMPtr<nsIThread> mThread; + nsTArray<nsCString> mEntries; +}; + +Logger::Logger(const nsACString& aLeafBaseName) + : mMutex("mozilla::com::InterceptorLog::Logger") { + MOZ_ASSERT(NS_IsMainThread()); + nsCOMPtr<nsIFile> logFileName; + GeckoProcessType procType = XRE_GetProcessType(); + nsAutoCString leafName(aLeafBaseName); + nsresult rv; + if (procType == GeckoProcessType_Default) { + leafName.AppendLiteral("-Parent-"); + rv = NS_GetSpecialDirectory(NS_OS_TEMP_DIR, getter_AddRefs(logFileName)); + } else if (procType == GeckoProcessType_Content) { + leafName.AppendLiteral("-Content-"); +#if defined(MOZ_SANDBOX) + rv = NS_GetSpecialDirectory(NS_APP_CONTENT_PROCESS_TEMP_DIR, + getter_AddRefs(logFileName)); +#else + rv = NS_GetSpecialDirectory(NS_OS_TEMP_DIR, getter_AddRefs(logFileName)); +#endif // defined(MOZ_SANDBOX) + } else { + return; + } + if (NS_FAILED(rv)) { + return; + } + DWORD pid = GetCurrentProcessId(); + leafName.AppendPrintf("%lu.log", pid); + // Using AppendNative here because Windows + rv = logFileName->AppendNative(leafName); + if (NS_FAILED(rv)) { + return; + } + mLogFileName.swap(logFileName); + + nsCOMPtr<nsIObserverService> obsSvc = GetObserverService(); + nsCOMPtr<nsIObserver> shutdownEvent = new ShutdownEvent(); + rv = obsSvc->AddObserver(shutdownEvent, NS_XPCOM_SHUTDOWN_THREADS_OBSERVER_ID, + false); + if (NS_FAILED(rv)) { + return; + } + + nsCOMPtr<nsIRunnable> openRunnable( + NewNonOwningRunnableMethod("Logger::OpenFile", this, &Logger::OpenFile)); + rv = NS_NewNamedThread("COM Intcpt Log", getter_AddRefs(mThread), + openRunnable); + if (NS_FAILED(rv)) { + obsSvc->RemoveObserver(shutdownEvent, + NS_XPCOM_SHUTDOWN_THREADS_OBSERVER_ID); + } +} + +void Logger::AssertRunningOnLoggerThread() { +#if defined(DEBUG) + nsCOMPtr<nsIThread> curThread; + if (NS_FAILED(NS_GetCurrentThread(getter_AddRefs(curThread)))) { + return; + } + MutexAutoLock lock(mMutex); + MOZ_ASSERT(curThread == mThread); +#endif +} + +void Logger::OpenFile() { + AssertRunningOnLoggerThread(); + MOZ_ASSERT(mLogFileName && !mLogFile); + NS_NewLocalFileOutputStream(getter_AddRefs(mLogFile), mLogFileName, + PR_WRONLY | PR_CREATE_FILE | PR_TRUNCATE, + PR_IRUSR | PR_IWUSR | PR_IRGRP); +} + +void Logger::CloseFile() { + AssertRunningOnLoggerThread(); + MOZ_ASSERT(mLogFile); + if (!mLogFile) { + return; + } + Flush(); + mLogFile->Close(); + mLogFile = nullptr; +} + +nsresult Logger::Shutdown() { + MOZ_ASSERT(NS_IsMainThread()); + nsresult rv = mThread->Dispatch( + NewNonOwningRunnableMethod("Logger::CloseFile", this, &Logger::CloseFile), + NS_DISPATCH_NORMAL); + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "Dispatch failed"); + + rv = mThread->Shutdown(); + NS_WARNING_ASSERTION(NS_SUCCEEDED(rv), "Shutdown failed"); + (void)rv; + return NS_OK; +} + +bool Logger::VariantToString(const VARIANT& aVariant, nsACString& aOut, + LONG aIndex) { + switch (aVariant.vt) { + case VT_DISPATCH: { + aOut.AppendPrintf("(IDispatch*) 0x%p", aVariant.pdispVal); + return true; + } + case VT_DISPATCH | VT_BYREF: { + aOut.AppendPrintf("(IDispatch*) 0x%p", (aVariant.ppdispVal)[aIndex]); + return true; + } + case VT_UNKNOWN: { + aOut.AppendPrintf("(IUnknown*) 0x%p", aVariant.punkVal); + return true; + } + case VT_UNKNOWN | VT_BYREF: { + aOut.AppendPrintf("(IUnknown*) 0x%p", (aVariant.ppunkVal)[aIndex]); + return true; + } + case VT_VARIANT | VT_BYREF: { + return VariantToString((aVariant.pvarVal)[aIndex], aOut); + } + case VT_I4 | VT_BYREF: { + aOut.AppendPrintf("%ld", aVariant.plVal[aIndex]); + return true; + } + case VT_UI4 | VT_BYREF: { + aOut.AppendPrintf("%lu", aVariant.pulVal[aIndex]); + return true; + } + case VT_I4: { + aOut.AppendPrintf("%ld", aVariant.lVal); + return true; + } + case VT_UI4: { + aOut.AppendPrintf("%lu", aVariant.ulVal); + return true; + } + case VT_EMPTY: { + aOut.AppendLiteral("(empty VARIANT)"); + return true; + } + case VT_NULL: { + aOut.AppendLiteral("(null VARIANT)"); + return true; + } + case VT_BSTR: { + aOut.AppendPrintf("\"%S\"", aVariant.bstrVal); + return true; + } + case VT_BSTR | VT_BYREF: { + aOut.AppendPrintf("\"%S\"", *aVariant.pbstrVal); + return true; + } + default: { + aOut.AppendPrintf("(VariantToString failed, VARTYPE == 0x%04hx)", + aVariant.vt); + return false; + } + } +} + +/* static */ +double Logger::GetElapsedTime() { + TimeStamp ts = TimeStamp::Now(); + TimeDuration duration = ts - TimeStamp::ProcessCreation(); + return duration.ToMicroseconds(); +} + +void Logger::LogQI(HRESULT aResult, IUnknown* aTarget, REFIID aIid, + IUnknown* aInterface, const TimeDuration* aOverheadDuration, + const TimeDuration* aGeckoDuration) { + if (FAILED(aResult)) { + return; + } + + double elapsed = GetElapsedTime(); + + nsAutoCString strOverheadDuration; + if (aOverheadDuration) { + strOverheadDuration.AppendPrintf("%.3f", + aOverheadDuration->ToMicroseconds()); + } else { + strOverheadDuration.AppendLiteral("(none)"); + } + + nsAutoCString strGeckoDuration; + if (aGeckoDuration) { + strGeckoDuration.AppendPrintf("%.3f", aGeckoDuration->ToMicroseconds()); + } else { + strGeckoDuration.AppendLiteral("(none)"); + } + + nsPrintfCString line("%.3f\t%s\t%s\t0x%p\tIUnknown::QueryInterface\t([in] ", + elapsed, strOverheadDuration.get(), + strGeckoDuration.get(), aTarget); + + WCHAR buf[39] = {0}; + if (StringFromGUID2(aIid, buf, mozilla::ArrayLength(buf))) { + line.AppendPrintf("%S", buf); + } else { + line.AppendLiteral("(IID Conversion Failed)"); + } + line.AppendPrintf(", [out] 0x%p)\t0x%08lX\n", aInterface, aResult); + + MutexAutoLock lock(mMutex); + mEntries.AppendElement(std::move(line)); + mThread->Dispatch( + NewNonOwningRunnableMethod("Logger::Flush", this, &Logger::Flush), + NS_DISPATCH_NORMAL); +} + +bool Logger::TryParamAsGuid(REFIID aIid, ICallFrame* aCallFrame, + const CALLFRAMEPARAMINFO& aParamInfo, + nsACString& aLine) { + if (aIid != IID_IServiceProvider) { + return false; + } + + GUID** guid = reinterpret_cast<GUID**>( + static_cast<BYTE*>(aCallFrame->GetStackLocation()) + + aParamInfo.stackOffset); + + if (!IsValidGUID(**guid)) { + return false; + } + + WCHAR buf[39] = {0}; + if (!StringFromGUID2(**guid, buf, mozilla::ArrayLength(buf))) { + return false; + } + + aLine.AppendPrintf("%S", buf); + return true; +} + +void Logger::CaptureFrame(ICallFrame* aCallFrame, IUnknown* aTargetInterface, + nsACString& aCapturedFrame) { + aCapturedFrame.Truncate(); + + // (1) Gather info about the call + CALLFRAMEINFO callInfo; + HRESULT hr = aCallFrame->GetInfo(&callInfo); + if (FAILED(hr)) { + return; + } + + PWSTR interfaceName = nullptr; + PWSTR methodName = nullptr; + hr = aCallFrame->GetNames(&interfaceName, &methodName); + if (FAILED(hr)) { + return; + } + + // (2) Serialize the call + nsPrintfCString line("0x%p\t%S::%S\t(", aTargetInterface, interfaceName, + methodName); + + CoTaskMemFree(interfaceName); + interfaceName = nullptr; + CoTaskMemFree(methodName); + methodName = nullptr; + + // Check for supplemental array data + const ArrayData* arrayData = FindArrayData(callInfo.iid, callInfo.iMethod); + + for (ULONG paramIndex = 0; paramIndex < callInfo.cParams; ++paramIndex) { + CALLFRAMEPARAMINFO paramInfo; + hr = aCallFrame->GetParamInfo(paramIndex, ¶mInfo); + if (SUCCEEDED(hr)) { + line.AppendLiteral("["); + if (paramInfo.fIn) { + line.AppendLiteral("in"); + } + if (paramInfo.fOut) { + line.AppendLiteral("out"); + } + line.AppendLiteral("] "); + } + VARIANT paramValue; + hr = aCallFrame->GetParam(paramIndex, ¶mValue); + if (SUCCEEDED(hr)) { + if (arrayData && paramIndex == arrayData->mArrayParamIndex) { + VARIANT lengthParam; + hr = aCallFrame->GetParam(arrayData->mLengthParamIndex, &lengthParam); + if (SUCCEEDED(hr)) { + line.AppendLiteral("{ "); + StringJoinAppend(line, ", "_ns, + mozilla::IntegerRange<LONG>(0, *lengthParam.plVal), + [this, ¶mValue](nsACString& line, const LONG i) { + VariantToString(paramValue, line, i); + }); + line.AppendLiteral(" }"); + } else { + line.AppendPrintf("(GetParam failed with HRESULT 0x%08lX)", hr); + } + } else { + VariantToString(paramValue, line); + } + } else if (hr != DISP_E_BADVARTYPE || + !TryParamAsGuid(callInfo.iid, aCallFrame, paramInfo, line)) { + line.AppendPrintf("(GetParam failed with HRESULT 0x%08lX)", hr); + } + if (paramIndex < callInfo.cParams - 1) { + line.AppendLiteral(", "); + } + } + line.AppendLiteral(")\t"); + + HRESULT callResult = aCallFrame->GetReturnValue(); + line.AppendPrintf("0x%08lX\n", callResult); + + aCapturedFrame = std::move(line); +} + +void Logger::LogEvent(const nsACString& aCapturedFrame, + const TimeDuration& aOverheadDuration, + const TimeDuration& aGeckoDuration) { + double elapsed = GetElapsedTime(); + + nsPrintfCString line("%.3f\t%.3f\t%.3f\t%s", elapsed, + aOverheadDuration.ToMicroseconds(), + aGeckoDuration.ToMicroseconds(), + PromiseFlatCString(aCapturedFrame).get()); + + MutexAutoLock lock(mMutex); + mEntries.AppendElement(line); + mThread->Dispatch( + NewNonOwningRunnableMethod("Logger::Flush", this, &Logger::Flush), + NS_DISPATCH_NORMAL); +} + +void Logger::Flush() { + AssertRunningOnLoggerThread(); + MOZ_ASSERT(mLogFile); + if (!mLogFile) { + return; + } + nsTArray<nsCString> linesToWrite; + { // Scope for lock + MutexAutoLock lock(mMutex); + linesToWrite = std::move(mEntries); + } + + for (uint32_t i = 0, len = linesToWrite.Length(); i < len; ++i) { + uint32_t bytesWritten; + nsCString& line = linesToWrite[i]; + nsresult rv = mLogFile->Write(line.get(), line.Length(), &bytesWritten); + Unused << NS_WARN_IF(NS_FAILED(rv)); + } +} + +StaticAutoPtr<Logger> sLogger; + +NS_IMETHODIMP +ShutdownEvent::Observe(nsISupports* aSubject, const char* aTopic, + const char16_t* aData) { + if (strcmp(aTopic, NS_XPCOM_SHUTDOWN_THREADS_OBSERVER_ID)) { + MOZ_ASSERT(false); + return NS_ERROR_NOT_IMPLEMENTED; + } + MOZ_ASSERT(sLogger); + Unused << NS_WARN_IF(NS_FAILED(sLogger->Shutdown())); + nsCOMPtr<nsIObserver> kungFuDeathGrip(this); + nsCOMPtr<nsIObserverService> obsSvc = GetObserverService(); + obsSvc->RemoveObserver(this, aTopic); + return NS_OK; +} +} // anonymous namespace + +static bool MaybeCreateLog(const char* aEnvVarName) { + MOZ_ASSERT(NS_IsMainThread()); + MOZ_ASSERT(XRE_IsContentProcess() || XRE_IsParentProcess()); + MOZ_ASSERT(!sLogger); + const char* leafBaseName = PR_GetEnv(aEnvVarName); + if (!leafBaseName) { + return false; + } + nsDependentCString strLeafBaseName(leafBaseName); + if (strLeafBaseName.IsEmpty()) { + return false; + } + sLogger = new Logger(strLeafBaseName); + if (!sLogger->IsValid()) { + sLogger = nullptr; + return false; + } + ClearOnShutdown(&sLogger); + return true; +} + +namespace mozilla { +namespace mscom { + +/* static */ +bool InterceptorLog::Init() { + static const bool isEnabled = MaybeCreateLog("MOZ_MSCOM_LOG_BASENAME"); + return isEnabled; +} + +/* static */ +void InterceptorLog::QI(HRESULT aResult, IUnknown* aTarget, REFIID aIid, + IUnknown* aInterface, + const TimeDuration* aOverheadDuration, + const TimeDuration* aGeckoDuration) { + if (!sLogger) { + return; + } + sLogger->LogQI(aResult, aTarget, aIid, aInterface, aOverheadDuration, + aGeckoDuration); +} + +/* static */ +void InterceptorLog::CaptureFrame(ICallFrame* aCallFrame, + IUnknown* aTargetInterface, + nsACString& aCapturedFrame) { + if (!sLogger) { + return; + } + sLogger->CaptureFrame(aCallFrame, aTargetInterface, aCapturedFrame); +} + +/* static */ +void InterceptorLog::Event(const nsACString& aCapturedFrame, + const TimeDuration& aOverheadDuration, + const TimeDuration& aGeckoDuration) { + if (!sLogger) { + return; + } + sLogger->LogEvent(aCapturedFrame, aOverheadDuration, aGeckoDuration); +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/InterceptorLog.h b/ipc/mscom/InterceptorLog.h new file mode 100644 index 0000000000..893c2fcbb5 --- /dev/null +++ b/ipc/mscom/InterceptorLog.h @@ -0,0 +1,36 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_InterceptorLog_h +#define mozilla_mscom_InterceptorLog_h + +#include "mozilla/TimeStamp.h" +#include "nsString.h" + +struct ICallFrame; +struct IUnknown; + +namespace mozilla { +namespace mscom { + +class InterceptorLog { + public: + static bool Init(); + static void QI(HRESULT aResult, IUnknown* aTarget, REFIID aIid, + IUnknown* aInterface, + const TimeDuration* aOverheadDuration = nullptr, + const TimeDuration* aGeckoDuration = nullptr); + static void CaptureFrame(ICallFrame* aCallFrame, IUnknown* aTarget, + nsACString& aCapturedFrame); + static void Event(const nsACString& aCapturedFrame, + const TimeDuration& aOverheadDuration, + const TimeDuration& aGeckoDuration); +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_InterceptorLog_h diff --git a/ipc/mscom/MainThreadHandoff.cpp b/ipc/mscom/MainThreadHandoff.cpp new file mode 100644 index 0000000000..544befd559 --- /dev/null +++ b/ipc/mscom/MainThreadHandoff.cpp @@ -0,0 +1,697 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#define INITGUID + +#include "mozilla/mscom/MainThreadHandoff.h" + +#include <utility> + +#include "mozilla/Assertions.h" +#include "mozilla/Attributes.h" +#include "mozilla/DebugOnly.h" +#include "mozilla/ThreadLocal.h" +#include "mozilla/TimeStamp.h" +#include "mozilla/Unused.h" +#include "mozilla/mscom/AgileReference.h" +#include "mozilla/mscom/InterceptorLog.h" +#include "mozilla/mscom/Registration.h" +#include "mozilla/mscom/Utils.h" +#include "nsProxyRelease.h" +#include "nsThreadUtils.h" + +using mozilla::DebugOnly; +using mozilla::Unused; +using mozilla::mscom::AgileReference; + +namespace { + +class MOZ_NON_TEMPORARY_CLASS InParamWalker : private ICallFrameWalker { + public: + InParamWalker() : mPreHandoff(true) {} + + void SetHandoffDone() { + mPreHandoff = false; + mAgileRefsItr = mAgileRefs.begin(); + } + + HRESULT Walk(ICallFrame* aFrame) { + MOZ_ASSERT(aFrame); + if (!aFrame) { + return E_INVALIDARG; + } + + return aFrame->WalkFrame(CALLFRAME_WALK_IN, this); + } + + private: + // IUnknown + STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) override { + if (!aOutInterface) { + return E_INVALIDARG; + } + *aOutInterface = nullptr; + + if (aIid == IID_IUnknown || aIid == IID_ICallFrameWalker) { + *aOutInterface = static_cast<ICallFrameWalker*>(this); + return S_OK; + } + + return E_NOINTERFACE; + } + + STDMETHODIMP_(ULONG) AddRef() override { return 2; } + + STDMETHODIMP_(ULONG) Release() override { return 1; } + + // ICallFrameWalker + STDMETHODIMP OnWalkInterface(REFIID aIid, PVOID* aInterface, BOOL aIn, + BOOL aOut) override { + MOZ_ASSERT(aIn); + if (!aIn) { + return E_UNEXPECTED; + } + + IUnknown* origInterface = static_cast<IUnknown*>(*aInterface); + if (!origInterface) { + // Nothing to do + return S_OK; + } + + if (mPreHandoff) { + mAgileRefs.AppendElement(AgileReference(aIid, origInterface)); + return S_OK; + } + + MOZ_ASSERT(mAgileRefsItr != mAgileRefs.end()); + if (mAgileRefsItr == mAgileRefs.end()) { + return E_UNEXPECTED; + } + + HRESULT hr = mAgileRefsItr->Resolve(aIid, aInterface); + MOZ_ASSERT(SUCCEEDED(hr)); + if (SUCCEEDED(hr)) { + ++mAgileRefsItr; + } + + return hr; + } + + InParamWalker(const InParamWalker&) = delete; + InParamWalker(InParamWalker&&) = delete; + InParamWalker& operator=(const InParamWalker&) = delete; + InParamWalker& operator=(InParamWalker&&) = delete; + + private: + bool mPreHandoff; + AutoTArray<AgileReference, 1> mAgileRefs; + nsTArray<AgileReference>::iterator mAgileRefsItr; +}; + +class HandoffRunnable : public mozilla::Runnable { + public: + explicit HandoffRunnable(ICallFrame* aCallFrame, IUnknown* aTargetInterface) + : Runnable("HandoffRunnable"), + mCallFrame(aCallFrame), + mTargetInterface(aTargetInterface), + mResult(E_UNEXPECTED) { + DebugOnly<HRESULT> hr = mInParamWalker.Walk(aCallFrame); + MOZ_ASSERT(SUCCEEDED(hr)); + } + + NS_IMETHOD Run() override { + mInParamWalker.SetHandoffDone(); + // We declare hr a DebugOnly because if mInParamWalker.Walk() fails, then + // mCallFrame->Invoke will fail anyway. + DebugOnly<HRESULT> hr = mInParamWalker.Walk(mCallFrame); + MOZ_ASSERT(SUCCEEDED(hr)); + mResult = mCallFrame->Invoke(mTargetInterface); + return NS_OK; + } + + HRESULT GetResult() const { return mResult; } + + private: + ICallFrame* mCallFrame; + InParamWalker mInParamWalker; + IUnknown* mTargetInterface; + HRESULT mResult; +}; + +class MOZ_RAII SavedCallFrame final { + public: + explicit SavedCallFrame(mozilla::NotNull<ICallFrame*> aFrame) + : mCallFrame(aFrame) { + static const bool sIsInit = tlsFrame.init(); + MOZ_ASSERT(sIsInit); + MOZ_ASSERT(!tlsFrame.get()); + tlsFrame.set(this); + Unused << sIsInit; + } + + ~SavedCallFrame() { + MOZ_ASSERT(tlsFrame.get()); + tlsFrame.set(nullptr); + } + + HRESULT GetIidAndMethod(mozilla::NotNull<IID*> aIid, + mozilla::NotNull<ULONG*> aMethod) const { + return mCallFrame->GetIIDAndMethod(aIid, aMethod); + } + + static const SavedCallFrame& Get() { + SavedCallFrame* saved = tlsFrame.get(); + MOZ_ASSERT(saved); + + return *saved; + } + + SavedCallFrame(const SavedCallFrame&) = delete; + SavedCallFrame(SavedCallFrame&&) = delete; + SavedCallFrame& operator=(const SavedCallFrame&) = delete; + SavedCallFrame& operator=(SavedCallFrame&&) = delete; + + private: + ICallFrame* mCallFrame; + + private: + static MOZ_THREAD_LOCAL(SavedCallFrame*) tlsFrame; +}; + +MOZ_THREAD_LOCAL(SavedCallFrame*) SavedCallFrame::tlsFrame; + +class MOZ_RAII LogEvent final { + public: + LogEvent() : mCallStart(mozilla::TimeStamp::Now()) {} + + ~LogEvent() { + if (mCapturedFrame.IsEmpty()) { + return; + } + + mozilla::TimeStamp callEnd(mozilla::TimeStamp::Now()); + mozilla::TimeDuration totalTime(callEnd - mCallStart); + mozilla::TimeDuration overhead(totalTime - mGeckoDuration - + mCaptureDuration); + + mozilla::mscom::InterceptorLog::Event(mCapturedFrame, overhead, + mGeckoDuration); + } + + void CaptureFrame(ICallFrame* aFrame, IUnknown* aTarget, + const mozilla::TimeDuration& aGeckoDuration) { + mozilla::TimeStamp captureStart(mozilla::TimeStamp::Now()); + + mozilla::mscom::InterceptorLog::CaptureFrame(aFrame, aTarget, + mCapturedFrame); + mGeckoDuration = aGeckoDuration; + + mozilla::TimeStamp captureEnd(mozilla::TimeStamp::Now()); + + // Make sure that the time we spent in CaptureFrame isn't charged against + // overall overhead + mCaptureDuration = captureEnd - captureStart; + } + + LogEvent(const LogEvent&) = delete; + LogEvent(LogEvent&&) = delete; + LogEvent& operator=(const LogEvent&) = delete; + LogEvent& operator=(LogEvent&&) = delete; + + private: + mozilla::TimeStamp mCallStart; + mozilla::TimeDuration mGeckoDuration; + mozilla::TimeDuration mCaptureDuration; + nsAutoCString mCapturedFrame; +}; + +} // anonymous namespace + +namespace mozilla { +namespace mscom { + +/* static */ +HRESULT MainThreadHandoff::Create(IHandlerProvider* aHandlerProvider, + IInterceptorSink** aOutput) { + RefPtr<MainThreadHandoff> handoff(new MainThreadHandoff(aHandlerProvider)); + return handoff->QueryInterface(IID_IInterceptorSink, (void**)aOutput); +} + +MainThreadHandoff::MainThreadHandoff(IHandlerProvider* aHandlerProvider) + : mRefCnt(0), mHandlerProvider(aHandlerProvider) {} + +MainThreadHandoff::~MainThreadHandoff() { MOZ_ASSERT(NS_IsMainThread()); } + +HRESULT +MainThreadHandoff::QueryInterface(REFIID riid, void** ppv) { + IUnknown* punk = nullptr; + if (!ppv) { + return E_INVALIDARG; + } + + if (riid == IID_IUnknown || riid == IID_ICallFrameEvents || + riid == IID_IInterceptorSink || riid == IID_IMainThreadHandoff) { + punk = static_cast<IMainThreadHandoff*>(this); + } else if (riid == IID_ICallFrameWalker) { + punk = static_cast<ICallFrameWalker*>(this); + } + + *ppv = punk; + if (!punk) { + return E_NOINTERFACE; + } + + punk->AddRef(); + return S_OK; +} + +ULONG +MainThreadHandoff::AddRef() { + return (ULONG)InterlockedIncrement((LONG*)&mRefCnt); +} + +ULONG +MainThreadHandoff::Release() { + ULONG newRefCnt = (ULONG)InterlockedDecrement((LONG*)&mRefCnt); + if (newRefCnt == 0) { + // It is possible for the last Release() call to happen off-main-thread. + // If so, we need to dispatch an event to delete ourselves. + if (NS_IsMainThread()) { + delete this; + } else { + // We need to delete this object on the main thread, but we aren't on the + // main thread right now, so we send a reference to ourselves to the main + // thread to be re-released there. + RefPtr<MainThreadHandoff> self = this; + NS_ReleaseOnMainThread("MainThreadHandoff", self.forget()); + } + } + return newRefCnt; +} + +HRESULT +MainThreadHandoff::FixIServiceProvider(ICallFrame* aFrame) { + MOZ_ASSERT(aFrame); + + CALLFRAMEPARAMINFO iidOutParamInfo; + HRESULT hr = aFrame->GetParamInfo(1, &iidOutParamInfo); + if (FAILED(hr)) { + return hr; + } + + VARIANT varIfaceOut; + hr = aFrame->GetParam(2, &varIfaceOut); + if (FAILED(hr)) { + return hr; + } + + MOZ_ASSERT(varIfaceOut.vt == (VT_UNKNOWN | VT_BYREF)); + if (varIfaceOut.vt != (VT_UNKNOWN | VT_BYREF)) { + return DISP_E_BADVARTYPE; + } + + IID** iidOutParam = + reinterpret_cast<IID**>(static_cast<BYTE*>(aFrame->GetStackLocation()) + + iidOutParamInfo.stackOffset); + + return OnWalkInterface(**iidOutParam, + reinterpret_cast<void**>(varIfaceOut.ppunkVal), FALSE, + TRUE); +} + +HRESULT +MainThreadHandoff::OnCall(ICallFrame* aFrame) { + LogEvent logEvent; + + // (1) Get info about the method call + HRESULT hr; + IID iid; + ULONG method; + hr = aFrame->GetIIDAndMethod(&iid, &method); + if (FAILED(hr)) { + return hr; + } + + RefPtr<IInterceptor> interceptor; + hr = mInterceptor->Resolve(IID_IInterceptor, + (void**)getter_AddRefs(interceptor)); + if (FAILED(hr)) { + return hr; + } + + InterceptorTargetPtr<IUnknown> targetInterface; + hr = interceptor->GetTargetForIID(iid, targetInterface); + if (FAILED(hr)) { + return hr; + } + + // (2) Execute the method call synchronously on the main thread + RefPtr<HandoffRunnable> handoffInfo( + new HandoffRunnable(aFrame, targetInterface.get())); + MainThreadInvoker invoker; + if (!invoker.Invoke(do_AddRef(handoffInfo))) { + MOZ_ASSERT(false); + return E_UNEXPECTED; + } + hr = handoffInfo->GetResult(); + MOZ_ASSERT(SUCCEEDED(hr)); + if (FAILED(hr)) { + return hr; + } + + // (3) Capture *before* wrapping outputs so that the log will contain pointers + // to the true target interface, not the wrapped ones. + logEvent.CaptureFrame(aFrame, targetInterface.get(), invoker.GetDuration()); + + // (4) Scan the function call for outparams that contain interface pointers. + // Those will need to be wrapped with MainThreadHandoff so that they too will + // be exeuted on the main thread. + + hr = aFrame->GetReturnValue(); + if (FAILED(hr)) { + // If the call resulted in an error then there's not going to be anything + // that needs to be wrapped. + return S_OK; + } + + if (iid == IID_IServiceProvider) { + // The only possible method index for IID_IServiceProvider is for + // QueryService at index 3; its other methods are inherited from IUnknown + // and are not processed here. + MOZ_ASSERT(method == 3); + // (5) If our interface is IServiceProvider, we need to manually ensure + // that the correct IID is provided for the interface outparam in + // IServiceProvider::QueryService. + hr = FixIServiceProvider(aFrame); + if (FAILED(hr)) { + return hr; + } + } else if (const ArrayData* arrayData = FindArrayData(iid, method)) { + // (6) Unfortunately ICallFrame::WalkFrame does not correctly handle array + // outparams. Instead, we find out whether anybody has called + // mscom::RegisterArrayData to supply array parameter information and use it + // if available. This is a terrible hack, but it works for the short term. + // In the longer term we want to be able to use COM proxy/stub metadata to + // resolve array information for us. + hr = FixArrayElements(aFrame, *arrayData); + if (FAILED(hr)) { + return hr; + } + } else { + SavedCallFrame savedFrame(WrapNotNull(aFrame)); + + // (7) Scan the outputs looking for any outparam interfaces that need + // wrapping. NB: WalkFrame does not correctly handle array outparams. It + // processes the first element of an array but not the remaining elements + // (if any). + hr = aFrame->WalkFrame(CALLFRAME_WALK_OUT, this); + if (FAILED(hr)) { + return hr; + } + } + + return S_OK; +} + +static PVOID ResolveArrayPtr(VARIANT& aVariant) { + if (!(aVariant.vt & VT_BYREF)) { + return nullptr; + } + return aVariant.byref; +} + +static PVOID* ResolveInterfacePtr(PVOID aArrayPtr, VARTYPE aVartype, + LONG aIndex) { + if (aVartype != (VT_VARIANT | VT_BYREF)) { + IUnknown** ifaceArray = reinterpret_cast<IUnknown**>(aArrayPtr); + return reinterpret_cast<PVOID*>(&ifaceArray[aIndex]); + } + VARIANT* variantArray = reinterpret_cast<VARIANT*>(aArrayPtr); + VARIANT& element = variantArray[aIndex]; + return &element.byref; +} + +HRESULT +MainThreadHandoff::FixArrayElements(ICallFrame* aFrame, + const ArrayData& aArrayData) { + // Extract the array length + VARIANT paramVal; + VariantInit(¶mVal); + HRESULT hr = aFrame->GetParam(aArrayData.mLengthParamIndex, ¶mVal); + MOZ_ASSERT(SUCCEEDED(hr) && (paramVal.vt == (VT_I4 | VT_BYREF) || + paramVal.vt == (VT_UI4 | VT_BYREF))); + if (FAILED(hr) || (paramVal.vt != (VT_I4 | VT_BYREF) && + paramVal.vt != (VT_UI4 | VT_BYREF))) { + return hr; + } + + const LONG arrayLength = *(paramVal.plVal); + if (!arrayLength) { + // Nothing to do + return S_OK; + } + + // Extract the array parameter + VariantInit(¶mVal); + PVOID arrayPtr = nullptr; + hr = aFrame->GetParam(aArrayData.mArrayParamIndex, ¶mVal); + if (hr == DISP_E_BADVARTYPE) { + // ICallFrame::GetParam is not able to coerce the param into a VARIANT. + // That's ok, we can try to do it ourselves. + CALLFRAMEPARAMINFO paramInfo; + hr = aFrame->GetParamInfo(aArrayData.mArrayParamIndex, ¶mInfo); + if (FAILED(hr)) { + return hr; + } + PVOID stackBase = aFrame->GetStackLocation(); + if (aArrayData.mFlag == ArrayData::Flag::eAllocatedByServer) { + // In order for the server to allocate the array's buffer and store it in + // an outparam, the parameter must be typed as Type***. Since the base + // of the array is Type*, we must dereference twice. + arrayPtr = **reinterpret_cast<PVOID**>( + reinterpret_cast<PBYTE>(stackBase) + paramInfo.stackOffset); + } else { + // We dereference because we need to obtain the value of a parameter + // from a stack offset. This pointer is the base of the array. + arrayPtr = *reinterpret_cast<PVOID*>(reinterpret_cast<PBYTE>(stackBase) + + paramInfo.stackOffset); + } + } else if (FAILED(hr)) { + return hr; + } else { + arrayPtr = ResolveArrayPtr(paramVal); + } + + MOZ_ASSERT(arrayPtr); + if (!arrayPtr) { + return DISP_E_BADVARTYPE; + } + + // We walk the elements of the array and invoke OnWalkInterface to wrap each + // one, just as ICallFrame::WalkFrame would do. + for (LONG index = 0; index < arrayLength; ++index) { + hr = OnWalkInterface(aArrayData.mArrayParamIid, + ResolveInterfacePtr(arrayPtr, paramVal.vt, index), + FALSE, TRUE); + if (FAILED(hr)) { + return hr; + } + } + return S_OK; +} + +HRESULT +MainThreadHandoff::SetInterceptor(IWeakReference* aInterceptor) { + mInterceptor = aInterceptor; + return S_OK; +} + +HRESULT +MainThreadHandoff::GetHandler(NotNull<CLSID*> aHandlerClsid) { + if (!mHandlerProvider) { + return E_NOTIMPL; + } + + return mHandlerProvider->GetHandler(aHandlerClsid); +} + +HRESULT +MainThreadHandoff::GetHandlerPayloadSize(NotNull<IInterceptor*> aInterceptor, + NotNull<DWORD*> aOutPayloadSize) { + if (!mHandlerProvider) { + return E_NOTIMPL; + } + return mHandlerProvider->GetHandlerPayloadSize(aInterceptor, aOutPayloadSize); +} + +HRESULT +MainThreadHandoff::WriteHandlerPayload(NotNull<IInterceptor*> aInterceptor, + NotNull<IStream*> aStream) { + if (!mHandlerProvider) { + return E_NOTIMPL; + } + return mHandlerProvider->WriteHandlerPayload(aInterceptor, aStream); +} + +REFIID +MainThreadHandoff::MarshalAs(REFIID aIid) { + if (!mHandlerProvider) { + return aIid; + } + return mHandlerProvider->MarshalAs(aIid); +} + +HRESULT +MainThreadHandoff::DisconnectHandlerRemotes() { + if (!mHandlerProvider) { + return E_NOTIMPL; + } + + return mHandlerProvider->DisconnectHandlerRemotes(); +} + +HRESULT +MainThreadHandoff::IsInterfaceMaybeSupported(REFIID aIid) { + if (!mHandlerProvider) { + return S_OK; + } + return mHandlerProvider->IsInterfaceMaybeSupported(aIid); +} + +HRESULT +MainThreadHandoff::OnWalkInterface(REFIID aIid, PVOID* aInterface, + BOOL aIsInParam, BOOL aIsOutParam) { + MOZ_ASSERT(aInterface && aIsOutParam); + if (!aInterface || !aIsOutParam) { + return E_UNEXPECTED; + } + + // Adopt aInterface for the time being. We can't touch its refcount off + // the main thread, so we'll use STAUniquePtr so that we can safely + // Release() it if necessary. + STAUniquePtr<IUnknown> origInterface(static_cast<IUnknown*>(*aInterface)); + *aInterface = nullptr; + + if (!origInterface) { + // Nothing to wrap. + return S_OK; + } + + // First make sure that aInterface isn't a proxy - we don't want to wrap + // those. + if (IsProxy(origInterface.get())) { + *aInterface = origInterface.release(); + return S_OK; + } + + RefPtr<IInterceptor> interceptor; + HRESULT hr = mInterceptor->Resolve(IID_IInterceptor, + (void**)getter_AddRefs(interceptor)); + MOZ_ASSERT(SUCCEEDED(hr)); + if (FAILED(hr)) { + return hr; + } + + // Now make sure that origInterface isn't referring to the same IUnknown + // as an interface that we are already managing. We can determine this by + // querying (NOT casting!) both objects for IUnknown and then comparing the + // resulting pointers. + InterceptorTargetPtr<IUnknown> existingTarget; + hr = interceptor->GetTargetForIID(aIid, existingTarget); + if (SUCCEEDED(hr)) { + // We'll start by checking the raw pointers. If they are equal, then the + // objects are equal. OTOH, if they differ, we must compare their + // IUnknown pointers to know for sure. + bool areTargetsEqual = existingTarget.get() == origInterface.get(); + + if (!areTargetsEqual) { + // This check must be done on the main thread + auto checkFn = [&existingTarget, &origInterface, + &areTargetsEqual]() -> void { + RefPtr<IUnknown> unkExisting; + HRESULT hrExisting = existingTarget->QueryInterface( + IID_IUnknown, (void**)getter_AddRefs(unkExisting)); + RefPtr<IUnknown> unkNew; + HRESULT hrNew = origInterface->QueryInterface( + IID_IUnknown, (void**)getter_AddRefs(unkNew)); + areTargetsEqual = + SUCCEEDED(hrExisting) && SUCCEEDED(hrNew) && unkExisting == unkNew; + }; + + MainThreadInvoker invoker; + invoker.Invoke(NS_NewRunnableFunction( + "MainThreadHandoff::OnWalkInterface", checkFn)); + } + + if (areTargetsEqual) { + // The existing interface and the new interface both belong to the same + // target object. Let's just use the existing one. + void* intercepted = nullptr; + hr = interceptor->GetInterceptorForIID(aIid, &intercepted); + MOZ_ASSERT(SUCCEEDED(hr)); + if (FAILED(hr)) { + return hr; + } + *aInterface = intercepted; + return S_OK; + } + } + + IID effectiveIid = aIid; + + RefPtr<IHandlerProvider> payload; + if (mHandlerProvider) { + if (aIid == IID_IUnknown) { + const SavedCallFrame& curFrame = SavedCallFrame::Get(); + + IID callIid; + ULONG callMethod; + hr = curFrame.GetIidAndMethod(WrapNotNull(&callIid), + WrapNotNull(&callMethod)); + if (FAILED(hr)) { + return hr; + } + + effectiveIid = + mHandlerProvider->GetEffectiveOutParamIid(callIid, callMethod); + } + + hr = mHandlerProvider->NewInstance( + effectiveIid, ToInterceptorTargetPtr(origInterface), + WrapNotNull((IHandlerProvider**)getter_AddRefs(payload))); + MOZ_ASSERT(SUCCEEDED(hr)); + if (FAILED(hr)) { + return hr; + } + } + + // Now create a new MainThreadHandoff wrapper... + RefPtr<IInterceptorSink> handoff; + hr = MainThreadHandoff::Create(payload, getter_AddRefs(handoff)); + MOZ_ASSERT(SUCCEEDED(hr)); + if (FAILED(hr)) { + return hr; + } + + REFIID interceptorIid = + payload ? payload->MarshalAs(effectiveIid) : effectiveIid; + + RefPtr<IUnknown> wrapped; + hr = Interceptor::Create(std::move(origInterface), handoff, interceptorIid, + getter_AddRefs(wrapped)); + MOZ_ASSERT(SUCCEEDED(hr)); + if (FAILED(hr)) { + return hr; + } + + // And replace the original interface pointer with the wrapped one. + wrapped.forget(reinterpret_cast<IUnknown**>(aInterface)); + + return S_OK; +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/MainThreadHandoff.h b/ipc/mscom/MainThreadHandoff.h new file mode 100644 index 0000000000..e790cea4b4 --- /dev/null +++ b/ipc/mscom/MainThreadHandoff.h @@ -0,0 +1,105 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_MainThreadHandoff_h +#define mozilla_mscom_MainThreadHandoff_h + +#include <utility> + +#include "mozilla/Assertions.h" +#include "mozilla/Mutex.h" +#include "mozilla/mscom/Interceptor.h" +#include "mozilla/mscom/MainThreadInvoker.h" +#include "mozilla/mscom/Utils.h" +#include "nsTArray.h" + +namespace mozilla { +namespace mscom { + +// {9a907000-7829-47f1-80eb-f67a26f47b34} +DEFINE_GUID(IID_IMainThreadHandoff, 0x9a907000, 0x7829, 0x47f1, 0x80, 0xeb, + 0xf6, 0x7a, 0x26, 0xf4, 0x7b, 0x34); + +struct IMainThreadHandoff : public IInterceptorSink { + virtual STDMETHODIMP GetHandlerProvider(IHandlerProvider** aProvider) = 0; +}; + +struct ArrayData; + +class MainThreadHandoff final : public IMainThreadHandoff, + public ICallFrameWalker { + public: + static HRESULT Create(IHandlerProvider* aHandlerProvider, + IInterceptorSink** aOutput); + + template <typename Interface> + static HRESULT WrapInterface(STAUniquePtr<Interface> aTargetInterface, + Interface** aOutInterface) { + return WrapInterface<Interface>(std::move(aTargetInterface), nullptr, + aOutInterface); + } + + template <typename Interface> + static HRESULT WrapInterface(STAUniquePtr<Interface> aTargetInterface, + IHandlerProvider* aHandlerProvider, + Interface** aOutInterface) { + MOZ_ASSERT(!IsProxy(aTargetInterface.get())); + RefPtr<IInterceptorSink> handoff; + HRESULT hr = + MainThreadHandoff::Create(aHandlerProvider, getter_AddRefs(handoff)); + if (FAILED(hr)) { + return hr; + } + return CreateInterceptor(std::move(aTargetInterface), handoff, + aOutInterface); + } + + // IUnknown + STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override; + STDMETHODIMP_(ULONG) AddRef() override; + STDMETHODIMP_(ULONG) Release() override; + + // ICallFrameEvents + STDMETHODIMP OnCall(ICallFrame* aFrame) override; + + // IInterceptorSink + STDMETHODIMP SetInterceptor(IWeakReference* aInterceptor) override; + STDMETHODIMP GetHandler(NotNull<CLSID*> aHandlerClsid) override; + STDMETHODIMP GetHandlerPayloadSize(NotNull<IInterceptor*> aInterceptor, + NotNull<DWORD*> aOutPayloadSize) override; + STDMETHODIMP WriteHandlerPayload(NotNull<IInterceptor*> aInterceptor, + NotNull<IStream*> aStream) override; + STDMETHODIMP_(REFIID) MarshalAs(REFIID aIid) override; + STDMETHODIMP DisconnectHandlerRemotes() override; + STDMETHODIMP IsInterfaceMaybeSupported(REFIID aIid) override; + + // IMainThreadHandoff + STDMETHODIMP GetHandlerProvider(IHandlerProvider** aProvider) override { + RefPtr<IHandlerProvider> provider = mHandlerProvider; + provider.forget(aProvider); + return mHandlerProvider ? S_OK : S_FALSE; + } + + // ICallFrameWalker + STDMETHODIMP OnWalkInterface(REFIID aIid, PVOID* aInterface, BOOL aIsInParam, + BOOL aIsOutParam) override; + + private: + explicit MainThreadHandoff(IHandlerProvider* aHandlerProvider); + ~MainThreadHandoff(); + HRESULT FixArrayElements(ICallFrame* aFrame, const ArrayData& aArrayData); + HRESULT FixIServiceProvider(ICallFrame* aFrame); + + private: + ULONG mRefCnt; + RefPtr<IWeakReference> mInterceptor; + RefPtr<IHandlerProvider> mHandlerProvider; +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_MainThreadHandoff_h diff --git a/ipc/mscom/MainThreadInvoker.cpp b/ipc/mscom/MainThreadInvoker.cpp new file mode 100644 index 0000000000..8062800249 --- /dev/null +++ b/ipc/mscom/MainThreadInvoker.cpp @@ -0,0 +1,177 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/mscom/MainThreadInvoker.h" + +#include "MainThreadUtils.h" +#include "mozilla/Assertions.h" +#include "mozilla/BackgroundHangMonitor.h" +#include "mozilla/ClearOnShutdown.h" +#include "mozilla/ProfilerThreadState.h" +#include "mozilla/SchedulerGroup.h" +#include "mozilla/mscom/SpinEvent.h" +#include "mozilla/RefPtr.h" +#include "mozilla/Unused.h" +#include "private/prpriv.h" // For PR_GetThreadID +#include <winternl.h> // For NTSTATUS and NTAPI + +namespace { + +typedef NTSTATUS(NTAPI* NtTestAlertPtr)(VOID); + +/** + * SyncRunnable implements different code paths depending on whether or not + * we are running on a multiprocessor system. In the multiprocessor case, we + * leave the thread in a spin loop while waiting for the main thread to execute + * our runnable. Since spinning is pointless in the uniprocessor case, we block + * on an event that is set by the main thread once it has finished the runnable. + */ +class SyncRunnable : public mozilla::Runnable { + public: + explicit SyncRunnable(already_AddRefed<nsIRunnable> aRunnable) + : mozilla::Runnable("MainThreadInvoker"), mRunnable(aRunnable) { + static const bool gotStatics = InitStatics(); + MOZ_ASSERT(gotStatics); + Unused << gotStatics; + } + + ~SyncRunnable() = default; + + NS_IMETHOD Run() override { + if (mHasRun) { + // The APC already ran, so we have nothing to do. + return NS_OK; + } + + // Run the pending APC in the queue. + MOZ_ASSERT(sNtTestAlert); + sNtTestAlert(); + return NS_OK; + } + + // This is called by MainThreadInvoker::MainThreadAPC. + void APCRun() { + mHasRun = true; + + TimeStamp runStart(TimeStamp::Now()); + mRunnable->Run(); + TimeStamp runEnd(TimeStamp::Now()); + + mDuration = runEnd - runStart; + + mEvent.Signal(); + } + + bool WaitUntilComplete() { + return mEvent.Wait(mozilla::mscom::MainThreadInvoker::GetTargetThread()); + } + + const mozilla::TimeDuration& GetDuration() const { return mDuration; } + + private: + bool mHasRun = false; + nsCOMPtr<nsIRunnable> mRunnable; + mozilla::mscom::SpinEvent mEvent; + mozilla::TimeDuration mDuration; + + static NtTestAlertPtr sNtTestAlert; + + static bool InitStatics() { + sNtTestAlert = reinterpret_cast<NtTestAlertPtr>( + ::GetProcAddress(::GetModuleHandleW(L"ntdll.dll"), "NtTestAlert")); + MOZ_ASSERT(sNtTestAlert); + return sNtTestAlert; + } +}; + +NtTestAlertPtr SyncRunnable::sNtTestAlert = nullptr; + +} // anonymous namespace + +namespace mozilla { +namespace mscom { + +HANDLE MainThreadInvoker::sMainThread = nullptr; + +/* static */ +bool MainThreadInvoker::InitStatics() { + nsCOMPtr<nsIThread> mainThread; + nsresult rv = ::NS_GetMainThread(getter_AddRefs(mainThread)); + if (NS_FAILED(rv)) { + return false; + } + + PRThread* mainPrThread = nullptr; + rv = mainThread->GetPRThread(&mainPrThread); + if (NS_FAILED(rv)) { + return false; + } + + PRUint32 tid = ::PR_GetThreadID(mainPrThread); + sMainThread = ::OpenThread(SYNCHRONIZE | THREAD_SET_CONTEXT, FALSE, tid); + + return !!sMainThread; +} + +MainThreadInvoker::MainThreadInvoker() { + static const bool gotStatics = InitStatics(); + MOZ_ASSERT(gotStatics); + Unused << gotStatics; +} + +bool MainThreadInvoker::Invoke(already_AddRefed<nsIRunnable>&& aRunnable) { + nsCOMPtr<nsIRunnable> runnable(std::move(aRunnable)); + if (!runnable) { + return false; + } + + if (NS_IsMainThread()) { + runnable->Run(); + return true; + } + + RefPtr<SyncRunnable> syncRunnable = new SyncRunnable(runnable.forget()); + + // The main thread could be either blocked on a condition variable waiting + // for a Gecko event, or it could be blocked waiting on a Windows HANDLE in + // IPC code (doing a sync message send). In the former case, we wake it by + // posting a Gecko runnable to the main thread. In the latter case, we wake + // it using an APC. However, the latter case doesn't happen very often now + // and APCs aren't otherwise run by the main thread. To ensure the + // SyncRunnable is cleaned up, we need both to run consistently. + // To do this, we: + // 1. Queue an APC which does the actual work. + // This ref gets released in MainThreadAPC when it runs. + SyncRunnable* syncRunnableRef = syncRunnable.get(); + NS_ADDREF(syncRunnableRef); + if (!::QueueUserAPC(&MainThreadAPC, sMainThread, + reinterpret_cast<UINT_PTR>(syncRunnableRef))) { + return false; + } + + // 2. Post a Gecko runnable (which always runs). If the APC hasn't run, the + // Gecko runnable runs it. Otherwise, it does nothing. + if (NS_FAILED(SchedulerGroup::Dispatch(TaskCategory::Other, + do_AddRef(syncRunnable)))) { + return false; + } + + bool result = syncRunnable->WaitUntilComplete(); + mDuration = syncRunnable->GetDuration(); + return result; +} + +/* static */ VOID CALLBACK MainThreadInvoker::MainThreadAPC(ULONG_PTR aParam) { + AUTO_PROFILER_THREAD_WAKE; + mozilla::BackgroundHangMonitor().NotifyActivity(); + MOZ_ASSERT(NS_IsMainThread()); + auto runnable = reinterpret_cast<SyncRunnable*>(aParam); + runnable->APCRun(); + NS_RELEASE(runnable); +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/MainThreadInvoker.h b/ipc/mscom/MainThreadInvoker.h new file mode 100644 index 0000000000..8b56f9bed0 --- /dev/null +++ b/ipc/mscom/MainThreadInvoker.h @@ -0,0 +1,56 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_MainThreadInvoker_h +#define mozilla_mscom_MainThreadInvoker_h + +#include <windows.h> + +#include <utility> + +#include "mozilla/AlreadyAddRefed.h" +#include "mozilla/StaticPtr.h" +#include "mozilla/TimeStamp.h" +#include "nsCOMPtr.h" +#include "nsThreadUtils.h" + +class nsIRunnable; + +namespace mozilla { +namespace mscom { + +class MainThreadInvoker { + public: + MainThreadInvoker(); + + bool Invoke(already_AddRefed<nsIRunnable>&& aRunnable); + const TimeDuration& GetDuration() const { return mDuration; } + static HANDLE GetTargetThread() { return sMainThread; } + + private: + TimeDuration mDuration; + + static bool InitStatics(); + static VOID CALLBACK MainThreadAPC(ULONG_PTR aParam); + + static HANDLE sMainThread; +}; + +template <typename Class, typename... Args> +inline bool InvokeOnMainThread(const char* aName, Class* aObject, + void (Class::*aMethod)(Args...), + Args&&... aArgs) { + nsCOMPtr<nsIRunnable> runnable(NewNonOwningRunnableMethod<Args...>( + aName, aObject, aMethod, std::forward<Args>(aArgs)...)); + + MainThreadInvoker invoker; + return invoker.Invoke(runnable.forget()); +} + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_MainThreadInvoker_h diff --git a/ipc/mscom/Objref.cpp b/ipc/mscom/Objref.cpp new file mode 100644 index 0000000000..37108c71ac --- /dev/null +++ b/ipc/mscom/Objref.cpp @@ -0,0 +1,411 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/mscom/Objref.h" + +#include "mozilla/Assertions.h" +#include "mozilla/mscom/Utils.h" +#include "mozilla/RefPtr.h" +#include "mozilla/ScopeExit.h" +#include "mozilla/UniquePtr.h" + +#include <guiddef.h> +#include <objidl.h> +#include <winnt.h> + +// {00000027-0000-0008-C000-000000000046} +static const GUID CLSID_AggStdMarshal = { + 0x27, 0x0, 0x8, {0xC0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x46}}; + +namespace { + +#pragma pack(push, 1) + +typedef uint64_t OID; +typedef uint64_t OXID; +typedef GUID IPID; + +struct STDOBJREF { + uint32_t mFlags; + uint32_t mPublicRefs; + OXID mOxid; + OID mOid; + IPID mIpid; +}; + +enum STDOBJREF_FLAGS { SORF_PING = 0, SORF_NOPING = 0x1000 }; + +struct DUALSTRINGARRAY { + static size_t SizeFromNumEntries(const uint16_t aNumEntries) { + return sizeof(mNumEntries) + sizeof(mSecurityOffset) + + aNumEntries * sizeof(uint16_t); + } + + size_t SizeOf() const { return SizeFromNumEntries(mNumEntries); } + + uint16_t mNumEntries; + uint16_t mSecurityOffset; + uint16_t mStringArray[1]; // Length is mNumEntries +}; + +struct OBJREF_STANDARD { + static size_t SizeOfFixedLenHeader() { return sizeof(mStd); } + + size_t SizeOf() const { return SizeOfFixedLenHeader() + mResAddr.SizeOf(); } + + STDOBJREF mStd; + DUALSTRINGARRAY mResAddr; +}; + +struct OBJREF_HANDLER { + static size_t SizeOfFixedLenHeader() { return sizeof(mStd) + sizeof(mClsid); } + + size_t SizeOf() const { return SizeOfFixedLenHeader() + mResAddr.SizeOf(); } + + STDOBJREF mStd; + CLSID mClsid; + DUALSTRINGARRAY mResAddr; +}; + +struct OBJREF_CUSTOM { + static size_t SizeOfFixedLenHeader() { + return sizeof(mClsid) + sizeof(mCbExtension) + sizeof(mReserved); + } + + CLSID mClsid; + uint32_t mCbExtension; + uint32_t mReserved; + uint8_t mPayload[1]; +}; + +enum OBJREF_FLAGS { + OBJREF_TYPE_STANDARD = 0x00000001UL, + OBJREF_TYPE_HANDLER = 0x00000002UL, + OBJREF_TYPE_CUSTOM = 0x00000004UL, + OBJREF_TYPE_EXTENDED = 0x00000008UL, +}; + +struct OBJREF { + static size_t SizeOfFixedLenHeader(OBJREF_FLAGS aFlags) { + size_t size = sizeof(mSignature) + sizeof(mFlags) + sizeof(mIid); + + switch (aFlags) { + case OBJREF_TYPE_STANDARD: + size += OBJREF_STANDARD::SizeOfFixedLenHeader(); + break; + case OBJREF_TYPE_HANDLER: + size += OBJREF_HANDLER::SizeOfFixedLenHeader(); + break; + case OBJREF_TYPE_CUSTOM: + size += OBJREF_CUSTOM::SizeOfFixedLenHeader(); + break; + default: + MOZ_ASSERT_UNREACHABLE("Unsupported OBJREF type"); + return 0; + } + + return size; + } + + size_t SizeOf() const { + size_t size = sizeof(mSignature) + sizeof(mFlags) + sizeof(mIid); + + switch (mFlags) { + case OBJREF_TYPE_STANDARD: + size += mObjRefStd.SizeOf(); + break; + case OBJREF_TYPE_HANDLER: + size += mObjRefHandler.SizeOf(); + break; + default: + MOZ_ASSERT_UNREACHABLE("Unsupported OBJREF type"); + return 0; + } + + return size; + } + + uint32_t mSignature; + uint32_t mFlags; + IID mIid; + union { + OBJREF_STANDARD mObjRefStd; + OBJREF_HANDLER mObjRefHandler; + OBJREF_CUSTOM mObjRefCustom; + // There are others but we're not supporting them here + }; +}; + +enum OBJREF_SIGNATURES { OBJREF_SIGNATURE = 0x574F454DUL }; + +#pragma pack(pop) + +struct ByteArrayDeleter { + void operator()(void* aPtr) { delete[] reinterpret_cast<uint8_t*>(aPtr); } +}; + +template <typename T> +using VarStructUniquePtr = mozilla::UniquePtr<T, ByteArrayDeleter>; + +} // anonymous namespace + +namespace mozilla { +namespace mscom { + +bool StripHandlerFromOBJREF(NotNull<IStream*> aStream, const uint64_t aStartPos, + const uint64_t aEndPos) { + // Ensure that the current stream position is set to the beginning + LARGE_INTEGER seekTo; + seekTo.QuadPart = aStartPos; + + HRESULT hr = aStream->Seek(seekTo, STREAM_SEEK_SET, nullptr); + if (FAILED(hr)) { + return false; + } + + ULONG bytesRead; + + uint32_t signature; + hr = aStream->Read(&signature, sizeof(signature), &bytesRead); + if (FAILED(hr) || bytesRead != sizeof(signature) || + signature != OBJREF_SIGNATURE) { + return false; + } + + uint32_t type; + hr = aStream->Read(&type, sizeof(type), &bytesRead); + if (FAILED(hr) || bytesRead != sizeof(type)) { + return false; + } + if (type != OBJREF_TYPE_HANDLER) { + // If we're not a handler then just seek to the end of the OBJREF and return + // success; there is nothing left to do. + seekTo.QuadPart = aEndPos; + return SUCCEEDED(aStream->Seek(seekTo, STREAM_SEEK_SET, nullptr)); + } + + IID iid; + hr = aStream->Read(&iid, sizeof(iid), &bytesRead); + if (FAILED(hr) || bytesRead != sizeof(iid) || !IsValidGUID(iid)) { + return false; + } + + // Seek past fixed-size STDOBJREF and CLSID + seekTo.QuadPart = sizeof(STDOBJREF) + sizeof(CLSID); + hr = aStream->Seek(seekTo, STREAM_SEEK_CUR, nullptr); + if (FAILED(hr)) { + return false; + } + + uint16_t numEntries; + hr = aStream->Read(&numEntries, sizeof(numEntries), &bytesRead); + if (FAILED(hr) || bytesRead != sizeof(numEntries)) { + return false; + } + + // We'll try to use a stack buffer if resAddrSize <= kMinDualStringArraySize + const uint32_t kMinDualStringArraySize = 12; + uint16_t staticResAddrBuf[kMinDualStringArraySize / sizeof(uint16_t)]; + + size_t resAddrSize = DUALSTRINGARRAY::SizeFromNumEntries(numEntries); + + DUALSTRINGARRAY* resAddr; + VarStructUniquePtr<DUALSTRINGARRAY> dynamicResAddrBuf; + + if (resAddrSize <= kMinDualStringArraySize) { + resAddr = reinterpret_cast<DUALSTRINGARRAY*>(staticResAddrBuf); + } else { + dynamicResAddrBuf.reset( + reinterpret_cast<DUALSTRINGARRAY*>(new uint8_t[resAddrSize])); + resAddr = dynamicResAddrBuf.get(); + } + + resAddr->mNumEntries = numEntries; + + // Because we've already read numEntries + ULONG bytesToRead = resAddrSize - sizeof(numEntries); + + hr = aStream->Read(&resAddr->mSecurityOffset, bytesToRead, &bytesRead); + if (FAILED(hr) || bytesRead != bytesToRead) { + return false; + } + + // Signature doesn't change so we'll seek past that + seekTo.QuadPart = aStartPos + sizeof(signature); + hr = aStream->Seek(seekTo, STREAM_SEEK_SET, nullptr); + if (FAILED(hr)) { + return false; + } + + ULONG bytesWritten; + + uint32_t newType = OBJREF_TYPE_STANDARD; + hr = aStream->Write(&newType, sizeof(newType), &bytesWritten); + if (FAILED(hr) || bytesWritten != sizeof(newType)) { + return false; + } + + // Skip past IID and STDOBJREF since those don't change + seekTo.QuadPart = sizeof(IID) + sizeof(STDOBJREF); + hr = aStream->Seek(seekTo, STREAM_SEEK_CUR, nullptr); + if (FAILED(hr)) { + return false; + } + + hr = aStream->Write(resAddr, resAddrSize, &bytesWritten); + if (FAILED(hr) || bytesWritten != resAddrSize) { + return false; + } + + // The difference between a OBJREF_STANDARD and an OBJREF_HANDLER is + // sizeof(CLSID), so we'll zero out the remaining bytes. + hr = aStream->Write(&CLSID_NULL, sizeof(CLSID), &bytesWritten); + if (FAILED(hr) || bytesWritten != sizeof(CLSID)) { + return false; + } + + // Back up to the end of the tweaked OBJREF. + // There are now sizeof(CLSID) less bytes. + // Bug 1403180: Using -sizeof(CLSID) with a relative seek sometimes + // doesn't work on Windows 7. + // It succeeds, but doesn't seek the stream for some unknown reason. + // Use an absolute seek instead. + seekTo.QuadPart = aEndPos - sizeof(CLSID); + return SUCCEEDED(aStream->Seek(seekTo, STREAM_SEEK_SET, nullptr)); +} + +uint32_t GetOBJREFSize(NotNull<IStream*> aStream) { + // Make a clone so that we don't manipulate aStream's seek pointer + RefPtr<IStream> cloned; + HRESULT hr = aStream->Clone(getter_AddRefs(cloned)); + if (FAILED(hr)) { + return 0; + } + + uint32_t accumulatedSize = 0; + + ULONG bytesRead; + + uint32_t signature; + hr = cloned->Read(&signature, sizeof(signature), &bytesRead); + if (FAILED(hr) || bytesRead != sizeof(signature) || + signature != OBJREF_SIGNATURE) { + return 0; + } + + accumulatedSize += bytesRead; + + uint32_t type; + hr = cloned->Read(&type, sizeof(type), &bytesRead); + if (FAILED(hr) || bytesRead != sizeof(type)) { + return 0; + } + + accumulatedSize += bytesRead; + + IID iid; + hr = cloned->Read(&iid, sizeof(iid), &bytesRead); + if (FAILED(hr) || bytesRead != sizeof(iid) || !IsValidGUID(iid)) { + return 0; + } + + accumulatedSize += bytesRead; + + LARGE_INTEGER seekTo; + + if (type == OBJREF_TYPE_CUSTOM) { + CLSID clsid; + hr = cloned->Read(&clsid, sizeof(CLSID), &bytesRead); + if (FAILED(hr) || bytesRead != sizeof(CLSID)) { + return 0; + } + + if (clsid != CLSID_StdMarshal && clsid != CLSID_AggStdMarshal) { + // We can only calulate the size if the payload is a standard OBJREF as + // identified by clsid being CLSID_StdMarshal or CLSID_AggStdMarshal. + // (CLSID_AggStdMarshal, the aggregated standard marshaler, is used when + // the handler marshals an interface.) + return 0; + } + + accumulatedSize += bytesRead; + + seekTo.QuadPart = + sizeof(OBJREF_CUSTOM::mCbExtension) + sizeof(OBJREF_CUSTOM::mReserved); + hr = cloned->Seek(seekTo, STREAM_SEEK_CUR, nullptr); + if (FAILED(hr)) { + return 0; + } + + accumulatedSize += seekTo.LowPart; + + uint32_t payloadLen = GetOBJREFSize(WrapNotNull(cloned.get())); + if (!payloadLen) { + return 0; + } + + accumulatedSize += payloadLen; + return accumulatedSize; + } + + switch (type) { + case OBJREF_TYPE_STANDARD: + seekTo.QuadPart = OBJREF_STANDARD::SizeOfFixedLenHeader(); + break; + case OBJREF_TYPE_HANDLER: + seekTo.QuadPart = OBJREF_HANDLER::SizeOfFixedLenHeader(); + break; + default: + return 0; + } + + hr = cloned->Seek(seekTo, STREAM_SEEK_CUR, nullptr); + if (FAILED(hr)) { + return 0; + } + + accumulatedSize += seekTo.LowPart; + + uint16_t numEntries; + hr = cloned->Read(&numEntries, sizeof(numEntries), &bytesRead); + if (FAILED(hr) || bytesRead != sizeof(numEntries)) { + return 0; + } + + accumulatedSize += DUALSTRINGARRAY::SizeFromNumEntries(numEntries); + return accumulatedSize; +} + +bool SetIID(NotNull<IStream*> aStream, const uint64_t aStart, REFIID aNewIid) { + ULARGE_INTEGER initialStreamPos; + + LARGE_INTEGER seekTo; + seekTo.QuadPart = 0LL; + HRESULT hr = aStream->Seek(seekTo, STREAM_SEEK_CUR, &initialStreamPos); + if (FAILED(hr)) { + return false; + } + + auto resetStreamPos = MakeScopeExit([&]() { + seekTo.QuadPart = initialStreamPos.QuadPart; + hr = aStream->Seek(seekTo, STREAM_SEEK_SET, nullptr); + MOZ_DIAGNOSTIC_ASSERT(SUCCEEDED(hr)); + }); + + seekTo.QuadPart = + aStart + sizeof(OBJREF::mSignature) + sizeof(OBJREF::mFlags); + hr = aStream->Seek(seekTo, STREAM_SEEK_SET, nullptr); + if (FAILED(hr)) { + return false; + } + + ULONG bytesWritten; + hr = aStream->Write(&aNewIid, sizeof(IID), &bytesWritten); + return SUCCEEDED(hr) && bytesWritten == sizeof(IID); +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/Objref.h b/ipc/mscom/Objref.h new file mode 100644 index 0000000000..1b19b68644 --- /dev/null +++ b/ipc/mscom/Objref.h @@ -0,0 +1,53 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_Objref_h +#define mozilla_mscom_Objref_h + +#include "mozilla/NotNull.h" +#include "mozilla/RefPtr.h" + +#include <guiddef.h> + +struct IStream; + +namespace mozilla { +namespace mscom { + +/** + * Given a buffer containing a serialized proxy to an interface with a handler, + * this function strips out the handler and converts it to a standard one. + * @param aStream IStream containing a serialized proxy. + * There should be nothing else written to the stream past the + * current OBJREF. + * @param aStart Absolute position of the beginning of the OBJREF. + * @param aEnd Absolute position of the end of the OBJREF. + * @return true if the handler was successfully stripped, otherwise false. + */ +bool StripHandlerFromOBJREF(NotNull<IStream*> aStream, const uint64_t aStart, + const uint64_t aEnd); + +/** + * Given a buffer containing a serialized proxy to an interface, this function + * returns the length of the serialized data. + * @param aStream IStream containing a serialized proxy. The stream pointer + * must be positioned at the beginning of the OBJREF. + * @return The size of the serialized proxy, or 0 on error. + */ +uint32_t GetOBJREFSize(NotNull<IStream*> aStream); + +/** + * Overrides the IID in a serialized proxy with the specified IID. + * @param aStream Pointer to a stream containing a serialized proxy. + * @param aStart Offset to the beginning of the serialized proxy within aStream. + * @param aNewIid The replacement IID to apply to the serialized proxy. + */ +bool SetIID(NotNull<IStream*> aStream, const uint64_t aStart, REFIID aNewIid); + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_Objref_h diff --git a/ipc/mscom/PassthruProxy.cpp b/ipc/mscom/PassthruProxy.cpp new file mode 100644 index 0000000000..39bc7dba7e --- /dev/null +++ b/ipc/mscom/PassthruProxy.cpp @@ -0,0 +1,393 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/mscom/PassthruProxy.h" +#include "mozilla/mscom/ProxyStream.h" +#include "VTableBuilder.h" + +// {96EF5801-CE6D-416E-A50A-0C2959AEAE1C} +static const GUID CLSID_PassthruProxy = { + 0x96ef5801, 0xce6d, 0x416e, {0xa5, 0xa, 0xc, 0x29, 0x59, 0xae, 0xae, 0x1c}}; + +namespace mozilla { +namespace mscom { + +PassthruProxy::PassthruProxy() + : mRefCnt(0), + mWrappedIid(), + mVTableSize(0), + mVTable(nullptr), + mForgetPreservedStream(false) {} + +PassthruProxy::PassthruProxy(ProxyStream::Environment* aEnv, REFIID aIidToWrap, + uint32_t aVTableSize, + NotNull<IUnknown*> aObjToWrap) + : mRefCnt(0), + mWrappedIid(aIidToWrap), + mVTableSize(aVTableSize), + mVTable(nullptr), + mForgetPreservedStream(false) { + ProxyStream proxyStream(aIidToWrap, aObjToWrap, aEnv, + ProxyStreamFlags::ePreservable); + mPreservedStream = proxyStream.GetPreservedStream(); + MOZ_ASSERT(mPreservedStream); +} + +PassthruProxy::~PassthruProxy() { + if (mForgetPreservedStream) { + // We want to release the ref without clearing marshal data + IStream* stream = mPreservedStream.release(); + stream->Release(); + } + + if (mVTable) { + DeleteNullVTable(mVTable); + } +} + +HRESULT +PassthruProxy::QueryProxyInterface(void** aOutInterface) { + // Even though we don't really provide the methods for the interface that + // we are proxying, we need to support it in QueryInterface. Instead we + // return an interface that, other than IUnknown, contains nullptr for all of + // its vtable entires. Obviously this interface is not intended to actually + // be called, it just has to be there. + + if (!mVTable) { + MOZ_ASSERT(mVTableSize); + mVTable = BuildNullVTable(static_cast<IMarshal*>(this), mVTableSize); + MOZ_ASSERT(mVTable); + } + + *aOutInterface = mVTable; + mVTable->AddRef(); + return S_OK; +} + +HRESULT +PassthruProxy::QueryInterface(REFIID aIid, void** aOutInterface) { + if (!aOutInterface) { + return E_INVALIDARG; + } + + *aOutInterface = nullptr; + + if (aIid == IID_IUnknown || aIid == IID_IMarshal) { + RefPtr<IMarshal> ptr(this); + ptr.forget(aOutInterface); + return S_OK; + } + + if (!IsInitialMarshal()) { + // We implement IClientSecurity so that IsProxy() recognizes us as such + if (aIid == IID_IClientSecurity) { + RefPtr<IClientSecurity> ptr(this); + ptr.forget(aOutInterface); + return S_OK; + } + + if (aIid == mWrappedIid) { + return QueryProxyInterface(aOutInterface); + } + } + + return E_NOINTERFACE; +} + +ULONG +PassthruProxy::AddRef() { return ++mRefCnt; } + +ULONG +PassthruProxy::Release() { + ULONG result = --mRefCnt; + if (!result) { + delete this; + } + + return result; +} + +HRESULT +PassthruProxy::GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + CLSID* pCid) { + if (!pCid) { + return E_INVALIDARG; + } + + if (IsInitialMarshal()) { + // To properly use this class we need to be using TABLESTRONG marshaling + MOZ_ASSERT(mshlflags & MSHLFLAGS_TABLESTRONG); + + // When we're marshaling for the first time, we identify ourselves as the + // class to use for unmarshaling. + *pCid = CLSID_PassthruProxy; + } else { + // Subsequent marshals use the standard marshaler. + *pCid = CLSID_StdMarshal; + } + + return S_OK; +} + +HRESULT +PassthruProxy::GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + DWORD* pSize) { + STATSTG statstg; + HRESULT hr; + + if (!IsInitialMarshal()) { + // If we are not the initial marshal then we are just copying mStream out + // to the marshal stream, so we just use mStream's size. + hr = mStream->Stat(&statstg, STATFLAG_NONAME); + if (FAILED(hr)) { + return hr; + } + + *pSize = statstg.cbSize.LowPart; + + return hr; + } + + // To properly use this class we need to be using TABLESTRONG marshaling + MOZ_ASSERT(mshlflags & MSHLFLAGS_TABLESTRONG); + + if (!mPreservedStream) { + return E_POINTER; + } + + hr = mPreservedStream->Stat(&statstg, STATFLAG_NONAME); + if (FAILED(hr)) { + return hr; + } + + *pSize = statstg.cbSize.LowPart + sizeof(mVTableSize) + sizeof(mWrappedIid); + return hr; +} + +HRESULT +PassthruProxy::MarshalInterface(IStream* pStm, REFIID riid, void* pv, + DWORD dwDestContext, void* pvDestContext, + DWORD mshlflags) { + MOZ_ASSERT(riid == mWrappedIid); + if (riid != mWrappedIid) { + return E_NOINTERFACE; + } + + MOZ_ASSERT(pv == mVTable); + if (pv != mVTable) { + return E_INVALIDARG; + } + + HRESULT hr; + RefPtr<IStream> cloned; + + if (IsInitialMarshal()) { + // To properly use this class we need to be using TABLESTRONG marshaling + MOZ_ASSERT(mshlflags & MSHLFLAGS_TABLESTRONG); + + if (!mPreservedStream) { + return E_POINTER; + } + + // We write out the vtable size and the IID so that the wrapped proxy knows + // how to build its vtable on the content side. + ULONG bytesWritten; + hr = pStm->Write(&mVTableSize, sizeof(mVTableSize), &bytesWritten); + if (FAILED(hr)) { + return hr; + } + if (bytesWritten != sizeof(mVTableSize)) { + return E_UNEXPECTED; + } + + hr = pStm->Write(&mWrappedIid, sizeof(mWrappedIid), &bytesWritten); + if (FAILED(hr)) { + return hr; + } + if (bytesWritten != sizeof(mWrappedIid)) { + return E_UNEXPECTED; + } + + hr = mPreservedStream->Clone(getter_AddRefs(cloned)); + } else { + hr = mStream->Clone(getter_AddRefs(cloned)); + } + + if (FAILED(hr)) { + return hr; + } + + STATSTG statstg; + hr = cloned->Stat(&statstg, STATFLAG_NONAME); + if (FAILED(hr)) { + return hr; + } + + // Copy the proxy data + hr = cloned->CopyTo(pStm, statstg.cbSize, nullptr, nullptr); + + if (SUCCEEDED(hr) && IsInitialMarshal() && mPreservedStream && + (mshlflags & MSHLFLAGS_TABLESTRONG)) { + // If we have successfully copied mPreservedStream at least once for a + // MSHLFLAGS_TABLESTRONG marshal, then we want to forget our reference to + // it. This is because the COM runtime will manage it from here on out. + mForgetPreservedStream = true; + } + + return hr; +} + +HRESULT +PassthruProxy::UnmarshalInterface(IStream* pStm, REFIID riid, void** ppv) { + // Read out the interface info that we copied during marshaling + ULONG bytesRead; + HRESULT hr = pStm->Read(&mVTableSize, sizeof(mVTableSize), &bytesRead); + if (FAILED(hr)) { + return hr; + } + if (bytesRead != sizeof(mVTableSize)) { + return E_UNEXPECTED; + } + + hr = pStm->Read(&mWrappedIid, sizeof(mWrappedIid), &bytesRead); + if (FAILED(hr)) { + return hr; + } + if (bytesRead != sizeof(mWrappedIid)) { + return E_UNEXPECTED; + } + + // Now we copy the proxy inside pStm into mStream + hr = CopySerializedProxy(pStm, getter_AddRefs(mStream)); + if (FAILED(hr)) { + return hr; + } + + return QueryInterface(riid, ppv); +} + +HRESULT +PassthruProxy::ReleaseMarshalData(IStream* pStm) { + if (!IsInitialMarshal()) { + return S_OK; + } + + if (!pStm) { + return E_INVALIDARG; + } + + if (mPreservedStream) { + // If we still have mPreservedStream, then simply clearing it will release + // its marshal data automagically. + mPreservedStream = nullptr; + return S_OK; + } + + // Skip past the metadata that we wrote during initial marshaling. + LARGE_INTEGER seekTo; + seekTo.QuadPart = sizeof(mVTableSize) + sizeof(mWrappedIid); + HRESULT hr = pStm->Seek(seekTo, STREAM_SEEK_CUR, nullptr); + if (FAILED(hr)) { + return hr; + } + + // Now release the "inner" marshal data + return ::CoReleaseMarshalData(pStm); +} + +HRESULT +PassthruProxy::DisconnectObject(DWORD dwReserved) { return S_OK; } + +// The remainder of this code is just boilerplate COM stuff that provides the +// association between CLSID_PassthruProxy and the PassthruProxy class itself. + +class PassthruProxyClassObject final : public IClassFactory { + public: + PassthruProxyClassObject(); + + // IUnknown + STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) override; + STDMETHODIMP_(ULONG) AddRef() override; + STDMETHODIMP_(ULONG) Release() override; + + // IClassFactory + STDMETHODIMP CreateInstance(IUnknown* aOuter, REFIID aIid, + void** aOutObject) override; + STDMETHODIMP LockServer(BOOL aLock) override; + + private: + ~PassthruProxyClassObject() = default; + + Atomic<ULONG> mRefCnt; +}; + +PassthruProxyClassObject::PassthruProxyClassObject() : mRefCnt(0) {} + +HRESULT +PassthruProxyClassObject::QueryInterface(REFIID aIid, void** aOutInterface) { + if (!aOutInterface) { + return E_INVALIDARG; + } + + *aOutInterface = nullptr; + + if (aIid == IID_IUnknown || aIid == IID_IClassFactory) { + RefPtr<IClassFactory> ptr(this); + ptr.forget(aOutInterface); + return S_OK; + } + + return E_NOINTERFACE; +} + +ULONG +PassthruProxyClassObject::AddRef() { return ++mRefCnt; } + +ULONG +PassthruProxyClassObject::Release() { + ULONG result = --mRefCnt; + if (!result) { + delete this; + } + + return result; +} + +HRESULT +PassthruProxyClassObject::CreateInstance(IUnknown* aOuter, REFIID aIid, + void** aOutObject) { + // We don't expect to aggregate + MOZ_ASSERT(!aOuter); + if (aOuter) { + return E_INVALIDARG; + } + + RefPtr<PassthruProxy> ptr(new PassthruProxy()); + return ptr->QueryInterface(aIid, aOutObject); +} + +HRESULT +PassthruProxyClassObject::LockServer(BOOL aLock) { + // No-op since xul.dll is always in memory + return S_OK; +} + +/* static */ +HRESULT PassthruProxy::Register() { + DWORD cookie; + RefPtr<IClassFactory> classObj(new PassthruProxyClassObject()); + return ::CoRegisterClassObject(CLSID_PassthruProxy, classObj, + CLSCTX_INPROC_SERVER, REGCLS_MULTIPLEUSE, + &cookie); +} + +} // namespace mscom +} // namespace mozilla + +HRESULT +RegisterPassthruProxy() { return mozilla::mscom::PassthruProxy::Register(); } diff --git a/ipc/mscom/PassthruProxy.h b/ipc/mscom/PassthruProxy.h new file mode 100644 index 0000000000..afdb5c648e --- /dev/null +++ b/ipc/mscom/PassthruProxy.h @@ -0,0 +1,127 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_PassthruProxy_h +#define mozilla_mscom_PassthruProxy_h + +#include "mozilla/Atomics.h" +#include "mozilla/mscom/ProxyStream.h" +#include "mozilla/mscom/Ptr.h" +#include "mozilla/NotNull.h" +#if defined(MOZ_SANDBOX) +# include "mozilla/SandboxSettings.h" +#endif // defined(MOZ_SANDBOX) + +#include <objbase.h> + +namespace mozilla { +namespace mscom { +namespace detail { + +template <typename Iface> +struct VTableSizer; + +template <> +struct VTableSizer<IDispatch> { + enum { Size = 7 }; +}; + +} // namespace detail + +class PassthruProxy final : public IMarshal, public IClientSecurity { + public: + template <typename Iface> + static RefPtr<Iface> Wrap(NotNull<Iface*> aIn) { + static_assert(detail::VTableSizer<Iface>::Size >= 3, "VTable too small"); + +#if defined(MOZ_SANDBOX) + if (mozilla::GetEffectiveContentSandboxLevel() < 3) { + // The sandbox isn't strong enough to be a problem; no wrapping required + return aIn.get(); + } + + typename detail::EnvironmentSelector<Iface>::Type env; + + RefPtr<PassthruProxy> passthru(new PassthruProxy( + &env, __uuidof(Iface), detail::VTableSizer<Iface>::Size, aIn)); + + RefPtr<Iface> result; + if (FAILED(passthru->QueryProxyInterface(getter_AddRefs(result)))) { + return nullptr; + } + + return result; +#else + // No wrapping required + return aIn.get(); +#endif // defined(MOZ_SANDBOX) + } + + static HRESULT Register(); + + PassthruProxy(); + + // IUnknown + STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override; + STDMETHODIMP_(ULONG) AddRef() override; + STDMETHODIMP_(ULONG) Release() override; + + // IMarshal + STDMETHODIMP GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + CLSID* pCid) override; + STDMETHODIMP GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + DWORD* pSize) override; + STDMETHODIMP MarshalInterface(IStream* pStm, REFIID riid, void* pv, + DWORD dwDestContext, void* pvDestContext, + DWORD mshlflags) override; + STDMETHODIMP UnmarshalInterface(IStream* pStm, REFIID riid, + void** ppv) override; + STDMETHODIMP ReleaseMarshalData(IStream* pStm) override; + STDMETHODIMP DisconnectObject(DWORD dwReserved) override; + + // IClientSecurity - we don't actually implement this interface, but its + // presence signals to mscom::IsProxy() that we are a proxy. + STDMETHODIMP QueryBlanket(IUnknown* aProxy, DWORD* aAuthnSvc, + DWORD* aAuthzSvc, OLECHAR** aSrvPrincName, + DWORD* aAuthnLevel, DWORD* aImpLevel, + void** aAuthInfo, DWORD* aCapabilities) override { + return E_NOTIMPL; + } + + STDMETHODIMP SetBlanket(IUnknown* aProxy, DWORD aAuthnSvc, DWORD aAuthzSvc, + OLECHAR* aSrvPrincName, DWORD aAuthnLevel, + DWORD aImpLevel, void* aAuthInfo, + DWORD aCapabilities) override { + return E_NOTIMPL; + } + + STDMETHODIMP CopyProxy(IUnknown* aProxy, IUnknown** aOutCopy) override { + return E_NOTIMPL; + } + + private: + PassthruProxy(ProxyStream::Environment* aEnv, REFIID aIidToWrap, + uint32_t aVTableSize, NotNull<IUnknown*> aObjToWrap); + ~PassthruProxy(); + + bool IsInitialMarshal() const { return !mStream; } + HRESULT QueryProxyInterface(void** aOutInterface); + + Atomic<ULONG> mRefCnt; + IID mWrappedIid; + PreservedStreamPtr mPreservedStream; + RefPtr<IStream> mStream; + uint32_t mVTableSize; + IUnknown* mVTable; + bool mForgetPreservedStream; +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_PassthruProxy_h diff --git a/ipc/mscom/ProcessRuntime.cpp b/ipc/mscom/ProcessRuntime.cpp new file mode 100644 index 0000000000..2b4e613d75 --- /dev/null +++ b/ipc/mscom/ProcessRuntime.cpp @@ -0,0 +1,480 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/mscom/ProcessRuntime.h" + +#if defined(ACCESSIBILITY) && \ + (defined(MOZILLA_INTERNAL_API) || defined(MOZ_HAS_MOZGLUE)) +# include "mozilla/mscom/ActCtxResource.h" +#endif // defined(ACCESSIBILITY) && (defined(MOZILLA_INTERNAL_API) || + // defined(MOZ_HAS_MOZGLUE)) +#include "mozilla/Assertions.h" +#include "mozilla/DynamicallyLinkedFunctionPtr.h" +#include "mozilla/mscom/COMWrappers.h" +#include "mozilla/mscom/ProcessRuntimeShared.h" +#include "mozilla/RefPtr.h" +#include "mozilla/UniquePtr.h" +#include "mozilla/Unused.h" +#include "mozilla/Vector.h" +#include "mozilla/WindowsProcessMitigations.h" +#include "mozilla/WindowsVersion.h" + +#if defined(MOZILLA_INTERNAL_API) +# include "mozilla/mscom/EnsureMTA.h" +# if defined(MOZ_SANDBOX) +# include "mozilla/sandboxTarget.h" +# endif // defined(MOZ_SANDBOX) +#endif // defined(MOZILLA_INTERNAL_API) + +#include <accctrl.h> +#include <aclapi.h> +#include <objbase.h> +#include <objidl.h> + +// This API from oleaut32.dll is not declared in Windows SDK headers +extern "C" void __cdecl SetOaNoCache(void); + +using namespace mozilla::mscom::detail; + +namespace mozilla { +namespace mscom { + +#if defined(MOZILLA_INTERNAL_API) +ProcessRuntime* ProcessRuntime::sInstance = nullptr; + +ProcessRuntime::ProcessRuntime() : ProcessRuntime(XRE_GetProcessType()) {} + +ProcessRuntime::ProcessRuntime(const GeckoProcessType aProcessType) + : ProcessRuntime(aProcessType == GeckoProcessType_Default + ? ProcessCategory::GeckoBrowserParent + : ProcessCategory::GeckoChild) {} +#endif // defined(MOZILLA_INTERNAL_API) + +ProcessRuntime::ProcessRuntime(const ProcessCategory aProcessCategory) + : mInitResult(CO_E_NOTINITIALIZED), mProcessCategory(aProcessCategory) { +#if defined(ACCESSIBILITY) +# if defined(MOZILLA_INTERNAL_API) + // If we're inside XUL, and we're the parent process, then we trust that + // this has already been initialized for us prior to XUL being loaded. + // Only required in the child if the Resource ID has been passed down. + if (aProcessCategory != ProcessCategory::GeckoBrowserParent && + ActCtxResource::GetAccessibilityResourceId()) { + mActCtxRgn.emplace(ActCtxResource::GetAccessibilityResource()); + } +# elif defined(MOZ_HAS_MOZGLUE) + // If we're here, then we're in mozglue and initializing this for the parent + // process. + MOZ_ASSERT(aProcessCategory == ProcessCategory::GeckoBrowserParent); + mActCtxRgn.emplace(ActCtxResource::GetAccessibilityResource()); +# endif +#endif // defined(ACCESSIBILITY) + +#if defined(MOZILLA_INTERNAL_API) + MOZ_DIAGNOSTIC_ASSERT(!sInstance); + sInstance = this; + + EnsureMTA(); + /** + * From this point forward, all threads in this process are implicitly + * members of the multi-threaded apartment, with the following exceptions: + * 1. If any Win32 GUI APIs were called on the current thread prior to + * executing this constructor, then this thread has already been implicitly + * initialized as the process's main STA thread; or + * 2. A thread explicitly and successfully calls CoInitialize(Ex) to specify + * otherwise. + */ + + const bool isCurThreadImplicitMTA = IsCurrentThreadImplicitMTA(); + // We only assert that the implicit MTA precondition holds when not running + // as the Gecko parent process. + MOZ_DIAGNOSTIC_ASSERT(aProcessCategory == + ProcessCategory::GeckoBrowserParent || + isCurThreadImplicitMTA); + +# if defined(MOZ_SANDBOX) + const bool isLockedDownChildProcess = + mProcessCategory == ProcessCategory::GeckoChild && IsWin32kLockedDown(); + // If our process is running under Win32k lockdown, we cannot initialize + // COM with a single-threaded apartment. This is because STAs create a hidden + // window, which implicitly requires user32 and Win32k, which are blocked. + // Instead we start the multi-threaded apartment and conduct our process-wide + // COM initialization there. + if (isLockedDownChildProcess) { + // Make sure we're still running with the sandbox's privileged impersonation + // token. + HANDLE rawCurThreadImpToken; + if (!::OpenThreadToken(::GetCurrentThread(), TOKEN_DUPLICATE | TOKEN_QUERY, + FALSE, &rawCurThreadImpToken)) { + mInitResult = HRESULT_FROM_WIN32(::GetLastError()); + return; + } + nsAutoHandle curThreadImpToken(rawCurThreadImpToken); + + // Ensure that our current token is still an impersonation token (ie, we + // have not yet called RevertToSelf() on this thread). + DWORD len; + TOKEN_TYPE tokenType; + MOZ_RELEASE_ASSERT( + ::GetTokenInformation(rawCurThreadImpToken, TokenType, &tokenType, + sizeof(tokenType), &len) && + len == sizeof(tokenType) && tokenType == TokenImpersonation); + + // Ideally we want our current thread to be running implicitly inside the + // MTA, but if for some wacky reason we did not end up with that, we may + // compensate by completing initialization via EnsureMTA's persistent + // thread. + if (!isCurThreadImplicitMTA) { + InitUsingPersistentMTAThread(curThreadImpToken); + return; + } + } +# endif // defined(MOZ_SANDBOX) +#endif // defined(MOZILLA_INTERNAL_API) + + mAptRegion.Init(GetDesiredApartmentType(mProcessCategory)); + + // It can happen that we are not the outermost COM initialization on this + // thread. In fact it should regularly be the case that the outermost + // initialization occurs from outside of XUL, before we show the skeleton UI, + // at which point we still need to run some things here from within XUL. + if (!mAptRegion.IsValidOutermost()) { + mInitResult = mAptRegion.GetHResult(); +#if defined(MOZILLA_INTERNAL_API) + MOZ_ASSERT(mProcessCategory == ProcessCategory::GeckoBrowserParent); + if (mProcessCategory != ProcessCategory::GeckoBrowserParent) { + // This is unexpected unless we're GeckoBrowserParent + return; + } + + ProcessInitLock lock; + + // Is another instance of ProcessRuntime responsible for the outer + // initialization? + const bool prevInit = + lock.GetInitState() == ProcessInitState::FullyInitialized; + MOZ_ASSERT(prevInit); + if (prevInit) { + PostInit(); + } +#endif // defined(MOZILLA_INTERNAL_API) + return; + } + + InitInsideApartment(); + if (FAILED(mInitResult)) { + return; + } + +#if defined(MOZILLA_INTERNAL_API) +# if defined(MOZ_SANDBOX) + if (isLockedDownChildProcess) { + // In locked-down child processes, defer PostInit until priv drop + SandboxTarget::Instance()->RegisterSandboxStartCallback([self = this]() { + // Ensure that we're still live and the init was successful before + // calling PostInit() + if (self == sInstance && SUCCEEDED(self->mInitResult)) { + PostInit(); + } + }); + return; + } +# endif // defined(MOZ_SANDBOX) + + PostInit(); +#endif // defined(MOZILLA_INTERNAL_API) +} + +#if defined(MOZILLA_INTERNAL_API) +ProcessRuntime::~ProcessRuntime() { + MOZ_DIAGNOSTIC_ASSERT(sInstance == this); + sInstance = nullptr; +} + +# if defined(MOZ_SANDBOX) +void ProcessRuntime::InitUsingPersistentMTAThread( + const nsAutoHandle& aCurThreadToken) { + // Create an impersonation token based on the current thread's token + HANDLE rawMtaThreadImpToken = nullptr; + if (!::DuplicateToken(aCurThreadToken, SecurityImpersonation, + &rawMtaThreadImpToken)) { + mInitResult = HRESULT_FROM_WIN32(::GetLastError()); + return; + } + nsAutoHandle mtaThreadImpToken(rawMtaThreadImpToken); + + // Impersonate and initialize. + bool tokenSet = false; + EnsureMTA( + [this, rawMtaThreadImpToken, &tokenSet]() -> void { + if (!::SetThreadToken(nullptr, rawMtaThreadImpToken)) { + mInitResult = HRESULT_FROM_WIN32(::GetLastError()); + return; + } + + tokenSet = true; + InitInsideApartment(); + }, + EnsureMTA::Option::ForceDispatchToPersistentThread); + + if (!tokenSet) { + return; + } + + SandboxTarget::Instance()->RegisterSandboxStartCallback( + [self = this]() -> void { + EnsureMTA( + []() -> void { + // This is a security risk if it fails, so we release assert + MOZ_RELEASE_ASSERT(::RevertToSelf(), + "mscom::ProcessRuntime RevertToSelf failed"); + }, + EnsureMTA::Option::ForceDispatchToPersistentThread); + + // Ensure that we're still live and the init was successful before + // calling PostInit() + if (self == sInstance && SUCCEEDED(self->mInitResult)) { + PostInit(); + } + }); +} +# endif // defined(MOZ_SANDBOX) +#endif // defined(MOZILLA_INTERNAL_API) + +/* static */ +COINIT ProcessRuntime::GetDesiredApartmentType( + const ProcessRuntime::ProcessCategory aProcessCategory) { + switch (aProcessCategory) { + case ProcessCategory::GeckoBrowserParent: + return COINIT_APARTMENTTHREADED; + case ProcessCategory::GeckoChild: + if (!IsWin32kLockedDown()) { + // If Win32k is not locked down then we probably still need STA. + // We disable DDE since that is not usable from child processes. + return static_cast<COINIT>(COINIT_APARTMENTTHREADED | + COINIT_DISABLE_OLE1DDE); + } + + [[fallthrough]]; + default: + return COINIT_MULTITHREADED; + } +} + +void ProcessRuntime::InitInsideApartment() { + ProcessInitLock lock; + const ProcessInitState prevInitState = lock.GetInitState(); + if (prevInitState == ProcessInitState::FullyInitialized) { + // COM has already been initialized by a previous ProcessRuntime instance + mInitResult = S_OK; + return; + } + + if (prevInitState < ProcessInitState::PartialSecurityInitialized) { + // We are required to initialize security prior to configuring global + // options. + mInitResult = InitializeSecurity(mProcessCategory); + MOZ_DIAGNOSTIC_ASSERT(SUCCEEDED(mInitResult)); + + // Even though this isn't great, we should try to proceed even when + // CoInitializeSecurity has previously been called: the additional settings + // we want to change are important enough that we don't want to skip them. + if (FAILED(mInitResult) && mInitResult != RPC_E_TOO_LATE) { + return; + } + + lock.SetInitState(ProcessInitState::PartialSecurityInitialized); + } + + if (prevInitState < ProcessInitState::PartialGlobalOptions) { + RefPtr<IGlobalOptions> globalOpts; + mInitResult = wrapped::CoCreateInstance( + CLSID_GlobalOptions, nullptr, CLSCTX_INPROC_SERVER, IID_IGlobalOptions, + getter_AddRefs(globalOpts)); + MOZ_ASSERT(SUCCEEDED(mInitResult)); + if (FAILED(mInitResult)) { + return; + } + + // Disable COM's catch-all exception handler + mInitResult = globalOpts->Set(COMGLB_EXCEPTION_HANDLING, + COMGLB_EXCEPTION_DONOT_HANDLE_ANY); + MOZ_ASSERT(SUCCEEDED(mInitResult)); + if (FAILED(mInitResult)) { + return; + } + + lock.SetInitState(ProcessInitState::PartialGlobalOptions); + } + + // Disable the BSTR cache (as it never invalidates, thus leaking memory) + // (This function is itself idempotent, so we do not concern ourselves with + // tracking whether or not we've already called it.) + ::SetOaNoCache(); + + lock.SetInitState(ProcessInitState::FullyInitialized); +} + +#if defined(MOZILLA_INTERNAL_API) +/** + * Guaranteed to run *after* the COM (and possible sandboxing) initialization + * has successfully completed and stabilized. This method MUST BE IDEMPOTENT! + */ +/* static */ void ProcessRuntime::PostInit() { + // Currently "roughed-in" but unused. +} +#endif // defined(MOZILLA_INTERNAL_API) + +/* static */ +DWORD +ProcessRuntime::GetClientThreadId() { + DWORD callerTid; + HRESULT hr = ::CoGetCallerTID(&callerTid); + // Don't return callerTid unless the call succeeded and returned S_FALSE, + // indicating that the caller originates from a different process. + if (hr != S_FALSE) { + return 0; + } + + return callerTid; +} + +/* static */ +HRESULT +ProcessRuntime::InitializeSecurity(const ProcessCategory aProcessCategory) { + HANDLE rawToken = nullptr; + BOOL ok = ::OpenProcessToken(::GetCurrentProcess(), TOKEN_QUERY, &rawToken); + if (!ok) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + nsAutoHandle token(rawToken); + + DWORD len = 0; + ok = ::GetTokenInformation(token, TokenUser, nullptr, len, &len); + DWORD win32Error = ::GetLastError(); + if (!ok && win32Error != ERROR_INSUFFICIENT_BUFFER) { + return HRESULT_FROM_WIN32(win32Error); + } + + auto tokenUserBuf = MakeUnique<BYTE[]>(len); + TOKEN_USER& tokenUser = *reinterpret_cast<TOKEN_USER*>(tokenUserBuf.get()); + ok = ::GetTokenInformation(token, TokenUser, tokenUserBuf.get(), len, &len); + if (!ok) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + + len = 0; + ok = ::GetTokenInformation(token, TokenPrimaryGroup, nullptr, len, &len); + win32Error = ::GetLastError(); + if (!ok && win32Error != ERROR_INSUFFICIENT_BUFFER) { + return HRESULT_FROM_WIN32(win32Error); + } + + auto tokenPrimaryGroupBuf = MakeUnique<BYTE[]>(len); + TOKEN_PRIMARY_GROUP& tokenPrimaryGroup = + *reinterpret_cast<TOKEN_PRIMARY_GROUP*>(tokenPrimaryGroupBuf.get()); + ok = ::GetTokenInformation(token, TokenPrimaryGroup, + tokenPrimaryGroupBuf.get(), len, &len); + if (!ok) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + + SECURITY_DESCRIPTOR sd; + if (!::InitializeSecurityDescriptor(&sd, SECURITY_DESCRIPTOR_REVISION)) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + + BYTE systemSid[SECURITY_MAX_SID_SIZE]; + DWORD systemSidSize = sizeof(systemSid); + if (!::CreateWellKnownSid(WinLocalSystemSid, nullptr, systemSid, + &systemSidSize)) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + + BYTE adminSid[SECURITY_MAX_SID_SIZE]; + DWORD adminSidSize = sizeof(adminSid); + if (!::CreateWellKnownSid(WinBuiltinAdministratorsSid, nullptr, adminSid, + &adminSidSize)) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + + const bool allowAppContainers = + aProcessCategory == ProcessCategory::GeckoBrowserParent && + IsWin8OrLater(); + + BYTE appContainersSid[SECURITY_MAX_SID_SIZE]; + DWORD appContainersSidSize = sizeof(appContainersSid); + if (allowAppContainers) { + if (!::CreateWellKnownSid(WinBuiltinAnyPackageSid, nullptr, + appContainersSid, &appContainersSidSize)) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + } + + // Grant access to SYSTEM, Administrators, the user, and when running as the + // browser process on Windows 8+, all app containers. + const size_t kMaxInlineEntries = 4; + mozilla::Vector<EXPLICIT_ACCESS_W, kMaxInlineEntries> entries; + + Unused << entries.append(EXPLICIT_ACCESS_W{ + COM_RIGHTS_EXECUTE, + GRANT_ACCESS, + NO_INHERITANCE, + {nullptr, NO_MULTIPLE_TRUSTEE, TRUSTEE_IS_SID, TRUSTEE_IS_USER, + reinterpret_cast<LPWSTR>(systemSid)}}); + + Unused << entries.append(EXPLICIT_ACCESS_W{ + COM_RIGHTS_EXECUTE, + GRANT_ACCESS, + NO_INHERITANCE, + {nullptr, NO_MULTIPLE_TRUSTEE, TRUSTEE_IS_SID, + TRUSTEE_IS_WELL_KNOWN_GROUP, reinterpret_cast<LPWSTR>(adminSid)}}); + + Unused << entries.append(EXPLICIT_ACCESS_W{ + COM_RIGHTS_EXECUTE, + GRANT_ACCESS, + NO_INHERITANCE, + {nullptr, NO_MULTIPLE_TRUSTEE, TRUSTEE_IS_SID, TRUSTEE_IS_USER, + reinterpret_cast<LPWSTR>(tokenUser.User.Sid)}}); + + if (allowAppContainers) { + Unused << entries.append( + EXPLICIT_ACCESS_W{COM_RIGHTS_EXECUTE, + GRANT_ACCESS, + NO_INHERITANCE, + {nullptr, NO_MULTIPLE_TRUSTEE, TRUSTEE_IS_SID, + TRUSTEE_IS_WELL_KNOWN_GROUP, + reinterpret_cast<LPWSTR>(appContainersSid)}}); + } + + PACL rawDacl = nullptr; + win32Error = + ::SetEntriesInAclW(entries.length(), entries.begin(), nullptr, &rawDacl); + if (win32Error != ERROR_SUCCESS) { + return HRESULT_FROM_WIN32(win32Error); + } + + UniquePtr<ACL, LocalFreeDeleter> dacl(rawDacl); + + if (!::SetSecurityDescriptorDacl(&sd, TRUE, dacl.get(), FALSE)) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + + if (!::SetSecurityDescriptorOwner(&sd, tokenUser.User.Sid, FALSE)) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + + if (!::SetSecurityDescriptorGroup(&sd, tokenPrimaryGroup.PrimaryGroup, + FALSE)) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + + return wrapped::CoInitializeSecurity( + &sd, -1, nullptr, nullptr, RPC_C_AUTHN_LEVEL_DEFAULT, + RPC_C_IMP_LEVEL_IDENTIFY, nullptr, EOAC_NONE, nullptr); +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/ProcessRuntime.h b/ipc/mscom/ProcessRuntime.h new file mode 100644 index 0000000000..7b52a56b8a --- /dev/null +++ b/ipc/mscom/ProcessRuntime.h @@ -0,0 +1,96 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_ProcessRuntime_h +#define mozilla_mscom_ProcessRuntime_h + +#include "mozilla/Attributes.h" +#if defined(ACCESSIBILITY) +# include "mozilla/mscom/ActivationContext.h" +# include "mozilla/Maybe.h" +#endif // defined(ACCESSIBILITY) +#include "mozilla/mscom/ApartmentRegion.h" +#include "nsWindowsHelpers.h" +#if defined(MOZILLA_INTERNAL_API) +# include "nsXULAppAPI.h" +#endif // defined(MOZILLA_INTERNAL_API) + +namespace mozilla { +namespace mscom { + +class MOZ_NON_TEMPORARY_CLASS ProcessRuntime final { +#if !defined(MOZILLA_INTERNAL_API) + public: +#endif // defined(MOZILLA_INTERNAL_API) + enum class ProcessCategory { + GeckoBrowserParent, + // We give Launcher its own process category, but internally to this class + // it should be treated identically to GeckoBrowserParent. + Launcher = GeckoBrowserParent, + GeckoChild, + Service, + }; + + // This constructor is only public when compiled outside of XUL + explicit ProcessRuntime(const ProcessCategory aProcessCategory); + + public: +#if defined(MOZILLA_INTERNAL_API) + ProcessRuntime(); + ~ProcessRuntime(); +#else + ~ProcessRuntime() = default; +#endif // defined(MOZILLA_INTERNAL_API) + + explicit operator bool() const { return SUCCEEDED(mInitResult); } + HRESULT GetHResult() const { return mInitResult; } + + ProcessRuntime(const ProcessRuntime&) = delete; + ProcessRuntime(ProcessRuntime&&) = delete; + ProcessRuntime& operator=(const ProcessRuntime&) = delete; + ProcessRuntime& operator=(ProcessRuntime&&) = delete; + + /** + * @return 0 if call is in-process or resolving the calling thread failed, + * otherwise contains the thread id of the calling thread. + */ + static DWORD GetClientThreadId(); + + private: +#if defined(MOZILLA_INTERNAL_API) + explicit ProcessRuntime(const GeckoProcessType aProcessType); +# if defined(MOZ_SANDBOX) + void InitUsingPersistentMTAThread(const nsAutoHandle& aCurThreadToken); +# endif // defined(MOZ_SANDBOX) +#endif // defined(MOZILLA_INTERNAL_API) + void InitInsideApartment(); + +#if defined(MOZILLA_INTERNAL_API) + static void PostInit(); +#endif // defined(MOZILLA_INTERNAL_API) + static HRESULT InitializeSecurity(const ProcessCategory aProcessCategory); + static COINIT GetDesiredApartmentType(const ProcessCategory aProcessCategory); + + private: + HRESULT mInitResult; + const ProcessCategory mProcessCategory; +#if defined(ACCESSIBILITY) && \ + (defined(MOZILLA_INTERNAL_API) || defined(MOZ_HAS_MOZGLUE)) + Maybe<ActivationContextRegion> mActCtxRgn; +#endif // defined(ACCESSIBILITY) && (defined(MOZILLA_INTERNAL_API) || + // defined(MOZ_HAS_MOZGLUE)) + ApartmentRegion mAptRegion; + + private: +#if defined(MOZILLA_INTERNAL_API) + static ProcessRuntime* sInstance; +#endif // defined(MOZILLA_INTERNAL_API) +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_ProcessRuntime_h diff --git a/ipc/mscom/ProfilerMarkers.cpp b/ipc/mscom/ProfilerMarkers.cpp new file mode 100644 index 0000000000..29d033723f --- /dev/null +++ b/ipc/mscom/ProfilerMarkers.cpp @@ -0,0 +1,236 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ProfilerMarkers.h" + +#include "MainThreadUtils.h" +#include "mozilla/Assertions.h" +#include "mozilla/Atomics.h" +#include "mozilla/DebugOnly.h" +#include "mozilla/mscom/Utils.h" +#include "mozilla/ProfilerMarkers.h" +#include "mozilla/Services.h" +#include "nsCOMPtr.h" +#include "nsIObserver.h" +#include "nsIObserverService.h" +#include "nsISupportsImpl.h" +#include "nsString.h" +#include "nsXULAppAPI.h" + +#include <objbase.h> +#include <objidlbase.h> + +// {9DBE6B28-E5E7-4FDE-AF00-9404604E74DC} +static const GUID GUID_MozProfilerMarkerExtension = { + 0x9dbe6b28, 0xe5e7, 0x4fde, {0xaf, 0x0, 0x94, 0x4, 0x60, 0x4e, 0x74, 0xdc}}; + +namespace { + +class ProfilerMarkerChannelHook final : public IChannelHook { + ~ProfilerMarkerChannelHook() = default; + + public: + ProfilerMarkerChannelHook() : mRefCnt(0) {} + + // IUnknown + STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) override; + STDMETHODIMP_(ULONG) AddRef() override; + STDMETHODIMP_(ULONG) Release() override; + + /** + * IChannelHook exposes six methods: The Client* methods are called when + * a client is sending an IPC request, whereas the Server* methods are called + * when a server is receiving an IPC request. + * + * For our purposes, we only care about the client-side methods. The COM + * runtime invokes the methods in the following order: + * 1. ClientGetSize, where the hook specifies the size of its payload; + * 2. ClientFillBuffer, where the hook fills the channel's buffer with its + * payload information. NOTE: This method is only called when ClientGetSize + * specifies a non-zero payload size. For our purposes, since we are not + * sending a payload, this method will never be called! + * 3. ClientNotify, when the response has been received from the server. + * + * Since we want to use these hooks to record the beginning and end of a COM + * IPC call, we use ClientGetSize for logging the start, and ClientNotify for + * logging the end. + * + * Finally, our implementation responds to any request matching our extension + * ID, however we only care about main thread COM calls. + */ + + // IChannelHook + STDMETHODIMP_(void) + ClientGetSize(REFGUID aExtensionId, REFIID aIid, + ULONG* aOutDataSize) override; + + // No-op (see the large comment above) + STDMETHODIMP_(void) + ClientFillBuffer(REFGUID aExtensionId, REFIID aIid, ULONG* aDataSize, + void* aDataBuf) override {} + + STDMETHODIMP_(void) + ClientNotify(REFGUID aExtensionId, REFIID aIid, ULONG aDataSize, + void* aDataBuffer, DWORD aDataRep, HRESULT aFault) override; + + // We don't care about the server-side notifications, so leave as no-ops. + STDMETHODIMP_(void) + ServerNotify(REFGUID aExtensionId, REFIID aIid, ULONG aDataSize, + void* aDataBuf, DWORD aDataRep) override {} + + STDMETHODIMP_(void) + ServerGetSize(REFGUID aExtensionId, REFIID aIid, HRESULT aFault, + ULONG* aOutDataSize) override {} + + STDMETHODIMP_(void) + ServerFillBuffer(REFGUID aExtensionId, REFIID aIid, ULONG* aDataSize, + void* aDataBuf, HRESULT aFault) override {} + + private: + void BuildMarkerName(REFIID aIid, nsACString& aOutMarkerName); + + private: + mozilla::Atomic<ULONG> mRefCnt; +}; + +HRESULT ProfilerMarkerChannelHook::QueryInterface(REFIID aIid, + void** aOutInterface) { + if (aIid == IID_IChannelHook || aIid == IID_IUnknown) { + RefPtr<IChannelHook> ptr(this); + ptr.forget(aOutInterface); + return S_OK; + } + + return E_NOINTERFACE; +} + +ULONG ProfilerMarkerChannelHook::AddRef() { return ++mRefCnt; } + +ULONG ProfilerMarkerChannelHook::Release() { + ULONG result = --mRefCnt; + if (!result) { + delete this; + } + + return result; +} + +void ProfilerMarkerChannelHook::BuildMarkerName(REFIID aIid, + nsACString& aOutMarkerName) { + aOutMarkerName.AssignLiteral("ORPC Call for "); + + nsAutoCString iidStr; + mozilla::mscom::DiagnosticNameForIID(aIid, iidStr); + aOutMarkerName.Append(iidStr); +} + +void ProfilerMarkerChannelHook::ClientGetSize(REFGUID aExtensionId, REFIID aIid, + ULONG* aOutDataSize) { + if (aExtensionId == GUID_MozProfilerMarkerExtension) { + if (NS_IsMainThread()) { + nsAutoCString markerName; + BuildMarkerName(aIid, markerName); + PROFILER_MARKER(markerName, IPC, mozilla::MarkerTiming::IntervalStart(), + Tracing, "MSCOM"); + } + + if (aOutDataSize) { + // We don't add any payload data to the channel + *aOutDataSize = 0UL; + } + } +} + +void ProfilerMarkerChannelHook::ClientNotify(REFGUID aExtensionId, REFIID aIid, + ULONG aDataSize, void* aDataBuffer, + DWORD aDataRep, HRESULT aFault) { + if (NS_IsMainThread() && aExtensionId == GUID_MozProfilerMarkerExtension) { + nsAutoCString markerName; + BuildMarkerName(aIid, markerName); + PROFILER_MARKER(markerName, IPC, mozilla::MarkerTiming::IntervalEnd(), + Tracing, "MSCOM"); + } +} + +} // anonymous namespace + +static void RegisterChannelHook() { + RefPtr<ProfilerMarkerChannelHook> hook(new ProfilerMarkerChannelHook()); + mozilla::DebugOnly<HRESULT> hr = + ::CoRegisterChannelHook(GUID_MozProfilerMarkerExtension, hook); + MOZ_ASSERT(SUCCEEDED(hr)); +} + +namespace { + +class ProfilerStartupObserver final : public nsIObserver { + ~ProfilerStartupObserver() = default; + + public: + NS_DECL_ISUPPORTS + NS_DECL_NSIOBSERVER +}; + +NS_IMPL_ISUPPORTS(ProfilerStartupObserver, nsIObserver) + +NS_IMETHODIMP ProfilerStartupObserver::Observe(nsISupports* aSubject, + const char* aTopic, + const char16_t* aData) { + if (strcmp(aTopic, "profiler-started")) { + return NS_OK; + } + + RegisterChannelHook(); + + // Once we've set the channel hook, we don't care about this notification + // anymore; our channel hook will remain set for the lifetime of the process. + nsCOMPtr<nsIObserverService> obsServ(mozilla::services::GetObserverService()); + MOZ_ASSERT(!!obsServ); + if (!obsServ) { + return NS_OK; + } + + obsServ->RemoveObserver(this, "profiler-started"); + return NS_OK; +} + +} // anonymous namespace + +namespace mozilla { +namespace mscom { + +void InitProfilerMarkers() { + if (!XRE_IsParentProcess()) { + return; + } + + MOZ_ASSERT(NS_IsMainThread()); + if (!NS_IsMainThread()) { + return; + } + + if (profiler_is_active()) { + // If the profiler is already running, we'll immediately register our + // channel hook. + RegisterChannelHook(); + return; + } + + // The profiler is not running yet. To avoid unnecessary invocations of the + // channel hook, we won't bother with installing it until the profiler starts. + // Set up an observer to watch for this. + nsCOMPtr<nsIObserverService> obsServ(mozilla::services::GetObserverService()); + MOZ_ASSERT(!!obsServ); + if (!obsServ) { + return; + } + + nsCOMPtr<nsIObserver> obs(new ProfilerStartupObserver()); + obsServ->AddObserver(obs, "profiler-started", false); +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/ProfilerMarkers.h b/ipc/mscom/ProfilerMarkers.h new file mode 100644 index 0000000000..a21d5b375a --- /dev/null +++ b/ipc/mscom/ProfilerMarkers.h @@ -0,0 +1,18 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_ProfilerMarkers_h +#define mozilla_mscom_ProfilerMarkers_h + +namespace mozilla { +namespace mscom { + +void InitProfilerMarkers(); + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_ProfilerMarkers_h diff --git a/ipc/mscom/ProxyStream.cpp b/ipc/mscom/ProxyStream.cpp new file mode 100644 index 0000000000..5a3457e147 --- /dev/null +++ b/ipc/mscom/ProxyStream.cpp @@ -0,0 +1,411 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <utility> +#if defined(ACCESSIBILITY) +# include "HandlerData.h" +# include "mozilla/a11y/Platform.h" +# include "mozilla/mscom/ActivationContext.h" +#endif // defined(ACCESSIBILITY) +#include "mozilla/mscom/EnsureMTA.h" +#include "mozilla/mscom/ProxyStream.h" +#include "mozilla/mscom/Utils.h" +#include "mozilla/ScopeExit.h" + +#include "mozilla/mscom/Objref.h" +#include "nsExceptionHandler.h" +#include "nsPrintfCString.h" +#include "RegistrationAnnotator.h" + +#include <windows.h> +#include <objbase.h> +#include <shlwapi.h> + +namespace mozilla { +namespace mscom { + +ProxyStream::ProxyStream() + : mGlobalLockedBuf(nullptr), + mHGlobal(nullptr), + mBufSize(0), + mPreserveStream(false) {} + +// GetBuffer() fails with this variant, but that's okay because we're just +// reconstructing the stream from a buffer anyway. +ProxyStream::ProxyStream(REFIID aIID, const BYTE* aInitBuf, + const int aInitBufSize, Environment* aEnv) + : mGlobalLockedBuf(nullptr), + mHGlobal(nullptr), + mBufSize(aInitBufSize), + mPreserveStream(false) { + CrashReporter::Annotation kCrashReportKey = + CrashReporter::Annotation::ProxyStreamUnmarshalStatus; + + if (!aInitBufSize) { + CrashReporter::AnnotateCrashReport(kCrashReportKey, "!aInitBufSize"_ns); + // We marshaled a nullptr. Nothing else to do here. + return; + } + + HRESULT createStreamResult = + CreateStream(aInitBuf, aInitBufSize, getter_AddRefs(mStream)); + if (FAILED(createStreamResult)) { + nsPrintfCString hrAsStr("0x%08lX", createStreamResult); + CrashReporter::AnnotateCrashReport(kCrashReportKey, hrAsStr); + return; + } + + // NB: We can't check for a null mStream until after we have checked for + // the zero aInitBufSize above. This is because InitStream will also fail + // in that case, even though marshaling a nullptr is allowable. + MOZ_ASSERT(mStream); + if (!mStream) { + CrashReporter::AnnotateCrashReport(kCrashReportKey, "!mStream"_ns); + return; + } + +#if defined(ACCESSIBILITY) + const uint32_t expectedStreamLen = GetOBJREFSize(WrapNotNull(mStream)); + nsAutoCString strActCtx; + nsAutoString manifestPath; +#endif // defined(ACCESSIBILITY) + + HRESULT unmarshalResult = S_OK; + + // We need to convert to an interface here otherwise we mess up const + // correctness with IPDL. We'll request an IUnknown and then QI the + // actual interface later. + +#if defined(ACCESSIBILITY) + auto marshalFn = [this, &strActCtx, &manifestPath, &unmarshalResult, &aIID, + aEnv]() -> void +#else + auto marshalFn = [this, &unmarshalResult, &aIID, aEnv]() -> void +#endif // defined(ACCESSIBILITY) + { + if (aEnv) { + bool pushOk = aEnv->Push(); + MOZ_DIAGNOSTIC_ASSERT(pushOk); + if (!pushOk) { + return; + } + } + + auto popEnv = MakeScopeExit([aEnv]() -> void { + if (!aEnv) { + return; + } + +#ifdef MOZ_DIAGNOSTIC_ASSERT_ENABLED + bool popOk = +#endif + aEnv->Pop(); + MOZ_DIAGNOSTIC_ASSERT(popOk); + }); + +#if defined(ACCESSIBILITY) + auto curActCtx = ActivationContext::GetCurrent(); + if (curActCtx.isOk()) { + strActCtx.AppendPrintf("0x%" PRIxPTR, curActCtx.unwrap()); + } else { + strActCtx.AppendPrintf("HRESULT 0x%08lX", curActCtx.unwrapErr()); + } + + ActivationContext::GetCurrentManifestPath(manifestPath); +#endif // defined(ACCESSIBILITY) + + unmarshalResult = ::CoUnmarshalInterface(mStream, aIID, + getter_AddRefs(mUnmarshaledProxy)); + MOZ_ASSERT(SUCCEEDED(unmarshalResult)); + }; + + if (XRE_IsParentProcess()) { + // We'll marshal this stuff directly using the current thread, therefore its + // proxy will reside in the same apartment as the current thread. + marshalFn(); + } else { + // When marshaling in child processes, we want to force the MTA. + EnsureMTA mta(marshalFn); + } + + mStream = nullptr; + + if (FAILED(unmarshalResult) || !mUnmarshaledProxy) { + nsPrintfCString hrAsStr("0x%08lX", unmarshalResult); + CrashReporter::AnnotateCrashReport( + CrashReporter::Annotation::CoUnmarshalInterfaceResult, hrAsStr); + AnnotateInterfaceRegistration(aIID); + if (!mUnmarshaledProxy) { + CrashReporter::AnnotateCrashReport(kCrashReportKey, + "!mUnmarshaledProxy"_ns); + } + +#if defined(ACCESSIBILITY) + AnnotateClassRegistration(CLSID_AccessibleHandler); + CrashReporter::AnnotateCrashReport( + CrashReporter::Annotation::UnmarshalActCtx, strActCtx); + CrashReporter::AnnotateCrashReport( + CrashReporter::Annotation::UnmarshalActCtxManifestPath, + NS_ConvertUTF16toUTF8(manifestPath)); + CrashReporter::AnnotateCrashReport( + CrashReporter::Annotation::A11yHandlerRegistered, + a11y::IsHandlerRegistered() ? "true"_ns : "false"_ns); + + CrashReporter::AnnotateCrashReport( + CrashReporter::Annotation::ExpectedStreamLen, expectedStreamLen); + + CrashReporter::AnnotateCrashReport( + CrashReporter::Annotation::ActualStreamLen, aInitBufSize); +#endif // defined(ACCESSIBILITY) + } +} + +ProxyStream::ProxyStream(ProxyStream&& aOther) + : mGlobalLockedBuf(nullptr), + mHGlobal(nullptr), + mBufSize(0), + mPreserveStream(false) { + *this = std::move(aOther); +} + +ProxyStream& ProxyStream::operator=(ProxyStream&& aOther) { + if (mHGlobal && mGlobalLockedBuf) { + DebugOnly<BOOL> result = ::GlobalUnlock(mHGlobal); + MOZ_ASSERT(!result && ::GetLastError() == NO_ERROR); + } + + mStream = std::move(aOther.mStream); + + mGlobalLockedBuf = aOther.mGlobalLockedBuf; + aOther.mGlobalLockedBuf = nullptr; + + // ::GlobalFree() was called implicitly when mStream was replaced. + mHGlobal = aOther.mHGlobal; + aOther.mHGlobal = nullptr; + + mBufSize = aOther.mBufSize; + aOther.mBufSize = 0; + + mUnmarshaledProxy = std::move(aOther.mUnmarshaledProxy); + + mPreserveStream = aOther.mPreserveStream; + return *this; +} + +ProxyStream::~ProxyStream() { + if (mHGlobal && mGlobalLockedBuf) { + DebugOnly<BOOL> result = ::GlobalUnlock(mHGlobal); + MOZ_ASSERT(!result && ::GetLastError() == NO_ERROR); + // ::GlobalFree() is called implicitly when mStream is released + } + + // If this assert triggers then we will be leaking a marshaled proxy! + // Call GetPreservedStream to obtain a preservable stream and then save it + // until the proxy is no longer needed. + MOZ_ASSERT(!mPreserveStream); +} + +const BYTE* ProxyStream::GetBuffer(int& aReturnedBufSize) const { + aReturnedBufSize = 0; + if (!mStream) { + return nullptr; + } + if (!mGlobalLockedBuf) { + return nullptr; + } + aReturnedBufSize = mBufSize; + return mGlobalLockedBuf; +} + +PreservedStreamPtr ProxyStream::GetPreservedStream() { + MOZ_ASSERT(mStream); + MOZ_ASSERT(mHGlobal); + + if (!mStream || !mPreserveStream) { + return nullptr; + } + + // Clone the stream so that the result has a distinct seek pointer. + RefPtr<IStream> cloned; + HRESULT hr = mStream->Clone(getter_AddRefs(cloned)); + if (FAILED(hr)) { + return nullptr; + } + + // Ensure the stream is rewound. We do this because CoReleaseMarshalData needs + // the stream to be pointing to the beginning of the marshal data. + LARGE_INTEGER pos; + pos.QuadPart = 0LL; + hr = cloned->Seek(pos, STREAM_SEEK_SET, nullptr); + if (FAILED(hr)) { + return nullptr; + } + + mPreserveStream = false; + return ToPreservedStreamPtr(std::move(cloned)); +} + +bool ProxyStream::GetInterface(void** aOutInterface) { + // We should not have a locked buffer on this side + MOZ_ASSERT(!mGlobalLockedBuf); + MOZ_ASSERT(aOutInterface); + + if (!aOutInterface) { + return false; + } + + *aOutInterface = mUnmarshaledProxy.release(); + return true; +} + +ProxyStream::ProxyStream(REFIID aIID, IUnknown* aObject, Environment* aEnv, + ProxyStreamFlags aFlags) + : mGlobalLockedBuf(nullptr), + mHGlobal(nullptr), + mBufSize(0), + mPreserveStream(aFlags & ProxyStreamFlags::ePreservable) { + if (!aObject) { + return; + } + + RefPtr<IStream> stream; + HGLOBAL hglobal = NULL; + int streamSize = 0; + DWORD mshlFlags = mPreserveStream ? MSHLFLAGS_TABLESTRONG : MSHLFLAGS_NORMAL; + + HRESULT createStreamResult = S_OK; + HRESULT marshalResult = S_OK; + HRESULT statResult = S_OK; + HRESULT getHGlobalResult = S_OK; + +#if defined(ACCESSIBILITY) + nsAutoString manifestPath; + auto marshalFn = [&aIID, aObject, mshlFlags, &stream, &streamSize, &hglobal, + &createStreamResult, &marshalResult, &statResult, + &getHGlobalResult, aEnv, &manifestPath]() -> void { +#else + auto marshalFn = [&aIID, aObject, mshlFlags, &stream, &streamSize, &hglobal, + &createStreamResult, &marshalResult, &statResult, + &getHGlobalResult, aEnv]() -> void { +#endif // defined(ACCESSIBILITY) + if (aEnv) { + bool pushOk = aEnv->Push(); + MOZ_DIAGNOSTIC_ASSERT(pushOk); + if (!pushOk) { + return; + } + } + + auto popEnv = MakeScopeExit([aEnv]() -> void { + if (!aEnv) { + return; + } + +#ifdef MOZ_DIAGNOSTIC_ASSERT_ENABLED + bool popOk = +#endif + aEnv->Pop(); + MOZ_DIAGNOSTIC_ASSERT(popOk); + }); + + createStreamResult = + ::CreateStreamOnHGlobal(nullptr, TRUE, getter_AddRefs(stream)); + if (FAILED(createStreamResult)) { + return; + } + +#if defined(ACCESSIBILITY) + ActivationContext::GetCurrentManifestPath(manifestPath); +#endif // defined(ACCESSIBILITY) + + marshalResult = ::CoMarshalInterface(stream, aIID, aObject, MSHCTX_LOCAL, + nullptr, mshlFlags); + MOZ_ASSERT(marshalResult != E_INVALIDARG); + if (FAILED(marshalResult)) { + return; + } + + STATSTG statstg; + statResult = stream->Stat(&statstg, STATFLAG_NONAME); + if (SUCCEEDED(statResult)) { + streamSize = static_cast<int>(statstg.cbSize.LowPart); + } else { + return; + } + + getHGlobalResult = ::GetHGlobalFromStream(stream, &hglobal); + MOZ_ASSERT(SUCCEEDED(getHGlobalResult)); + }; + + if (XRE_IsParentProcess()) { + // We'll marshal this stuff directly using the current thread, therefore its + // stub will reside in the same apartment as the current thread. + marshalFn(); + } else { + // When marshaling in child processes, we want to force the MTA. + EnsureMTA mta(marshalFn); + } + + if (FAILED(createStreamResult)) { + nsPrintfCString hrAsStr("0x%08lX", createStreamResult); + CrashReporter::AnnotateCrashReport( + CrashReporter::Annotation::CreateStreamOnHGlobalFailure, hrAsStr); + } + + if (FAILED(marshalResult)) { + AnnotateInterfaceRegistration(aIID); + nsPrintfCString hrAsStr("0x%08lX", marshalResult); + CrashReporter::AnnotateCrashReport( + CrashReporter::Annotation::CoMarshalInterfaceFailure, hrAsStr); +#if defined(ACCESSIBILITY) + CrashReporter::AnnotateCrashReport( + CrashReporter::Annotation::MarshalActCtxManifestPath, + NS_ConvertUTF16toUTF8(manifestPath)); +#endif // defined(ACCESSIBILITY) + } + + if (FAILED(statResult)) { + nsPrintfCString hrAsStr("0x%08lX", statResult); + CrashReporter::AnnotateCrashReport(CrashReporter::Annotation::StatFailure, + hrAsStr); + } + + if (FAILED(getHGlobalResult)) { + nsPrintfCString hrAsStr("0x%08lX", getHGlobalResult); + CrashReporter::AnnotateCrashReport( + CrashReporter::Annotation::GetHGlobalFromStreamFailure, hrAsStr); + } + + mStream = std::move(stream); + + if (streamSize) { + CrashReporter::AnnotateCrashReport( + CrashReporter::Annotation::ProxyStreamSizeFrom, "IStream::Stat"_ns); + mBufSize = streamSize; + } + + if (!hglobal) { + return; + } + + mGlobalLockedBuf = reinterpret_cast<BYTE*>(::GlobalLock(hglobal)); + mHGlobal = hglobal; + + // If we couldn't get the stream size directly from mStream, we may use + // the size of the memory block allocated by the HGLOBAL, though it might + // be larger than the actual stream size. + if (!streamSize) { + CrashReporter::AnnotateCrashReport( + CrashReporter::Annotation::ProxyStreamSizeFrom, "GlobalSize"_ns); + mBufSize = static_cast<int>(::GlobalSize(hglobal)); + } + + CrashReporter::AnnotateCrashReport(CrashReporter::Annotation::ProxyStreamSize, + mBufSize); +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/ProxyStream.h b/ipc/mscom/ProxyStream.h new file mode 100644 index 0000000000..92b55745ca --- /dev/null +++ b/ipc/mscom/ProxyStream.h @@ -0,0 +1,88 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_ProxyStream_h +#define mozilla_mscom_ProxyStream_h + +#include "ipc/IPCMessageUtils.h" + +#include "mozilla/mscom/Ptr.h" +#include "mozilla/RefPtr.h" +#include "mozilla/TypedEnumBits.h" +#include "mozilla/UniquePtr.h" + +namespace mozilla { +namespace mscom { + +enum class ProxyStreamFlags : uint32_t { + eDefault = 0, + // When ePreservable is set on a ProxyStream, its caller *must* call + // GetPreservableStream() before the ProxyStream is destroyed. + ePreservable = 1 +}; + +MOZ_MAKE_ENUM_CLASS_BITWISE_OPERATORS(ProxyStreamFlags); + +class ProxyStream final { + public: + class MOZ_RAII Environment { + public: + virtual ~Environment() = default; + virtual bool Push() = 0; + virtual bool Pop() = 0; + }; + + class MOZ_RAII DefaultEnvironment : public Environment { + public: + bool Push() override { return true; } + bool Pop() override { return true; } + }; + + ProxyStream(); + ProxyStream(REFIID aIID, IUnknown* aObject, Environment* aEnv, + ProxyStreamFlags aFlags = ProxyStreamFlags::eDefault); + ProxyStream(REFIID aIID, const BYTE* aInitBuf, const int aInitBufSize, + Environment* aEnv); + + ~ProxyStream(); + + // Not copyable because this would mess up the COM marshaling. + ProxyStream(const ProxyStream& aOther) = delete; + ProxyStream& operator=(const ProxyStream& aOther) = delete; + + ProxyStream(ProxyStream&& aOther); + ProxyStream& operator=(ProxyStream&& aOther); + + inline bool IsValid() const { return !(mUnmarshaledProxy && mStream); } + + bool GetInterface(void** aOutInterface); + const BYTE* GetBuffer(int& aReturnedBufSize) const; + + PreservedStreamPtr GetPreservedStream(); + + bool operator==(const ProxyStream& aOther) const { return this == &aOther; } + + private: + RefPtr<IStream> mStream; + BYTE* mGlobalLockedBuf; + HGLOBAL mHGlobal; + int mBufSize; + ProxyUniquePtr<IUnknown> mUnmarshaledProxy; + bool mPreserveStream; +}; + +namespace detail { + +template <typename Interface> +struct EnvironmentSelector { + typedef ProxyStream::DefaultEnvironment Type; +}; + +} // namespace detail +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_ProxyStream_h diff --git a/ipc/mscom/Ptr.h b/ipc/mscom/Ptr.h new file mode 100644 index 0000000000..fd4bd9c91f --- /dev/null +++ b/ipc/mscom/Ptr.h @@ -0,0 +1,306 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_Ptr_h +#define mozilla_mscom_Ptr_h + +#include "mozilla/Assertions.h" +#include "mozilla/DebugOnly.h" +#include "mozilla/mscom/EnsureMTA.h" +#include "mozilla/SchedulerGroup.h" +#include "mozilla/UniquePtr.h" +#include "nsError.h" +#include "nsThreadUtils.h" +#include "nsXULAppAPI.h" + +#include <objidl.h> + +/** + * The glue code in mozilla::mscom often needs to pass around interface pointers + * belonging to a different apartment from the current one. We must not touch + * the reference counts of those objects on the wrong apartment. By using these + * UniquePtr specializations, we may ensure that the reference counts are always + * handled correctly. + */ + +namespace mozilla { +namespace mscom { + +namespace detail { + +template <typename T> +struct MainThreadRelease { + void operator()(T* aPtr) { + if (!aPtr) { + return; + } + if (NS_IsMainThread()) { + aPtr->Release(); + return; + } + DebugOnly<nsresult> rv = SchedulerGroup::Dispatch( + TaskCategory::Other, + NewNonOwningRunnableMethod("mscom::MainThreadRelease", aPtr, + &T::Release)); + MOZ_ASSERT(NS_SUCCEEDED(rv)); + } +}; + +template <typename T> +struct MTADelete { + void operator()(T* aPtr) { + if (!aPtr) { + return; + } + + EnsureMTA::AsyncOperation([aPtr]() -> void { delete aPtr; }); + } +}; + +template <typename T> +struct MTARelease { + void operator()(T* aPtr) { + if (!aPtr) { + return; + } + + // Static analysis doesn't recognize that, even though aPtr escapes the + // current scope, we are in effect moving our strong ref into the lambda. + void* ptr = aPtr; + EnsureMTA::AsyncOperation( + [ptr]() -> void { reinterpret_cast<T*>(ptr)->Release(); }); + } +}; + +template <typename T> +struct MTAReleaseInChildProcess { + void operator()(T* aPtr) { + if (!aPtr) { + return; + } + + if (XRE_IsParentProcess()) { + MOZ_ASSERT(NS_IsMainThread()); + aPtr->Release(); + return; + } + + // Static analysis doesn't recognize that, even though aPtr escapes the + // current scope, we are in effect moving our strong ref into the lambda. + void* ptr = aPtr; + EnsureMTA::AsyncOperation( + [ptr]() -> void { reinterpret_cast<T*>(ptr)->Release(); }); + } +}; + +struct InterceptorTargetDeleter { + void operator()(IUnknown* aPtr) { + // We intentionally do not touch the refcounts of interceptor targets! + } +}; + +struct PreservedStreamDeleter { + void operator()(IStream* aPtr) { + if (!aPtr) { + return; + } + + // Static analysis doesn't recognize that, even though aPtr escapes the + // current scope, we are in effect moving our strong ref into the lambda. + void* ptr = aPtr; + auto cleanup = [ptr]() -> void { + DebugOnly<HRESULT> hr = + ::CoReleaseMarshalData(reinterpret_cast<LPSTREAM>(ptr)); + MOZ_ASSERT(SUCCEEDED(hr)); + reinterpret_cast<LPSTREAM>(ptr)->Release(); + }; + + if (XRE_IsParentProcess()) { + MOZ_ASSERT(NS_IsMainThread()); + cleanup(); + return; + } + + EnsureMTA::AsyncOperation(cleanup); + } +}; + +} // namespace detail + +template <typename T> +using STAUniquePtr = mozilla::UniquePtr<T, detail::MainThreadRelease<T>>; + +template <typename T> +using MTAUniquePtr = mozilla::UniquePtr<T, detail::MTARelease<T>>; + +template <typename T> +using MTADeletePtr = mozilla::UniquePtr<T, detail::MTADelete<T>>; + +template <typename T> +using ProxyUniquePtr = + mozilla::UniquePtr<T, detail::MTAReleaseInChildProcess<T>>; + +template <typename T> +using InterceptorTargetPtr = + mozilla::UniquePtr<T, detail::InterceptorTargetDeleter>; + +using PreservedStreamPtr = + mozilla::UniquePtr<IStream, detail::PreservedStreamDeleter>; + +namespace detail { + +// We don't have direct access to UniquePtr's storage, so we use mPtrStorage +// to receive the pointer and then set the target inside the destructor. +template <typename T, typename Deleter> +class UniquePtrGetterAddRefs { + public: + explicit UniquePtrGetterAddRefs(UniquePtr<T, Deleter>& aSmartPtr) + : mTargetSmartPtr(aSmartPtr), mPtrStorage(nullptr) {} + + ~UniquePtrGetterAddRefs() { mTargetSmartPtr.reset(mPtrStorage); } + + operator void**() { return reinterpret_cast<void**>(&mPtrStorage); } + + operator T**() { return &mPtrStorage; } + + T*& operator*() { return mPtrStorage; } + + private: + UniquePtr<T, Deleter>& mTargetSmartPtr; + T* mPtrStorage; +}; + +} // namespace detail + +template <typename T> +inline STAUniquePtr<T> ToSTAUniquePtr(RefPtr<T>&& aRefPtr) { + return STAUniquePtr<T>(aRefPtr.forget().take()); +} + +template <typename T> +inline STAUniquePtr<T> ToSTAUniquePtr(const RefPtr<T>& aRefPtr) { + MOZ_ASSERT(NS_IsMainThread()); + return STAUniquePtr<T>(do_AddRef(aRefPtr).take()); +} + +template <typename T> +inline STAUniquePtr<T> ToSTAUniquePtr(T* aRawPtr) { + MOZ_ASSERT(NS_IsMainThread()); + if (aRawPtr) { + aRawPtr->AddRef(); + } + return STAUniquePtr<T>(aRawPtr); +} + +template <typename T, typename U> +inline STAUniquePtr<T> ToSTAUniquePtr(const InterceptorTargetPtr<U>& aTarget) { + MOZ_ASSERT(NS_IsMainThread()); + RefPtr<T> newRef(static_cast<T*>(aTarget.get())); + return ToSTAUniquePtr(std::move(newRef)); +} + +template <typename T> +inline MTAUniquePtr<T> ToMTAUniquePtr(RefPtr<T>&& aRefPtr) { + return MTAUniquePtr<T>(aRefPtr.forget().take()); +} + +template <typename T> +inline MTAUniquePtr<T> ToMTAUniquePtr(const RefPtr<T>& aRefPtr) { + MOZ_ASSERT(IsCurrentThreadMTA()); + return MTAUniquePtr<T>(do_AddRef(aRefPtr).take()); +} + +template <typename T> +inline MTAUniquePtr<T> ToMTAUniquePtr(T* aRawPtr) { + MOZ_ASSERT(IsCurrentThreadMTA()); + if (aRawPtr) { + aRawPtr->AddRef(); + } + return MTAUniquePtr<T>(aRawPtr); +} + +template <typename T> +inline ProxyUniquePtr<T> ToProxyUniquePtr(RefPtr<T>&& aRefPtr) { + return ProxyUniquePtr<T>(aRefPtr.forget().take()); +} + +template <typename T> +inline ProxyUniquePtr<T> ToProxyUniquePtr(const RefPtr<T>& aRefPtr) { + MOZ_ASSERT(IsProxy(aRefPtr)); + MOZ_ASSERT((XRE_IsParentProcess() && NS_IsMainThread()) || + (XRE_IsContentProcess() && IsCurrentThreadMTA())); + + return ProxyUniquePtr<T>(do_AddRef(aRefPtr).take()); +} + +template <typename T> +inline ProxyUniquePtr<T> ToProxyUniquePtr(T* aRawPtr) { + MOZ_ASSERT(IsProxy(aRawPtr)); + MOZ_ASSERT((XRE_IsParentProcess() && NS_IsMainThread()) || + (XRE_IsContentProcess() && IsCurrentThreadMTA())); + + if (aRawPtr) { + aRawPtr->AddRef(); + } + return ProxyUniquePtr<T>(aRawPtr); +} + +template <typename T, typename Deleter> +inline InterceptorTargetPtr<T> ToInterceptorTargetPtr( + const UniquePtr<T, Deleter>& aTargetPtr) { + return InterceptorTargetPtr<T>(aTargetPtr.get()); +} + +inline PreservedStreamPtr ToPreservedStreamPtr(RefPtr<IStream>&& aStream) { + return PreservedStreamPtr(aStream.forget().take()); +} + +inline PreservedStreamPtr ToPreservedStreamPtr( + already_AddRefed<IStream>& aStream) { + return PreservedStreamPtr(aStream.take()); +} + +template <typename T, typename Deleter> +inline detail::UniquePtrGetterAddRefs<T, Deleter> getter_AddRefs( + UniquePtr<T, Deleter>& aSmartPtr) { + return detail::UniquePtrGetterAddRefs<T, Deleter>(aSmartPtr); +} + +} // namespace mscom +} // namespace mozilla + +// This block makes it possible for these smart pointers to be correctly +// applied in NewRunnableMethod and friends +namespace detail { + +template <typename T> +struct SmartPointerStorageClass<mozilla::mscom::STAUniquePtr<T>> { + typedef StoreCopyPassByRRef<mozilla::mscom::STAUniquePtr<T>> Type; +}; + +template <typename T> +struct SmartPointerStorageClass<mozilla::mscom::MTAUniquePtr<T>> { + typedef StoreCopyPassByRRef<mozilla::mscom::MTAUniquePtr<T>> Type; +}; + +template <typename T> +struct SmartPointerStorageClass<mozilla::mscom::ProxyUniquePtr<T>> { + typedef StoreCopyPassByRRef<mozilla::mscom::ProxyUniquePtr<T>> Type; +}; + +template <typename T> +struct SmartPointerStorageClass<mozilla::mscom::InterceptorTargetPtr<T>> { + typedef StoreCopyPassByRRef<mozilla::mscom::InterceptorTargetPtr<T>> Type; +}; + +template <> +struct SmartPointerStorageClass<mozilla::mscom::PreservedStreamPtr> { + typedef StoreCopyPassByRRef<mozilla::mscom::PreservedStreamPtr> Type; +}; + +} // namespace detail + +#endif // mozilla_mscom_Ptr_h diff --git a/ipc/mscom/Registration.cpp b/ipc/mscom/Registration.cpp new file mode 100644 index 0000000000..e984ad0de4 --- /dev/null +++ b/ipc/mscom/Registration.cpp @@ -0,0 +1,534 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// COM registration data structures are built with C code, so we need to +// simulate that in our C++ code by defining CINTERFACE before including +// anything else that could possibly pull in Windows header files. +#define CINTERFACE + +#include "mozilla/mscom/Registration.h" + +#include <utility> + +#include "mozilla/ArrayUtils.h" +#include "mozilla/Assertions.h" +#include "mozilla/RefPtr.h" +#include "mozilla/StaticPtr.h" +#include "mozilla/Vector.h" +#include "mozilla/mscom/ActivationContext.h" +#include "mozilla/mscom/Utils.h" +#include "nsWindowsHelpers.h" + +#if defined(MOZILLA_INTERNAL_API) +# include "mozilla/ClearOnShutdown.h" +# include "mozilla/mscom/EnsureMTA.h" +HRESULT RegisterPassthruProxy(); +#else +# include <stdlib.h> +#endif // defined(MOZILLA_INTERNAL_API) + +#include <oaidl.h> +#include <objidl.h> +#include <rpcproxy.h> +#include <shlwapi.h> + +#include <algorithm> + +/* This code MUST NOT use any non-inlined internal Mozilla APIs, as it will be + compiled into DLLs that COM may load into non-Mozilla processes! */ + +extern "C" { + +// This function is defined in generated code for proxy DLLs but is not declared +// in rpcproxy.h, so we need this declaration. +void RPC_ENTRY GetProxyDllInfo(const ProxyFileInfo*** aInfo, const CLSID** aId); +} + +namespace mozilla { +namespace mscom { + +static bool GetContainingLibPath(wchar_t* aBuffer, size_t aBufferLen) { + HMODULE thisModule = reinterpret_cast<HMODULE>(GetContainingModuleHandle()); + if (!thisModule) { + return false; + } + + DWORD fileNameResult = GetModuleFileName(thisModule, aBuffer, aBufferLen); + if (!fileNameResult || (fileNameResult == aBufferLen && + ::GetLastError() == ERROR_INSUFFICIENT_BUFFER)) { + return false; + } + + return true; +} + +static bool BuildLibPath(RegistrationFlags aFlags, wchar_t* aBuffer, + size_t aBufferLen, const wchar_t* aLeafName) { + if (aFlags == RegistrationFlags::eUseBinDirectory) { + if (!GetContainingLibPath(aBuffer, aBufferLen)) { + return false; + } + + if (!PathRemoveFileSpec(aBuffer)) { + return false; + } + } else if (aFlags == RegistrationFlags::eUseSystemDirectory) { + UINT result = GetSystemDirectoryW(aBuffer, static_cast<UINT>(aBufferLen)); + if (!result || result > aBufferLen) { + return false; + } + } else { + return false; + } + + if (!PathAppend(aBuffer, aLeafName)) { + return false; + } + return true; +} + +static bool RegisterPSClsids(const ProxyFileInfo** aProxyInfo, + const CLSID* aProxyClsid) { + while (*aProxyInfo) { + const ProxyFileInfo& curInfo = **aProxyInfo; + for (unsigned short idx = 0, size = curInfo.TableSize; idx < size; ++idx) { + HRESULT hr = CoRegisterPSClsid(*(curInfo.pStubVtblList[idx]->header.piid), + *aProxyClsid); + if (FAILED(hr)) { + return false; + } + } + ++aProxyInfo; + } + + return true; +} + +#if !defined(MOZILLA_INTERNAL_API) +using GetProxyDllInfoFnT = decltype(&GetProxyDllInfo); + +static GetProxyDllInfoFnT ResolveGetProxyDllInfo() { + HMODULE thisModule = reinterpret_cast<HMODULE>(GetContainingModuleHandle()); + if (!thisModule) { + return nullptr; + } + + return reinterpret_cast<GetProxyDllInfoFnT>( + GetProcAddress(thisModule, "GetProxyDllInfo")); +} +#endif // !defined(MOZILLA_INTERNAL_API) + +UniquePtr<RegisteredProxy> RegisterProxy() { +#if !defined(MOZILLA_INTERNAL_API) + GetProxyDllInfoFnT GetProxyDllInfoFn = ResolveGetProxyDllInfo(); + MOZ_ASSERT(!!GetProxyDllInfoFn); + if (!GetProxyDllInfoFn) { + return nullptr; + } +#endif // !defined(MOZILLA_INTERNAL_API) + + const ProxyFileInfo** proxyInfo = nullptr; + const CLSID* proxyClsid = nullptr; +#if defined(MOZILLA_INTERNAL_API) + GetProxyDllInfo(&proxyInfo, &proxyClsid); +#else + GetProxyDllInfoFn(&proxyInfo, &proxyClsid); +#endif // defined(MOZILLA_INTERNAL_API) + if (!proxyInfo || !proxyClsid) { + return nullptr; + } + + IUnknown* classObject = nullptr; + HRESULT hr = + DllGetClassObject(*proxyClsid, IID_IUnknown, (void**)&classObject); + if (FAILED(hr)) { + return nullptr; + } + + DWORD regCookie; + hr = CoRegisterClassObject(*proxyClsid, classObject, CLSCTX_INPROC_SERVER, + REGCLS_MULTIPLEUSE, ®Cookie); + if (FAILED(hr)) { + classObject->lpVtbl->Release(classObject); + return nullptr; + } + + wchar_t modulePathBuf[MAX_PATH + 1] = {0}; + if (!GetContainingLibPath(modulePathBuf, ArrayLength(modulePathBuf))) { + CoRevokeClassObject(regCookie); + classObject->lpVtbl->Release(classObject); + return nullptr; + } + + ITypeLib* typeLib = nullptr; + hr = LoadTypeLibEx(modulePathBuf, REGKIND_NONE, &typeLib); + MOZ_ASSERT(SUCCEEDED(hr)); + if (FAILED(hr)) { + CoRevokeClassObject(regCookie); + classObject->lpVtbl->Release(classObject); + return nullptr; + } + +#if defined(MOZILLA_INTERNAL_API) + hr = RegisterPassthruProxy(); + MOZ_ASSERT(SUCCEEDED(hr)); + if (FAILED(hr)) { + CoRevokeClassObject(regCookie); + classObject->lpVtbl->Release(classObject); + return nullptr; + } +#endif // defined(MOZILLA_INTERNAL_API) + + // RegisteredProxy takes ownership of classObject and typeLib references + auto result(MakeUnique<RegisteredProxy>(classObject, regCookie, typeLib)); + + if (!RegisterPSClsids(proxyInfo, proxyClsid)) { + return nullptr; + } + + return result; +} + +UniquePtr<RegisteredProxy> RegisterProxy(const wchar_t* aLeafName, + RegistrationFlags aFlags) { + wchar_t modulePathBuf[MAX_PATH + 1] = {0}; + if (!BuildLibPath(aFlags, modulePathBuf, ArrayLength(modulePathBuf), + aLeafName)) { + return nullptr; + } + + nsModuleHandle proxyDll(LoadLibrary(modulePathBuf)); + if (!proxyDll.get()) { + return nullptr; + } + + // Instantiate an activation context so that CoGetClassObject will use any + // COM metadata embedded in proxyDll's manifest to resolve CLSIDs. + ActivationContextRegion actCtxRgn(proxyDll.get()); + + auto GetProxyDllInfoFn = reinterpret_cast<decltype(&GetProxyDllInfo)>( + GetProcAddress(proxyDll, "GetProxyDllInfo")); + if (!GetProxyDllInfoFn) { + return nullptr; + } + + const ProxyFileInfo** proxyInfo = nullptr; + const CLSID* proxyClsid = nullptr; + GetProxyDllInfoFn(&proxyInfo, &proxyClsid); + if (!proxyInfo || !proxyClsid) { + return nullptr; + } + + // We call CoGetClassObject instead of DllGetClassObject because it forces + // the COM runtime to manage the lifetime of the DLL. + IUnknown* classObject = nullptr; + HRESULT hr = CoGetClassObject(*proxyClsid, CLSCTX_INPROC_SERVER, nullptr, + IID_IUnknown, (void**)&classObject); + if (FAILED(hr)) { + return nullptr; + } + + DWORD regCookie; + hr = CoRegisterClassObject(*proxyClsid, classObject, CLSCTX_INPROC_SERVER, + REGCLS_MULTIPLEUSE, ®Cookie); + if (FAILED(hr)) { + classObject->lpVtbl->Release(classObject); + return nullptr; + } + + ITypeLib* typeLib = nullptr; + hr = LoadTypeLibEx(modulePathBuf, REGKIND_NONE, &typeLib); + MOZ_ASSERT(SUCCEEDED(hr)); + if (FAILED(hr)) { + CoRevokeClassObject(regCookie); + classObject->lpVtbl->Release(classObject); + return nullptr; + } + + // RegisteredProxy takes ownership of proxyDll, classObject, and typeLib + // references + auto result(MakeUnique<RegisteredProxy>( + reinterpret_cast<uintptr_t>(proxyDll.disown()), classObject, regCookie, + typeLib)); + + if (!RegisterPSClsids(proxyInfo, proxyClsid)) { + return nullptr; + } + + return result; +} + +UniquePtr<RegisteredProxy> RegisterTypelib(const wchar_t* aLeafName, + RegistrationFlags aFlags) { + wchar_t modulePathBuf[MAX_PATH + 1] = {0}; + if (!BuildLibPath(aFlags, modulePathBuf, ArrayLength(modulePathBuf), + aLeafName)) { + return nullptr; + } + + ITypeLib* typeLib = nullptr; + HRESULT hr = LoadTypeLibEx(modulePathBuf, REGKIND_NONE, &typeLib); + if (FAILED(hr)) { + return nullptr; + } + + // RegisteredProxy takes ownership of typeLib reference + auto result(MakeUnique<RegisteredProxy>(typeLib)); + return result; +} + +RegisteredProxy::RegisteredProxy(uintptr_t aModule, IUnknown* aClassObject, + uint32_t aRegCookie, ITypeLib* aTypeLib) + : mModule(aModule), + mClassObject(aClassObject), + mRegCookie(aRegCookie), + mTypeLib(aTypeLib) +#if defined(MOZILLA_INTERNAL_API) + , + mIsRegisteredInMTA(IsCurrentThreadMTA()) +#endif // defined(MOZILLA_INTERNAL_API) +{ + MOZ_ASSERT(aClassObject); + MOZ_ASSERT(aTypeLib); + AddToRegistry(this); +} + +RegisteredProxy::RegisteredProxy(IUnknown* aClassObject, uint32_t aRegCookie, + ITypeLib* aTypeLib) + : mModule(0), + mClassObject(aClassObject), + mRegCookie(aRegCookie), + mTypeLib(aTypeLib) +#if defined(MOZILLA_INTERNAL_API) + , + mIsRegisteredInMTA(IsCurrentThreadMTA()) +#endif // defined(MOZILLA_INTERNAL_API) +{ + MOZ_ASSERT(aClassObject); + MOZ_ASSERT(aTypeLib); + AddToRegistry(this); +} + +// If we're initializing from a typelib, it doesn't matter which apartment we +// run in, so mIsRegisteredInMTA may always be set to false in this case. +RegisteredProxy::RegisteredProxy(ITypeLib* aTypeLib) + : mModule(0), + mClassObject(nullptr), + mRegCookie(0), + mTypeLib(aTypeLib) +#if defined(MOZILLA_INTERNAL_API) + , + mIsRegisteredInMTA(false) +#endif // defined(MOZILLA_INTERNAL_API) +{ + MOZ_ASSERT(aTypeLib); + AddToRegistry(this); +} + +void RegisteredProxy::Clear() { + if (mTypeLib) { + mTypeLib->lpVtbl->Release(mTypeLib); + mTypeLib = nullptr; + } + if (mClassObject) { + // NB: mClassObject and mRegCookie must be freed from inside the apartment + // which they were created in. + auto cleanupFn = [&]() -> void { + ::CoRevokeClassObject(mRegCookie); + mRegCookie = 0; + mClassObject->lpVtbl->Release(mClassObject); + mClassObject = nullptr; + }; +#if defined(MOZILLA_INTERNAL_API) + // This code only supports MTA when built internally + if (mIsRegisteredInMTA) { + EnsureMTA mta(cleanupFn); + } else { + cleanupFn(); + } +#else + cleanupFn(); +#endif // defined(MOZILLA_INTERNAL_API) + } + if (mModule) { + ::FreeLibrary(reinterpret_cast<HMODULE>(mModule)); + mModule = 0; + } +} + +RegisteredProxy::~RegisteredProxy() { + DeleteFromRegistry(this); + Clear(); +} + +RegisteredProxy::RegisteredProxy(RegisteredProxy&& aOther) + : mModule(0), + mClassObject(nullptr), + mRegCookie(0), + mTypeLib(nullptr) +#if defined(MOZILLA_INTERNAL_API) + , + mIsRegisteredInMTA(false) +#endif // defined(MOZILLA_INTERNAL_API) +{ + *this = std::forward<RegisteredProxy>(aOther); + AddToRegistry(this); +} + +RegisteredProxy& RegisteredProxy::operator=(RegisteredProxy&& aOther) { + Clear(); + + mModule = aOther.mModule; + aOther.mModule = 0; + mClassObject = aOther.mClassObject; + aOther.mClassObject = nullptr; + mRegCookie = aOther.mRegCookie; + aOther.mRegCookie = 0; + mTypeLib = aOther.mTypeLib; + aOther.mTypeLib = nullptr; + +#if defined(MOZILLA_INTERNAL_API) + mIsRegisteredInMTA = aOther.mIsRegisteredInMTA; +#endif // defined(MOZILLA_INTERNAL_API) + + return *this; +} + +HRESULT +RegisteredProxy::GetTypeInfoForGuid(REFGUID aGuid, + ITypeInfo** aOutTypeInfo) const { + if (!aOutTypeInfo) { + return E_INVALIDARG; + } + if (!mTypeLib) { + return E_UNEXPECTED; + } + return mTypeLib->lpVtbl->GetTypeInfoOfGuid(mTypeLib, aGuid, aOutTypeInfo); +} + +static StaticAutoPtr<Vector<RegisteredProxy*>> sRegistry; + +namespace UseGetMutexForAccess { + +// This must not be accessed directly; use GetMutex() instead +static CRITICAL_SECTION sMutex; + +} // namespace UseGetMutexForAccess + +static CRITICAL_SECTION* GetMutex() { + static CRITICAL_SECTION& mutex = []() -> CRITICAL_SECTION& { +#if defined(RELEASE_OR_BETA) + DWORD flags = CRITICAL_SECTION_NO_DEBUG_INFO; +#else + DWORD flags = 0; +#endif + InitializeCriticalSectionEx(&UseGetMutexForAccess::sMutex, 4000, flags); +#if !defined(MOZILLA_INTERNAL_API) + atexit([]() { DeleteCriticalSection(&UseGetMutexForAccess::sMutex); }); +#endif + return UseGetMutexForAccess::sMutex; + }(); + return &mutex; +} + +/* static */ +bool RegisteredProxy::Find(REFIID aIid, ITypeInfo** aTypeInfo) { + AutoCriticalSection lock(GetMutex()); + + if (!sRegistry) { + return false; + } + + for (auto&& proxy : *sRegistry) { + if (SUCCEEDED(proxy->GetTypeInfoForGuid(aIid, aTypeInfo))) { + return true; + } + } + + return false; +} + +/* static */ +void RegisteredProxy::AddToRegistry(RegisteredProxy* aProxy) { + MOZ_ASSERT(aProxy); + + AutoCriticalSection lock(GetMutex()); + + if (!sRegistry) { + sRegistry = new Vector<RegisteredProxy*>(); + +#if !defined(MOZILLA_INTERNAL_API) + // sRegistry allocation is fallible outside of Mozilla processes + if (!sRegistry) { + return; + } +#endif + } + + MOZ_ALWAYS_TRUE(sRegistry->emplaceBack(aProxy)); +} + +/* static */ +void RegisteredProxy::DeleteFromRegistry(RegisteredProxy* aProxy) { + MOZ_ASSERT(aProxy); + + AutoCriticalSection lock(GetMutex()); + + MOZ_ASSERT(sRegistry && !sRegistry->empty()); + + if (!sRegistry) { + return; + } + + sRegistry->erase(std::remove(sRegistry->begin(), sRegistry->end(), aProxy), + sRegistry->end()); + + if (sRegistry->empty()) { + sRegistry = nullptr; + } +} + +#if defined(MOZILLA_INTERNAL_API) + +static StaticAutoPtr<Vector<std::pair<const ArrayData*, size_t>>> sArrayData; + +void RegisterArrayData(const ArrayData* aArrayData, size_t aLength) { + AutoCriticalSection lock(GetMutex()); + + if (!sArrayData) { + sArrayData = new Vector<std::pair<const ArrayData*, size_t>>(); + ClearOnShutdown(&sArrayData, ShutdownPhase::XPCOMShutdownThreads); + } + + MOZ_ALWAYS_TRUE(sArrayData->emplaceBack(std::make_pair(aArrayData, aLength))); +} + +const ArrayData* FindArrayData(REFIID aIid, ULONG aMethodIndex) { + AutoCriticalSection lock(GetMutex()); + + if (!sArrayData) { + return nullptr; + } + + for (auto&& data : *sArrayData) { + for (size_t innerIdx = 0, innerLen = data.second; innerIdx < innerLen; + ++innerIdx) { + const ArrayData* array = data.first; + if (aMethodIndex == array[innerIdx].mMethodIndex && + IsInterfaceEqualToOrInheritedFrom(aIid, array[innerIdx].mIid, + aMethodIndex)) { + return &array[innerIdx]; + } + } + } + + return nullptr; +} + +#endif // defined(MOZILLA_INTERNAL_API) + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/Registration.h b/ipc/mscom/Registration.h new file mode 100644 index 0000000000..3dfd5458cd --- /dev/null +++ b/ipc/mscom/Registration.h @@ -0,0 +1,142 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_Registration_h +#define mozilla_mscom_Registration_h + +#include "mozilla/RefPtr.h" +#include "mozilla/UniquePtr.h" + +#include <objbase.h> + +struct ITypeInfo; +struct ITypeLib; + +namespace mozilla { +namespace mscom { + +/** + * Assumptions: + * (1) The DLL exports GetProxyDllInfo. This is not exported by default; it must + * be specified in the EXPORTS section of the DLL's module definition file. + */ +class RegisteredProxy { + public: + RegisteredProxy(uintptr_t aModule, IUnknown* aClassObject, + uint32_t aRegCookie, ITypeLib* aTypeLib); + RegisteredProxy(IUnknown* aClassObject, uint32_t aRegCookie, + ITypeLib* aTypeLib); + explicit RegisteredProxy(ITypeLib* aTypeLib); + RegisteredProxy(RegisteredProxy&& aOther); + RegisteredProxy& operator=(RegisteredProxy&& aOther); + + ~RegisteredProxy(); + + HRESULT GetTypeInfoForGuid(REFGUID aGuid, ITypeInfo** aOutTypeInfo) const; + + static bool Find(REFIID aIid, ITypeInfo** aOutTypeInfo); + + private: + RegisteredProxy() = delete; + RegisteredProxy(RegisteredProxy&) = delete; + RegisteredProxy& operator=(RegisteredProxy&) = delete; + + void Clear(); + + static void AddToRegistry(RegisteredProxy* aProxy); + static void DeleteFromRegistry(RegisteredProxy* aProxy); + + private: + // Not using Windows types here: We shouldn't #include windows.h + // since it might pull in COM code which we want to do very carefully in + // Registration.cpp. + uintptr_t mModule; + IUnknown* mClassObject; + uint32_t mRegCookie; + ITypeLib* mTypeLib; +#if defined(MOZILLA_INTERNAL_API) + bool mIsRegisteredInMTA; +#endif // defined(MOZILLA_INTERNAL_API) +}; + +enum class RegistrationFlags { eUseBinDirectory, eUseSystemDirectory }; + +// For our own DLL that we are currently executing in (ie, xul). +// Assumes corresponding TLB is embedded in resources. +UniquePtr<RegisteredProxy> RegisterProxy(); + +// For DLL files. Assumes corresponding TLB is embedded in resources. +UniquePtr<RegisteredProxy> RegisterProxy( + const wchar_t* aLeafName, + RegistrationFlags aFlags = RegistrationFlags::eUseBinDirectory); +// For standalone TLB files. +UniquePtr<RegisteredProxy> RegisterTypelib( + const wchar_t* aLeafName, + RegistrationFlags aFlags = RegistrationFlags::eUseBinDirectory); + +#if defined(MOZILLA_INTERNAL_API) + +/** + * The COM interceptor uses type library information to build its interface + * proxies. Unfortunately type libraries do not encode size_is and length_is + * annotations that have been specified in IDL. This structure allows us to + * explicitly declare such relationships so that the COM interceptor may + * be made aware of them. + */ +struct ArrayData { + enum class Flag { + eNone = 0, + eAllocatedByServer = 1 // This implies an extra level of indirection + }; + + ArrayData(REFIID aIid, ULONG aMethodIndex, ULONG aArrayParamIndex, + VARTYPE aArrayParamType, REFIID aArrayParamIid, + ULONG aLengthParamIndex, Flag aFlag = Flag::eNone) + : mIid(aIid), + mMethodIndex(aMethodIndex), + mArrayParamIndex(aArrayParamIndex), + mArrayParamType(aArrayParamType), + mArrayParamIid(aArrayParamIid), + mLengthParamIndex(aLengthParamIndex), + mFlag(aFlag) {} + + ArrayData(const ArrayData& aOther) { *this = aOther; } + + ArrayData& operator=(const ArrayData& aOther) { + mIid = aOther.mIid; + mMethodIndex = aOther.mMethodIndex; + mArrayParamIndex = aOther.mArrayParamIndex; + mArrayParamType = aOther.mArrayParamType; + mArrayParamIid = aOther.mArrayParamIid; + mLengthParamIndex = aOther.mLengthParamIndex; + mFlag = aOther.mFlag; + return *this; + } + + IID mIid; + ULONG mMethodIndex; + ULONG mArrayParamIndex; + VARTYPE mArrayParamType; + IID mArrayParamIid; + ULONG mLengthParamIndex; + Flag mFlag; +}; + +void RegisterArrayData(const ArrayData* aArrayData, size_t aLength); + +template <size_t N> +inline void RegisterArrayData(const ArrayData (&aData)[N]) { + RegisterArrayData(aData, N); +} + +const ArrayData* FindArrayData(REFIID aIid, ULONG aMethodIndex); + +#endif // defined(MOZILLA_INTERNAL_API) + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_Registration_h diff --git a/ipc/mscom/RegistrationAnnotator.cpp b/ipc/mscom/RegistrationAnnotator.cpp new file mode 100644 index 0000000000..7c45920975 --- /dev/null +++ b/ipc/mscom/RegistrationAnnotator.cpp @@ -0,0 +1,385 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "RegistrationAnnotator.h" + +#include "mozilla/JSONStringWriteFuncs.h" +#include "mozilla/mscom/Utils.h" +#include "mozilla/NotNull.h" +#include "nsExceptionHandler.h" +#include "nsPrintfCString.h" +#include "nsWindowsHelpers.h" +#include "nsXULAppAPI.h" + +#include <oleauto.h> + +namespace mozilla { +namespace mscom { + +static const char16_t kSoftwareClasses[] = u"SOFTWARE\\Classes"; +static const char16_t kInterface[] = u"\\Interface\\"; +static const char16_t kDefaultValue[] = u""; +static const char16_t kThreadingModel[] = u"ThreadingModel"; +static const char16_t kBackslash[] = u"\\"; +static const char16_t kFlags[] = u"FLAGS"; +static const char16_t kProxyStubClsid32[] = u"\\ProxyStubClsid32"; +static const char16_t kClsid[] = u"\\CLSID\\"; +static const char16_t kInprocServer32[] = u"\\InprocServer32"; +static const char16_t kInprocHandler32[] = u"\\InprocHandler32"; +static const char16_t kTypeLib[] = u"\\TypeLib"; +static const char16_t kVersion[] = u"Version"; +static const char16_t kWin32[] = u"Win32"; +static const char16_t kWin64[] = u"Win64"; + +static bool GetStringValue(HKEY aBaseKey, const nsAString& aStrSubKey, + const nsAString& aValueName, nsAString& aOutput) { + const nsString& flatSubKey = PromiseFlatString(aStrSubKey); + const nsString& flatValueName = PromiseFlatString(aValueName); + LPCWSTR valueName = aValueName.IsEmpty() ? nullptr : flatValueName.get(); + + DWORD type = 0; + DWORD numBytes = 0; + LONG result = RegGetValue(aBaseKey, flatSubKey.get(), valueName, RRF_RT_ANY, + &type, nullptr, &numBytes); + if (result != ERROR_SUCCESS || (type != REG_SZ && type != REG_EXPAND_SZ)) { + return false; + } + + int numChars = (numBytes + 1) / sizeof(wchar_t); + aOutput.SetLength(numChars); + + DWORD acceptFlag = type == REG_SZ ? RRF_RT_REG_SZ : RRF_RT_REG_EXPAND_SZ; + + result = RegGetValue(aBaseKey, flatSubKey.get(), valueName, acceptFlag, + nullptr, aOutput.BeginWriting(), &numBytes); + if (result == ERROR_SUCCESS) { + // Truncate null terminator + aOutput.SetLength(((numBytes + 1) / sizeof(wchar_t)) - 1); + } + + return result == ERROR_SUCCESS; +} + +template <size_t N> +inline static bool GetStringValue(HKEY aBaseKey, const nsAString& aStrSubKey, + const char16_t (&aValueName)[N], + nsAString& aOutput) { + return GetStringValue(aBaseKey, aStrSubKey, nsLiteralString(aValueName), + aOutput); +} + +/** + * This function fails unless the entire string has been converted. + * (eg, the string "FLAGS" will convert to 0xF but we will return false) + */ +static bool ConvertLCID(const wchar_t* aStr, NotNull<unsigned long*> aOutLcid) { + wchar_t* endChar; + *aOutLcid = wcstoul(aStr, &endChar, 16); + return *endChar == 0; +} + +static bool GetLoadedPath(nsAString& aPath) { + // These paths may be REG_EXPAND_SZ, so we expand any environment strings + DWORD bufCharLen = + ExpandEnvironmentStrings(PromiseFlatString(aPath).get(), nullptr, 0); + + auto buf = MakeUnique<WCHAR[]>(bufCharLen); + + if (!ExpandEnvironmentStrings(PromiseFlatString(aPath).get(), buf.get(), + bufCharLen)) { + return false; + } + + // Use LoadLibrary so that the DLL is resolved using the loader's DLL search + // rules + nsModuleHandle mod(LoadLibrary(buf.get())); + if (!mod) { + return false; + } + + WCHAR finalPath[MAX_PATH + 1] = {}; + DWORD result = GetModuleFileNameW(mod, finalPath, ArrayLength(finalPath)); + if (!result || (result == ArrayLength(finalPath) && + GetLastError() == ERROR_INSUFFICIENT_BUFFER)) { + return false; + } + + aPath = nsDependentString(finalPath, result); + return true; +} + +static void AnnotateClsidRegistrationForHive( + JSONWriter& aJson, HKEY aHive, const nsAString& aClsid, + const JSONWriter::CollectionStyle aStyle) { + nsAutoString clsidSubkey; + clsidSubkey.AppendLiteral(kSoftwareClasses); + clsidSubkey.AppendLiteral(kClsid); + clsidSubkey.Append(aClsid); + + nsAutoString className; + if (GetStringValue(aHive, clsidSubkey, kDefaultValue, className)) { + aJson.StringProperty("ClassName", NS_ConvertUTF16toUTF8(className)); + } + + nsAutoString inprocServerSubkey(clsidSubkey); + inprocServerSubkey.AppendLiteral(kInprocServer32); + + nsAutoString pathToServerDll; + if (GetStringValue(aHive, inprocServerSubkey, kDefaultValue, + pathToServerDll)) { + aJson.StringProperty("Path", NS_ConvertUTF16toUTF8(pathToServerDll)); + if (GetLoadedPath(pathToServerDll)) { + aJson.StringProperty("LoadedPath", + NS_ConvertUTF16toUTF8(pathToServerDll)); + } + } + + nsAutoString apartment; + if (GetStringValue(aHive, inprocServerSubkey, kThreadingModel, apartment)) { + aJson.StringProperty("ThreadingModel", NS_ConvertUTF16toUTF8(apartment)); + } + + nsAutoString inprocHandlerSubkey(clsidSubkey); + inprocHandlerSubkey.AppendLiteral(kInprocHandler32); + nsAutoString pathToHandlerDll; + if (GetStringValue(aHive, inprocHandlerSubkey, kDefaultValue, + pathToHandlerDll)) { + aJson.StringProperty("HandlerPath", + NS_ConvertUTF16toUTF8(pathToHandlerDll)); + if (GetLoadedPath(pathToHandlerDll)) { + aJson.StringProperty("LoadedHandlerPath", + NS_ConvertUTF16toUTF8(pathToHandlerDll)); + } + } + + nsAutoString handlerApartment; + if (GetStringValue(aHive, inprocHandlerSubkey, kThreadingModel, + handlerApartment)) { + aJson.StringProperty("HandlerThreadingModel", + NS_ConvertUTF16toUTF8(handlerApartment)); + } +} + +static void CheckTlbPath(JSONWriter& aJson, const nsAString& aTypelibPath) { + const nsString& flatPath = PromiseFlatString(aTypelibPath); + DWORD bufCharLen = ExpandEnvironmentStrings(flatPath.get(), nullptr, 0); + + auto buf = MakeUnique<WCHAR[]>(bufCharLen); + + if (!ExpandEnvironmentStrings(flatPath.get(), buf.get(), bufCharLen)) { + return; + } + + // See whether this tlb can actually be loaded + RefPtr<ITypeLib> typeLib; + HRESULT hr = LoadTypeLibEx(buf.get(), REGKIND_NONE, getter_AddRefs(typeLib)); + + nsPrintfCString loadResult("0x%08lX", hr); + aJson.StringProperty("LoadResult", loadResult); +} + +template <size_t N> +static void AnnotateTypelibPlatform(JSONWriter& aJson, HKEY aBaseKey, + const nsAString& aLcidSubkey, + const char16_t (&aPlatform)[N], + const JSONWriter::CollectionStyle aStyle) { + nsLiteralString platform(aPlatform); + + nsAutoString fullSubkey(aLcidSubkey); + fullSubkey.AppendLiteral(kBackslash); + fullSubkey.Append(platform); + + nsAutoString tlbPath; + if (GetStringValue(aBaseKey, fullSubkey, kDefaultValue, tlbPath)) { + aJson.StartObjectProperty(NS_ConvertUTF16toUTF8(platform), aStyle); + aJson.StringProperty("Path", NS_ConvertUTF16toUTF8(tlbPath)); + CheckTlbPath(aJson, tlbPath); + aJson.EndObject(); + } +} + +static void AnnotateTypelibRegistrationForHive( + JSONWriter& aJson, HKEY aHive, const nsAString& aTypelibId, + const nsAString& aTypelibVersion, + const JSONWriter::CollectionStyle aStyle) { + nsAutoString typelibSubKey; + typelibSubKey.AppendLiteral(kSoftwareClasses); + typelibSubKey.AppendLiteral(kTypeLib); + typelibSubKey.AppendLiteral(kBackslash); + typelibSubKey.Append(aTypelibId); + typelibSubKey.AppendLiteral(kBackslash); + typelibSubKey.Append(aTypelibVersion); + + nsAutoString typelibDesc; + if (GetStringValue(aHive, typelibSubKey, kDefaultValue, typelibDesc)) { + aJson.StringProperty("Description", NS_ConvertUTF16toUTF8(typelibDesc)); + } + + nsAutoString flagsSubKey(typelibSubKey); + flagsSubKey.AppendLiteral(kBackslash); + flagsSubKey.AppendLiteral(kFlags); + + nsAutoString typelibFlags; + if (GetStringValue(aHive, flagsSubKey, kDefaultValue, typelibFlags)) { + aJson.StringProperty("Flags", NS_ConvertUTF16toUTF8(typelibFlags)); + } + + HKEY rawTypelibKey; + LONG result = + RegOpenKeyEx(aHive, typelibSubKey.get(), 0, KEY_READ, &rawTypelibKey); + if (result != ERROR_SUCCESS) { + return; + } + nsAutoRegKey typelibKey(rawTypelibKey); + + const size_t kMaxLcidCharLen = 9; + WCHAR keyName[kMaxLcidCharLen]; + + for (DWORD index = 0; result == ERROR_SUCCESS; ++index) { + DWORD keyNameLength = ArrayLength(keyName); + result = RegEnumKeyEx(typelibKey, index, keyName, &keyNameLength, nullptr, + nullptr, nullptr, nullptr); + + unsigned long lcid; + if (result == ERROR_SUCCESS && ConvertLCID(keyName, WrapNotNull(&lcid))) { + nsDependentString strLcid(keyName, keyNameLength); + aJson.StartObjectProperty(NS_ConvertUTF16toUTF8(strLcid), aStyle); + AnnotateTypelibPlatform(aJson, typelibKey, strLcid, kWin32, aStyle); +#if defined(HAVE_64BIT_BUILD) + AnnotateTypelibPlatform(aJson, typelibKey, strLcid, kWin64, aStyle); +#endif + aJson.EndObject(); + } + } +} + +static void AnnotateInterfaceRegistrationForHive( + JSONWriter& aJson, HKEY aHive, REFIID aIid, + const JSONWriter::CollectionStyle aStyle) { + nsAutoString interfaceSubKey; + interfaceSubKey.AppendLiteral(kSoftwareClasses); + interfaceSubKey.AppendLiteral(kInterface); + nsAutoString iid; + GUIDToString(aIid, iid); + interfaceSubKey.Append(iid); + + nsAutoString interfaceName; + if (GetStringValue(aHive, interfaceSubKey, kDefaultValue, interfaceName)) { + aJson.StringProperty("InterfaceName", NS_ConvertUTF16toUTF8(interfaceName)); + } + + nsAutoString psSubKey(interfaceSubKey); + psSubKey.AppendLiteral(kProxyStubClsid32); + + nsAutoString psClsid; + if (GetStringValue(aHive, psSubKey, kDefaultValue, psClsid)) { + aJson.StartObjectProperty("ProxyStub", aStyle); + aJson.StringProperty("CLSID", NS_ConvertUTF16toUTF8(psClsid)); + AnnotateClsidRegistrationForHive(aJson, aHive, psClsid, aStyle); + aJson.EndObject(); + } + + nsAutoString typelibSubKey(interfaceSubKey); + typelibSubKey.AppendLiteral(kTypeLib); + + nsAutoString typelibId; + bool haveTypelibId = + GetStringValue(aHive, typelibSubKey, kDefaultValue, typelibId); + + nsAutoString typelibVersion; + bool haveTypelibVersion = + GetStringValue(aHive, typelibSubKey, kVersion, typelibVersion); + + if (haveTypelibId || haveTypelibVersion) { + aJson.StartObjectProperty("TypeLib", aStyle); + } + + if (haveTypelibId) { + aJson.StringProperty("ID", NS_ConvertUTF16toUTF8(typelibId)); + } + + if (haveTypelibVersion) { + aJson.StringProperty("Version", NS_ConvertUTF16toUTF8(typelibVersion)); + } + + if (haveTypelibId && haveTypelibVersion) { + AnnotateTypelibRegistrationForHive(aJson, aHive, typelibId, typelibVersion, + aStyle); + } + + if (haveTypelibId || haveTypelibVersion) { + aJson.EndObject(); + } +} + +void AnnotateInterfaceRegistration(REFIID aIid) { +#if defined(DEBUG) + const JSONWriter::CollectionStyle style = JSONWriter::MultiLineStyle; +#else + const JSONWriter::CollectionStyle style = JSONWriter::SingleLineStyle; +#endif + + JSONStringWriteFunc<nsCString> jsonString; + JSONWriter json(jsonString); + + json.Start(style); + + json.StartObjectProperty("HKLM", style); + AnnotateInterfaceRegistrationForHive(json, HKEY_LOCAL_MACHINE, aIid, style); + json.EndObject(); + + json.StartObjectProperty("HKCU", style); + AnnotateInterfaceRegistrationForHive(json, HKEY_CURRENT_USER, aIid, style); + json.EndObject(); + + json.End(); + + CrashReporter::Annotation annotationKey; + if (XRE_IsParentProcess()) { + annotationKey = CrashReporter::Annotation::InterfaceRegistrationInfoParent; + } else { + annotationKey = CrashReporter::Annotation::InterfaceRegistrationInfoChild; + } + CrashReporter::AnnotateCrashReport(annotationKey, jsonString.StringCRef()); +} + +void AnnotateClassRegistration(REFCLSID aClsid) { +#if defined(DEBUG) + const JSONWriter::CollectionStyle style = JSONWriter::MultiLineStyle; +#else + const JSONWriter::CollectionStyle style = JSONWriter::SingleLineStyle; +#endif + + nsAutoString strClsid; + GUIDToString(aClsid, strClsid); + + JSONStringWriteFunc<nsCString> jsonString; + JSONWriter json(jsonString); + + json.Start(style); + + json.StartObjectProperty("HKLM", style); + AnnotateClsidRegistrationForHive(json, HKEY_LOCAL_MACHINE, strClsid, style); + json.EndObject(); + + json.StartObjectProperty("HKCU", style); + AnnotateClsidRegistrationForHive(json, HKEY_CURRENT_USER, strClsid, style); + json.EndObject(); + + json.End(); + + CrashReporter::Annotation annotationKey; + if (XRE_IsParentProcess()) { + annotationKey = CrashReporter::Annotation::ClassRegistrationInfoParent; + } else { + annotationKey = CrashReporter::Annotation::ClassRegistrationInfoChild; + } + + CrashReporter::AnnotateCrashReport(annotationKey, jsonString.StringCRef()); +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/RegistrationAnnotator.h b/ipc/mscom/RegistrationAnnotator.h new file mode 100644 index 0000000000..3065595045 --- /dev/null +++ b/ipc/mscom/RegistrationAnnotator.h @@ -0,0 +1,19 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_RegistrationAnnotator_h +#define mozilla_mscom_RegistrationAnnotator_h + +namespace mozilla { +namespace mscom { + +void AnnotateInterfaceRegistration(REFIID aIid); +void AnnotateClassRegistration(REFCLSID aClsid); + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_RegistrationAnnotator_h diff --git a/ipc/mscom/SpinEvent.cpp b/ipc/mscom/SpinEvent.cpp new file mode 100644 index 0000000000..493a1646b7 --- /dev/null +++ b/ipc/mscom/SpinEvent.cpp @@ -0,0 +1,77 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/mscom/SpinEvent.h" + +#include "mozilla/ArrayUtils.h" +#include "mozilla/Assertions.h" +#include "mozilla/TimeStamp.h" +#include "nsServiceManagerUtils.h" +#include "nsString.h" +#include "nsSystemInfo.h" + +namespace mozilla { +namespace mscom { + +static const TimeDuration kMaxSpinTime = TimeDuration::FromMilliseconds(30); +bool SpinEvent::sIsMulticore = false; + +/* static */ +bool SpinEvent::InitStatics() { + SYSTEM_INFO sysInfo; + ::GetSystemInfo(&sysInfo); + sIsMulticore = sysInfo.dwNumberOfProcessors > 1; + return true; +} + +SpinEvent::SpinEvent() : mDone(false) { + static const bool gotStatics = InitStatics(); + MOZ_ALWAYS_TRUE(gotStatics); + + mDoneEvent.own(::CreateEventW(nullptr, FALSE, FALSE, nullptr)); + MOZ_ASSERT(mDoneEvent); +} + +bool SpinEvent::Wait(HANDLE aTargetThread) { + MOZ_ASSERT(aTargetThread); + if (!aTargetThread) { + return false; + } + + if (sIsMulticore) { + // Bug 1311834: Spinning allows for faster response than waiting on an + // event, as events are constrained by the system's timer resolution. + // Bug 1429665: However, we only want to spin for a very short time. If + // we're waiting for a while, we don't want to be burning CPU for the + // entire time. At that point, a few extra ms isn't going to make much + // difference to perceived responsiveness. + TimeStamp start(TimeStamp::Now()); + while (!mDone) { + TimeDuration elapsed(TimeStamp::Now() - start); + if (elapsed >= kMaxSpinTime) { + break; + } + YieldProcessor(); + } + if (mDone) { + return true; + } + } + + MOZ_ASSERT(mDoneEvent); + HANDLE handles[] = {mDoneEvent, aTargetThread}; + DWORD waitResult = ::WaitForMultipleObjects(mozilla::ArrayLength(handles), + handles, FALSE, INFINITE); + return waitResult == WAIT_OBJECT_0; +} + +void SpinEvent::Signal() { + ::SetEvent(mDoneEvent); + mDone = true; +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/SpinEvent.h b/ipc/mscom/SpinEvent.h new file mode 100644 index 0000000000..461865a6bc --- /dev/null +++ b/ipc/mscom/SpinEvent.h @@ -0,0 +1,40 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_SpinEvent_h +#define mozilla_mscom_SpinEvent_h + +#include "mozilla/Atomics.h" +#include "mozilla/Attributes.h" +#include "nsWindowsHelpers.h" + +namespace mozilla { +namespace mscom { + +class MOZ_NON_TEMPORARY_CLASS SpinEvent final { + public: + SpinEvent(); + ~SpinEvent() = default; + + bool Wait(HANDLE aTargetThread); + void Signal(); + + SpinEvent(const SpinEvent&) = delete; + SpinEvent(SpinEvent&&) = delete; + SpinEvent& operator=(SpinEvent&&) = delete; + SpinEvent& operator=(const SpinEvent&) = delete; + + private: + Atomic<bool, ReleaseAcquire> mDone; + nsAutoHandle mDoneEvent; + static bool InitStatics(); + static bool sIsMulticore; +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_SpinEvent_h diff --git a/ipc/mscom/StructStream.cpp b/ipc/mscom/StructStream.cpp new file mode 100644 index 0000000000..4a24962a6f --- /dev/null +++ b/ipc/mscom/StructStream.cpp @@ -0,0 +1,23 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <malloc.h> +#include <rpc.h> + +/** + * These functions need to be defined in order for the types that use + * mozilla::mscom::StructToStream and mozilla::mscom::StructFromStream to work. + */ +extern "C" { + +void __RPC_FAR* __RPC_USER midl_user_allocate(size_t aNumBytes) { + const unsigned long kRpcReqdBufAlignment = 8; + return _aligned_malloc(aNumBytes, kRpcReqdBufAlignment); +} + +void __RPC_USER midl_user_free(void* aBuffer) { _aligned_free(aBuffer); } + +} // extern "C" diff --git a/ipc/mscom/StructStream.h b/ipc/mscom/StructStream.h new file mode 100644 index 0000000000..45a32c7687 --- /dev/null +++ b/ipc/mscom/StructStream.h @@ -0,0 +1,239 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_StructStream_h +#define mozilla_mscom_StructStream_h + +#include "mozilla/Attributes.h" +#include "mozilla/UniquePtr.h" +#include "nscore.h" + +#include <memory.h> +#include <midles.h> +#include <objidl.h> +#include <rpc.h> + +/** + * This code is used for (de)serializing data structures that have been + * declared using midl, thus allowing us to use Microsoft RPC for marshaling + * data for our COM handlers that may run in other processes that are not ours. + */ + +namespace mozilla { +namespace mscom { + +namespace detail { + +typedef ULONG EncodedLenT; + +} // namespace detail + +class MOZ_NON_TEMPORARY_CLASS StructToStream { + public: + /** + * This constructor variant represents an empty/null struct to be serialized. + */ + StructToStream() + : mStatus(RPC_S_OK), mHandle(nullptr), mBuffer(nullptr), mEncodedLen(0) {} + + template <typename StructT> + StructToStream(StructT& aSrcStruct, void (*aEncodeFnPtr)(handle_t, StructT*)) + : mStatus(RPC_X_INVALID_BUFFER), + mHandle(nullptr), + mBuffer(nullptr), + mEncodedLen(0) { + mStatus = + ::MesEncodeDynBufferHandleCreate(&mBuffer, &mEncodedLen, &mHandle); + if (mStatus != RPC_S_OK) { + return; + } + + MOZ_SEH_TRY { aEncodeFnPtr(mHandle, &aSrcStruct); } +#ifdef HAVE_SEH_EXCEPTIONS + MOZ_SEH_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { + mStatus = ::RpcExceptionCode(); + return; + } +#endif + + if (!mBuffer || !mEncodedLen) { + mStatus = RPC_X_NO_MEMORY; + return; + } + } + + ~StructToStream() { + if (mHandle) { + ::MesHandleFree(mHandle); + } + if (mBuffer) { + // Bug 1440564: You'd think that MesHandleFree would free the buffer, + // since it was created by RPC, but it doesn't. + midl_user_free(mBuffer); + } + } + + static unsigned long GetEmptySize() { return sizeof(detail::EncodedLenT); } + + static HRESULT WriteEmpty(IStream* aDestStream) { + StructToStream emptyStruct; + return emptyStruct.Write(aDestStream); + } + + explicit operator bool() const { return mStatus == RPC_S_OK; } + + bool IsEmpty() const { return mStatus == RPC_S_OK && !mEncodedLen; } + + unsigned long GetSize() const { return sizeof(mEncodedLen) + mEncodedLen; } + + HRESULT Write(IStream* aDestStream) { + if (!aDestStream) { + return E_INVALIDARG; + } + if (mStatus != RPC_S_OK) { + return E_FAIL; + } + + ULONG bytesWritten; + HRESULT hr = + aDestStream->Write(&mEncodedLen, sizeof(mEncodedLen), &bytesWritten); + if (FAILED(hr)) { + return hr; + } + if (bytesWritten != sizeof(mEncodedLen)) { + return E_UNEXPECTED; + } + + if (mBuffer && mEncodedLen) { + hr = aDestStream->Write(mBuffer, mEncodedLen, &bytesWritten); + if (FAILED(hr)) { + return hr; + } + if (bytesWritten != mEncodedLen) { + return E_UNEXPECTED; + } + } + + return hr; + } + + StructToStream(const StructToStream&) = delete; + StructToStream(StructToStream&&) = delete; + StructToStream& operator=(const StructToStream&) = delete; + StructToStream& operator=(StructToStream&&) = delete; + + private: + RPC_STATUS mStatus; + handle_t mHandle; + char* mBuffer; + detail::EncodedLenT mEncodedLen; +}; + +class MOZ_NON_TEMPORARY_CLASS StructFromStream { + struct AlignedFreeDeleter { + void operator()(void* aPtr) { ::_aligned_free(aPtr); } + }; + + static const detail::EncodedLenT kRpcReqdBufAlignment = 8; + + public: + explicit StructFromStream(IStream* aStream) + : mStatus(RPC_X_INVALID_BUFFER), mHandle(nullptr) { + MOZ_ASSERT(aStream); + + // Read the length of the encoded data first + detail::EncodedLenT encodedLen = 0; + ULONG bytesRead = 0; + HRESULT hr = aStream->Read(&encodedLen, sizeof(encodedLen), &bytesRead); + if (FAILED(hr)) { + return; + } + + // NB: Some implementations of IStream return S_FALSE to indicate EOF, + // other implementations return S_OK and set the number of bytes read to 0. + // We must handle both. + if (hr == S_FALSE || !bytesRead) { + mStatus = RPC_S_OBJECT_NOT_FOUND; + return; + } + + if (bytesRead != sizeof(encodedLen)) { + return; + } + + if (!encodedLen) { + mStatus = RPC_S_OBJECT_NOT_FOUND; + return; + } + + MOZ_ASSERT(encodedLen % kRpcReqdBufAlignment == 0); + if (encodedLen % kRpcReqdBufAlignment) { + return; + } + + // This memory allocation is fallible + mEncodedBuffer.reset(static_cast<char*>( + ::_aligned_malloc(encodedLen, kRpcReqdBufAlignment))); + if (!mEncodedBuffer) { + return; + } + + ULONG bytesReadFromStream = 0; + hr = aStream->Read(mEncodedBuffer.get(), encodedLen, &bytesReadFromStream); + if (FAILED(hr) || bytesReadFromStream != encodedLen) { + return; + } + + mStatus = ::MesDecodeBufferHandleCreate(mEncodedBuffer.get(), encodedLen, + &mHandle); + } + + ~StructFromStream() { + if (mHandle) { + ::MesHandleFree(mHandle); + } + } + + explicit operator bool() const { return mStatus == RPC_S_OK || IsEmpty(); } + + bool IsEmpty() const { return mStatus == RPC_S_OBJECT_NOT_FOUND; } + + template <typename StructT> + bool Read(StructT* aDestStruct, void (*aDecodeFnPtr)(handle_t, StructT*)) { + if (!aDestStruct || !aDecodeFnPtr || mStatus != RPC_S_OK) { + return false; + } + + // NB: Deserialization will fail with BSTRs unless the destination data + // is zeroed out! + ZeroMemory(aDestStruct, sizeof(StructT)); + + MOZ_SEH_TRY { aDecodeFnPtr(mHandle, aDestStruct); } +#ifdef HAVE_SEH_EXCEPTIONS + MOZ_SEH_EXCEPT(EXCEPTION_EXECUTE_HANDLER) { + mStatus = ::RpcExceptionCode(); + return false; + } +#endif + + return true; + } + + StructFromStream(const StructFromStream&) = delete; + StructFromStream(StructFromStream&&) = delete; + StructFromStream& operator=(const StructFromStream&) = delete; + StructFromStream& operator=(StructFromStream&&) = delete; + + private: + RPC_STATUS mStatus; + handle_t mHandle; + UniquePtr<char, AlignedFreeDeleter> mEncodedBuffer; +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_StructStream_h diff --git a/ipc/mscom/Utils.cpp b/ipc/mscom/Utils.cpp new file mode 100644 index 0000000000..377bbeba5c --- /dev/null +++ b/ipc/mscom/Utils.cpp @@ -0,0 +1,600 @@ +/* -*- Mode: C; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#if defined(MOZILLA_INTERNAL_API) +# include "MainThreadUtils.h" +# include "mozilla/dom/ContentChild.h" +#endif + +#if defined(ACCESSIBILITY) +# include "mozilla/mscom/Registration.h" +# if defined(MOZILLA_INTERNAL_API) +# include "nsTArray.h" +# endif +#endif + +#include "mozilla/ArrayUtils.h" +#include "mozilla/mscom/COMWrappers.h" +#include "mozilla/DebugOnly.h" +#include "mozilla/mscom/Objref.h" +#include "mozilla/mscom/Utils.h" +#include "mozilla/RefPtr.h" +#include "mozilla/WindowsVersion.h" + +#include <objidl.h> +#include <shlwapi.h> +#include <winnt.h> + +#include <utility> + +#if defined(_MSC_VER) +extern "C" IMAGE_DOS_HEADER __ImageBase; +#endif + +namespace mozilla { +namespace mscom { + +bool IsCOMInitializedOnCurrentThread() { + APTTYPE aptType; + APTTYPEQUALIFIER aptTypeQualifier; + HRESULT hr = wrapped::CoGetApartmentType(&aptType, &aptTypeQualifier); + return hr != CO_E_NOTINITIALIZED; +} + +bool IsCurrentThreadMTA() { + APTTYPE aptType; + APTTYPEQUALIFIER aptTypeQualifier; + HRESULT hr = wrapped::CoGetApartmentType(&aptType, &aptTypeQualifier); + if (FAILED(hr)) { + return false; + } + + return aptType == APTTYPE_MTA; +} + +bool IsCurrentThreadExplicitMTA() { + APTTYPE aptType; + APTTYPEQUALIFIER aptTypeQualifier; + HRESULT hr = wrapped::CoGetApartmentType(&aptType, &aptTypeQualifier); + if (FAILED(hr)) { + return false; + } + + return aptType == APTTYPE_MTA && + aptTypeQualifier != APTTYPEQUALIFIER_IMPLICIT_MTA; +} + +bool IsCurrentThreadImplicitMTA() { + APTTYPE aptType; + APTTYPEQUALIFIER aptTypeQualifier; + HRESULT hr = wrapped::CoGetApartmentType(&aptType, &aptTypeQualifier); + if (FAILED(hr)) { + return false; + } + + return aptType == APTTYPE_MTA && + aptTypeQualifier == APTTYPEQUALIFIER_IMPLICIT_MTA; +} + +#if defined(MOZILLA_INTERNAL_API) +bool IsCurrentThreadNonMainMTA() { + if (NS_IsMainThread()) { + return false; + } + + return IsCurrentThreadMTA(); +} +#endif // defined(MOZILLA_INTERNAL_API) + +bool IsProxy(IUnknown* aUnknown) { + if (!aUnknown) { + return false; + } + + // Only proxies implement this interface, so if it is present then we must + // be dealing with a proxied object. + RefPtr<IClientSecurity> clientSecurity; + HRESULT hr = aUnknown->QueryInterface(IID_IClientSecurity, + (void**)getter_AddRefs(clientSecurity)); + if (SUCCEEDED(hr) || hr == RPC_E_WRONG_THREAD) { + return true; + } + return false; +} + +bool IsValidGUID(REFGUID aCheckGuid) { + // This function determines whether or not aCheckGuid conforms to RFC4122 + // as it applies to Microsoft COM. + + BYTE variant = aCheckGuid.Data4[0]; + if (!(variant & 0x80)) { + // NCS Reserved + return false; + } + if ((variant & 0xE0) == 0xE0) { + // Reserved for future use + return false; + } + if ((variant & 0xC0) == 0xC0) { + // Microsoft Reserved. + return true; + } + + BYTE version = HIBYTE(aCheckGuid.Data3) >> 4; + // Other versions are specified in RFC4122 but these are the two used by COM. + return version == 1 || version == 4; +} + +uintptr_t GetContainingModuleHandle() { + HMODULE thisModule = nullptr; +#if defined(_MSC_VER) + thisModule = reinterpret_cast<HMODULE>(&__ImageBase); +#else + if (!GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + reinterpret_cast<LPCTSTR>(&GetContainingModuleHandle), + &thisModule)) { + return 0; + } +#endif + return reinterpret_cast<uintptr_t>(thisModule); +} + +namespace detail { + +long BuildRegGuidPath(REFGUID aGuid, const GuidType aGuidType, wchar_t* aBuf, + const size_t aBufLen) { + constexpr wchar_t kClsid[] = L"CLSID\\"; + constexpr wchar_t kAppid[] = L"AppID\\"; + constexpr wchar_t kSubkeyBase[] = L"SOFTWARE\\Classes\\"; + + // We exclude null terminators in these length calculations because we include + // the stringified GUID's null terminator at the end. Since kClsid and kAppid + // have identical lengths, we just choose one to compute this length. + constexpr size_t kSubkeyBaseLen = mozilla::ArrayLength(kSubkeyBase) - 1; + constexpr size_t kSubkeyLen = + kSubkeyBaseLen + mozilla::ArrayLength(kClsid) - 1; + // Guid length as formatted for the registry (including curlies and dashes), + // but excluding null terminator. + constexpr size_t kGuidLen = kGuidRegFormatCharLenInclNul - 1; + constexpr size_t kExpectedPathLenInclNul = kSubkeyLen + kGuidLen + 1; + + if (aBufLen < kExpectedPathLenInclNul) { + // Buffer is too short + return E_INVALIDARG; + } + + if (wcscpy_s(aBuf, aBufLen, kSubkeyBase)) { + return E_INVALIDARG; + } + + const wchar_t* strGuidType = aGuidType == GuidType::CLSID ? kClsid : kAppid; + if (wcscat_s(aBuf, aBufLen, strGuidType)) { + return E_INVALIDARG; + } + + int guidConversionResult = + ::StringFromGUID2(aGuid, &aBuf[kSubkeyLen], aBufLen - kSubkeyLen); + if (!guidConversionResult) { + return E_INVALIDARG; + } + + return S_OK; +} + +} // namespace detail + +long CreateStream(const uint8_t* aInitBuf, const uint32_t aInitBufSize, + IStream** aOutStream) { + if (!aInitBufSize || !aOutStream) { + return E_INVALIDARG; + } + + *aOutStream = nullptr; + + HRESULT hr; + RefPtr<IStream> stream; + + if (IsWin8OrLater()) { + // SHCreateMemStream is not safe for us to use until Windows 8. On older + // versions of Windows it is not thread-safe and it creates IStreams that do + // not support the full IStream API. + + // If aInitBuf is null then initSize must be 0. + UINT initSize = aInitBuf ? aInitBufSize : 0; + stream = already_AddRefed<IStream>(::SHCreateMemStream(aInitBuf, initSize)); + if (!stream) { + return E_OUTOFMEMORY; + } + + if (!aInitBuf) { + // Now we'll set the required size + ULARGE_INTEGER newSize; + newSize.QuadPart = aInitBufSize; + hr = stream->SetSize(newSize); + if (FAILED(hr)) { + return hr; + } + } + } else { + HGLOBAL hglobal = ::GlobalAlloc(GMEM_MOVEABLE, aInitBufSize); + if (!hglobal) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + + // stream takes ownership of hglobal if this call is successful + hr = ::CreateStreamOnHGlobal(hglobal, TRUE, getter_AddRefs(stream)); + if (FAILED(hr)) { + ::GlobalFree(hglobal); + return hr; + } + + // The default stream size is derived from ::GlobalSize(hglobal), which due + // to rounding may be larger than aInitBufSize. We forcibly set the correct + // stream size here. + ULARGE_INTEGER streamSize; + streamSize.QuadPart = aInitBufSize; + hr = stream->SetSize(streamSize); + if (FAILED(hr)) { + return hr; + } + + if (aInitBuf) { + ULONG bytesWritten; + hr = stream->Write(aInitBuf, aInitBufSize, &bytesWritten); + if (FAILED(hr)) { + return hr; + } + + if (bytesWritten != aInitBufSize) { + return E_UNEXPECTED; + } + } + } + + // Ensure that the stream is rewound + LARGE_INTEGER streamOffset; + streamOffset.QuadPart = 0LL; + hr = stream->Seek(streamOffset, STREAM_SEEK_SET, nullptr); + if (FAILED(hr)) { + return hr; + } + + stream.forget(aOutStream); + return S_OK; +} + +long CopySerializedProxy(IStream* aInStream, IStream** aOutStream) { + if (!aInStream || !aOutStream) { + return E_INVALIDARG; + } + + *aOutStream = nullptr; + + uint32_t desiredStreamSize = GetOBJREFSize(WrapNotNull(aInStream)); + if (!desiredStreamSize) { + return E_INVALIDARG; + } + + RefPtr<IStream> stream; + HRESULT hr = CreateStream(nullptr, desiredStreamSize, getter_AddRefs(stream)); + if (FAILED(hr)) { + return hr; + } + + ULARGE_INTEGER numBytesToCopy; + numBytesToCopy.QuadPart = desiredStreamSize; + hr = aInStream->CopyTo(stream, numBytesToCopy, nullptr, nullptr); + if (FAILED(hr)) { + return hr; + } + + LARGE_INTEGER seekTo; + seekTo.QuadPart = 0LL; + hr = stream->Seek(seekTo, STREAM_SEEK_SET, nullptr); + if (FAILED(hr)) { + return hr; + } + + stream.forget(aOutStream); + return S_OK; +} + +#if defined(MOZILLA_INTERNAL_API) + +void GUIDToString(REFGUID aGuid, nsAString& aOutString) { + // This buffer length is long enough to hold a GUID string that is formatted + // to include curly braces and dashes. + const int kBufLenWithNul = 39; + aOutString.SetLength(kBufLenWithNul); + int result = StringFromGUID2(aGuid, char16ptr_t(aOutString.BeginWriting()), + kBufLenWithNul); + MOZ_ASSERT(result); + if (result) { + // Truncate the terminator + aOutString.SetLength(result - 1); + } +} + +// Undocumented IIDs that are relevant for diagnostic purposes +static const IID IID_ISCMLocalActivator = { + 0x00000136, + 0x0000, + 0x0000, + {0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46}}; +static const IID IID_IRundown = { + 0x00000134, + 0x0000, + 0x0000, + {0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46}}; +static const IID IID_IRemUnknown = { + 0x00000131, + 0x0000, + 0x0000, + {0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46}}; +static const IID IID_IRemUnknown2 = { + 0x00000143, + 0x0000, + 0x0000, + {0xC0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x46}}; + +struct IIDToLiteralMapEntry { + constexpr IIDToLiteralMapEntry(REFIID aIid, nsLiteralCString&& aStr) + : mIid(aIid), mStr(std::forward<nsLiteralCString>(aStr)) {} + + REFIID mIid; + const nsLiteralCString mStr; +}; + +/** + * Given the name of an interface, the IID_ENTRY macro generates a pair + * containing a reference to the interface ID and a stringified version of + * the interface name. + * + * For example: + * + * {IID_ENTRY(IUnknown)} + * is expanded to: + * {IID_IUnknown, "IUnknown"_ns} + * + */ +// clang-format off +# define IID_ENTRY_STRINGIFY(iface) #iface##_ns +# define IID_ENTRY(iface) IID_##iface, IID_ENTRY_STRINGIFY(iface) +// clang-format on + +// Mapping of selected IIDs to friendly, human readable descriptions for each +// interface. +static constexpr IIDToLiteralMapEntry sIidDiagStrs[] = { + {IID_ENTRY(IUnknown)}, + {IID_IRemUnknown, "cross-apartment IUnknown"_ns}, + {IID_IRundown, "cross-apartment object management"_ns}, + {IID_ISCMLocalActivator, "out-of-process object instantiation"_ns}, + {IID_IRemUnknown2, "cross-apartment IUnknown"_ns}}; + +# undef IID_ENTRY +# undef IID_ENTRY_STRINGIFY + +void DiagnosticNameForIID(REFIID aIid, nsACString& aOutString) { + // If the IID matches something in sIidDiagStrs, output its string. + for (const auto& curEntry : sIidDiagStrs) { + if (curEntry.mIid == aIid) { + aOutString.Assign(curEntry.mStr); + return; + } + } + + // Otherwise just convert the IID to string form and output that. + nsAutoString strIid; + GUIDToString(aIid, strIid); + + aOutString.AssignLiteral("IID "); + AppendUTF16toUTF8(strIid, aOutString); +} + +#else + +void GUIDToString(REFGUID aGuid, + wchar_t (&aOutBuf)[kGuidRegFormatCharLenInclNul]) { + DebugOnly<int> result = + ::StringFromGUID2(aGuid, aOutBuf, ArrayLength(aOutBuf)); + MOZ_ASSERT(result); +} + +#endif // defined(MOZILLA_INTERNAL_API) + +#if defined(ACCESSIBILITY) + +static bool IsVtableIndexFromParentInterface(TYPEATTR* aTypeAttr, + unsigned long aVtableIndex) { + MOZ_ASSERT(aTypeAttr); + + // This is the number of functions declared in this interface (excluding + // parent interfaces). + unsigned int numExclusiveFuncs = aTypeAttr->cFuncs; + + // This is the number of vtable entries (which includes parent interfaces). + // TYPEATTR::cbSizeVft is the entire vtable size in bytes, so we need to + // divide in order to compute the number of entries. + unsigned int numVtblEntries = aTypeAttr->cbSizeVft / sizeof(void*); + + // This is the index of the first entry in the vtable that belongs to this + // interface and not a parent. + unsigned int firstVtblIndex = numVtblEntries - numExclusiveFuncs; + + // If aVtableIndex is less than firstVtblIndex, then we're asking for an + // index that may belong to a parent interface. + return aVtableIndex < firstVtblIndex; +} + +bool IsVtableIndexFromParentInterface(REFIID aInterface, + unsigned long aVtableIndex) { + RefPtr<ITypeInfo> typeInfo; + if (!RegisteredProxy::Find(aInterface, getter_AddRefs(typeInfo))) { + return false; + } + + TYPEATTR* typeAttr = nullptr; + HRESULT hr = typeInfo->GetTypeAttr(&typeAttr); + if (FAILED(hr)) { + return false; + } + + bool result = IsVtableIndexFromParentInterface(typeAttr, aVtableIndex); + + typeInfo->ReleaseTypeAttr(typeAttr); + return result; +} + +# if defined(MOZILLA_INTERNAL_API) + +bool IsCallerExternalProcess() { + MOZ_ASSERT(XRE_IsContentProcess()); + + /** + * CoGetCallerTID() gives us the caller's thread ID when that thread resides + * in a single-threaded apartment. Since our chrome main thread does live + * inside an STA, we will therefore be able to check whether the caller TID + * equals our chrome main thread TID. This enables us to distinguish + * between our chrome thread vs other out-of-process callers. We check for + * S_FALSE to ensure that the caller is a different process from ours, which + * is the only scenario that we care about. + */ + DWORD callerTid; + if (::CoGetCallerTID(&callerTid) != S_FALSE) { + return false; + } + + // Now check whether the caller is our parent process main thread. + const DWORD parentMainTid = + dom::ContentChild::GetSingleton()->GetChromeMainThreadId(); + return callerTid != parentMainTid; +} + +bool IsInterfaceEqualToOrInheritedFrom(REFIID aInterface, REFIID aFrom, + unsigned long aVtableIndexHint) { + if (aInterface == aFrom) { + return true; + } + + // We expect this array to be length 1 but that is not guaranteed by the API. + AutoTArray<RefPtr<ITypeInfo>, 1> typeInfos; + + // Grab aInterface's ITypeInfo so that we may obtain information about its + // inheritance hierarchy. + RefPtr<ITypeInfo> typeInfo; + if (RegisteredProxy::Find(aInterface, getter_AddRefs(typeInfo))) { + typeInfos.AppendElement(std::move(typeInfo)); + } + + /** + * The main loop of this function searches the hierarchy of aInterface's + * parent interfaces, searching for aFrom. + */ + while (!typeInfos.IsEmpty()) { + RefPtr<ITypeInfo> curTypeInfo(typeInfos.PopLastElement()); + + TYPEATTR* typeAttr = nullptr; + HRESULT hr = curTypeInfo->GetTypeAttr(&typeAttr); + if (FAILED(hr)) { + break; + } + + bool isFromParentVtable = + IsVtableIndexFromParentInterface(typeAttr, aVtableIndexHint); + WORD numParentInterfaces = typeAttr->cImplTypes; + + curTypeInfo->ReleaseTypeAttr(typeAttr); + typeAttr = nullptr; + + if (!isFromParentVtable) { + // The vtable index cannot belong to this interface (otherwise the IIDs + // would already have matched and we would have returned true). Since we + // now also know that the vtable index cannot possibly be contained inside + // curTypeInfo's parent interface, there is no point searching any further + // up the hierarchy from here. OTOH we still should check any remaining + // entries that are still in the typeInfos array, so we continue. + continue; + } + + for (WORD i = 0; i < numParentInterfaces; ++i) { + HREFTYPE refCookie; + hr = curTypeInfo->GetRefTypeOfImplType(i, &refCookie); + if (FAILED(hr)) { + continue; + } + + RefPtr<ITypeInfo> nextTypeInfo; + hr = curTypeInfo->GetRefTypeInfo(refCookie, getter_AddRefs(nextTypeInfo)); + if (FAILED(hr)) { + continue; + } + + hr = nextTypeInfo->GetTypeAttr(&typeAttr); + if (FAILED(hr)) { + continue; + } + + IID nextIid = typeAttr->guid; + + nextTypeInfo->ReleaseTypeAttr(typeAttr); + typeAttr = nullptr; + + if (nextIid == aFrom) { + return true; + } + + typeInfos.AppendElement(std::move(nextTypeInfo)); + } + } + + return false; +} + +# endif // defined(MOZILLA_INTERNAL_API) + +#endif // defined(ACCESSIBILITY) + +#if defined(MOZILLA_INTERNAL_API) +bool IsClassThreadAwareInprocServer(REFCLSID aClsid) { + nsAutoString strClsid; + GUIDToString(aClsid, strClsid); + + nsAutoString inprocServerSubkey(u"CLSID\\"_ns); + inprocServerSubkey.Append(strClsid); + inprocServerSubkey.Append(u"\\InprocServer32"_ns); + + // Of the possible values, "Apartment" is the longest, so we'll make this + // buffer large enough to hold that one. + wchar_t threadingModelBuf[ArrayLength(L"Apartment")] = {}; + + DWORD numBytes = sizeof(threadingModelBuf); + LONG result = ::RegGetValueW(HKEY_CLASSES_ROOT, inprocServerSubkey.get(), + L"ThreadingModel", RRF_RT_REG_SZ, nullptr, + threadingModelBuf, &numBytes); + if (result != ERROR_SUCCESS) { + // This will also handle the case where the CLSID is not an inproc server. + return false; + } + + DWORD numChars = numBytes / sizeof(wchar_t); + // numChars includes the null terminator + if (numChars <= 1) { + return false; + } + + nsDependentString threadingModel(threadingModelBuf, numChars - 1); + + // Ensure that the threading model is one of the known values that indicates + // that the class can operate natively (ie, no proxying) inside a MTA. + return threadingModel.LowerCaseEqualsLiteral("both") || + threadingModel.LowerCaseEqualsLiteral("free") || + threadingModel.LowerCaseEqualsLiteral("neutral"); +} +#endif // defined(MOZILLA_INTERNAL_API) + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/Utils.h b/ipc/mscom/Utils.h new file mode 100644 index 0000000000..c36efde3d3 --- /dev/null +++ b/ipc/mscom/Utils.h @@ -0,0 +1,168 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_Utils_h +#define mozilla_mscom_Utils_h + +#if defined(MOZILLA_INTERNAL_API) +# include "nsString.h" +#endif // defined(MOZILLA_INTERNAL_API) + +#include "mozilla/Attributes.h" +#include <guiddef.h> +#include <stdint.h> + +struct IStream; +struct IUnknown; + +namespace mozilla { +namespace mscom { +namespace detail { + +enum class GuidType { + CLSID, + AppID, +}; + +long BuildRegGuidPath(REFGUID aGuid, const GuidType aGuidType, wchar_t* aBuf, + const size_t aBufLen); + +} // namespace detail + +bool IsCOMInitializedOnCurrentThread(); +bool IsCurrentThreadMTA(); +bool IsCurrentThreadExplicitMTA(); +bool IsCurrentThreadImplicitMTA(); +#if defined(MOZILLA_INTERNAL_API) +bool IsCurrentThreadNonMainMTA(); +#endif // defined(MOZILLA_INTERNAL_API) +bool IsProxy(IUnknown* aUnknown); +bool IsValidGUID(REFGUID aCheckGuid); +uintptr_t GetContainingModuleHandle(); + +template <size_t N> +inline long BuildAppidPath(REFGUID aAppId, wchar_t (&aPath)[N]) { + return detail::BuildRegGuidPath(aAppId, detail::GuidType::AppID, aPath, N); +} + +template <size_t N> +inline long BuildClsidPath(REFCLSID aClsid, wchar_t (&aPath)[N]) { + return detail::BuildRegGuidPath(aClsid, detail::GuidType::CLSID, aPath, N); +} + +/** + * Given a buffer, create a new IStream object. + * @param aBuf Buffer containing data to initialize the stream. This parameter + * may be nullptr, causing the stream to be created with aBufLen + * bytes of uninitialized data. + * @param aBufLen Length of data in aBuf, or desired stream size if aBuf is + * nullptr. + * @param aOutStream Outparam to receive the newly created stream. + * @return HRESULT error code. + */ +long CreateStream(const uint8_t* aBuf, const uint32_t aBufLen, + IStream** aOutStream); + +/** + * Creates a deep copy of a proxy contained in a stream. + * @param aInStream Stream containing the proxy to copy. Its seek pointer must + * be positioned to point at the beginning of the proxy data. + * @param aOutStream Outparam to receive the newly created stream. + * @return HRESULT error code. + */ +long CopySerializedProxy(IStream* aInStream, IStream** aOutStream); + +/** + * Length of a stringified GUID as formatted for the registry, i.e. including + * curly-braces and dashes. + */ +constexpr size_t kGuidRegFormatCharLenInclNul = 39; + +#if defined(MOZILLA_INTERNAL_API) +/** + * Checks the registry to see if |aClsid| is a thread-aware in-process server. + * + * In DCOM, an in-process server is a server that is implemented inside a DLL + * that is loaded into the client's process for execution. If |aClsid| declares + * itself to be a local server (that is, a server that resides in another + * process), this function returns false. + * + * For the server to be thread-aware, its registry entry must declare a + * ThreadingModel that is one of "Free", "Both", or "Neutral". If the threading + * model is "Apartment" or some other, invalid value, the class is treated as + * being single-threaded. + * + * NB: This function cannot check CLSIDs that were registered via manifests, + * as unfortunately there is not a documented API available to query for those. + * This should not be an issue for most CLSIDs that Gecko is interested in, as + * we typically instantiate system CLSIDs which are available in the registry. + * + * @param aClsid The CLSID of the COM class to be checked. + * @return true if the class meets the above criteria, otherwise false. + */ +bool IsClassThreadAwareInprocServer(REFCLSID aClsid); + +void GUIDToString(REFGUID aGuid, nsAString& aOutString); + +/** + * Converts an IID to a human-readable string for the purposes of diagnostic + * tools such as the profiler. For some special cases, we output a friendly + * string that describes the purpose of the interface. If no such description + * exists, we simply fall back to outputting the IID as a string formatted by + * GUIDToString(). + */ +void DiagnosticNameForIID(REFIID aIid, nsACString& aOutString); +#else +void GUIDToString(REFGUID aGuid, + wchar_t (&aOutBuf)[kGuidRegFormatCharLenInclNul]); +#endif // defined(MOZILLA_INTERNAL_API) + +#if defined(ACCESSIBILITY) +bool IsVtableIndexFromParentInterface(REFIID aInterface, + unsigned long aVtableIndex); + +# if defined(MOZILLA_INTERNAL_API) +bool IsCallerExternalProcess(); + +bool IsInterfaceEqualToOrInheritedFrom(REFIID aInterface, REFIID aFrom, + unsigned long aVtableIndexHint); +# endif // defined(MOZILLA_INTERNAL_API) + +#endif // defined(ACCESSIBILITY) + +/** + * Execute cleanup code when going out of scope if a condition is met. + * This is useful when, for example, particular cleanup needs to be performed + * whenever a call returns a failure HRESULT. + * Both the condition and cleanup code are provided as functions (usually + * lambdas). + */ +template <typename CondFnT, typename ExeFnT> +class MOZ_RAII ExecuteWhen final { + public: + ExecuteWhen(CondFnT& aCondFn, ExeFnT& aExeFn) + : mCondFn(aCondFn), mExeFn(aExeFn) {} + + ~ExecuteWhen() { + if (mCondFn()) { + mExeFn(); + } + } + + ExecuteWhen(const ExecuteWhen&) = delete; + ExecuteWhen(ExecuteWhen&&) = delete; + ExecuteWhen& operator=(const ExecuteWhen&) = delete; + ExecuteWhen& operator=(ExecuteWhen&&) = delete; + + private: + CondFnT& mCondFn; + ExeFnT& mExeFn; +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_Utils_h diff --git a/ipc/mscom/VTableBuilder.c b/ipc/mscom/VTableBuilder.c new file mode 100644 index 0000000000..48550503c2 --- /dev/null +++ b/ipc/mscom/VTableBuilder.c @@ -0,0 +1,54 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "VTableBuilder.h" + +#include <stdlib.h> + +static HRESULT STDMETHODCALLTYPE QueryInterfaceThunk(IUnknown* aThis, + REFIID aIid, + void** aOutInterface) { + void** table = (void**)aThis; + IUnknown* real = (IUnknown*)table[1]; + return real->lpVtbl->QueryInterface(real, aIid, aOutInterface); +} + +static ULONG STDMETHODCALLTYPE AddRefThunk(IUnknown* aThis) { + void** table = (void**)aThis; + IUnknown* real = (IUnknown*)table[1]; + return real->lpVtbl->AddRef(real); +} + +static ULONG STDMETHODCALLTYPE ReleaseThunk(IUnknown* aThis) { + void** table = (void**)aThis; + IUnknown* real = (IUnknown*)table[1]; + return real->lpVtbl->Release(real); +} + +IUnknown* BuildNullVTable(IUnknown* aUnk, uint32_t aVtblSize) { + void** table; + + if (aVtblSize < 3) { + return NULL; + } + + // We need to allocate slots for two additional pointers: The |lpVtbl| + // pointer, as well as |aUnk| which is needed to get |this| right when + // something calls through one of the IUnknown thunks. + table = calloc(aVtblSize + 2, sizeof(void*)); + + table[0] = &table[2]; // |lpVtbl|, points to the first entry of the vtable + table[1] = aUnk; // |this| + // Now the actual vtable entries for IUnknown + table[2] = &QueryInterfaceThunk; + table[3] = &AddRefThunk; + table[4] = &ReleaseThunk; + // Remaining entries are NULL thanks to calloc zero-initializing everything. + + return (IUnknown*)table; +} + +void DeleteNullVTable(IUnknown* aUnk) { free(aUnk); } diff --git a/ipc/mscom/VTableBuilder.h b/ipc/mscom/VTableBuilder.h new file mode 100644 index 0000000000..765ec51b0b --- /dev/null +++ b/ipc/mscom/VTableBuilder.h @@ -0,0 +1,37 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_VTableBuilder_h +#define mozilla_mscom_VTableBuilder_h + +#include <stdint.h> +#include <unknwn.h> + +#if defined(__cplusplus) +extern "C" { +#endif + +/** + * This function constructs an interface with |aVtblSize| entries, where the + * IUnknown methods are valid and redirect to |aUnk|. Other than the IUnknown + * methods, the resulting interface is not intended to actually be called; the + * remaining vtable entries are null pointers to enforce that. + * + * @param aUnk The IUnknown to which the null vtable will redirect its IUnknown + * methods. + * @param aVtblSize The total size of the vtable, including the IUnknown + * entries (thus this parameter must be >= 3). + * @return The resulting IUnknown, or nullptr on error. + */ +IUnknown* BuildNullVTable(IUnknown* aUnk, uint32_t aVtblSize); + +void DeleteNullVTable(IUnknown* aUnk); + +#if defined(__cplusplus) +} +#endif + +#endif // mozilla_mscom_VTableBuilder_h diff --git a/ipc/mscom/WeakRef.cpp b/ipc/mscom/WeakRef.cpp new file mode 100644 index 0000000000..b15c027153 --- /dev/null +++ b/ipc/mscom/WeakRef.cpp @@ -0,0 +1,225 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#define INITGUID +#include "mozilla/mscom/WeakRef.h" + +#include "mozilla/DebugOnly.h" +#include "mozilla/Mutex.h" +#include "nsThreadUtils.h" +#include "nsWindowsHelpers.h" +#include "nsProxyRelease.h" + +static void InitializeCS(CRITICAL_SECTION& aCS) { + DWORD flags = 0; +#if defined(RELEASE_OR_BETA) + flags |= CRITICAL_SECTION_NO_DEBUG_INFO; +#endif + InitializeCriticalSectionEx(&aCS, 4000, flags); +} + +namespace mozilla { +namespace mscom { + +namespace detail { + +SharedRef::SharedRef(WeakReferenceSupport* aSupport) : mSupport(aSupport) { + ::InitializeCS(mCS); +} + +SharedRef::~SharedRef() { ::DeleteCriticalSection(&mCS); } + +void SharedRef::Lock() { ::EnterCriticalSection(&mCS); } + +void SharedRef::Unlock() { ::LeaveCriticalSection(&mCS); } + +HRESULT +SharedRef::ToStrongRef(IWeakReferenceSource** aOutStrongReference) { + RefPtr<IWeakReferenceSource> strongRef; + + { // Scope for lock + AutoCriticalSection lock(&mCS); + if (!mSupport) { + return E_POINTER; + } + strongRef = mSupport; + } + + strongRef.forget(aOutStrongReference); + return S_OK; +} + +HRESULT +SharedRef::Resolve(REFIID aIid, void** aOutStrongReference) { + RefPtr<WeakReferenceSupport> strongRef; + + { // Scope for lock + AutoCriticalSection lock(&mCS); + if (!mSupport) { + return E_POINTER; + } + strongRef = mSupport; + } + + return strongRef->QueryInterface(aIid, aOutStrongReference); +} + +void SharedRef::Clear() { + AutoCriticalSection lock(&mCS); + mSupport = nullptr; +} + +} // namespace detail + +typedef mozilla::detail::BaseAutoLock<detail::SharedRef&> SharedRefAutoLock; +typedef mozilla::detail::BaseAutoUnlock<detail::SharedRef&> SharedRefAutoUnlock; + +WeakReferenceSupport::WeakReferenceSupport(Flags aFlags) + : mRefCnt(0), mFlags(aFlags) { + mSharedRef = new detail::SharedRef(this); +} + +HRESULT +WeakReferenceSupport::QueryInterface(REFIID riid, void** ppv) { + RefPtr<IUnknown> punk; + if (!ppv) { + return E_INVALIDARG; + } + *ppv = nullptr; + + // Raise the refcount for stabilization purposes during aggregation + StabilizeRefCount stabilize(*this); + + if (riid == IID_IUnknown || riid == IID_IWeakReferenceSource) { + punk = static_cast<IUnknown*>(this); + } else { + HRESULT hr = WeakRefQueryInterface(riid, getter_AddRefs(punk)); + if (FAILED(hr)) { + return hr; + } + } + + if (!punk) { + return E_NOINTERFACE; + } + + punk.forget(ppv); + return S_OK; +} + +WeakReferenceSupport::StabilizeRefCount::StabilizeRefCount( + WeakReferenceSupport& aObject) + : mObject(aObject) { + SharedRefAutoLock lock(*mObject.mSharedRef); + ++mObject.mRefCnt; +} + +WeakReferenceSupport::StabilizeRefCount::~StabilizeRefCount() { + // We directly access these fields instead of calling Release() because we + // want to adjust the ref count without the other side effects (such as + // deleting this if the count drops back to zero, which may happen during + // an initial QI during object creation). + SharedRefAutoLock lock(*mObject.mSharedRef); + --mObject.mRefCnt; +} + +ULONG +WeakReferenceSupport::AddRef() { + SharedRefAutoLock lock(*mSharedRef); + ULONG result = ++mRefCnt; + NS_LOG_ADDREF(this, result, "mscom::WeakReferenceSupport", sizeof(*this)); + return result; +} + +ULONG +WeakReferenceSupport::Release() { + ULONG newRefCnt; + { // Scope for lock + SharedRefAutoLock lock(*mSharedRef); + newRefCnt = --mRefCnt; + if (newRefCnt == 0) { + mSharedRef->Clear(); + } + } + NS_LOG_RELEASE(this, newRefCnt, "mscom::WeakReferenceSupport"); + if (newRefCnt == 0) { + if (mFlags != Flags::eDestroyOnMainThread || NS_IsMainThread()) { + delete this; + } else { + // We need to delete this object on the main thread, but we aren't on the + // main thread right now, so we send a reference to ourselves to the main + // thread to be re-released there. + RefPtr<WeakReferenceSupport> self = this; + NS_ReleaseOnMainThread("WeakReferenceSupport", self.forget()); + } + } + return newRefCnt; +} + +HRESULT +WeakReferenceSupport::GetWeakReference(IWeakReference** aOutWeakRef) { + if (!aOutWeakRef) { + return E_INVALIDARG; + } + + RefPtr<WeakRef> weakRef = MakeAndAddRef<WeakRef>(mSharedRef); + return weakRef->QueryInterface(IID_IWeakReference, (void**)aOutWeakRef); +} + +WeakRef::WeakRef(RefPtr<detail::SharedRef>& aSharedRef) + : mRefCnt(0), mSharedRef(aSharedRef) { + MOZ_ASSERT(aSharedRef); +} + +HRESULT +WeakRef::QueryInterface(REFIID riid, void** ppv) { + IUnknown* punk = nullptr; + if (!ppv) { + return E_INVALIDARG; + } + + if (riid == IID_IUnknown || riid == IID_IWeakReference) { + punk = static_cast<IUnknown*>(this); + } + + *ppv = punk; + if (!punk) { + return E_NOINTERFACE; + } + + punk->AddRef(); + return S_OK; +} + +ULONG +WeakRef::AddRef() { + ULONG result = ++mRefCnt; + NS_LOG_ADDREF(this, result, "mscom::WeakRef", sizeof(*this)); + return result; +} + +ULONG +WeakRef::Release() { + ULONG newRefCnt = --mRefCnt; + NS_LOG_RELEASE(this, newRefCnt, "mscom::WeakRef"); + if (newRefCnt == 0) { + delete this; + } + return newRefCnt; +} + +HRESULT +WeakRef::ToStrongRef(IWeakReferenceSource** aOutStrongReference) { + return mSharedRef->ToStrongRef(aOutStrongReference); +} + +HRESULT +WeakRef::Resolve(REFIID aIid, void** aOutStrongReference) { + return mSharedRef->Resolve(aIid, aOutStrongReference); +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/WeakRef.h b/ipc/mscom/WeakRef.h new file mode 100644 index 0000000000..ba87eecb70 --- /dev/null +++ b/ipc/mscom/WeakRef.h @@ -0,0 +1,142 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_WeakRef_h +#define mozilla_mscom_WeakRef_h + +#include <guiddef.h> +#include <unknwn.h> + +#include "mozilla/Assertions.h" +#include "mozilla/Atomics.h" +#include "mozilla/RefPtr.h" +#include "nsISupportsImpl.h" + +/** + * Thread-safe weak references for COM that works pre-Windows 8 and do not + * require WinRT. + */ + +namespace mozilla { +namespace mscom { + +struct IWeakReferenceSource; +class WeakReferenceSupport; + +namespace detail { + +class SharedRef final { + public: + explicit SharedRef(WeakReferenceSupport* aSupport); + void Lock(); + void Unlock(); + + HRESULT ToStrongRef(IWeakReferenceSource** aOutStringReference); + HRESULT Resolve(REFIID aIid, void** aOutStrongReference); + void Clear(); + + NS_INLINE_DECL_THREADSAFE_REFCOUNTING(SharedRef) + + SharedRef(const SharedRef&) = delete; + SharedRef(SharedRef&&) = delete; + SharedRef& operator=(const SharedRef&) = delete; + SharedRef& operator=(SharedRef&&) = delete; + + private: + ~SharedRef(); + + private: + CRITICAL_SECTION mCS; + WeakReferenceSupport* mSupport; +}; + +} // namespace detail + +// {F841AEFA-064C-49A4-B73D-EBD14A90F012} +DEFINE_GUID(IID_IWeakReference, 0xf841aefa, 0x64c, 0x49a4, 0xb7, 0x3d, 0xeb, + 0xd1, 0x4a, 0x90, 0xf0, 0x12); + +struct IWeakReference : public IUnknown { + virtual STDMETHODIMP ToStrongRef( + IWeakReferenceSource** aOutStrongReference) = 0; + virtual STDMETHODIMP Resolve(REFIID aIid, void** aOutStrongReference) = 0; +}; + +// {87611F0C-9BBB-4F78-9D43-CAC5AD432CA1} +DEFINE_GUID(IID_IWeakReferenceSource, 0x87611f0c, 0x9bbb, 0x4f78, 0x9d, 0x43, + 0xca, 0xc5, 0xad, 0x43, 0x2c, 0xa1); + +struct IWeakReferenceSource : public IUnknown { + virtual STDMETHODIMP GetWeakReference(IWeakReference** aOutWeakRef) = 0; +}; + +class WeakRef; + +class WeakReferenceSupport : public IWeakReferenceSource { + public: + enum class Flags { eNone = 0, eDestroyOnMainThread = 1 }; + + // IUnknown + STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override; + STDMETHODIMP_(ULONG) AddRef() override; + STDMETHODIMP_(ULONG) Release() override; + + // IWeakReferenceSource + STDMETHODIMP GetWeakReference(IWeakReference** aOutWeakRef) override; + + protected: + explicit WeakReferenceSupport(Flags aFlags); + virtual ~WeakReferenceSupport() = default; + + virtual HRESULT WeakRefQueryInterface(REFIID aIid, + IUnknown** aOutInterface) = 0; + + class MOZ_RAII StabilizeRefCount final { + public: + explicit StabilizeRefCount(WeakReferenceSupport& aObject); + ~StabilizeRefCount(); + + StabilizeRefCount(const StabilizeRefCount&) = delete; + StabilizeRefCount(StabilizeRefCount&&) = delete; + StabilizeRefCount& operator=(const StabilizeRefCount&) = delete; + StabilizeRefCount& operator=(StabilizeRefCount&&) = delete; + + private: + WeakReferenceSupport& mObject; + }; + + friend class StabilizeRefCount; + + private: + RefPtr<detail::SharedRef> mSharedRef; + ULONG mRefCnt; + Flags mFlags; +}; + +class WeakRef final : public IWeakReference { + public: + // IUnknown + STDMETHODIMP QueryInterface(REFIID riid, void** ppv) override; + STDMETHODIMP_(ULONG) AddRef() override; + STDMETHODIMP_(ULONG) Release() override; + + // IWeakReference + STDMETHODIMP ToStrongRef(IWeakReferenceSource** aOutStrongReference) override; + STDMETHODIMP Resolve(REFIID aIid, void** aOutStrongReference) override; + + explicit WeakRef(RefPtr<detail::SharedRef>& aSharedRef); + + private: + ~WeakRef() = default; + + Atomic<ULONG> mRefCnt; + RefPtr<detail::SharedRef> mSharedRef; +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_WeakRef_h diff --git a/ipc/mscom/moz.build b/ipc/mscom/moz.build new file mode 100644 index 0000000000..d615bd82d3 --- /dev/null +++ b/ipc/mscom/moz.build @@ -0,0 +1,97 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +EXPORTS.mozilla.mscom += [ + "Aggregation.h", + "AgileReference.h", + "ApartmentRegion.h", + "AsyncInvoker.h", + "COMPtrHolder.h", + "COMWrappers.h", + "EnsureMTA.h", + "Objref.h", + "PassthruProxy.h", + "ProcessRuntime.h", + "ProfilerMarkers.h", + "ProxyStream.h", + "Ptr.h", + "Utils.h", +] + +DIRS += [ + "mozglue", +] + +SOURCES += [ + "VTableBuilder.c", +] + +UNIFIED_SOURCES += [ + "AgileReference.cpp", + "COMWrappers.cpp", + "EnsureMTA.cpp", + "Objref.cpp", + "PassthruProxy.cpp", + "ProcessRuntime.cpp", + "ProfilerMarkers.cpp", + "ProxyStream.cpp", + "RegistrationAnnotator.cpp", + "Utils.cpp", +] + +if CONFIG["ACCESSIBILITY"]: + DIRS += [ + "oop", + ] + + EXPORTS.mozilla.mscom += [ + "ActivationContext.h", + "DispatchForwarder.h", + "FastMarshaler.h", + "IHandlerProvider.h", + "Interceptor.h", + "InterceptorLog.h", + "MainThreadHandoff.h", + "MainThreadInvoker.h", + "Registration.h", + "SpinEvent.h", + "StructStream.h", + "WeakRef.h", + ] + + SOURCES += [ + "Interceptor.cpp", + "MainThreadHandoff.cpp", + "Registration.cpp", + "SpinEvent.cpp", + "WeakRef.cpp", + ] + + UNIFIED_SOURCES += [ + "ActivationContext.cpp", + "DispatchForwarder.cpp", + "FastMarshaler.cpp", + "InterceptorLog.cpp", + "MainThreadInvoker.cpp", + "StructStream.cpp", + ] + +LOCAL_INCLUDES += [ + "/xpcom/base", + "/xpcom/build", +] + +DEFINES["MOZ_MSCOM_REMARSHAL_NO_HANDLER"] = True + +include("/ipc/chromium/chromium-config.mozbuild") + +FINAL_LIBRARY = "xul" + +with Files("**"): + BUG_COMPONENT = ("Core", "IPC: MSCOM") + SCHEDULES.exclusive = ["windows"] + +REQUIRES_UNIFIED_BUILD = True diff --git a/ipc/mscom/mozglue/ActCtxResource.cpp b/ipc/mscom/mozglue/ActCtxResource.cpp new file mode 100644 index 0000000000..778e36c4bc --- /dev/null +++ b/ipc/mscom/mozglue/ActCtxResource.cpp @@ -0,0 +1,239 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ActCtxResource.h" + +#include <string> + +#include "mozilla/GetKnownFolderPath.h" +#include "mozilla/WindowsVersion.h" +#include "nsWindowsHelpers.h" + +namespace mozilla { +namespace mscom { + +#if !defined(HAVE_64BIT_BUILD) + +static bool ReadCOMRegDefaultString(const std::wstring& aRegPath, + std::wstring& aOutBuf) { + aOutBuf.clear(); + + std::wstring fullyQualifiedRegPath; + fullyQualifiedRegPath.append(L"SOFTWARE\\Classes\\"); + fullyQualifiedRegPath.append(aRegPath); + + // Get the required size and type of the registry value. + // We expect either REG_SZ or REG_EXPAND_SZ. + DWORD type; + DWORD bufLen = 0; + LONG result = + ::RegGetValueW(HKEY_LOCAL_MACHINE, fullyQualifiedRegPath.c_str(), nullptr, + RRF_RT_ANY, &type, nullptr, &bufLen); + if (result != ERROR_SUCCESS || (type != REG_SZ && type != REG_EXPAND_SZ)) { + return false; + } + + // Now obtain the value + DWORD flags = type == REG_SZ ? RRF_RT_REG_SZ : RRF_RT_REG_EXPAND_SZ; + + aOutBuf.resize((bufLen + 1) / sizeof(char16_t)); + + result = ::RegGetValueW(HKEY_LOCAL_MACHINE, fullyQualifiedRegPath.c_str(), + nullptr, flags, nullptr, &aOutBuf[0], &bufLen); + if (result != ERROR_SUCCESS) { + aOutBuf.clear(); + return false; + } + + // Truncate terminator + aOutBuf.resize((bufLen + 1) / sizeof(char16_t) - 1); + return true; +} + +static bool IsSystemOleAcc(HANDLE aFile) { + if (aFile == INVALID_HANDLE_VALUE) { + return false; + } + + BY_HANDLE_FILE_INFORMATION info = {}; + if (!::GetFileInformationByHandle(aFile, &info)) { + return false; + } + + // Use FOLDERID_SystemX86 so that Windows doesn't give us a redirected + // system32 if we're a 32-bit process running on a 64-bit OS. This is + // necessary because the values that we are reading from the registry + // are not redirected; they reference SysWOW64 directly. + auto systemPath = GetKnownFolderPath(FOLDERID_SystemX86); + if (!systemPath) { + return false; + } + + std::wstring oleAccPath(systemPath.get()); + oleAccPath.append(L"\\oleacc.dll"); + + nsAutoHandle oleAcc( + ::CreateFileW(oleAccPath.c_str(), GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr)); + + if (oleAcc.get() == INVALID_HANDLE_VALUE) { + return false; + } + + BY_HANDLE_FILE_INFORMATION oleAccInfo = {}; + if (!::GetFileInformationByHandle(oleAcc, &oleAccInfo)) { + return false; + } + + return info.dwVolumeSerialNumber == oleAccInfo.dwVolumeSerialNumber && + info.nFileIndexLow == oleAccInfo.nFileIndexLow && + info.nFileIndexHigh == oleAccInfo.nFileIndexHigh; +} + +static bool IsTypelibPreferred() { + // If IAccessible's Proxy/Stub CLSID is kUniversalMarshalerClsid, then any + // external a11y clients are expecting to use a typelib. + const wchar_t kUniversalMarshalerClsid[] = + L"{00020424-0000-0000-C000-000000000046}"; + + const wchar_t kIAccessiblePSClsidPath[] = + L"Interface\\{618736E0-3C3D-11CF-810C-00AA00389B71}" + L"\\ProxyStubClsid32"; + + std::wstring psClsid; + if (!ReadCOMRegDefaultString(kIAccessiblePSClsidPath, psClsid)) { + return false; + } + + if (psClsid.size() != + sizeof(kUniversalMarshalerClsid) / sizeof(kUniversalMarshalerClsid)[0] - + 1) { + return false; + } + + int index = 0; + while (kUniversalMarshalerClsid[index]) { + if (toupper(psClsid[index]) != kUniversalMarshalerClsid[index]) { + return false; + } + index++; + } + return true; +} + +static bool IsIAccessibleTypelibRegistered() { + // The system default IAccessible typelib is always registered with version + // 1.1, under the neutral locale (LCID 0). + const wchar_t kIAccessibleTypelibRegPath[] = + L"TypeLib\\{1EA4DBF0-3C3B-11CF-810C-00AA00389B71}\\1.1\\0\\win32"; + + std::wstring typelibPath; + if (!ReadCOMRegDefaultString(kIAccessibleTypelibRegPath, typelibPath)) { + return false; + } + + nsAutoHandle libTestFile( + ::CreateFileW(typelibPath.c_str(), GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr)); + + return IsSystemOleAcc(libTestFile); +} + +static bool IsIAccessiblePSRegistered() { + const wchar_t kIAccessiblePSRegPath[] = + L"CLSID\\{03022430-ABC4-11D0-BDE2-00AA001A1953}\\InProcServer32"; + + std::wstring proxyStubPath; + if (!ReadCOMRegDefaultString(kIAccessiblePSRegPath, proxyStubPath)) { + return false; + } + + nsAutoHandle libTestFile( + ::CreateFileW(proxyStubPath.c_str(), GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + nullptr, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, nullptr)); + + return IsSystemOleAcc(libTestFile); +} + +static bool UseIAccessibleProxyStub() { + // If a typelib is preferred then external clients are expecting to use + // typelib marshaling, so we should use that whenever available. + if (IsTypelibPreferred() && IsIAccessibleTypelibRegistered()) { + return false; + } + + // Otherwise we try the proxy/stub + if (IsIAccessiblePSRegistered()) { + return true; + } + + return false; +} + +#endif // !defined(HAVE_64BIT_BUILD) + +#if defined(_MSC_VER) +extern "C" IMAGE_DOS_HEADER __ImageBase; +#endif + +static HMODULE GetContainingModuleHandle() { + HMODULE thisModule = nullptr; +#if defined(_MSC_VER) + thisModule = reinterpret_cast<HMODULE>(&__ImageBase); +#else + if (!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + reinterpret_cast<LPCWSTR>(&GetContainingModuleHandle), + &thisModule)) { + return 0; + } +#endif + return thisModule; +} + +static uint16_t sActCtxResourceId = 0; + +/* static */ +void ActCtxResource::SetAccessibilityResourceId(uint16_t aResourceId) { + sActCtxResourceId = aResourceId; +} + +/* static */ +uint16_t ActCtxResource::GetAccessibilityResourceId() { + return sActCtxResourceId; +} + +static void EnsureAccessibilityResourceId() { + if (!sActCtxResourceId) { +#if defined(HAVE_64BIT_BUILD) + // The manifest for 64-bit Windows is embedded with resource ID 64. + sActCtxResourceId = 64; +#else + // The manifest for 32-bit Windows is embedded with resource ID 32. + // Beginning with Windows 10 Creators Update, 32-bit builds always use the + // 64-bit manifest. Older builds of Windows may or may not require the + // 64-bit manifest: UseIAccessibleProxyStub() determines the course of + // action. + if (mozilla::IsWin10CreatorsUpdateOrLater() || UseIAccessibleProxyStub()) { + sActCtxResourceId = 64; + } else { + sActCtxResourceId = 32; + } +#endif // defined(HAVE_64BIT_BUILD) + } +} + +ActCtxResource ActCtxResource::GetAccessibilityResource() { + ActCtxResource result = {}; + result.mModule = GetContainingModuleHandle(); + EnsureAccessibilityResourceId(); + result.mId = GetAccessibilityResourceId(); + return result; +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/mozglue/ActCtxResource.h b/ipc/mscom/mozglue/ActCtxResource.h new file mode 100644 index 0000000000..4e9c16c8e0 --- /dev/null +++ b/ipc/mscom/mozglue/ActCtxResource.h @@ -0,0 +1,40 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef ACT_CTX_RESOURCE_H +#define ACT_CTX_RESOURCE_H + +#include <stdint.h> +#include <windows.h> +#include "mozilla/Types.h" + +namespace mozilla { +namespace mscom { + +struct ActCtxResource { + uint16_t mId; + HMODULE mModule; + + /** + * Set the resource ID used by GetAccessibilityResource. This is so that + * sandboxed child processes can use a value passed down from the parent. + */ + static MFBT_API void SetAccessibilityResourceId(uint16_t aResourceId); + + /** + * Get the resource ID used by GetAccessibilityResource. + */ + static MFBT_API uint16_t GetAccessibilityResourceId(); + + /** + * @return ActCtxResource of a11y manifest resource to be passed to + * mscom::ActivationContext + */ + static MFBT_API ActCtxResource GetAccessibilityResource(); +}; + +} // namespace mscom +} // namespace mozilla + +#endif diff --git a/ipc/mscom/mozglue/ProcessRuntimeShared.cpp b/ipc/mscom/mozglue/ProcessRuntimeShared.cpp new file mode 100644 index 0000000000..ee4f6b221f --- /dev/null +++ b/ipc/mscom/mozglue/ProcessRuntimeShared.cpp @@ -0,0 +1,31 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/mscom/ProcessRuntimeShared.h" + +#include "mozilla/glue/WinUtils.h" + +// We allow multiple ProcessRuntime instances to exist simultaneously (even +// on separate threads), but only one should be doing the process-wide +// initialization. These variables provide that mutual exclusion. +static mozilla::glue::Win32SRWLock gLock; +static mozilla::mscom::detail::ProcessInitState gProcessInitState = + mozilla::mscom::detail::ProcessInitState::Uninitialized; + +namespace mozilla { +namespace mscom { +namespace detail { + +MFBT_API ProcessInitState& BeginProcessRuntimeInit() { + gLock.LockExclusive(); + return gProcessInitState; +} + +MFBT_API void EndProcessRuntimeInit() { gLock.UnlockExclusive(); } + +} // namespace detail +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/mozglue/ProcessRuntimeShared.h b/ipc/mscom/mozglue/ProcessRuntimeShared.h new file mode 100644 index 0000000000..d5a58a7b92 --- /dev/null +++ b/ipc/mscom/mozglue/ProcessRuntimeShared.h @@ -0,0 +1,55 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_ProcessRuntimeShared_h +#define mozilla_mscom_ProcessRuntimeShared_h + +#include "mozilla/Assertions.h" +#include "mozilla/Attributes.h" +#include "mozilla/Types.h" + +namespace mozilla { +namespace mscom { +namespace detail { + +enum class ProcessInitState : uint32_t { + Uninitialized = 0, + PartialSecurityInitialized, + PartialGlobalOptions, + FullyInitialized, +}; + +MFBT_API ProcessInitState& BeginProcessRuntimeInit(); +MFBT_API void EndProcessRuntimeInit(); + +} // namespace detail + +class MOZ_RAII ProcessInitLock final { + public: + ProcessInitLock() : mInitState(detail::BeginProcessRuntimeInit()) {} + + ~ProcessInitLock() { detail::EndProcessRuntimeInit(); } + + detail::ProcessInitState GetInitState() const { return mInitState; } + + void SetInitState(const detail::ProcessInitState aNewState) { + MOZ_DIAGNOSTIC_ASSERT(aNewState > mInitState); + mInitState = aNewState; + } + + ProcessInitLock(const ProcessInitLock&) = delete; + ProcessInitLock(ProcessInitLock&&) = delete; + ProcessInitLock operator=(const ProcessInitLock&) = delete; + ProcessInitLock operator=(ProcessInitLock&&) = delete; + + private: + detail::ProcessInitState& mInitState; +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_ProcessRuntimeShared_h diff --git a/ipc/mscom/mozglue/moz.build b/ipc/mscom/mozglue/moz.build new file mode 100644 index 0000000000..635b5264cf --- /dev/null +++ b/ipc/mscom/mozglue/moz.build @@ -0,0 +1,19 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +FINAL_LIBRARY = "mozglue" + +EXPORTS.mozilla.mscom += [ + "ActCtxResource.h", + "ProcessRuntimeShared.h", +] + +UNIFIED_SOURCES += [ + "ActCtxResource.cpp", + "ProcessRuntimeShared.cpp", +] + +REQUIRES_UNIFIED_BUILD = True diff --git a/ipc/mscom/oop/Factory.h b/ipc/mscom/oop/Factory.h new file mode 100644 index 0000000000..e95f1d2499 --- /dev/null +++ b/ipc/mscom/oop/Factory.h @@ -0,0 +1,142 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_Factory_h +#define mozilla_mscom_Factory_h + +#if defined(MOZILLA_INTERNAL_API) +# error This code is NOT for internal Gecko use! +#endif // defined(MOZILLA_INTERNAL_API) + +#include <objbase.h> +#include <unknwn.h> + +#include <utility> + +#include "Module.h" +#include "mozilla/Attributes.h" +#include "mozilla/DebugOnly.h" +#include "mozilla/RefPtr.h" +#include "mozilla/StaticPtr.h" + +/* WARNING! The code in this file may be loaded into the address spaces of other + processes! It MUST NOT link against xul.dll or other Gecko binaries! Only + inline code may be included! */ + +namespace mozilla { +namespace mscom { + +template <typename T> +class MOZ_NONHEAP_CLASS Factory : public IClassFactory { + template <typename... Args> + HRESULT DoCreate(Args&&... args) { + MOZ_DIAGNOSTIC_ASSERT(false, "This should not be executed"); + return E_NOTIMPL; + } + + template <typename... Args> + HRESULT DoCreate(HRESULT (*aFnPtr)(IUnknown*, REFIID, void**), + Args&&... args) { + return aFnPtr(std::forward<Args>(args)...); + } + + public: + // IUnknown + STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) override { + if (!aOutInterface) { + return E_INVALIDARG; + } + + if (aIid == IID_IUnknown || aIid == IID_IClassFactory) { + RefPtr<IClassFactory> punk(this); + punk.forget(aOutInterface); + return S_OK; + } + + *aOutInterface = nullptr; + + return E_NOINTERFACE; + } + + STDMETHODIMP_(ULONG) AddRef() override { + Module::Lock(); + return 2; + } + + STDMETHODIMP_(ULONG) Release() override { + Module::Unlock(); + return 1; + } + + // IClassFactory + STDMETHODIMP CreateInstance(IUnknown* aOuter, REFIID aIid, + void** aOutInterface) override { + return DoCreate(&T::Create, aOuter, aIid, aOutInterface); + } + + STDMETHODIMP LockServer(BOOL aLock) override { + if (aLock) { + Module::Lock(); + } else { + Module::Unlock(); + } + return S_OK; + } +}; + +template <typename T> +class MOZ_NONHEAP_CLASS SingletonFactory : public Factory<T> { + public: + STDMETHODIMP CreateInstance(IUnknown* aOuter, REFIID aIid, + void** aOutInterface) override { + if (aOuter || !aOutInterface) { + return E_INVALIDARG; + } + + RefPtr<T> obj(sInstance); + if (!obj) { + obj = GetOrCreateSingleton(); + } + + return obj->QueryInterface(aIid, aOutInterface); + } + + RefPtr<T> GetOrCreateSingleton() { + if (!sInstance) { + RefPtr<T> object; + if (FAILED(T::Create(getter_AddRefs(object)))) { + return nullptr; + } + + sInstance = object.forget(); + } + + return sInstance; + } + + RefPtr<T> GetSingleton() { return sInstance; } + + void ClearSingleton() { + if (!sInstance) { + return; + } + + DebugOnly<HRESULT> hr = ::CoDisconnectObject(sInstance.get(), 0); + MOZ_ASSERT(SUCCEEDED(hr)); + sInstance = nullptr; + } + + private: + static StaticRefPtr<T> sInstance; +}; + +template <typename T> +StaticRefPtr<T> SingletonFactory<T>::sInstance; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_Factory_h diff --git a/ipc/mscom/oop/Handler.cpp b/ipc/mscom/oop/Handler.cpp new file mode 100644 index 0000000000..1c40e5300f --- /dev/null +++ b/ipc/mscom/oop/Handler.cpp @@ -0,0 +1,281 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "Handler.h" +#include "Module.h" + +#include "mozilla/ArrayUtils.h" +#include "mozilla/Assertions.h" +#include "mozilla/mscom/Objref.h" +#include "nsWindowsHelpers.h" + +#include <objbase.h> +#include <shlwapi.h> +#include <string.h> + +/* WARNING! The code in this file may be loaded into the address spaces of other + processes! It MUST NOT link against xul.dll or other Gecko binaries! Only + inline code may be included! */ + +namespace mozilla { +namespace mscom { + +Handler::Handler(IUnknown* aOuter, HRESULT* aResult) + : mRefCnt(0), mOuter(aOuter), mUnmarshal(nullptr), mHasPayload(false) { + MOZ_ASSERT(aResult); + + if (!aOuter) { + *aResult = E_INVALIDARG; + return; + } + + StabilizedRefCount<ULONG> stabilizer(mRefCnt); + + *aResult = + ::CoGetStdMarshalEx(aOuter, SMEXF_HANDLER, getter_AddRefs(mInnerUnk)); + if (FAILED(*aResult)) { + return; + } + + *aResult = mInnerUnk->QueryInterface(IID_IMarshal, (void**)&mUnmarshal); + if (FAILED(*aResult)) { + return; + } + + // mUnmarshal is a weak ref + mUnmarshal->Release(); +} + +HRESULT +Handler::InternalQueryInterface(REFIID riid, void** ppv) { + if (!ppv) { + return E_INVALIDARG; + } + + if (riid == IID_IUnknown) { + RefPtr<IUnknown> punk(static_cast<IUnknown*>(&mInternalUnknown)); + punk.forget(ppv); + return S_OK; + } + + if (riid == IID_IMarshal) { + RefPtr<IMarshal> ptr(this); + ptr.forget(ppv); + return S_OK; + } + + // Try the handler implementation + HRESULT hr = QueryHandlerInterface(mInnerUnk, riid, ppv); + if (hr == S_FALSE) { + // The handler knows this interface is not available, so don't bother + // asking the proxy. + return E_NOINTERFACE; + } + if (hr != E_NOINTERFACE) { + return hr; + } + + // Now forward to the marshaler's inner + return mInnerUnk->QueryInterface(riid, ppv); +} + +ULONG +Handler::InternalAddRef() { + if (!mRefCnt) { + Module::Lock(); + } + return ++mRefCnt; +} + +ULONG +Handler::InternalRelease() { + ULONG newRefCnt = --mRefCnt; + if (newRefCnt == 0) { + delete this; + Module::Unlock(); + } + return newRefCnt; +} + +HRESULT +Handler::GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, CLSID* pCid) { + return mUnmarshal->GetUnmarshalClass(MarshalAs(riid), pv, dwDestContext, + pvDestContext, mshlflags, pCid); +} + +HRESULT +Handler::GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, DWORD* pSize) { + if (!pSize) { + return E_INVALIDARG; + } + + *pSize = 0; + + RefPtr<IUnknown> unkToMarshal; + HRESULT hr; + + REFIID marshalAs = MarshalAs(riid); + if (marshalAs == riid) { + unkToMarshal = static_cast<IUnknown*>(pv); + } else { + hr = mInnerUnk->QueryInterface(marshalAs, getter_AddRefs(unkToMarshal)); + if (FAILED(hr)) { + return hr; + } + } + + // We do not necessarily want to use the pv that COM is giving us; we may want + // to marshal a different proxy that is more appropriate to what we're + // wrapping... + hr = mUnmarshal->GetMarshalSizeMax(marshalAs, unkToMarshal.get(), + dwDestContext, pvDestContext, mshlflags, + pSize); + +#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) + return hr; +#else + if (FAILED(hr)) { + return hr; + } + + if (!HasPayload()) { + return S_OK; + } + + DWORD payloadSize = 0; + hr = GetHandlerPayloadSize(marshalAs, &payloadSize); + if (FAILED(hr)) { + return hr; + } + + *pSize += payloadSize; + return S_OK; +#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) +} + +HRESULT +Handler::GetMarshalInterface(REFIID aMarshalAsIid, NotNull<IUnknown*> aProxy, + NotNull<IID*> aOutIid, + NotNull<IUnknown**> aOutUnk) { + *aOutIid = aMarshalAsIid; + return aProxy->QueryInterface( + aMarshalAsIid, + reinterpret_cast<void**>(static_cast<IUnknown**>(aOutUnk))); +} + +HRESULT +Handler::MarshalInterface(IStream* pStm, REFIID riid, void* pv, + DWORD dwDestContext, void* pvDestContext, + DWORD mshlflags) { + // We do not necessarily want to use the pv that COM is giving us; we may want + // to marshal a different proxy that is more appropriate to what we're + // wrapping... + RefPtr<IUnknown> unkToMarshal; + HRESULT hr; + +#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) + LARGE_INTEGER seekTo; + seekTo.QuadPart = 0; + + ULARGE_INTEGER objrefPos; + + // Save the current position as it points to the location where the OBJREF + // will be written. + hr = pStm->Seek(seekTo, STREAM_SEEK_CUR, &objrefPos); + if (FAILED(hr)) { + return hr; + } +#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) + + REFIID marshalAs = MarshalAs(riid); + IID marshalOutAs; + + hr = GetMarshalInterface( + marshalAs, WrapNotNull<IUnknown*>(mInnerUnk), WrapNotNull(&marshalOutAs), + WrapNotNull<IUnknown**>(getter_AddRefs(unkToMarshal))); + if (FAILED(hr)) { + return hr; + } + + hr = mUnmarshal->MarshalInterface(pStm, marshalAs, unkToMarshal.get(), + dwDestContext, pvDestContext, mshlflags); + if (FAILED(hr)) { + return hr; + } + +#if defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) + // Obtain the current stream position which is the end of the OBJREF + ULARGE_INTEGER endPos; + hr = pStm->Seek(seekTo, STREAM_SEEK_CUR, &endPos); + if (FAILED(hr)) { + return hr; + } + + // Now strip out the handler. + if (!StripHandlerFromOBJREF(WrapNotNull(pStm), objrefPos.QuadPart, + endPos.QuadPart)) { + return E_FAIL; + } + + // Fix the IID + if (!SetIID(WrapNotNull(pStm), objrefPos.QuadPart, marshalOutAs)) { + return E_FAIL; + } + + return S_OK; +#else + if (!HasPayload()) { + return S_OK; + } + + // Unfortunately when COM re-marshals a proxy that prevouisly had a payload, + // we must re-serialize it. + return WriteHandlerPayload(pStm, marshalAs); +#endif // defined(MOZ_MSCOM_REMARSHAL_NO_HANDLER) +} + +HRESULT +Handler::UnmarshalInterface(IStream* pStm, REFIID riid, void** ppv) { + REFIID unmarshalAs = MarshalAs(riid); + HRESULT hr = mUnmarshal->UnmarshalInterface(pStm, unmarshalAs, ppv); + if (FAILED(hr)) { + return hr; + } + + // This method may be called on the same object multiple times (as new + // interfaces are queried off the proxy). Not all interfaces will necessarily + // refresh the payload, so we set mHasPayload using OR to reflect that fact. + // (Otherwise mHasPayload could be cleared and the handler would think that + // it doesn't have a payload even though it actually does). + mHasPayload |= (ReadHandlerPayload(pStm, unmarshalAs) == S_OK); + + return hr; +} + +HRESULT +Handler::ReleaseMarshalData(IStream* pStm) { + return mUnmarshal->ReleaseMarshalData(pStm); +} + +HRESULT +Handler::DisconnectObject(DWORD dwReserved) { + return mUnmarshal->DisconnectObject(dwReserved); +} + +HRESULT +Handler::Unregister(REFCLSID aClsid) { return Module::Deregister(aClsid); } + +HRESULT +Handler::Register(REFCLSID aClsid, const bool aMsixContainer) { + return Module::Register(aClsid, Module::ThreadingModel::DedicatedUiThreadOnly, + Module::ClassType::InprocHandler, nullptr, + aMsixContainer); +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/oop/Handler.h b/ipc/mscom/oop/Handler.h new file mode 100644 index 0000000000..b22c925c22 --- /dev/null +++ b/ipc/mscom/oop/Handler.h @@ -0,0 +1,134 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_Handler_h +#define mozilla_mscom_Handler_h + +#if defined(MOZILLA_INTERNAL_API) +# error This code is NOT for internal Gecko use! +#endif // defined(MOZILLA_INTERNAL_API) + +#include <objidl.h> + +#include "mozilla/mscom/Aggregation.h" +#include "mozilla/NotNull.h" +#include "mozilla/RefPtr.h" + +/* WARNING! The code in this file may be loaded into the address spaces of other + processes! It MUST NOT link against xul.dll or other Gecko binaries! Only + inline code may be included! */ + +namespace mozilla { +namespace mscom { + +class Handler : public IMarshal { + public: + // IMarshal + STDMETHODIMP GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + CLSID* pCid) override; + STDMETHODIMP GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext, + void* pvDestContext, DWORD mshlflags, + DWORD* pSize) override; + STDMETHODIMP MarshalInterface(IStream* pStm, REFIID riid, void* pv, + DWORD dwDestContext, void* pvDestContext, + DWORD mshlflags) override; + STDMETHODIMP UnmarshalInterface(IStream* pStm, REFIID riid, + void** ppv) override; + STDMETHODIMP ReleaseMarshalData(IStream* pStm) override; + STDMETHODIMP DisconnectObject(DWORD dwReserved) override; + + /** + * This method allows the handler to return its own interfaces that override + * those interfaces that are exposed by the underlying COM proxy. + * @param aProxyUnknown is the IUnknown of the underlying COM proxy. This is + * provided to give the handler implementation an + * opportunity to acquire interfaces to the underlying + * remote object, if needed. + * @param aIid Interface requested, similar to IUnknown::QueryInterface + * @param aOutInterface Outparam for the resulting interface to return to the + * client. + * @return The usual HRESULT codes similarly to IUnknown::QueryInterface. + * If E_NOINTERFACE is returned, the proxy will be queried. + * If the handler is certain that this interface is not available, + * it can return S_FALSE to avoid querying the proxy. This will be + * translated to E_NOINTERFACE before it is returned to the client. + */ + virtual HRESULT QueryHandlerInterface(IUnknown* aProxyUnknown, REFIID aIid, + void** aOutInterface) = 0; + /** + * Called when the implementer should deserialize data in aStream. + * @return S_OK on success; + * S_FALSE if the deserialization was successful but there was no + * data; HRESULT error code otherwise. + */ + virtual HRESULT ReadHandlerPayload(IStream* aStream, REFIID aIid) { + return S_FALSE; + } + + /** + * Unfortunately when COM marshals a proxy, it doesn't implicitly marshal + * the payload that was originally sent with the proxy. We must implement + * that code in the handler in order to make this happen. + */ + + /** + * This function allows the implementer to substitute a different interface + * for marshaling than the one that COM is intending to marshal. For example, + * the implementer might want to marshal a proxy for an interface that is + * derived from the requested interface. + * + * The default implementation is the identity function. + */ + virtual REFIID MarshalAs(REFIID aRequestedIid) { return aRequestedIid; } + + virtual HRESULT GetMarshalInterface(REFIID aMarshalAsIid, + NotNull<IUnknown*> aProxy, + NotNull<IID*> aOutIid, + NotNull<IUnknown**> aOutUnk); + + /** + * Called when the implementer must provide the size of the payload. + */ + virtual HRESULT GetHandlerPayloadSize(REFIID aIid, DWORD* aOutPayloadSize) { + if (!aOutPayloadSize) { + return E_INVALIDARG; + } + *aOutPayloadSize = 0; + return S_OK; + } + + /** + * Called when the implementer should serialize the payload data into aStream. + */ + virtual HRESULT WriteHandlerPayload(IStream* aStream, REFIID aIid) { + return S_OK; + } + + IUnknown* GetProxy() const { return mInnerUnk; } + + static HRESULT Register(REFCLSID aClsid, const bool aMsixContainer = false); + static HRESULT Unregister(REFCLSID aClsid); + + protected: + Handler(IUnknown* aOuter, HRESULT* aResult); + virtual ~Handler() {} + bool HasPayload() const { return mHasPayload; } + IUnknown* GetOuter() const { return mOuter; } + + private: + ULONG mRefCnt; + IUnknown* mOuter; + RefPtr<IUnknown> mInnerUnk; + IMarshal* mUnmarshal; // WEAK + bool mHasPayload; + DECLARE_AGGREGATABLE(Handler); +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_Handler_h diff --git a/ipc/mscom/oop/Module.cpp b/ipc/mscom/oop/Module.cpp new file mode 100644 index 0000000000..88ebf91531 --- /dev/null +++ b/ipc/mscom/oop/Module.cpp @@ -0,0 +1,292 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "Module.h" + +#include <stdlib.h> + +#include <ktmw32.h> +#include <memory.h> +#include <rpc.h> + +#include "mozilla/ArrayUtils.h" +#include "mozilla/Assertions.h" +#include "mozilla/mscom/Utils.h" +#include "mozilla/Range.h" +#include "nsWindowsHelpers.h" + +template <size_t N> +static const mozilla::Range<const wchar_t> LiteralToRange( + const wchar_t (&aArg)[N]) { + return mozilla::Range(aArg, N); +} + +namespace mozilla { +namespace mscom { + +ULONG Module::sRefCount = 0; + +static const wchar_t* SubkeyNameFromClassType( + const Module::ClassType aClassType) { + switch (aClassType) { + case Module::ClassType::InprocServer: + return L"InprocServer32"; + case Module::ClassType::InprocHandler: + return L"InprocHandler32"; + default: + MOZ_CRASH("Unknown ClassType"); + return nullptr; + } +} + +static const Range<const wchar_t> ThreadingModelAsString( + const Module::ThreadingModel aThreadingModel) { + switch (aThreadingModel) { + case Module::ThreadingModel::DedicatedUiThreadOnly: + return LiteralToRange(L"Apartment"); + case Module::ThreadingModel::MultiThreadedApartmentOnly: + return LiteralToRange(L"Free"); + case Module::ThreadingModel::DedicatedUiThreadXorMultiThreadedApartment: + return LiteralToRange(L"Both"); + case Module::ThreadingModel::AllThreadsAllApartments: + return LiteralToRange(L"Neutral"); + default: + MOZ_CRASH("Unknown ThreadingModel"); + return Range<const wchar_t>(); + } +} + +/* static */ +HRESULT Module::Register(const CLSID* const* aClsids, const size_t aNumClsids, + const ThreadingModel aThreadingModel, + const ClassType aClassType, const GUID* const aAppId, + const bool aMsixContainer) { + MOZ_ASSERT(aClsids && aNumClsids); + if (!aClsids || !aNumClsids) { + return E_INVALIDARG; + } + MOZ_ASSERT(!aAppId || !aMsixContainer, + "aAppId isn't valid in an MSIX container"); + + const wchar_t* inprocName = SubkeyNameFromClassType(aClassType); + + const Range<const wchar_t> threadingModelStr = + ThreadingModelAsString(aThreadingModel); + const DWORD threadingModelStrLenBytesInclNul = + threadingModelStr.length() * sizeof(wchar_t); + + wchar_t strAppId[kGuidRegFormatCharLenInclNul] = {}; + if (aAppId) { + GUIDToString(*aAppId, strAppId); + } + + // Obtain the full path to this DLL + HMODULE thisModule; + if (!::GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + reinterpret_cast<LPCWSTR>(&Module::CanUnload), + &thisModule)) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + + wchar_t absThisModulePath[MAX_PATH + 1] = {}; + DWORD actualPathLenCharsExclNul = ::GetModuleFileNameW( + thisModule, absThisModulePath, ArrayLength(absThisModulePath)); + if (!actualPathLenCharsExclNul || + actualPathLenCharsExclNul == ArrayLength(absThisModulePath)) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + const DWORD actualPathLenBytesInclNul = + (actualPathLenCharsExclNul + 1) * sizeof(wchar_t); + + nsAutoHandle txn; + // RegCreateKeyTransacted doesn't work in MSIX containers. + if (!aMsixContainer) { + // Use the name of this DLL as the name of the transaction + wchar_t txnName[_MAX_FNAME] = {}; + if (_wsplitpath_s(absThisModulePath, nullptr, 0, nullptr, 0, txnName, + ArrayLength(txnName), nullptr, 0)) { + return HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND); + } + + // Manipulate the registry using a transaction so that any failures are + // rolled back. + txn.own(::CreateTransaction(nullptr, nullptr, TRANSACTION_DO_NOT_PROMOTE, 0, + 0, 0, txnName)); + if (txn.get() == INVALID_HANDLE_VALUE) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + } + + HRESULT hr; + LSTATUS status; + + // A single DLL may serve multiple components. For each CLSID, we register + // this DLL as its server and, when an AppId is specified, set up a reference + // from the CLSID to the specified AppId. + for (size_t idx = 0; idx < aNumClsids; ++idx) { + if (!aClsids[idx]) { + return E_INVALIDARG; + } + + wchar_t clsidKeyPath[256]; + hr = BuildClsidPath(*aClsids[idx], clsidKeyPath); + if (FAILED(hr)) { + return hr; + } + + // Create the CLSID key + HKEY rawClsidKey; + // Subtle: If aMsixContainer is true, as well as calling a different + // function, we also use HKEY_CURRENT_USER. When false, we use + // HKEY_LOCAL_MACHINE. + if (aMsixContainer) { + status = ::RegCreateKeyExW(HKEY_CURRENT_USER, clsidKeyPath, 0, nullptr, + REG_OPTION_NON_VOLATILE, KEY_ALL_ACCESS, + nullptr, &rawClsidKey, nullptr); + } else { + status = ::RegCreateKeyTransactedW( + HKEY_LOCAL_MACHINE, clsidKeyPath, 0, nullptr, REG_OPTION_NON_VOLATILE, + KEY_ALL_ACCESS, nullptr, &rawClsidKey, nullptr, txn, nullptr); + } + if (status != ERROR_SUCCESS) { + return HRESULT_FROM_WIN32(status); + } + nsAutoRegKey clsidKey(rawClsidKey); + + if (aAppId) { + // This value associates the registered CLSID with the specified AppID + status = ::RegSetValueExW(clsidKey, L"AppID", 0, REG_SZ, + reinterpret_cast<const BYTE*>(strAppId), + ArrayLength(strAppId) * sizeof(wchar_t)); + if (status != ERROR_SUCCESS) { + return HRESULT_FROM_WIN32(status); + } + } + + HKEY rawInprocKey; + if (aMsixContainer) { + status = ::RegCreateKeyExW(clsidKey, inprocName, 0, nullptr, + REG_OPTION_NON_VOLATILE, KEY_ALL_ACCESS, + nullptr, &rawInprocKey, nullptr); + } else { + status = ::RegCreateKeyTransactedW( + clsidKey, inprocName, 0, nullptr, REG_OPTION_NON_VOLATILE, + KEY_ALL_ACCESS, nullptr, &rawInprocKey, nullptr, txn, nullptr); + } + if (status != ERROR_SUCCESS) { + return HRESULT_FROM_WIN32(status); + } + nsAutoRegKey inprocKey(rawInprocKey); + + // Set the component's path to this DLL + status = ::RegSetValueExW(inprocKey, nullptr, 0, REG_EXPAND_SZ, + reinterpret_cast<const BYTE*>(absThisModulePath), + actualPathLenBytesInclNul); + if (status != ERROR_SUCCESS) { + return HRESULT_FROM_WIN32(status); + } + + status = ::RegSetValueExW( + inprocKey, L"ThreadingModel", 0, REG_SZ, + reinterpret_cast<const BYTE*>(threadingModelStr.begin().get()), + threadingModelStrLenBytesInclNul); + if (status != ERROR_SUCCESS) { + return HRESULT_FROM_WIN32(status); + } + } + + if (aAppId) { + // When specified, we must also create a key for the AppID. + wchar_t appidKeyPath[256]; + hr = BuildAppidPath(*aAppId, appidKeyPath); + if (FAILED(hr)) { + return hr; + } + + HKEY rawAppidKey; + status = ::RegCreateKeyTransactedW( + HKEY_LOCAL_MACHINE, appidKeyPath, 0, nullptr, REG_OPTION_NON_VOLATILE, + KEY_ALL_ACCESS, nullptr, &rawAppidKey, nullptr, txn, nullptr); + if (status != ERROR_SUCCESS) { + return HRESULT_FROM_WIN32(status); + } + nsAutoRegKey appidKey(rawAppidKey); + + // Setting DllSurrogate to a null or empty string indicates to Windows that + // we want to use the default surrogate (i.e. dllhost.exe) to load our DLL. + status = + ::RegSetValueExW(appidKey, L"DllSurrogate", 0, REG_SZ, + reinterpret_cast<const BYTE*>(L""), sizeof(wchar_t)); + if (status != ERROR_SUCCESS) { + return HRESULT_FROM_WIN32(status); + } + } + + if (!aMsixContainer) { + if (!::CommitTransaction(txn)) { + return HRESULT_FROM_WIN32(::GetLastError()); + } + } + + return S_OK; +} + +/** + * Unfortunately the registry transaction APIs are not as well-developed for + * deleting things as they are for creating them. We just use RegDeleteTree + * for the implementation of this method. + */ +HRESULT Module::Deregister(const CLSID* const* aClsids, const size_t aNumClsids, + const GUID* const aAppId) { + MOZ_ASSERT(aClsids && aNumClsids); + if (!aClsids || !aNumClsids) { + return E_INVALIDARG; + } + + HRESULT hr; + LSTATUS status; + + // Delete the key for each CLSID. This will also delete any references to + // the AppId. + for (size_t idx = 0; idx < aNumClsids; ++idx) { + if (!aClsids[idx]) { + return E_INVALIDARG; + } + + wchar_t clsidKeyPath[256]; + hr = BuildClsidPath(*aClsids[idx], clsidKeyPath); + if (FAILED(hr)) { + return hr; + } + + status = ::RegDeleteTreeW(HKEY_LOCAL_MACHINE, clsidKeyPath); + // We allow the deletion to succeed if the key was already gone + if (status != ERROR_SUCCESS && status != ERROR_FILE_NOT_FOUND) { + return HRESULT_FROM_WIN32(status); + } + } + + // Now delete the AppID key, if desired. + if (aAppId) { + wchar_t appidKeyPath[256]; + hr = BuildAppidPath(*aAppId, appidKeyPath); + if (FAILED(hr)) { + return hr; + } + + status = ::RegDeleteTreeW(HKEY_LOCAL_MACHINE, appidKeyPath); + // We allow the deletion to succeed if the key was already gone + if (status != ERROR_SUCCESS && status != ERROR_FILE_NOT_FOUND) { + return HRESULT_FROM_WIN32(status); + } + } + + return S_OK; +} + +} // namespace mscom +} // namespace mozilla diff --git a/ipc/mscom/oop/Module.h b/ipc/mscom/oop/Module.h new file mode 100644 index 0000000000..1fa4b31c49 --- /dev/null +++ b/ipc/mscom/oop/Module.h @@ -0,0 +1,98 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=8 sts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_mscom_Module_h +#define mozilla_mscom_Module_h + +#if defined(MOZILLA_INTERNAL_API) +# error This code is NOT for internal Gecko use! +#endif // defined(MOZILLA_INTERNAL_API) + +#include <objbase.h> + +/* WARNING! The code in this file may be loaded into the address spaces of other + processes! It MUST NOT link against xul.dll or other Gecko binaries! Only + inline code may be included! */ + +namespace mozilla { +namespace mscom { + +class Module { + public: + static HRESULT CanUnload() { return sRefCount == 0 ? S_OK : S_FALSE; } + + static void Lock() { ++sRefCount; } + static void Unlock() { --sRefCount; } + + enum class ThreadingModel { + DedicatedUiThreadOnly, + MultiThreadedApartmentOnly, + DedicatedUiThreadXorMultiThreadedApartment, + AllThreadsAllApartments, + }; + + enum class ClassType { + InprocServer, + InprocHandler, + }; + + /** + * In the Register functions, the aMsixContainer parameter specifies whether + * this registration is being performed inside an MSIX container. If true, + * the CLSID is registered in HKCU and a registry transaction is not used, as + * registry transactions don't work in an MSIX container. If false (the + * default), the CLSID is registered in HKLM and a registry transaction is + * used so that any failures roll back the entire operation. Specifying aAppId + * is only valid when aMsixContainer is false. + */ + static HRESULT Register(REFCLSID aClsid, const ThreadingModel aThreadingModel, + const ClassType aClassType = ClassType::InprocServer, + const GUID* const aAppId = nullptr, + const bool aMsixContainer = false) { + const CLSID* clsidArray[] = {&aClsid}; + return Register(clsidArray, aThreadingModel, aClassType, aAppId, + aMsixContainer); + } + + template <size_t N> + static HRESULT Register(const CLSID* (&aClsids)[N], + const ThreadingModel aThreadingModel, + const ClassType aClassType = ClassType::InprocServer, + const GUID* const aAppId = nullptr, + const bool aMsixContainer = false) { + return Register(aClsids, N, aThreadingModel, aClassType, aAppId, + aMsixContainer); + } + + static HRESULT Deregister(REFCLSID aClsid, + const GUID* const aAppId = nullptr) { + const CLSID* clsidArray[] = {&aClsid}; + return Deregister(clsidArray, aAppId); + } + + template <size_t N> + static HRESULT Deregister(const CLSID* (&aClsids)[N], + const GUID* const aAppId = nullptr) { + return Deregister(aClsids, N, aAppId); + } + + private: + static HRESULT Register(const CLSID* const* aClsids, const size_t aNumClsids, + const ThreadingModel aThreadingModel, + const ClassType aClassType, const GUID* const aAppId, + const bool aMsixContainer); + + static HRESULT Deregister(const CLSID* const* aClsids, + const size_t aNumClsids, const GUID* const aAppId); + + private: + static ULONG sRefCount; +}; + +} // namespace mscom +} // namespace mozilla + +#endif // mozilla_mscom_Module_h diff --git a/ipc/mscom/oop/moz.build b/ipc/mscom/oop/moz.build new file mode 100644 index 0000000000..5ead655e6f --- /dev/null +++ b/ipc/mscom/oop/moz.build @@ -0,0 +1,43 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +Library("mscom_oop") + +SOURCES += [ + "../ActivationContext.cpp", + "../COMWrappers.cpp", + "../Objref.cpp", + "../Registration.cpp", + "../StructStream.cpp", + "../Utils.cpp", +] + +UNIFIED_SOURCES += [ + "Handler.cpp", + "Module.cpp", +] + +OS_LIBS += [ + "ktmw32", + "ole32", + "oleaut32", + "shlwapi", +] + +LIBRARY_DEFINES["UNICODE"] = True +LIBRARY_DEFINES["_UNICODE"] = True +LIBRARY_DEFINES["MOZ_NO_MOZALLOC"] = True +LIBRARY_DEFINES["MOZ_MSCOM_REMARSHAL_NO_HANDLER"] = True + +DisableStlWrapping() +NO_EXPAND_LIBS = True +FORCE_STATIC_LIB = True + +# This DLL may be loaded into other processes, so we need static libs for +# Windows 7 and Windows 8. +USE_STATIC_LIBS = True + +REQUIRES_UNIFIED_BUILD = True |