diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 19:33:14 +0000 |
commit | 36d22d82aa202bb199967e9512281e9a53db42c9 (patch) | |
tree | 105e8c98ddea1c1e4784a60a5a6410fa416be2de /toolkit/mozapps/defaultagent/rust | |
parent | Initial commit. (diff) | |
download | firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.tar.xz firefox-esr-36d22d82aa202bb199967e9512281e9a53db42c9.zip |
Adding upstream version 115.7.0esr.upstream/115.7.0esr
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'toolkit/mozapps/defaultagent/rust')
7 files changed, 597 insertions, 0 deletions
diff --git a/toolkit/mozapps/defaultagent/rust/Cargo.toml b/toolkit/mozapps/defaultagent/rust/Cargo.toml new file mode 100644 index 0000000000..4bd44a9df8 --- /dev/null +++ b/toolkit/mozapps/defaultagent/rust/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "defaultagent-static" +version = "0.1.0" +authors = ["The Mozilla Install/Update Team <install-update@mozilla.com>"] +edition = "2018" +description = "FFI to Rust for use in Firefox's default browser agent." +repository = "https://github.com/mozilla/defaultagent-static" +license = "MPL-2.0" + +[dependencies] +log = { version = "0.4", features = ["std"] } +mozilla-central-workspace-hack = { path = "../../../../build/workspace-hack" } +serde = "1.0" +serde_derive = "1.0" +serde_json = "1.0" +url = "2.1" +viaduct = "0.1" +wineventlog = { path = "wineventlog"} +wio = "0.2" +winapi = { version = "0.3", features = ["errhandlingapi", "handleapi", "minwindef", "winerror", "wininet", "winuser"] } + +[lib] +crate-type = ["staticlib"] diff --git a/toolkit/mozapps/defaultagent/rust/moz.build b/toolkit/mozapps/defaultagent/rust/moz.build new file mode 100644 index 0000000000..accce52265 --- /dev/null +++ b/toolkit/mozapps/defaultagent/rust/moz.build @@ -0,0 +1,7 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +RustLibrary("defaultagent-static") diff --git a/toolkit/mozapps/defaultagent/rust/src/lib.rs b/toolkit/mozapps/defaultagent/rust/src/lib.rs new file mode 100644 index 0000000000..de5f1ba03a --- /dev/null +++ b/toolkit/mozapps/defaultagent/rust/src/lib.rs @@ -0,0 +1,166 @@ +/* -*- Mode: rust; rust-indent-offset: 4 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#![allow(non_snake_case)] + +use std::ffi::{CStr, OsString}; +use std::os::raw::c_char; + +#[macro_use] +extern crate serde_derive; +#[macro_use] +extern crate log; +use winapi::shared::ntdef::HRESULT; +use winapi::shared::winerror::{HRESULT_FROM_WIN32, S_OK}; +use wio::wide::FromWide; + +mod viaduct_wininet; +use viaduct_wininet::WinInetBackend; + +// HRESULT with 0x80000000 is an error, 0x20000000 set is a customer error code. +#[allow(overflowing_literals)] +const HR_NETWORK_ERROR: HRESULT = 0x80000000 | 0x20000000 | 0x1; +#[allow(overflowing_literals)] +const HR_SETTINGS_ERROR: HRESULT = 0x80000000 | 0x20000000 | 0x2; + +#[derive(Debug, Deserialize)] +pub struct EnabledRecord { + // Unknown fields are ignored by serde: see the docs for `#[serde(deny_unknown_fields)]`. + pub(crate) enabled: bool, +} + +pub enum Error { + /// A backend error with an attached Windows error code from `GetLastError()`. + WindowsError(u32), + + /// A network or otherwise transient error. + NetworkError, + + /// A configuration or settings data error that probably requires code, configuration, or + /// server-side changes to address. + SettingsError, +} + +impl From<viaduct::UnexpectedStatus> for Error { + fn from(_err: viaduct::UnexpectedStatus) -> Self { + Error::NetworkError + } +} + +impl From<viaduct::Error> for Error { + fn from(err: viaduct::Error) -> Self { + match err { + viaduct::Error::NetworkError(_) => Error::NetworkError, + viaduct::Error::BackendError(raw) => { + // If we have a string that's a hex error code like + // "0xabcde", that's a Windows error. + if raw.starts_with("0x") { + let without_prefix = raw.trim_start_matches("0x"); + let parse_result = u32::from_str_radix(without_prefix, 16); + if let Ok(parsed) = parse_result { + return Error::WindowsError(parsed); + } + } + Error::SettingsError + } + _ => Error::SettingsError, + } + } +} + +impl From<serde_json::Error> for Error { + fn from(_err: serde_json::Error) -> Self { + Error::SettingsError + } +} + +impl From<url::ParseError> for Error { + fn from(_err: url::ParseError) -> Self { + Error::SettingsError + } +} + +fn is_agent_remote_disabled<S: AsRef<str>>(url: S) -> Result<bool, Error> { + // Be careful setting the viaduct backend twice. If the backend + // has been set already, assume that it's our backend: we may as + // well at least try to continue. + match viaduct::set_backend(&WinInetBackend) { + Ok(_) => {} + Err(viaduct::Error::SetBackendError) => {} + e => e?, + } + + let url = url::Url::parse(url.as_ref())?; + let req = viaduct::Request::new(viaduct::Method::Get, url); + let resp = req.send()?; + + let resp = resp.require_success()?; + + let body: serde_json::Value = resp.json()?; + let data = body.get("data").ok_or(Error::SettingsError)?; + let record: EnabledRecord = serde_json::from_value(data.clone())?; + + let disabled = !record.enabled; + Ok(disabled) +} + +// This is an easy way to consume `MOZ_APP_DISPLAYNAME` from Rust code. +extern "C" { + static gWinEventLogSourceName: *const u16; +} + +#[allow(dead_code)] +#[no_mangle] +extern "C" fn IsAgentRemoteDisabledRust(szUrl: *const c_char, lpdwDisabled: *mut u32) -> HRESULT { + let wineventlog_name = unsafe { OsString::from_wide_ptr_null(gWinEventLogSourceName) }; + let logger = wineventlog::EventLogger::new(&wineventlog_name); + // It's fine to initialize logging twice. + let _ = log::set_boxed_logger(Box::new(logger)); + log::set_max_level(log::LevelFilter::Info); + + // Use an IIFE for `?`. + let disabled_result = (|| { + if lpdwDisabled.is_null() { + return Err(Error::SettingsError); + } + + let url = unsafe { CStr::from_ptr(szUrl).to_str().map(|x| x.to_string()) } + .map_err(|_| Error::SettingsError)?; + + info!("Using remote settings URL: {}", url); + + is_agent_remote_disabled(url) + })(); + + match disabled_result { + Err(e) => { + return match e { + Error::WindowsError(errno) => { + let hr = HRESULT_FROM_WIN32(errno); + error!("Error::WindowsError({}) (HRESULT: 0x{:x})", errno, hr); + hr + } + Error::NetworkError => { + let hr = HR_NETWORK_ERROR; + error!("Error::NetworkError (HRESULT: 0x{:x})", hr); + hr + } + Error::SettingsError => { + let hr = HR_SETTINGS_ERROR; + error!("Error::SettingsError (HRESULT: 0x{:x})", hr); + hr + } + }; + } + + Ok(remote_disabled) => { + // We null-checked `lpdwDisabled` earlier, but just to be safe. + if !lpdwDisabled.is_null() { + unsafe { *lpdwDisabled = if remote_disabled { 1 } else { 0 } }; + } + return S_OK; + } + } +} diff --git a/toolkit/mozapps/defaultagent/rust/src/viaduct_wininet/internet_handle.rs b/toolkit/mozapps/defaultagent/rust/src/viaduct_wininet/internet_handle.rs new file mode 100644 index 0000000000..85f4254c88 --- /dev/null +++ b/toolkit/mozapps/defaultagent/rust/src/viaduct_wininet/internet_handle.rs @@ -0,0 +1,53 @@ +// Licensed under the Apache License, Version 2.0 +// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option. +// All files in the project carrying such notice may not be copied, modified, or distributed +// except according to those terms. + +//! Wrapping and automatically closing Internet handles. Copy-pasted from +//! [comedy-rs](https://github.com/agashlin/comedy-rs/blob/c244b91e9237c887f6a7bc6cd03db98b51966494/src/handle.rs). + +use winapi::shared::minwindef::DWORD; +use winapi::shared::ntdef::NULL; +use winapi::um::errhandlingapi::GetLastError; +use winapi::um::wininet::{InternetCloseHandle, HINTERNET}; + +/// Check and automatically close a Windows `HINTERNET`. +#[repr(transparent)] +#[derive(Debug)] +pub struct InternetHandle(HINTERNET); + +impl InternetHandle { + /// Take ownership of a `HINTERNET`, which will be closed with `InternetCloseHandle` upon drop. + /// Returns an error in case of `NULL`. + /// + /// # Safety + /// + /// `h` should be the only copy of the handle. `GetLastError()` is called to + /// return an error, so the last Windows API called on this thread should have been + /// what produced the invalid handle. + pub unsafe fn new(h: HINTERNET) -> Result<InternetHandle, DWORD> { + if h == NULL { + Err(GetLastError()) + } else { + Ok(InternetHandle(h)) + } + } + + /// Obtains the raw `HINTERNET` without transferring ownership. + /// + /// Do __not__ close this handle because it is still owned by the `InternetHandle`. + /// + /// Do __not__ use this handle beyond the lifetime of the `InternetHandle`. + pub fn as_raw(&self) -> HINTERNET { + self.0 + } +} + +impl Drop for InternetHandle { + fn drop(&mut self) { + unsafe { + InternetCloseHandle(self.0); + } + } +} diff --git a/toolkit/mozapps/defaultagent/rust/src/viaduct_wininet/mod.rs b/toolkit/mozapps/defaultagent/rust/src/viaduct_wininet/mod.rs new file mode 100644 index 0000000000..1abd170f8f --- /dev/null +++ b/toolkit/mozapps/defaultagent/rust/src/viaduct_wininet/mod.rs @@ -0,0 +1,257 @@ +// Licensed under the Apache License, Version 2.0 +// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license +// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option. +// All files in the project carrying such notice may not be copied, modified, or distributed +// except according to those terms. + +use winapi::shared::winerror::ERROR_INSUFFICIENT_BUFFER; +use winapi::um::errhandlingapi::GetLastError; +use winapi::um::wininet; +use wio::wide::ToWide; + +use viaduct::Backend; + +mod internet_handle; +use internet_handle::InternetHandle; + +pub struct WinInetBackend; + +/// Errors +fn to_viaduct_error(e: u32) -> viaduct::Error { + // Like "0xabcde". + viaduct::Error::BackendError(format!("{:#x}", e)) +} + +fn get_status(req: wininet::HINTERNET) -> Result<u16, viaduct::Error> { + let mut status: u32 = 0; + let mut size: u32 = std::mem::size_of::<u32>() as u32; + let result = unsafe { + wininet::HttpQueryInfoW( + req, + wininet::HTTP_QUERY_STATUS_CODE | wininet::HTTP_QUERY_FLAG_NUMBER, + &mut status as *mut _ as *mut _, + &mut size, + std::ptr::null_mut(), + ) + }; + if 0 == result { + return Err(to_viaduct_error(unsafe { GetLastError() })); + } + + Ok(status as u16) +} + +fn get_headers(req: wininet::HINTERNET) -> Result<viaduct::Headers, viaduct::Error> { + // We follow https://docs.microsoft.com/en-us/windows/win32/wininet/retrieving-http-headers. + // + // Per + // https://docs.microsoft.com/en-us/windows/win32/api/wininet/nf-wininet-httpqueryinfoa: + // The `HttpQueryInfoA` function represents headers as ISO-8859-1 characters + // not ANSI characters. + let mut size: u32 = 0; + + let result = unsafe { + wininet::HttpQueryInfoA( + req, + wininet::HTTP_QUERY_RAW_HEADERS, + std::ptr::null_mut(), + &mut size, + std::ptr::null_mut(), + ) + }; + if 0 == result { + let error = unsafe { GetLastError() }; + if error == wininet::ERROR_HTTP_HEADER_NOT_FOUND { + return Ok(viaduct::Headers::new()); + } else if error != ERROR_INSUFFICIENT_BUFFER { + return Err(to_viaduct_error(error)); + } + } + + let mut buffer = vec![0 as u8; size as usize]; + let result = unsafe { + wininet::HttpQueryInfoA( + req, + wininet::HTTP_QUERY_RAW_HEADERS, + buffer.as_mut_ptr() as *mut _, + &mut size, + std::ptr::null_mut(), + ) + }; + if 0 == result { + let error = unsafe { GetLastError() }; + if error == wininet::ERROR_HTTP_HEADER_NOT_FOUND { + return Ok(viaduct::Headers::new()); + } else { + return Err(to_viaduct_error(error)); + } + } + + // The API returns all of the headers as a single char buffer in + // ISO-8859-1 encoding. Each header is terminated by '\0' and + // there's a trailing '\0' terminator as well. + // + // We want UTF-8. It's not worth include a non-trivial encoding + // library like `encoding_rs` just for these headers, so let's use + // the fact that ISO-8859-1 and UTF-8 intersect on the lower 7 bits + // and decode lossily. It will at least be reasonably clear when + // there is an encoding issue. + let allheaders = String::from_utf8_lossy(&buffer); + + let mut headers = viaduct::Headers::new(); + for header in allheaders.split(0 as char) { + let mut it = header.splitn(2, ":"); + if let (Some(name), Some(value)) = (it.next(), it.next()) { + headers.insert(name.trim().to_string(), value.trim().to_string())?; + } + } + + return Ok(headers); +} + +fn get_body(req: wininet::HINTERNET) -> Result<Vec<u8>, viaduct::Error> { + let mut body = Vec::new(); + + const BUFFER_SIZE: usize = 65535; + let mut buffer: [u8; BUFFER_SIZE] = [0; BUFFER_SIZE]; + + loop { + let mut bytes_downloaded: u32 = 0; + let result = unsafe { + wininet::InternetReadFile( + req, + buffer.as_mut_ptr() as *mut _, + BUFFER_SIZE as u32, + &mut bytes_downloaded, + ) + }; + if 0 == result { + return Err(to_viaduct_error(unsafe { GetLastError() })); + } + if bytes_downloaded == 0 { + break; + } + + body.extend_from_slice(&buffer[0..bytes_downloaded as usize]); + } + Ok(body) +} + +impl Backend for WinInetBackend { + fn send(&self, request: viaduct::Request) -> Result<viaduct::Response, viaduct::Error> { + viaduct::note_backend("wininet.dll"); + + let request_method = request.method; + let url = request.url; + + let session = unsafe { + InternetHandle::new(wininet::InternetOpenW( + "DefaultAgent/1.0".to_wide_null().as_ptr(), + wininet::INTERNET_OPEN_TYPE_PRECONFIG, + std::ptr::null_mut(), + std::ptr::null_mut(), + 0, + )) + } + .map_err(to_viaduct_error)?; + + // Consider asserting the scheme here too, for documentation purposes. + // Viaduct itself only allows HTTPS at this time, but that might change. + let host = url + .host_str() + .ok_or(viaduct::Error::BackendError("no host".to_string()))?; + + let conn = unsafe { + InternetHandle::new(wininet::InternetConnectW( + session.as_raw(), + host.to_wide_null().as_ptr(), + wininet::INTERNET_DEFAULT_HTTPS_PORT as u16, + std::ptr::null_mut(), + std::ptr::null_mut(), + wininet::INTERNET_SERVICE_HTTP, + 0, + 0, + )) + } + .map_err(to_viaduct_error)?; + + let path = url[url::Position::BeforePath..].to_string(); + let req = unsafe { + wininet::HttpOpenRequestW( + conn.as_raw(), + request_method.as_str().to_wide_null().as_ptr(), + path.to_wide_null().as_ptr(), + std::ptr::null_mut(), /* lpszVersion */ + std::ptr::null_mut(), /* lpszReferrer */ + std::ptr::null_mut(), /* lplpszAcceptTypes */ + // Avoid the cache as best we can. + wininet::INTERNET_FLAG_NO_AUTH + | wininet::INTERNET_FLAG_NO_CACHE_WRITE + | wininet::INTERNET_FLAG_NO_COOKIES + | wininet::INTERNET_FLAG_NO_UI + | wininet::INTERNET_FLAG_PRAGMA_NOCACHE + | wininet::INTERNET_FLAG_RELOAD + | wininet::INTERNET_FLAG_SECURE, + 0, + ) + }; + if req.is_null() { + return Err(to_viaduct_error(unsafe { GetLastError() })); + } + + for header in request.headers { + // Per + // https://docs.microsoft.com/en-us/windows/win32/api/wininet/nf-wininet-httpaddrequestheadersw, + // "Each header must be terminated by a CR/LF (carriage return/line + // feed) pair." + let h = format!("{}: {}\r\n", header.name(), header.value()); + let result = unsafe { + wininet::HttpAddRequestHeadersW( + req, + h.to_wide_null().as_ptr(), /* lpszHeaders */ + -1i32 as u32, /* dwHeadersLength */ + wininet::HTTP_ADDREQ_FLAG_ADD | wininet::HTTP_ADDREQ_FLAG_REPLACE, /* dwModifiers */ + ) + }; + if 0 == result { + return Err(to_viaduct_error(unsafe { GetLastError() })); + } + } + + // Future work: support sending a body. + if request.body.is_some() { + return Err(viaduct::Error::BackendError( + "non-empty body is not yet supported".to_string(), + )); + } + + let result = unsafe { + wininet::HttpSendRequestW( + req, + std::ptr::null_mut(), /* lpszHeaders */ + 0, /* dwHeadersLength */ + std::ptr::null_mut(), /* lpOptional */ + 0, /* dwOptionalLength */ + ) + }; + if 0 == result { + return Err(to_viaduct_error(unsafe { GetLastError() })); + } + + let status = get_status(req)?; + let headers = get_headers(req)?; + + // Not all responses have a body. + let has_body = headers.get_header("content-type").is_some() + || headers.get_header("content-length").is_some(); + let body = if has_body { get_body(req)? } else { Vec::new() }; + + Ok(viaduct::Response { + request_method, + body, + url, + status, + headers, + }) + } +} diff --git a/toolkit/mozapps/defaultagent/rust/wineventlog/Cargo.toml b/toolkit/mozapps/defaultagent/rust/wineventlog/Cargo.toml new file mode 100644 index 0000000000..b38f704ca1 --- /dev/null +++ b/toolkit/mozapps/defaultagent/rust/wineventlog/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "wineventlog" +version = "0.1.0" +authors = ["The Mozilla Project Developers"] +license = "MPL-2.0" +autobins = false +edition = "2018" + +[target."cfg(windows)".dependencies] +log = "0.4" +wio = "0.2" + +[target."cfg(windows)".dependencies.winapi] +version = "0.3.7" +features = ["errhandlingapi", "minwindef", "ntdef", "oaidl", "oleauto", "sysinfoapi", "taskschd", "winbase", "winerror", "winnt", "winreg", "wtypes"] diff --git a/toolkit/mozapps/defaultagent/rust/wineventlog/src/lib.rs b/toolkit/mozapps/defaultagent/rust/wineventlog/src/lib.rs new file mode 100644 index 0000000000..c98b75c53a --- /dev/null +++ b/toolkit/mozapps/defaultagent/rust/wineventlog/src/lib.rs @@ -0,0 +1,76 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ + +//! Very simple implementation of logging via the Windows Event Log + +use std::ffi::OsStr; +use std::ptr; + +use log::{Level, Metadata, Record}; +use winapi::shared::minwindef::WORD; +use winapi::um::{winbase, winnt}; +use wio::wide::ToWide; + +pub struct EventLogger { + pub name: Vec<u16>, +} + +impl EventLogger { + pub fn new(name: impl AsRef<OsStr>) -> Self { + EventLogger { + name: name.to_wide_null(), + } + } +} + +impl log::Log for EventLogger { + fn enabled(&self, metadata: &Metadata) -> bool { + metadata.level() <= log::max_level() + } + + fn log(&self, record: &Record) { + if !self.enabled(record.metadata()) { + return; + } + + let name = self.name.as_ptr(); + let msg = format!("{} - {}", record.level(), record.args()).to_wide_null(); + + // Open and close the event log handle on every message, for simplicity. + let event_log; + unsafe { + event_log = winbase::RegisterEventSourceW(ptr::null(), name); + if event_log.is_null() { + return; + } + } + + let level = match record.level() { + Level::Error => winnt::EVENTLOG_ERROR_TYPE, + Level::Warn => winnt::EVENTLOG_WARNING_TYPE, + Level::Info | Level::Debug | Level::Trace => winnt::EVENTLOG_INFORMATION_TYPE, + }; + + unsafe { + // mut only to match the LPCWSTR* signature + let mut msg_array: [*const u16; 1] = [msg.as_ptr()]; + + let _ = winbase::ReportEventW( + event_log, + level, + 0, // no category + 0, // event id 0 + ptr::null_mut(), // no user sid + msg_array.len() as WORD, // string count + 0, // 0 bytes raw data + msg_array.as_mut_ptr(), // strings + ptr::null_mut(), // no raw data + ); + + let _ = winbase::DeregisterEventSource(event_log); + } + } + + fn flush(&self) {} +} |