summaryrefslogtreecommitdiffstats
path: root/toolkit/components/aboutthirdparty/tests/TestShellEx/RegUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'toolkit/components/aboutthirdparty/tests/TestShellEx/RegUtils.cpp')
-rw-r--r--toolkit/components/aboutthirdparty/tests/TestShellEx/RegUtils.cpp148
1 files changed, 148 insertions, 0 deletions
diff --git a/toolkit/components/aboutthirdparty/tests/TestShellEx/RegUtils.cpp b/toolkit/components/aboutthirdparty/tests/TestShellEx/RegUtils.cpp
new file mode 100644
index 0000000000..33b2feb12a
--- /dev/null
+++ b/toolkit/components/aboutthirdparty/tests/TestShellEx/RegUtils.cpp
@@ -0,0 +1,148 @@
+/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*-
+ * vim: sw=2 ts=2 et lcs=trail\:.,tab\:>~ :
+ * This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/UniquePtr.h"
+#include "RegUtils.h"
+
+#include <windows.h>
+#include <strsafe.h>
+
+extern std::wstring gDllPath;
+
+const wchar_t kClsIdPrefix[] = L"CLSID\\";
+const wchar_t* kExtensionSubkeys[] = {
+ L".zzz\\shellex\\IconHandler",
+};
+
+bool RegKey::SetStringInternal(const wchar_t* aValueName,
+ const wchar_t* aValueData,
+ DWORD aValueDataLength) {
+ if (!mKey) {
+ return false;
+ }
+
+ return ::RegSetValueExW(mKey, aValueName, 0, REG_SZ,
+ reinterpret_cast<const BYTE*>(aValueData),
+ aValueDataLength) == ERROR_SUCCESS;
+}
+
+RegKey::RegKey(HKEY root, const wchar_t* aSubkey) : mKey(nullptr) {
+ ::RegCreateKeyExW(root, aSubkey, 0, nullptr, 0, KEY_ALL_ACCESS, nullptr,
+ &mKey, nullptr);
+}
+
+RegKey::~RegKey() {
+ if (mKey) {
+ ::RegCloseKey(mKey);
+ }
+}
+
+bool RegKey::SetString(const wchar_t* aValueName, const wchar_t* aValueData) {
+ return SetStringInternal(
+ aValueName, aValueData,
+ aValueData
+ ? static_cast<DWORD>((wcslen(aValueData) + 1) * sizeof(wchar_t))
+ : 0);
+}
+
+bool RegKey::SetString(const wchar_t* aValueName,
+ const std::wstring& aValueData) {
+ return SetStringInternal(
+ aValueName, aValueData.c_str(),
+ static_cast<DWORD>((aValueData.size() + 1) * sizeof(wchar_t)));
+}
+
+std::wstring RegKey::GetString(const wchar_t* aValueName) {
+ DWORD len = 0;
+ LSTATUS status = ::RegGetValueW(mKey, aValueName, nullptr, RRF_RT_REG_SZ,
+ nullptr, nullptr, &len);
+
+ mozilla::UniquePtr<uint8_t[]> buf = mozilla::MakeUnique<uint8_t[]>(len);
+ status = ::RegGetValueW(mKey, aValueName, nullptr, RRF_RT_REG_SZ, nullptr,
+ buf.get(), &len);
+ if (status != ERROR_SUCCESS) {
+ return L"";
+ }
+
+ return reinterpret_cast<wchar_t*>(buf.get());
+}
+
+ComRegisterer::ComRegisterer(const GUID& aClsId, const wchar_t* aFriendlyName)
+ : mClassRoot(HKEY_CURRENT_USER, L"Software\\Classes"),
+ mFriendlyName(aFriendlyName) {
+ wchar_t guidStr[64];
+ HRESULT hr = ::StringCbPrintfW(
+ guidStr, sizeof(guidStr),
+ L"{%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x}", aClsId.Data1,
+ aClsId.Data2, aClsId.Data3, aClsId.Data4[0], aClsId.Data4[1],
+ aClsId.Data4[2], aClsId.Data4[3], aClsId.Data4[4], aClsId.Data4[5],
+ aClsId.Data4[6], aClsId.Data4[7]);
+ if (FAILED(hr)) {
+ return;
+ }
+
+ mClsId = guidStr;
+}
+
+bool ComRegisterer::UnregisterAll() {
+ bool isOk = true;
+ LSTATUS ls;
+
+ for (const wchar_t* subkey : kExtensionSubkeys) {
+ RegKey root(mClassRoot, subkey);
+
+ std::wstring currentHandler = root.GetString(nullptr);
+ if (currentHandler != mClsId) {
+ // If another extension is registered, don't overwrite it.
+ continue;
+ }
+
+ // Set an empty string instead of deleting the key.
+ if (!root.SetString(nullptr)) {
+ isOk = false;
+ }
+ }
+
+ std::wstring subkey(kClsIdPrefix);
+ subkey += mClsId;
+ ls = ::RegDeleteTreeW(mClassRoot, subkey.c_str());
+ if (ls != ERROR_SUCCESS && ls != ERROR_FILE_NOT_FOUND) {
+ isOk = false;
+ }
+
+ return isOk;
+}
+
+bool ComRegisterer::RegisterObject(const wchar_t* aThreadModel) {
+ std::wstring subkey(kClsIdPrefix);
+ subkey += mClsId;
+
+ RegKey root(mClassRoot, subkey.c_str());
+ if (!root || !root.SetString(nullptr, mFriendlyName)) {
+ return false;
+ }
+
+ RegKey inproc(root, L"InprocServer32");
+ return inproc && inproc.SetString(nullptr, gDllPath) &&
+ inproc.SetString(L"ThreadingModel", aThreadModel);
+}
+
+bool ComRegisterer::RegisterExtensions() {
+ for (const wchar_t* subkey : kExtensionSubkeys) {
+ RegKey root(mClassRoot, subkey);
+ std::wstring currentHandler = root.GetString(nullptr);
+ if (!currentHandler.empty()) {
+ // If another extension is registered, don't overwrite it.
+ continue;
+ }
+
+ if (!root.SetString(nullptr, mClsId)) {
+ return false;
+ }
+ }
+
+ return true;
+}