summaryrefslogtreecommitdiffstats
path: root/toolkit/xre/dllservices/mozglue/WindowsDllServices.h
blob: 70ba5e98b0e5fa1a32acf44b7ac37a2f1a0061d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=8 sts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

#ifndef mozilla_glue_WindowsDllServices_h
#define mozilla_glue_WindowsDllServices_h

#include <utility>

#include "mozilla/Assertions.h"
#include "mozilla/Authenticode.h"
#include "mozilla/LoaderAPIInterfaces.h"
#include "mozilla/UniquePtr.h"
#include "mozilla/Vector.h"
#include "mozilla/WinHeaderOnlyUtils.h"
#include "mozilla/WindowsDllBlocklist.h"
#include "mozilla/mozalloc.h"

#if defined(MOZILLA_INTERNAL_API)
#  include "MainThreadUtils.h"
#  include "nsISupportsImpl.h"
#  include "nsString.h"
#  include "nsThreadUtils.h"
#  include "prthread.h"
#  include "mozilla/SchedulerGroup.h"
#endif  // defined(MOZILLA_INTERNAL_API)

// For PCUNICODE_STRING
#include <winternl.h>

namespace mozilla {
namespace glue {
namespace detail {

class DllServicesBase : public Authenticode {
 public:
  /**
   * WARNING: This method is called from within an unsafe context that holds
   *          multiple locks inside the Windows loader. The only thing that
   *          this function should be used for is dispatching the event to our
   *          event loop so that it may be handled in a safe context.
   */
  virtual void DispatchDllLoadNotification(ModuleLoadInfo&& aModLoadInfo) = 0;

  /**
   * This function accepts module load events to be processed later for
   * the untrusted modules telemetry ping.
   *
   * WARNING: This method is run from within the Windows loader and should
   *          only perform trivial, loader-friendly operations.
   */
  virtual void DispatchModuleLoadBacklogNotification(
      ModuleLoadInfoVec&& aEvents) = 0;

  void SetAuthenticodeImpl(Authenticode* aAuthenticode) {
    mAuthenticode = aAuthenticode;
  }

  void SetWinLauncherServices(const nt::WinLauncherServices& aWinLauncher) {
    mWinLauncher = aWinLauncher;
  }

  template <typename... Args>
  LauncherVoidResultWithLineInfo InitDllBlocklistOOP(Args&&... aArgs) {
    MOZ_RELEASE_ASSERT(mWinLauncher.mInitDllBlocklistOOP);
    return mWinLauncher.mInitDllBlocklistOOP(std::forward<Args>(aArgs)...);
  }

  template <typename... Args>
  void HandleLauncherError(Args&&... aArgs) {
    MOZ_RELEASE_ASSERT(mWinLauncher.mHandleLauncherError);
    mWinLauncher.mHandleLauncherError(std::forward<Args>(aArgs)...);
  }

  nt::SharedSection* GetSharedSection() { return mWinLauncher.mSharedSection; }

  // In debug builds we override GetBinaryOrgName to add a Gecko-specific
  // assertion. OTOH, we normally do not want people overriding this function,
  // so we'll make it final in the release case, thus covering all bases.
#if defined(DEBUG)
  UniquePtr<wchar_t[]> GetBinaryOrgName(
      const wchar_t* aFilePath,
      AuthenticodeFlags aFlags = AuthenticodeFlags::Default) override
#else
  UniquePtr<wchar_t[]> GetBinaryOrgName(
      const wchar_t* aFilePath,
      AuthenticodeFlags aFlags = AuthenticodeFlags::Default) final
#endif  // defined(DEBUG)
  {
    if (!mAuthenticode) {
      return nullptr;
    }

    return mAuthenticode->GetBinaryOrgName(aFilePath, aFlags);
  }

  virtual void DisableFull() { DllBlocklist_SetFullDllServices(nullptr); }

  DllServicesBase(const DllServicesBase&) = delete;
  DllServicesBase(DllServicesBase&&) = delete;
  DllServicesBase& operator=(const DllServicesBase&) = delete;
  DllServicesBase& operator=(DllServicesBase&&) = delete;

 protected:
  DllServicesBase() : mAuthenticode(nullptr) {}

  virtual ~DllServicesBase() = default;

  void EnableFull() { DllBlocklist_SetFullDllServices(this); }
  void EnableBasic() { DllBlocklist_SetBasicDllServices(this); }

 private:
  Authenticode* mAuthenticode;
  nt::WinLauncherServices mWinLauncher;
};

}  // namespace detail

#if defined(MOZILLA_INTERNAL_API)

struct EnhancedModuleLoadInfo final {
  explicit EnhancedModuleLoadInfo(ModuleLoadInfo&& aModLoadInfo)
      : mNtLoadInfo(std::move(aModLoadInfo)) {
    // Only populate mThreadName when we're on the same thread as the event
    if (mNtLoadInfo.mThreadId == ::GetCurrentThreadId()) {
      mThreadName = PR_GetThreadName(PR_GetCurrentThread());
    }
    MOZ_ASSERT(!mNtLoadInfo.mSectionName.IsEmpty());
  }

  EnhancedModuleLoadInfo(EnhancedModuleLoadInfo&&) = default;
  EnhancedModuleLoadInfo& operator=(EnhancedModuleLoadInfo&&) = default;

  EnhancedModuleLoadInfo(const EnhancedModuleLoadInfo&) = delete;
  EnhancedModuleLoadInfo& operator=(const EnhancedModuleLoadInfo&) = delete;

  nsDependentString GetSectionName() const {
    return mNtLoadInfo.mSectionName.AsString();
  }

  using BacktraceType = decltype(ModuleLoadInfo::mBacktrace);

  ModuleLoadInfo mNtLoadInfo;
  nsCString mThreadName;
};

class DllServices : public detail::DllServicesBase {
 public:
  void DispatchDllLoadNotification(ModuleLoadInfo&& aModLoadInfo) final {
    // We only notify one blocked DLL load event per blocked DLL for the main
    // thread, because dispatching a notification can trigger a new blocked
    // DLL load if the DLL is registered as a WH_GETMESSAGE hook. In that case,
    // dispatching a notification with every load results in an infinite cycle,
    // see bug 1823412.
    if (aModLoadInfo.WasBlocked() && NS_IsMainThread()) {
      nsDependentString sectionName(aModLoadInfo.mSectionName.AsString());

      for (const auto& blockedModule : mMainThreadBlockedModules) {
        if (sectionName == blockedModule) {
          return;
        }
      }

      MOZ_ALWAYS_TRUE(mMainThreadBlockedModules.append(sectionName));
    }

    nsCOMPtr<nsIRunnable> runnable(
        NewRunnableMethod<StoreCopyPassByRRef<EnhancedModuleLoadInfo>>(
            "DllServices::NotifyDllLoad", this, &DllServices::NotifyDllLoad,
            std::move(aModLoadInfo)));

    SchedulerGroup::Dispatch(TaskCategory::Other, runnable.forget());
  }

  void DispatchModuleLoadBacklogNotification(
      ModuleLoadInfoVec&& aEvents) final {
    nsCOMPtr<nsIRunnable> runnable(
        NewRunnableMethod<StoreCopyPassByRRef<ModuleLoadInfoVec>>(
            "DllServices::NotifyModuleLoadBacklog", this,
            &DllServices::NotifyModuleLoadBacklog, std::move(aEvents)));

    SchedulerGroup::Dispatch(TaskCategory::Other, runnable.forget());
  }

#  if defined(DEBUG)
  UniquePtr<wchar_t[]> GetBinaryOrgName(
      const wchar_t* aFilePath,
      AuthenticodeFlags aFlags = AuthenticodeFlags::Default) final {
    // This function may perform disk I/O, so we should never call it on the
    // main thread.
    MOZ_ASSERT(!NS_IsMainThread());
    return detail::DllServicesBase::GetBinaryOrgName(aFilePath, aFlags);
  }
#  endif  // defined(DEBUG)

  NS_INLINE_DECL_THREADSAFE_VIRTUAL_REFCOUNTING(DllServices)

 protected:
  DllServices() = default;
  ~DllServices() = default;

  virtual void NotifyDllLoad(EnhancedModuleLoadInfo&& aModLoadInfo) = 0;
  virtual void NotifyModuleLoadBacklog(ModuleLoadInfoVec&& aEvents) = 0;

 private:
  // This vector has no associated lock. It must only be used on the main
  // thread.
  Vector<nsString> mMainThreadBlockedModules;
};

#else

class BasicDllServices final : public detail::DllServicesBase {
 public:
  BasicDllServices() { EnableBasic(); }

  ~BasicDllServices() = default;

  // Not useful in this class, so provide a default implementation
  virtual void DispatchDllLoadNotification(
      ModuleLoadInfo&& aModLoadInfo) override {}

  virtual void DispatchModuleLoadBacklogNotification(
      ModuleLoadInfoVec&& aEvents) override {}
};

#endif  // defined(MOZILLA_INTERNAL_API)

}  // namespace glue
}  // namespace mozilla

#endif  // mozilla_glue_WindowsDllServices_h