diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-21 11:44:51 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-21 11:44:51 +0000 |
commit | 9e3c08db40b8916968b9f30096c7be3f00ce9647 (patch) | |
tree | a68f146d7fa01f0134297619fbe7e33db084e0aa /netwerk/socket | |
parent | Initial commit. (diff) | |
download | thunderbird-9e3c08db40b8916968b9f30096c7be3f00ce9647.tar.xz thunderbird-9e3c08db40b8916968b9f30096c7be3f00ce9647.zip |
Adding upstream version 1:115.7.0.upstream/1%115.7.0upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'netwerk/socket')
22 files changed, 4913 insertions, 0 deletions
diff --git a/netwerk/socket/moz.build b/netwerk/socket/moz.build new file mode 100644 index 0000000000..d8ce8d0b2b --- /dev/null +++ b/netwerk/socket/moz.build @@ -0,0 +1,46 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +DIRS += [ + "neqo_glue", +] + +XPIDL_SOURCES += [ + "nsISocketProvider.idl", + "nsISocketProviderService.idl", +] + +XPIDL_MODULE = "necko_socket" + +LOCAL_INCLUDES += [ + "/netwerk/base", +] + +include("/ipc/chromium/chromium-config.mozbuild") + +EXPORTS += [ + "nsSocketProviderService.h", +] + +UNIFIED_SOURCES += [ + "nsSocketProviderService.cpp", + "nsSOCKSIOLayer.cpp", + "nsSOCKSSocketProvider.cpp", + "nsUDPSocketProvider.cpp", +] + +if CONFIG["MOZ_WIDGET_TOOLKIT"] == "windows": + XPIDL_SOURCES += [ + "nsINamedPipeService.idl", + ] + EXPORTS += [ + "nsNamedPipeService.h", + ] + UNIFIED_SOURCES += ["nsNamedPipeIOLayer.cpp", "nsNamedPipeService.cpp"] + +FINAL_LIBRARY = "xul" + +CONFIGURE_SUBST_FILES += ["neqo/extra-bindgen-flags"] diff --git a/netwerk/socket/neqo/extra-bindgen-flags.in b/netwerk/socket/neqo/extra-bindgen-flags.in new file mode 100644 index 0000000000..69f7dabf26 --- /dev/null +++ b/netwerk/socket/neqo/extra-bindgen-flags.in @@ -0,0 +1 @@ +@BINDGEN_SYSTEM_FLAGS@ @NSPR_CFLAGS@ @NSS_CFLAGS@ diff --git a/netwerk/socket/neqo_glue/Cargo.toml b/netwerk/socket/neqo_glue/Cargo.toml new file mode 100644 index 0000000000..4336bfe21f --- /dev/null +++ b/netwerk/socket/neqo_glue/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "neqo_glue" +version = "0.1.0" +authors = ["Dragana Damjanovic <dd.mozilla@gmail.com>"] +edition = "2018" +license = "MPL-2.0" + +[lib] +name = "neqo_glue" + +[dependencies] +neqo-http3 = { tag = "v0.6.4", git = "https://github.com/mozilla/neqo" } +neqo-transport = { tag = "v0.6.4", git = "https://github.com/mozilla/neqo" } +neqo-common = { tag = "v0.6.4", git = "https://github.com/mozilla/neqo" } +neqo-qpack = { tag = "v0.6.4", git = "https://github.com/mozilla/neqo" } +nserror = { path = "../../../xpcom/rust/nserror" } +nsstring = { path = "../../../xpcom/rust/nsstring" } +xpcom = { path = "../../../xpcom/rust/xpcom" } +thin-vec = { version = "0.2.1", features = ["gecko-ffi"] } +log = "0.4.0" +qlog = "0.4.0" +libc = "0.2.0" +static_prefs = { path = "../../../modules/libpref/init/static_prefs", optional = true } + +[target.'cfg(target_os = "windows")'.dependencies] +winapi = {version = "0.3", features = ["ws2def"] } + +[dependencies.neqo-crypto] +tag = "v0.6.4" +git = "https://github.com/mozilla/neqo" +default-features = false +features = ["gecko"] + +[features] +fuzzing = ["neqo-http3/fuzzing", "static_prefs"] diff --git a/netwerk/socket/neqo_glue/NeqoHttp3Conn.h b/netwerk/socket/neqo_glue/NeqoHttp3Conn.h new file mode 100644 index 0000000000..6ea47e7830 --- /dev/null +++ b/netwerk/socket/neqo_glue/NeqoHttp3Conn.h @@ -0,0 +1,162 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef NeqoHttp3Conn_h__ +#define NeqoHttp3Conn_h__ + +#include <cstdint> +#include "mozilla/net/neqo_glue_ffi_generated.h" + +namespace mozilla { +namespace net { + +class NeqoHttp3Conn final { + public: + static nsresult Init(const nsACString& aOrigin, const nsACString& aAlpn, + const NetAddr& aLocalAddr, const NetAddr& aRemoteAddr, + uint32_t aMaxTableSize, uint16_t aMaxBlockedStreams, + uint64_t aMaxData, uint64_t aMaxStreamData, + bool aVersionNegotiation, bool aWebTransport, + const nsACString& aQlogDir, uint32_t aDatagramSize, + NeqoHttp3Conn** aConn) { + return neqo_http3conn_new(&aOrigin, &aAlpn, &aLocalAddr, &aRemoteAddr, + aMaxTableSize, aMaxBlockedStreams, aMaxData, + aMaxStreamData, aVersionNegotiation, + aWebTransport, &aQlogDir, aDatagramSize, + (const mozilla::net::NeqoHttp3Conn**)aConn); + } + + void Close(uint64_t aError) { neqo_http3conn_close(this, aError); } + + nsresult GetSecInfo(NeqoSecretInfo* aSecInfo) { + return neqo_http3conn_tls_info(this, aSecInfo); + } + + nsresult PeerCertificateInfo(NeqoCertificateInfo* aCertInfo) { + return neqo_http3conn_peer_certificate_info(this, aCertInfo); + } + + void PeerAuthenticated(PRErrorCode aError) { + neqo_http3conn_authenticated(this, aError); + } + + nsresult ProcessInput(const NetAddr& aRemoteAddr, + const nsTArray<uint8_t>& aPacket) { + return neqo_http3conn_process_input(this, &aRemoteAddr, &aPacket); + } + + bool ProcessOutput(nsACString* aRemoteAddr, uint16_t* aPort, + nsTArray<uint8_t>& aData, uint64_t* aTimeout) { + aData.TruncateLength(0); + return neqo_http3conn_process_output(this, aRemoteAddr, aPort, &aData, + aTimeout); + } + + nsresult GetEvent(Http3Event* aEvent, nsTArray<uint8_t>& aData) { + return neqo_http3conn_event(this, aEvent, &aData); + } + + nsresult Fetch(const nsACString& aMethod, const nsACString& aScheme, + const nsACString& aHost, const nsACString& aPath, + const nsACString& aHeaders, uint64_t* aStreamId, + uint8_t aUrgency, bool aIncremental) { + return neqo_http3conn_fetch(this, &aMethod, &aScheme, &aHost, &aPath, + &aHeaders, aStreamId, aUrgency, aIncremental); + } + + nsresult PriorityUpdate(uint64_t aStreamId, uint8_t aUrgency, + bool aIncremental) { + return neqo_http3conn_priority_update(this, aStreamId, aUrgency, + aIncremental); + } + + nsresult SendRequestBody(uint64_t aStreamId, const uint8_t* aBuf, + uint32_t aCount, uint32_t* aCountRead) { + return neqo_htttp3conn_send_request_body(this, aStreamId, aBuf, aCount, + aCountRead); + } + + // This closes only the sending side of a stream. + nsresult CloseStream(uint64_t aStreamId) { + return neqo_http3conn_close_stream(this, aStreamId); + } + + nsresult ReadResponseData(uint64_t aStreamId, uint8_t* aBuf, uint32_t aLen, + uint32_t* aRead, bool* aFin) { + return neqo_http3conn_read_response_data(this, aStreamId, aBuf, aLen, aRead, + aFin); + } + + void CancelFetch(uint64_t aStreamId, uint64_t aError) { + neqo_http3conn_cancel_fetch(this, aStreamId, aError); + } + + void ResetStream(uint64_t aStreamId, uint64_t aError) { + neqo_http3conn_reset_stream(this, aStreamId, aError); + } + + void StreamStopSending(uint64_t aStreamId, uint64_t aError) { + neqo_http3conn_stream_stop_sending(this, aStreamId, aError); + } + + void SetResumptionToken(nsTArray<uint8_t>& aToken) { + neqo_http3conn_set_resumption_token(this, &aToken); + } + + void SetEchConfig(nsTArray<uint8_t>& aEchConfig) { + neqo_http3conn_set_ech_config(this, &aEchConfig); + } + + bool IsZeroRtt() { return neqo_http3conn_is_zero_rtt(this); } + + void AddRef() { neqo_http3conn_addref(this); } + void Release() { neqo_http3conn_release(this); } + + void GetStats(Http3Stats* aStats) { + return neqo_http3conn_get_stats(this, aStats); + } + + nsresult CreateWebTransport(const nsACString& aHost, const nsACString& aPath, + const nsACString& aHeaders, + uint64_t* aSessionId) { + return neqo_http3conn_webtransport_create_session(this, &aHost, &aPath, + &aHeaders, aSessionId); + } + + nsresult CloseWebTransport(uint64_t aSessionId, uint32_t aError, + const nsACString& aMessage) { + return neqo_http3conn_webtransport_close_session(this, aSessionId, aError, + &aMessage); + } + + nsresult CreateWebTransportStream(uint64_t aSessionId, + WebTransportStreamType aStreamType, + uint64_t* aStreamId) { + return neqo_http3conn_webtransport_create_stream(this, aSessionId, + aStreamType, aStreamId); + } + + nsresult WebTransportSendDatagram(uint64_t aSessionId, + nsTArray<uint8_t>& aData, + uint64_t aTrackingId) { + return neqo_http3conn_webtransport_send_datagram(this, aSessionId, &aData, + aTrackingId); + } + + nsresult WebTransportMaxDatagramSize(uint64_t aSessionId, uint64_t* aResult) { + return neqo_http3conn_webtransport_max_datagram_size(this, aSessionId, + aResult); + } + + private: + NeqoHttp3Conn() = delete; + ~NeqoHttp3Conn() = delete; + NeqoHttp3Conn(const NeqoHttp3Conn&) = delete; + NeqoHttp3Conn& operator=(const NeqoHttp3Conn&) = delete; +}; + +} // namespace net +} // namespace mozilla + +#endif diff --git a/netwerk/socket/neqo_glue/cbindgen.toml b/netwerk/socket/neqo_glue/cbindgen.toml new file mode 100644 index 0000000000..312e1731e5 --- /dev/null +++ b/netwerk/socket/neqo_glue/cbindgen.toml @@ -0,0 +1,27 @@ +header = """/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */""" +autogen_warning = """/* DO NOT MODIFY THIS MANUALLY! This file was generated using cbindgen. + */ + +namespace mozilla { +namespace net { +class NeqoHttp3Conn; +union NetAddr; +} // namespace net +} // namespace mozilla + """ +include_version = true +braces = "SameLine" +line_length = 100 +tab_width = 2 +language = "C++" +namespaces = ["mozilla", "net"] +includes = ["certt.h", "prerror.h"] + +[export] +exclude = ["NeqoHttp3Conn", "NetAddr"] +item_types = ["globals", "enums", "structs", "unions", "typedefs", "opaque", "functions", "constants"] + +[export.rename] +"ThinVec" = "nsTArray" diff --git a/netwerk/socket/neqo_glue/moz.build b/netwerk/socket/neqo_glue/moz.build new file mode 100644 index 0000000000..b3123dde4f --- /dev/null +++ b/netwerk/socket/neqo_glue/moz.build @@ -0,0 +1,21 @@ +# -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*- +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +EXPORTS.mozilla.net += [ + "NeqoHttp3Conn.h", +] + +LOCAL_INCLUDES += [ + "/security/manager/ssl", + "/security/nss/lib/ssl", +] + +if CONFIG["COMPILE_ENVIRONMENT"]: + CbindgenHeader("neqo_glue_ffi_generated.h", inputs=["/netwerk/socket/neqo_glue"]) + + EXPORTS.mozilla.net += [ + "!neqo_glue_ffi_generated.h", + ] diff --git a/netwerk/socket/neqo_glue/src/lib.rs b/netwerk/socket/neqo_glue/src/lib.rs new file mode 100644 index 0000000000..adc55042c5 --- /dev/null +++ b/netwerk/socket/neqo_glue/src/lib.rs @@ -0,0 +1,1330 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at https://mozilla.org/MPL/2.0/. */ + +#[cfg(not(windows))] +use libc::{AF_INET, AF_INET6}; +use neqo_common::event::Provider; +use neqo_common::{self as common, qlog::NeqoQlog, qwarn, Datagram, Header, Role}; +use neqo_crypto::{init, PRErrorCode}; +use neqo_http3::{ + features::extended_connect::SessionCloseReason, Error as Http3Error, Http3Client, + Http3ClientEvent, Http3Parameters, Http3State, Priority, WebTransportEvent, +}; +use neqo_transport::{ + stream_id::StreamType, CongestionControlAlgorithm, ConnectionParameters, + Error as TransportError, Output, RandomConnectionIdGenerator, StreamId, Version, +}; +use nserror::*; +use nsstring::*; +use qlog::QlogStreamer; +use std::borrow::Cow; +use std::cell::RefCell; +use std::convert::TryFrom; +use std::convert::TryInto; +use std::fs::OpenOptions; +use std::net::SocketAddr; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::path::PathBuf; +use std::ptr; +use std::rc::Rc; +use std::slice; +use std::str; +#[cfg(feature = "fuzzing")] +use std::time::Duration; +use std::time::Instant; +use thin_vec::ThinVec; +#[cfg(windows)] +use winapi::shared::ws2def::{AF_INET, AF_INET6}; +use xpcom::{AtomicRefcnt, RefCounted, RefPtr}; + +#[repr(C)] +pub struct NeqoHttp3Conn { + conn: Http3Client, + local_addr: SocketAddr, + refcnt: AtomicRefcnt, +} + +// Opaque interface to mozilla::net::NetAddr defined in DNS.h +#[repr(C)] +pub union NetAddr { + private: [u8; 0], +} + +extern "C" { + pub fn moz_netaddr_get_family(arg: *const NetAddr) -> u16; + pub fn moz_netaddr_get_network_order_ip(arg: *const NetAddr) -> u32; + pub fn moz_netaddr_get_ipv6(arg: *const NetAddr) -> *const u8; + pub fn moz_netaddr_get_network_order_port(arg: *const NetAddr) -> u16; +} + +fn netaddr_to_socket_addr(arg: *const NetAddr) -> Result<SocketAddr, nsresult> { + if arg == ptr::null() { + return Err(NS_ERROR_INVALID_ARG); + } + + unsafe { + let family = moz_netaddr_get_family(arg) as i32; + if family == AF_INET { + let port = u16::from_be(moz_netaddr_get_network_order_port(arg)); + let ipv4 = Ipv4Addr::from(u32::from_be(moz_netaddr_get_network_order_ip(arg))); + return Ok(SocketAddr::new(IpAddr::V4(ipv4), port)); + } + + if family == AF_INET6 { + let port = u16::from_be(moz_netaddr_get_network_order_port(arg)); + let ipv6_slice: [u8; 16] = slice::from_raw_parts(moz_netaddr_get_ipv6(arg), 16) + .try_into() + .expect("slice with incorrect length"); + let ipv6 = Ipv6Addr::from(ipv6_slice); + return Ok(SocketAddr::new(IpAddr::V6(ipv6), port)); + } + } + + Err(NS_ERROR_UNEXPECTED) +} + +impl NeqoHttp3Conn { + fn new( + origin: &nsACString, + alpn: &nsACString, + local_addr: *const NetAddr, + remote_addr: *const NetAddr, + max_table_size: u64, + max_blocked_streams: u16, + max_data: u64, + max_stream_data: u64, + version_negotiation: bool, + webtransport: bool, + qlog_dir: &nsACString, + webtransport_datagram_size: u32, + ) -> Result<RefPtr<NeqoHttp3Conn>, nsresult> { + // Nss init. + init(); + + let origin_conv = str::from_utf8(origin).map_err(|_| NS_ERROR_INVALID_ARG)?; + + let alpn_conv = str::from_utf8(alpn).map_err(|_| NS_ERROR_INVALID_ARG)?; + + let local: SocketAddr = netaddr_to_socket_addr(local_addr)?; + + let remote: SocketAddr = netaddr_to_socket_addr(remote_addr)?; + + let quic_version = match alpn_conv { + "h3-32" => Version::Draft32, + "h3-31" => Version::Draft31, + "h3-30" => Version::Draft30, + "h3-29" => Version::Draft29, + "h3" => Version::Version1, + _ => return Err(NS_ERROR_INVALID_ARG), + }; + + let version_list = if version_negotiation { + Version::all() + } else { + vec![quic_version] + }; + #[allow(unused_mut)] + let mut params = ConnectionParameters::default() + .versions(quic_version, version_list) + .cc_algorithm(CongestionControlAlgorithm::Cubic) + .max_data(max_data) + .max_stream_data(StreamType::BiDi, false, max_stream_data); + + // Set a short timeout when fuzzing. + #[cfg(feature = "fuzzing")] + if static_prefs::pref!("fuzzing.necko.http3") { + params = params.idle_timeout(Duration::from_millis(10)); + } + + if webtransport_datagram_size > 0 { + params = params.datagram_size(webtransport_datagram_size.into()); + } + + let http3_settings = Http3Parameters::default() + .max_table_size_encoder(max_table_size) + .max_table_size_decoder(max_table_size) + .max_blocked_streams(max_blocked_streams) + .max_concurrent_push_streams(0) + .connection_parameters(params) + .webtransport(webtransport) + .http3_datagram(webtransport); + + let mut conn = match Http3Client::new( + origin_conv, + Rc::new(RefCell::new(RandomConnectionIdGenerator::new(3))), + local, + remote, + http3_settings, + Instant::now(), + ) { + Ok(c) => c, + Err(_) => return Err(NS_ERROR_INVALID_ARG), + }; + + if !qlog_dir.is_empty() { + let qlog_dir_conv = str::from_utf8(qlog_dir).map_err(|_| NS_ERROR_INVALID_ARG)?; + let mut qlog_path = PathBuf::from(qlog_dir_conv); + qlog_path.push(format!("{}.qlog", origin)); + + // Emit warnings but to not return an error if qlog initialization + // fails. + match OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&qlog_path) + { + Err(_) => qwarn!("Could not open qlog path: {}", qlog_path.display()), + Ok(f) => { + let streamer = QlogStreamer::new( + qlog::QLOG_VERSION.to_string(), + Some("Firefox Client qlog".to_string()), + Some("Firefox Client qlog".to_string()), + None, + std::time::Instant::now(), + common::qlog::new_trace(Role::Client), + Box::new(f), + ); + + match NeqoQlog::enabled(streamer, &qlog_path) { + Err(_) => qwarn!("Could not write to qlog path: {}", qlog_path.display()), + Ok(nq) => conn.set_qlog(nq), + } + } + } + } + + let conn = Box::into_raw(Box::new(NeqoHttp3Conn { + conn, + local_addr: local, + refcnt: unsafe { AtomicRefcnt::new() }, + })); + unsafe { Ok(RefPtr::from_raw(conn).unwrap()) } + } +} + +#[no_mangle] +pub unsafe extern "C" fn neqo_http3conn_addref(conn: &NeqoHttp3Conn) { + conn.refcnt.inc(); +} + +#[no_mangle] +pub unsafe extern "C" fn neqo_http3conn_release(conn: &NeqoHttp3Conn) { + let rc = conn.refcnt.dec(); + if rc == 0 { + std::mem::drop(Box::from_raw(conn as *const _ as *mut NeqoHttp3Conn)); + } +} + +// xpcom::RefPtr support +unsafe impl RefCounted for NeqoHttp3Conn { + unsafe fn addref(&self) { + neqo_http3conn_addref(self); + } + unsafe fn release(&self) { + neqo_http3conn_release(self); + } +} + +// Allocate a new NeqoHttp3Conn object. +#[no_mangle] +pub extern "C" fn neqo_http3conn_new( + origin: &nsACString, + alpn: &nsACString, + local_addr: *const NetAddr, + remote_addr: *const NetAddr, + max_table_size: u64, + max_blocked_streams: u16, + max_data: u64, + max_stream_data: u64, + version_negotiation: bool, + webtransport: bool, + qlog_dir: &nsACString, + webtransport_datagram_size: u32, + result: &mut *const NeqoHttp3Conn, +) -> nsresult { + *result = ptr::null_mut(); + + match NeqoHttp3Conn::new( + origin, + alpn, + local_addr, + remote_addr, + max_table_size, + max_blocked_streams, + max_data, + max_stream_data, + version_negotiation, + webtransport, + qlog_dir, + webtransport_datagram_size, + ) { + Ok(http3_conn) => { + http3_conn.forget(result); + NS_OK + } + Err(e) => e, + } +} + +/* Process a packet. + * packet holds packet data. + */ +#[no_mangle] +pub unsafe extern "C" fn neqo_http3conn_process_input( + conn: &mut NeqoHttp3Conn, + remote_addr: *const NetAddr, + packet: *const ThinVec<u8>, +) -> nsresult { + let remote = match netaddr_to_socket_addr(remote_addr) { + Ok(addr) => addr, + Err(result) => return result, + }; + conn.conn.process_input( + Datagram::new(remote, conn.local_addr, (*packet).to_vec()), + Instant::now(), + ); + return NS_OK; +} + +/* Process output: + * this may return a packet that needs to be sent or a timeout. + * if it returns a packet the function returns true, otherwise it returns false. + */ +#[no_mangle] +pub extern "C" fn neqo_http3conn_process_output( + conn: &mut NeqoHttp3Conn, + remote_addr: &mut nsACString, + remote_port: &mut u16, + packet: &mut ThinVec<u8>, + timeout: &mut u64, +) -> bool { + match conn.conn.process_output(Instant::now()) { + Output::Datagram(dg) => { + packet.extend_from_slice(&dg); + remote_addr.append(&dg.destination().ip().to_string()); + *remote_port = dg.destination().port(); + true + } + Output::Callback(to) => { + *timeout = to.as_millis() as u64; + // Necko resolution is in milliseconds whereas neqo resolution + // is in nanoseconds. If we called process_output too soon due + // to this difference, we might do few unnecessary loops until + // we waste the remaining time. To avoid it, we return 1ms when + // the timeout is less than 1ms. + if *timeout == 0 { + *timeout = 1; + } + false + } + Output::None => { + *timeout = std::u64::MAX; + false + } + } +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_close(conn: &mut NeqoHttp3Conn, error: u64) { + conn.conn.close(Instant::now(), error, ""); +} + +fn is_excluded_header(name: &str) -> bool { + if (name == "connection") + || (name == "host") + || (name == "keep-alive") + || (name == "proxy-connection") + || (name == "te") + || (name == "transfer-encoding") + || (name == "upgrade") + || (name == "sec-websocket-key") + { + true + } else { + false + } +} + +fn parse_headers(headers: &nsACString) -> Result<Vec<Header>, nsresult> { + let mut hdrs = Vec::new(); + // this is only used for headers built by Firefox. + // Firefox supplies all headers already prepared for sending over http1. + // They need to be split into (String, String) pairs. + match str::from_utf8(headers) { + Err(_) => { + return Err(NS_ERROR_INVALID_ARG); + } + Ok(h) => { + for elem in h.split("\r\n").skip(1) { + if elem.starts_with(':') { + // colon headers are for http/2 and 3 and this is http/1 + // input, so that is probably a smuggling attack of some + // kind. + continue; + } + if elem.len() == 0 { + continue; + } + let hdr_str: Vec<_> = elem.splitn(2, ":").collect(); + let name = hdr_str[0].trim().to_lowercase(); + if is_excluded_header(&name) { + continue; + } + let value = if hdr_str.len() > 1 { + String::from(hdr_str[1].trim()) + } else { + String::new() + }; + + hdrs.push(Header::new(name, value)); + } + } + } + Ok(hdrs) +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_fetch( + conn: &mut NeqoHttp3Conn, + method: &nsACString, + scheme: &nsACString, + host: &nsACString, + path: &nsACString, + headers: &nsACString, + stream_id: &mut u64, + urgency: u8, + incremental: bool, +) -> nsresult { + let hdrs = match parse_headers(headers) { + Err(e) => { + return e; + } + Ok(h) => h, + }; + let method_tmp = match str::from_utf8(method) { + Ok(m) => m, + Err(_) => { + return NS_ERROR_INVALID_ARG; + } + }; + let scheme_tmp = match str::from_utf8(scheme) { + Ok(s) => s, + Err(_) => { + return NS_ERROR_INVALID_ARG; + } + }; + let host_tmp = match str::from_utf8(host) { + Ok(h) => h, + Err(_) => { + return NS_ERROR_INVALID_ARG; + } + }; + let path_tmp = match str::from_utf8(path) { + Ok(p) => p, + Err(_) => { + return NS_ERROR_INVALID_ARG; + } + }; + if urgency >= 8 { + return NS_ERROR_INVALID_ARG; + } + let priority = Priority::new(urgency, incremental); + match conn.conn.fetch( + Instant::now(), + method_tmp, + &(scheme_tmp, host_tmp, path_tmp), + &hdrs, + priority, + ) { + Ok(id) => { + *stream_id = id.as_u64(); + NS_OK + } + Err(Http3Error::StreamLimitError) => NS_BASE_STREAM_WOULD_BLOCK, + Err(_) => NS_ERROR_UNEXPECTED, + } +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_priority_update( + conn: &mut NeqoHttp3Conn, + stream_id: u64, + urgency: u8, + incremental: bool, +) -> nsresult { + if urgency >= 8 { + return NS_ERROR_INVALID_ARG; + } + let priority = Priority::new(urgency, incremental); + match conn + .conn + .priority_update(StreamId::from(stream_id), priority) + { + Ok(_) => NS_OK, + Err(_) => NS_ERROR_UNEXPECTED, + } +} + +#[no_mangle] +pub unsafe extern "C" fn neqo_htttp3conn_send_request_body( + conn: &mut NeqoHttp3Conn, + stream_id: u64, + buf: *const u8, + len: u32, + read: &mut u32, +) -> nsresult { + let array = slice::from_raw_parts(buf, len as usize); + match conn.conn.send_data(StreamId::from(stream_id), array) { + Ok(amount) => { + *read = u32::try_from(amount).unwrap(); + if amount == 0 { + NS_BASE_STREAM_WOULD_BLOCK + } else { + NS_OK + } + } + Err(_) => NS_ERROR_UNEXPECTED, + } +} + +fn crypto_error_code(err: neqo_crypto::Error) -> u64 { + match err { + neqo_crypto::Error::AeadError => 1, + neqo_crypto::Error::CertificateLoading => 2, + neqo_crypto::Error::CreateSslSocket => 3, + neqo_crypto::Error::HkdfError => 4, + neqo_crypto::Error::InternalError => 5, + neqo_crypto::Error::IntegerOverflow => 6, + neqo_crypto::Error::InvalidEpoch => 7, + neqo_crypto::Error::MixedHandshakeMethod => 8, + neqo_crypto::Error::NoDataAvailable => 9, + neqo_crypto::Error::NssError { .. } => 10, + neqo_crypto::Error::OverrunError => 11, + neqo_crypto::Error::SelfEncryptFailure => 12, + neqo_crypto::Error::TimeTravelError => 13, + neqo_crypto::Error::UnsupportedCipher => 14, + neqo_crypto::Error::UnsupportedVersion => 15, + neqo_crypto::Error::StringError => 16, + neqo_crypto::Error::EchRetry(_) => 17, + neqo_crypto::Error::CipherInitFailure => 18, + } +} + +// This is only used for telemetry. Therefore we only return error code +// numbers and do not label them. Recording telemetry is easier with a +// number. +#[repr(C)] +pub enum CloseError { + TransportInternalError(u16), + TransportInternalErrorOther(u16), + TransportError(u64), + CryptoError(u64), + CryptoAlert(u8), + PeerAppError(u64), + PeerError(u64), + AppError(u64), + EchRetry, +} + +impl From<TransportError> for CloseError { + fn from(error: TransportError) -> CloseError { + match error { + TransportError::InternalError(c) => CloseError::TransportInternalError(c), + TransportError::CryptoError(neqo_crypto::Error::EchRetry(_)) => CloseError::EchRetry, + TransportError::CryptoError(c) => CloseError::CryptoError(crypto_error_code(c)), + TransportError::CryptoAlert(c) => CloseError::CryptoAlert(c), + TransportError::PeerApplicationError(c) => CloseError::PeerAppError(c), + TransportError::PeerError(c) => CloseError::PeerError(c), + TransportError::NoError + | TransportError::IdleTimeout + | TransportError::ConnectionRefused + | TransportError::FlowControlError + | TransportError::StreamLimitError + | TransportError::StreamStateError + | TransportError::FinalSizeError + | TransportError::FrameEncodingError + | TransportError::TransportParameterError + | TransportError::ProtocolViolation + | TransportError::InvalidToken + | TransportError::KeysExhausted + | TransportError::ApplicationError + | TransportError::NoAvailablePath => CloseError::TransportError(error.code()), + TransportError::EchRetry(_) => CloseError::EchRetry, + TransportError::AckedUnsentPacket => CloseError::TransportInternalErrorOther(0), + TransportError::ConnectionIdLimitExceeded => CloseError::TransportInternalErrorOther(1), + TransportError::ConnectionIdsExhausted => CloseError::TransportInternalErrorOther(2), + TransportError::ConnectionState => CloseError::TransportInternalErrorOther(3), + TransportError::DecodingFrame => CloseError::TransportInternalErrorOther(4), + TransportError::DecryptError => CloseError::TransportInternalErrorOther(5), + TransportError::HandshakeFailed => CloseError::TransportInternalErrorOther(6), + TransportError::IntegerOverflow => CloseError::TransportInternalErrorOther(7), + TransportError::InvalidInput => CloseError::TransportInternalErrorOther(8), + TransportError::InvalidMigration => CloseError::TransportInternalErrorOther(9), + TransportError::InvalidPacket => CloseError::TransportInternalErrorOther(10), + TransportError::InvalidResumptionToken => CloseError::TransportInternalErrorOther(11), + TransportError::InvalidRetry => CloseError::TransportInternalErrorOther(12), + TransportError::InvalidStreamId => CloseError::TransportInternalErrorOther(13), + TransportError::KeysDiscarded(_) => CloseError::TransportInternalErrorOther(14), + TransportError::KeysPending(_) => CloseError::TransportInternalErrorOther(15), + TransportError::KeyUpdateBlocked => CloseError::TransportInternalErrorOther(16), + TransportError::NoMoreData => CloseError::TransportInternalErrorOther(17), + TransportError::NotConnected => CloseError::TransportInternalErrorOther(18), + TransportError::PacketNumberOverlap => CloseError::TransportInternalErrorOther(19), + TransportError::StatelessReset => CloseError::TransportInternalErrorOther(20), + TransportError::TooMuchData => CloseError::TransportInternalErrorOther(21), + TransportError::UnexpectedMessage => CloseError::TransportInternalErrorOther(22), + TransportError::UnknownConnectionId => CloseError::TransportInternalErrorOther(23), + TransportError::UnknownFrameType => CloseError::TransportInternalErrorOther(24), + TransportError::VersionNegotiation => CloseError::TransportInternalErrorOther(25), + TransportError::WrongRole => CloseError::TransportInternalErrorOther(26), + TransportError::QlogError => CloseError::TransportInternalErrorOther(27), + TransportError::NotAvailable => CloseError::TransportInternalErrorOther(28), + TransportError::DisabledVersion => CloseError::TransportInternalErrorOther(29), + } + } +} + +impl From<neqo_transport::ConnectionError> for CloseError { + fn from(error: neqo_transport::ConnectionError) -> CloseError { + match error { + neqo_transport::ConnectionError::Transport(c) => c.into(), + neqo_transport::ConnectionError::Application(c) => CloseError::AppError(c), + } + } +} + +// Reset a stream with streamId. +#[no_mangle] +pub extern "C" fn neqo_http3conn_cancel_fetch( + conn: &mut NeqoHttp3Conn, + stream_id: u64, + error: u64, +) -> nsresult { + match conn.conn.cancel_fetch(StreamId::from(stream_id), error) { + Ok(()) => NS_OK, + Err(_) => NS_ERROR_INVALID_ARG, + } +} + +// Reset a stream with streamId. +#[no_mangle] +pub extern "C" fn neqo_http3conn_reset_stream( + conn: &mut NeqoHttp3Conn, + stream_id: u64, + error: u64, +) -> nsresult { + match conn + .conn + .stream_reset_send(StreamId::from(stream_id), error) + { + Ok(()) => NS_OK, + Err(_) => NS_ERROR_INVALID_ARG, + } +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_stream_stop_sending( + conn: &mut NeqoHttp3Conn, + stream_id: u64, + error: u64, +) -> nsresult { + match conn + .conn + .stream_stop_sending(StreamId::from(stream_id), error) + { + Ok(()) => NS_OK, + Err(_) => NS_ERROR_INVALID_ARG, + } +} + +// Close sending side of a stream with stream_id +#[no_mangle] +pub extern "C" fn neqo_http3conn_close_stream( + conn: &mut NeqoHttp3Conn, + stream_id: u64, +) -> nsresult { + match conn.conn.stream_close_send(StreamId::from(stream_id)) { + Ok(()) => NS_OK, + Err(_) => NS_ERROR_INVALID_ARG, + } +} + +// WebTransport streams can be unidirectional and bidirectional. +// It is mapped to and from neqo's StreamType enum. +#[repr(C)] +pub enum WebTransportStreamType { + BiDi, + UniDi, +} + +impl From<StreamType> for WebTransportStreamType { + fn from(t: StreamType) -> WebTransportStreamType { + match t { + StreamType::BiDi => WebTransportStreamType::BiDi, + StreamType::UniDi => WebTransportStreamType::UniDi, + } + } +} + +impl From<WebTransportStreamType> for StreamType { + fn from(t: WebTransportStreamType) -> StreamType { + match t { + WebTransportStreamType::BiDi => StreamType::BiDi, + WebTransportStreamType::UniDi => StreamType::UniDi, + } + } +} + +#[repr(C)] +pub enum SessionCloseReasonExternal { + Error(u64), + Status(u16), + Clean(u32), +} + +impl SessionCloseReasonExternal { + fn new(reason: SessionCloseReason, data: &mut ThinVec<u8>) -> SessionCloseReasonExternal { + match reason { + SessionCloseReason::Error(e) => SessionCloseReasonExternal::Error(e), + SessionCloseReason::Status(s) => SessionCloseReasonExternal::Status(s), + SessionCloseReason::Clean { error, message } => { + data.extend_from_slice(message.as_ref()); + SessionCloseReasonExternal::Clean(error) + } + } + } +} + +#[repr(C)] +pub enum WebTransportEventExternal { + Negotiated(bool), + Session(u64), + SessionClosed { + stream_id: u64, + reason: SessionCloseReasonExternal, + }, + NewStream { + stream_id: u64, + stream_type: WebTransportStreamType, + session_id: u64, + }, + Datagram { + session_id: u64, + }, +} + +impl WebTransportEventExternal { + fn new(event: WebTransportEvent, data: &mut ThinVec<u8>) -> WebTransportEventExternal { + match event { + WebTransportEvent::Negotiated(n) => WebTransportEventExternal::Negotiated(n), + WebTransportEvent::Session { + stream_id, + status, + headers: _, + } => { + data.extend_from_slice(b"HTTP/3 "); + data.extend_from_slice(&status.to_string().as_bytes()); + data.extend_from_slice(b"\r\n\r\n"); + WebTransportEventExternal::Session(stream_id.as_u64()) + } + WebTransportEvent::SessionClosed { + stream_id, + reason, + headers: _, + } => match reason { + SessionCloseReason::Status(status) => { + data.extend_from_slice(b"HTTP/3 "); + data.extend_from_slice(&status.to_string().as_bytes()); + data.extend_from_slice(b"\r\n\r\n"); + WebTransportEventExternal::Session(stream_id.as_u64()) + } + _ => WebTransportEventExternal::SessionClosed { + stream_id: stream_id.as_u64(), + reason: SessionCloseReasonExternal::new(reason, data), + }, + }, + WebTransportEvent::NewStream { + stream_id, + session_id, + } => WebTransportEventExternal::NewStream { + stream_id: stream_id.as_u64(), + stream_type: stream_id.stream_type().into(), + session_id: session_id.as_u64(), + }, + WebTransportEvent::Datagram { + session_id, + datagram, + } => { + data.extend_from_slice(datagram.as_ref()); + WebTransportEventExternal::Datagram { + session_id: session_id.as_u64(), + } + } + } + } +} + +#[repr(C)] +pub enum Http3Event { + /// A request stream has space for more data to be sent. + DataWritable { + stream_id: u64, + }, + /// A server has send STOP_SENDING frame. + StopSending { + stream_id: u64, + error: u64, + }, + HeaderReady { + stream_id: u64, + fin: bool, + interim: bool, + }, + /// New bytes available for reading. + DataReadable { + stream_id: u64, + }, + /// Peer reset the stream. + Reset { + stream_id: u64, + error: u64, + local: bool, + }, + /// A PushPromise + PushPromise { + push_id: u64, + request_stream_id: u64, + }, + /// A push response headers are ready. + PushHeaderReady { + push_id: u64, + fin: bool, + }, + /// New bytes are available on a push stream for reading. + PushDataReadable { + push_id: u64, + }, + /// A push has been canceled. + PushCanceled { + push_id: u64, + }, + PushReset { + push_id: u64, + error: u64, + }, + RequestsCreatable, + AuthenticationNeeded, + ZeroRttRejected, + ConnectionConnected, + GoawayReceived, + ConnectionClosing { + error: CloseError, + }, + ConnectionClosed { + error: CloseError, + }, + ResumptionToken { + expire_in: u64, // microseconds + }, + EchFallbackAuthenticationNeeded, + WebTransport(WebTransportEventExternal), + NoEvent, +} + +fn sanitize_header(mut y: Cow<[u8]>) -> Cow<[u8]> { + for i in 0..y.len() { + if matches!(y[i], b'\n' | b'\r' | b'\0') { + y.to_mut()[i] = b' '; + } + } + y +} + +fn convert_h3_to_h1_headers(headers: Vec<Header>, ret_headers: &mut ThinVec<u8>) -> nsresult { + if headers.iter().filter(|&h| h.name() == ":status").count() != 1 { + return NS_ERROR_ILLEGAL_VALUE; + } + + let status_val = headers + .iter() + .find(|&h| h.name() == ":status") + .expect("must be one") + .value(); + + ret_headers.extend_from_slice(b"HTTP/3 "); + ret_headers.extend_from_slice(status_val.as_bytes()); + ret_headers.extend_from_slice(b"\r\n"); + + for hdr in headers.iter().filter(|&h| h.name() != ":status") { + ret_headers.extend_from_slice(&sanitize_header(Cow::from(hdr.name().as_bytes()))); + ret_headers.extend_from_slice(b": "); + ret_headers.extend_from_slice(&sanitize_header(Cow::from(hdr.value().as_bytes()))); + ret_headers.extend_from_slice(b"\r\n"); + } + ret_headers.extend_from_slice(b"\r\n"); + return NS_OK; +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_event( + conn: &mut NeqoHttp3Conn, + ret_event: &mut Http3Event, + data: &mut ThinVec<u8>, +) -> nsresult { + while let Some(evt) = conn.conn.next_event() { + let fe = match evt { + Http3ClientEvent::DataWritable { stream_id } => Http3Event::DataWritable { + stream_id: stream_id.as_u64(), + }, + Http3ClientEvent::StopSending { stream_id, error } => Http3Event::StopSending { + stream_id: stream_id.as_u64(), + error, + }, + Http3ClientEvent::HeaderReady { + stream_id, + headers, + fin, + interim, + } => { + let res = convert_h3_to_h1_headers(headers, data); + if res != NS_OK { + return res; + } + Http3Event::HeaderReady { + stream_id: stream_id.as_u64(), + fin, + interim, + } + } + Http3ClientEvent::DataReadable { stream_id } => Http3Event::DataReadable { + stream_id: stream_id.as_u64(), + }, + Http3ClientEvent::Reset { + stream_id, + error, + local, + } => Http3Event::Reset { + stream_id: stream_id.as_u64(), + error, + local, + }, + Http3ClientEvent::PushPromise { + push_id, + request_stream_id, + headers, + } => { + let res = convert_h3_to_h1_headers(headers, data); + if res != NS_OK { + return res; + } + Http3Event::PushPromise { + push_id, + request_stream_id: request_stream_id.as_u64(), + } + } + Http3ClientEvent::PushHeaderReady { + push_id, + headers, + fin, + interim, + } => { + if interim { + Http3Event::NoEvent + } else { + let res = convert_h3_to_h1_headers(headers, data); + if res != NS_OK { + return res; + } + Http3Event::PushHeaderReady { push_id, fin } + } + } + Http3ClientEvent::PushDataReadable { push_id } => { + Http3Event::PushDataReadable { push_id } + } + Http3ClientEvent::PushCanceled { push_id } => Http3Event::PushCanceled { push_id }, + Http3ClientEvent::PushReset { push_id, error } => { + Http3Event::PushReset { push_id, error } + } + Http3ClientEvent::RequestsCreatable => Http3Event::RequestsCreatable, + Http3ClientEvent::AuthenticationNeeded => Http3Event::AuthenticationNeeded, + Http3ClientEvent::ZeroRttRejected => Http3Event::ZeroRttRejected, + Http3ClientEvent::ResumptionToken(token) => { + // expiration_time time is Instant, transform it into microseconds it will + // be valid for. Necko code will add the value to PR_Now() to get the expiration + // time in PRTime. + if token.expiration_time() > Instant::now() { + let e = (token.expiration_time() - Instant::now()).as_micros(); + if let Ok(expire_in) = u64::try_from(e) { + data.extend_from_slice(token.as_ref()); + Http3Event::ResumptionToken { expire_in } + } else { + Http3Event::NoEvent + } + } else { + Http3Event::NoEvent + } + } + Http3ClientEvent::GoawayReceived => Http3Event::GoawayReceived, + Http3ClientEvent::StateChange(state) => match state { + Http3State::Connected => Http3Event::ConnectionConnected, + Http3State::Closing(error_code) => { + match error_code { + neqo_transport::ConnectionError::Transport( + TransportError::CryptoError(neqo_crypto::Error::EchRetry(ref c)), + ) + | neqo_transport::ConnectionError::Transport(TransportError::EchRetry( + ref c, + )) => { + data.extend_from_slice(c.as_ref()); + } + _ => {} + } + Http3Event::ConnectionClosing { + error: error_code.into(), + } + } + Http3State::Closed(error_code) => { + match error_code { + neqo_transport::ConnectionError::Transport( + TransportError::CryptoError(neqo_crypto::Error::EchRetry(ref c)), + ) + | neqo_transport::ConnectionError::Transport(TransportError::EchRetry( + ref c, + )) => { + data.extend_from_slice(c.as_ref()); + } + _ => {} + } + Http3Event::ConnectionClosed { + error: error_code.into(), + } + } + _ => Http3Event::NoEvent, + }, + Http3ClientEvent::EchFallbackAuthenticationNeeded { public_name } => { + data.extend_from_slice(public_name.as_ref()); + Http3Event::EchFallbackAuthenticationNeeded + } + Http3ClientEvent::WebTransport(e) => { + Http3Event::WebTransport(WebTransportEventExternal::new(e, data)) + } + }; + + if !matches!(fe, Http3Event::NoEvent) { + *ret_event = fe; + return NS_OK; + } + } + + *ret_event = Http3Event::NoEvent; + NS_OK +} + +// Read response data into buf. +#[no_mangle] +pub unsafe extern "C" fn neqo_http3conn_read_response_data( + conn: &mut NeqoHttp3Conn, + stream_id: u64, + buf: *mut u8, + len: u32, + read: &mut u32, + fin: &mut bool, +) -> nsresult { + let array = slice::from_raw_parts_mut(buf, len as usize); + match conn + .conn + .read_data(Instant::now(), StreamId::from(stream_id), &mut array[..]) + { + Ok((amount, fin_recvd)) => { + *read = u32::try_from(amount).unwrap(); + *fin = fin_recvd; + if (amount == 0) && !fin_recvd { + NS_BASE_STREAM_WOULD_BLOCK + } else { + NS_OK + } + } + Err(Http3Error::InvalidStreamId) + | Err(Http3Error::TransportError(TransportError::NoMoreData)) => NS_ERROR_INVALID_ARG, + Err(_) => NS_ERROR_NET_HTTP3_PROTOCOL_ERROR, + } +} + +#[repr(C)] +pub struct NeqoSecretInfo { + set: bool, + version: u16, + cipher: u16, + group: u16, + resumed: bool, + early_data: bool, + alpn: nsCString, + signature_scheme: u16, + ech_accepted: bool, +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_tls_info( + conn: &mut NeqoHttp3Conn, + sec_info: &mut NeqoSecretInfo, +) -> nsresult { + match conn.conn.tls_info() { + Some(info) => { + sec_info.set = true; + sec_info.version = info.version(); + sec_info.cipher = info.cipher_suite(); + sec_info.group = info.key_exchange(); + sec_info.resumed = info.resumed(); + sec_info.early_data = info.early_data_accepted(); + sec_info.alpn = match info.alpn() { + Some(a) => nsCString::from(a), + None => nsCString::new(), + }; + sec_info.signature_scheme = info.signature_scheme(); + sec_info.ech_accepted = info.ech_accepted(); + NS_OK + } + None => NS_ERROR_NOT_AVAILABLE, + } +} + +#[repr(C)] +pub struct NeqoCertificateInfo { + certs: ThinVec<ThinVec<u8>>, + stapled_ocsp_responses_present: bool, + stapled_ocsp_responses: ThinVec<ThinVec<u8>>, + signed_cert_timestamp_present: bool, + signed_cert_timestamp: ThinVec<u8>, +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_peer_certificate_info( + conn: &mut NeqoHttp3Conn, + neqo_certs_info: &mut NeqoCertificateInfo, +) -> nsresult { + let mut certs_info = match conn.conn.peer_certificate() { + Some(certs) => certs, + None => return NS_ERROR_NOT_AVAILABLE, + }; + + neqo_certs_info.certs = certs_info + .map(|cert| cert.iter().cloned().collect()) + .collect(); + + match &mut certs_info.stapled_ocsp_responses() { + Some(ocsp_val) => { + neqo_certs_info.stapled_ocsp_responses_present = true; + neqo_certs_info.stapled_ocsp_responses = ocsp_val + .iter() + .map(|ocsp| ocsp.iter().cloned().collect()) + .collect(); + } + None => { + neqo_certs_info.stapled_ocsp_responses_present = false; + } + }; + + match certs_info.signed_cert_timestamp() { + Some(sct_val) => { + neqo_certs_info.signed_cert_timestamp_present = true; + neqo_certs_info + .signed_cert_timestamp + .extend_from_slice(sct_val); + } + None => { + neqo_certs_info.signed_cert_timestamp_present = false; + } + }; + + NS_OK +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_authenticated(conn: &mut NeqoHttp3Conn, error: PRErrorCode) { + conn.conn.authenticated(error.into(), Instant::now()); +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_set_resumption_token( + conn: &mut NeqoHttp3Conn, + token: &mut ThinVec<u8>, +) { + let _ = conn.conn.enable_resumption(Instant::now(), token); +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_set_ech_config( + conn: &mut NeqoHttp3Conn, + ech_config: &mut ThinVec<u8>, +) { + let _ = conn.conn.enable_ech(ech_config); +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_is_zero_rtt(conn: &mut NeqoHttp3Conn) -> bool { + conn.conn.state() == Http3State::ZeroRtt +} + +#[repr(C)] +#[derive(Default)] +pub struct Http3Stats { + /// Total packets received, including all the bad ones. + pub packets_rx: usize, + /// Duplicate packets received. + pub dups_rx: usize, + /// Dropped packets or dropped garbage. + pub dropped_rx: usize, + /// The number of packet that were saved for later processing. + pub saved_datagrams: usize, + /// Total packets sent. + pub packets_tx: usize, + /// Total number of packets that are declared lost. + pub lost: usize, + /// Late acknowledgments, for packets that were declared lost already. + pub late_ack: usize, + /// Acknowledgments for packets that contained data that was marked + /// for retransmission when the PTO timer popped. + pub pto_ack: usize, + /// Count PTOs. Single PTOs, 2 PTOs in a row, 3 PTOs in row, etc. are counted + /// separately. + pub pto_counts: [usize; 16], +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_get_stats(conn: &mut NeqoHttp3Conn, stats: &mut Http3Stats) { + let t_stats = conn.conn.transport_stats(); + stats.packets_rx = t_stats.packets_rx; + stats.dups_rx = t_stats.dups_rx; + stats.dropped_rx = t_stats.dropped_rx; + stats.saved_datagrams = t_stats.saved_datagrams; + stats.packets_tx = t_stats.packets_tx; + stats.lost = t_stats.lost; + stats.late_ack = t_stats.late_ack; + stats.pto_ack = t_stats.pto_ack; + stats.pto_counts = t_stats.pto_counts; +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_webtransport_create_session( + conn: &mut NeqoHttp3Conn, + host: &nsACString, + path: &nsACString, + headers: &nsACString, + stream_id: &mut u64, +) -> nsresult { + let hdrs = match parse_headers(headers) { + Err(e) => { + return e; + } + Ok(h) => h, + }; + let host_tmp = match str::from_utf8(host) { + Ok(h) => h, + Err(_) => { + return NS_ERROR_INVALID_ARG; + } + }; + let path_tmp = match str::from_utf8(path) { + Ok(p) => p, + Err(_) => { + return NS_ERROR_INVALID_ARG; + } + }; + + match conn.conn.webtransport_create_session( + Instant::now(), + &("https", host_tmp, path_tmp), + &hdrs, + ) { + Ok(id) => { + *stream_id = id.as_u64(); + NS_OK + } + Err(Http3Error::StreamLimitError) => NS_BASE_STREAM_WOULD_BLOCK, + Err(_) => NS_ERROR_UNEXPECTED, + } +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_webtransport_close_session( + conn: &mut NeqoHttp3Conn, + session_id: u64, + error: u32, + message: &nsACString, +) -> nsresult { + let message_tmp = match str::from_utf8(message) { + Ok(p) => p, + Err(_) => { + return NS_ERROR_INVALID_ARG; + } + }; + match conn + .conn + .webtransport_close_session(StreamId::from(session_id), error, message_tmp) + { + Ok(()) => NS_OK, + Err(_) => NS_ERROR_INVALID_ARG, + } +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_webtransport_create_stream( + conn: &mut NeqoHttp3Conn, + session_id: u64, + stream_type: WebTransportStreamType, + stream_id: &mut u64, +) -> nsresult { + match conn + .conn + .webtransport_create_stream(StreamId::from(session_id), stream_type.into()) + { + Ok(id) => { + *stream_id = id.as_u64(); + NS_OK + } + Err(Http3Error::StreamLimitError) => NS_BASE_STREAM_WOULD_BLOCK, + Err(_) => NS_ERROR_UNEXPECTED, + } +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_webtransport_send_datagram( + conn: &mut NeqoHttp3Conn, + session_id: u64, + data: &mut ThinVec<u8>, + tracking_id: u64, +) -> nsresult { + let id = if tracking_id == 0 { + None + } else { + Some(tracking_id) + }; + match conn + .conn + .webtransport_send_datagram(StreamId::from(session_id), data, id) + { + Ok(()) => NS_OK, + Err(Http3Error::TransportError(TransportError::TooMuchData)) => NS_ERROR_NOT_AVAILABLE, + Err(_) => NS_ERROR_UNEXPECTED, + } +} + +#[no_mangle] +pub extern "C" fn neqo_http3conn_webtransport_max_datagram_size( + conn: &mut NeqoHttp3Conn, + session_id: u64, + result: &mut u64, +) -> nsresult { + match conn + .conn + .webtransport_max_datagram_size(StreamId::from(session_id)) + { + Ok(size) => { + *result = size; + NS_OK + } + Err(_) => NS_ERROR_UNEXPECTED, + } +} diff --git a/netwerk/socket/nsINamedPipeService.idl b/netwerk/socket/nsINamedPipeService.idl new file mode 100644 index 0000000000..9db557577d --- /dev/null +++ b/netwerk/socket/nsINamedPipeService.idl @@ -0,0 +1,77 @@ +/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nsISupports.idl" +#include "nsrootidl.idl" + +/** + * nsINamedPipeDataObserver + * + * This is the callback interface for nsINamedPipeService. + * The functions are called by the internal thread in the nsINamedPipeService. + */ +[scriptable, uuid(de4f460b-94fd-442c-9002-1637beb2185a)] +interface nsINamedPipeDataObserver : nsISupports +{ + /** + * onDataAvailable + * + * @param aBytesTransferred + * Transfered bytes during last transmission. + * @param aOverlapped + * Corresponding overlapped structure used by the async I/O + */ + void onDataAvailable(in unsigned long aBytesTransferred, + in voidPtr aOverlapped); + + /** + * onError + * + * @param aError + * Error code of the error. + * @param aOverlapped + * Corresponding overlapped structure used by the async I/O + */ + void onError(in unsigned long aError, + in voidPtr aOverlapped); +}; + +/** + * nsINamedPipeService + */ +[scriptable, uuid(1bf19133-5625-4ac8-836a-80b1c215f72b)] +interface nsINamedPipeService : nsISupports +{ + /** + * addDataObserver + * + * @param aHandle + * The handle that is going to be monitored for read/write operations. + * Only handles that are opened with overlapped IO are supported. + * @param aObserver + * The observer of the handle, the service strong-refs of the observer. + */ + void addDataObserver(in voidPtr aHandle, + in nsINamedPipeDataObserver aObserver); + + /** + * removeDataObserver + * + * @param aHandle + The handle associated to the observer, and will be closed by the + service. + * @param aObserver + * The observer to be removed. + */ + void removeDataObserver(in voidPtr aHandle, + in nsINamedPipeDataObserver aObserver); + + /** + * isOnCurrentThread + * + * @return the caller runs within the internal thread. + */ + boolean isOnCurrentThread(); +}; diff --git a/netwerk/socket/nsISocketProvider.idl b/netwerk/socket/nsISocketProvider.idl new file mode 100644 index 0000000000..1f19b932f9 --- /dev/null +++ b/netwerk/socket/nsISocketProvider.idl @@ -0,0 +1,145 @@ +/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nsISupports.idl" + +interface nsIProxyInfo; +interface nsITLSSocketControl; +[ptr] native PRFileDescStar(struct PRFileDesc); +native OriginAttributes(mozilla::OriginAttributes); +[ref] native const_OriginAttributesRef(const mozilla::OriginAttributes); + +%{ C++ +#include "mozilla/BasePrincipal.h" +%} + +/** + * nsISocketProvider + */ +[scriptable, uuid(508d5469-9e1e-4a08-b5b0-7cfebba1e51a)] +interface nsISocketProvider : nsISupports +{ + /** + * newSocket + * + * @param aFamily + * The address family for this socket (PR_AF_INET or PR_AF_INET6). + * @param aHost + * The origin hostname for this connection. + * @param aPort + * The origin port for this connection. + * @param aProxyHost + * If non-null, the proxy hostname for this connection. + * @param aProxyPort + * The proxy port for this connection. + * @param aFlags + * Control flags that govern this connection (see below.) + * @param aTlsFlags + * An opaque flags for non-standard behavior of the TLS system. + * It is unlikely this will need to be set outside of telemetry + * studies relating to the TLS implementation. + * @param aFileDesc + * The resulting PRFileDesc. + * @param aTLSSocketControl + * TLS socket control object that should be associated with + * aFileDesc, if applicable. + */ + [noscript] + void newSocket(in long aFamily, + in string aHost, + in long aPort, + in nsIProxyInfo aProxy, + in const_OriginAttributesRef aOriginAttributes, + in unsigned long aFlags, + in unsigned long aTlsFlags, + out PRFileDescStar aFileDesc, + out nsITLSSocketControl aTLSSocketControl); + + /** + * addToSocket + * + * This function is called to allow the socket provider to layer a + * PRFileDesc on top of another PRFileDesc. For example, SSL via a SOCKS + * proxy. + * + * Parameters are the same as newSocket with the exception of aFileDesc, + * which is an in-param instead. + */ + [noscript] + void addToSocket(in long aFamily, + in string aHost, + in long aPort, + in nsIProxyInfo aProxy, + in const_OriginAttributesRef aOriginAttributes, + in unsigned long aFlags, + in unsigned long aTlsFlags, + in PRFileDescStar aFileDesc, + out nsITLSSocketControl aTLSSocketControl); + + /** + * PROXY_RESOLVES_HOST + * + * This flag is set if the proxy is to perform hostname resolution instead + * of the client. When set, the hostname parameter passed when in this + * interface will be used instead of the address structure passed for a + * later connect et al. request. + */ + const long PROXY_RESOLVES_HOST = 1 << 0; + + /** + * When setting this flag, the socket will not apply any + * credentials when establishing a connection. For example, + * an SSL connection would not send any client-certificates + * if this flag is set. + */ + const long ANONYMOUS_CONNECT = 1 << 1; + + /** + * If set, indicates that the connection was initiated from a source + * defined as being private in the sense of Private Browsing. Generally, + * there should be no state shared between connections that are private + * and those that are not; it is OK for multiple private connections + * to share state with each other, and it is OK for multiple non-private + * connections to share state with each other. + */ + const unsigned long NO_PERMANENT_STORAGE = 1 << 2; + + /** + * If set, do not use newer protocol features that might have interop problems + * on the Internet. Intended only for use with critical infra like the updater. + * default is false. + */ + const unsigned long BE_CONSERVATIVE = 1 << 3; + + /** + * This is used for a temporary workaround for a web-compat issue. The flag is + * only set on CORS preflight request to allowed sending client certificates + * on a connection for an anonymous request. + */ + const long ANONYMOUS_CONNECT_ALLOW_CLIENT_CERT = 1 << 4; + + /** + * If set, indicates that this is a speculative connection. + */ + const unsigned long IS_SPECULATIVE_CONNECTION = 1 << 5; + + /** + * If set, do not send an ECH extension (whether GREASE or 'real'). + * Currently false by default and is set when retrying failed connections. + */ + const unsigned long DONT_TRY_ECH = (1 << 10); + + /** + * If set, indicates that the connection is a retry. + */ + const unsigned long IS_RETRY = (1 << 11); + + /** + * If set, indicates that the connection used a privacy-preserving DNS + * transport such as DoH, DoQ or similar. Currently this field is + * set only when DoH is used via the TRR. + */ + const unsigned long USED_PRIVATE_DNS = (1 << 12); +}; diff --git a/netwerk/socket/nsISocketProviderService.idl b/netwerk/socket/nsISocketProviderService.idl new file mode 100644 index 0000000000..db38d299a5 --- /dev/null +++ b/netwerk/socket/nsISocketProviderService.idl @@ -0,0 +1,20 @@ +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 4 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nsISupports.idl" + +interface nsISocketProvider; + +/** + * nsISocketProviderService + * + * Provides a mapping between a socket type and its associated socket provider + * instance. One could also use the service manager directly. + */ +[scriptable, uuid(8f8a23d0-5472-11d3-bbc8-0000861d1237)] +interface nsISocketProviderService : nsISupports +{ + nsISocketProvider getSocketProvider(in string socketType); +}; diff --git a/netwerk/socket/nsNamedPipeIOLayer.cpp b/netwerk/socket/nsNamedPipeIOLayer.cpp new file mode 100644 index 0000000000..d229f0e23e --- /dev/null +++ b/netwerk/socket/nsNamedPipeIOLayer.cpp @@ -0,0 +1,864 @@ +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nsNamedPipeIOLayer.h" + +#include <algorithm> +#include <utility> + +#include "mozilla/Atomics.h" +#include "mozilla/DebugOnly.h" +#include "mozilla/Logging.h" +#include "mozilla/RefPtr.h" +#include "mozilla/Unused.h" +#include "mozilla/net/DNS.h" +#include "nsISupportsImpl.h" +#include "nsNamedPipeService.h" +#include "nsNativeCharsetUtils.h" +#include "nsNetCID.h" +#include "nsServiceManagerUtils.h" +#include "nsSocketTransportService2.h" +#include "nsString.h" +#include "nsThreadUtils.h" +#include "nspr.h" +#include "private/pprio.h" + +namespace mozilla { +namespace net { + +static mozilla::LazyLogModule gNamedPipeLog("NamedPipeWin"); +#define LOG_NPIO_DEBUG(...) \ + MOZ_LOG(gNamedPipeLog, mozilla::LogLevel::Debug, (__VA_ARGS__)) +#define LOG_NPIO_ERROR(...) \ + MOZ_LOG(gNamedPipeLog, mozilla::LogLevel::Error, (__VA_ARGS__)) + +PRDescIdentity nsNamedPipeLayerIdentity; +static PRIOMethods nsNamedPipeLayerMethods; + +class NamedPipeInfo final : public nsINamedPipeDataObserver { + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSINAMEDPIPEDATAOBSERVER + + explicit NamedPipeInfo(); + + nsresult Connect(const nsAString& aPath); + nsresult Disconnect(); + + /** + * Both blocking/non-blocking mode are supported in this class. + * The default mode is non-blocking mode, however, the client may change its + * mode to blocking mode during hand-shaking (e.g. nsSOCKSSocketInfo). + * + * In non-blocking mode, |Read| and |Write| should be called by clients only + * when |GetPollFlags| reports data availability. That is, the client calls + * |GetPollFlags| with |PR_POLL_READ| and/or |PR_POLL_WRITE| set, and + * according to the flags that set, |GetPollFlags| will check buffers status + * and decide corresponding actions: + * + * ------------------------------------------------------------------- + * | | data in buffer | empty buffer | + * |---------------+-------------------------+-----------------------| + * | PR_POLL_READ | out: PR_POLL_READ | DoRead/DoReadContinue | + * |---------------+-------------------------+-----------------------| + * | PR_POLL_WRITE | DoWrite/DoWriteContinue | out: PR_POLL_WRITE | + * ------------------------------------------+------------------------ + * + * |DoRead| and |DoWrite| initiate read/write operations asynchronously, and + * the |DoReadContinue| and |DoWriteContinue| are used to check the amount + * of the data are read/written to/from buffers. + * + * The output parameter and the return value of |GetPollFlags| are identical + * because we don't rely on the low-level select function to wait for data + * availability, we instead use nsNamedPipeService to poll I/O completeness. + * + * When client get |PR_POLL_READ| or |PR_POLL_WRITE| from |GetPollFlags|, + * they are able to use |Read| or |Write| to access the data in the buffer, + * and this is supposed to be very fast because no network traffic is + * involved. + * + * In blocking mode, the flow is quite similar to non-blocking mode, but + * |DoReadContinue| and |DoWriteContinue| are never been used since the + * operations are done synchronously, which could lead to slow responses. + */ + int32_t Read(void* aBuffer, int32_t aSize); + int32_t Write(const void* aBuffer, int32_t aSize); + + // Like Read, but doesn't remove data in internal buffer. + uint32_t Peek(void* aBuffer, int32_t aSize); + + // Number of bytes available to read in internal buffer. + int32_t Available() const; + + // Flush write buffer + // + // @return whether the buffer has been flushed + bool Sync(uint32_t aTimeout); + void SetNonblocking(bool nonblocking); + + bool IsConnected() const; + bool IsNonblocking() const; + HANDLE GetHandle() const; + + // Initiate and check current status for read/write operations. + int16_t GetPollFlags(int16_t aInFlags, int16_t* aOutFlags); + + private: + virtual ~NamedPipeInfo(); + + /** + * DoRead/DoWrite starts a read/write call synchronously or asynchronously + * depending on |mNonblocking|. In blocking mode, they return when the action + * has been done and in non-blocking mode it returns the number of bytes that + * were read/written if the operation is done immediately. If it takes some + * time to finish the operation, zero is returned and + * DoReadContinue/DoWriteContinue must be called to get async I/O result. + */ + int32_t DoRead(); + int32_t DoReadContinue(); + int32_t DoWrite(); + int32_t DoWriteContinue(); + + /** + * There was a write size limitation of named pipe, + * see https://support.microsoft.com/en-us/kb/119218 for more information. + * The limitation no longer exists, so feel free to change the value. + */ + static const uint32_t kBufferSize = 65536; + + nsCOMPtr<nsINamedPipeService> mNamedPipeService; + + HANDLE mPipe; // the handle to the named pipe. + OVERLAPPED mReadOverlapped; // used for asynchronous read operations. + OVERLAPPED mWriteOverlapped; // used for asynchronous write operations. + + uint8_t mReadBuffer[kBufferSize]; // octets read from pipe. + + /** + * These indicates the [begin, end) position of the data in the buffer. + */ + DWORD mReadBegin; + DWORD mReadEnd; + + bool mHasPendingRead; // previous asynchronous read is not finished yet. + + uint8_t mWriteBuffer[kBufferSize]; // octets to be written to pipe. + + /** + * These indicates the [begin, end) position of the data in the buffer. + */ + DWORD mWriteBegin; // how many bytes are already written. + DWORD mWriteEnd; // valid amount of data in the buffer. + + bool mHasPendingWrite; // previous asynchronous write is not finished yet. + + /** + * current blocking mode is non-blocking or not, accessed only in socket + * thread. + */ + bool mNonblocking; + + Atomic<DWORD> mErrorCode; // error code from Named Pipe Service. +}; + +NS_IMPL_ISUPPORTS(NamedPipeInfo, nsINamedPipeDataObserver) + +NamedPipeInfo::NamedPipeInfo() + : mNamedPipeService(NamedPipeService::GetOrCreate()), + mPipe(INVALID_HANDLE_VALUE), + mReadBegin(0), + mReadEnd(0), + mHasPendingRead(false), + mWriteBegin(0), + mWriteEnd(0), + mHasPendingWrite(false), + mNonblocking(true), + mErrorCode(0) { + MOZ_ASSERT(mNamedPipeService); + + ZeroMemory(&mReadOverlapped, sizeof(OVERLAPPED)); + ZeroMemory(&mWriteOverlapped, sizeof(OVERLAPPED)); +} + +NamedPipeInfo::~NamedPipeInfo() { MOZ_ASSERT(!mPipe); } + +// nsINamedPipeDataObserver + +NS_IMETHODIMP +NamedPipeInfo::OnDataAvailable(uint32_t aBytesTransferred, void* aOverlapped) { + DebugOnly<bool> isOnPipeServiceThread; + MOZ_ASSERT(NS_SUCCEEDED(mNamedPipeService->IsOnCurrentThread( + &isOnPipeServiceThread)) && + isOnPipeServiceThread); + + if (aOverlapped == &mReadOverlapped) { + LOG_NPIO_DEBUG("[%s] %p read %d bytes", __func__, this, aBytesTransferred); + } else if (aOverlapped == &mWriteOverlapped) { + LOG_NPIO_DEBUG("[%s] %p write %d bytes", __func__, this, aBytesTransferred); + } else { + MOZ_ASSERT(false, "invalid callback"); + mErrorCode = ERROR_INVALID_DATA; + return NS_ERROR_FAILURE; + } + + mErrorCode = ERROR_SUCCESS; + + // dispatch an empty event to trigger STS thread + gSocketTransportService->Dispatch( + NS_NewRunnableFunction("NamedPipeInfo::OnDataAvailable", [] {}), + NS_DISPATCH_NORMAL); + + return NS_OK; +} + +NS_IMETHODIMP +NamedPipeInfo::OnError(uint32_t aError, void* aOverlapped) { + DebugOnly<bool> isOnPipeServiceThread; + MOZ_ASSERT(NS_SUCCEEDED(mNamedPipeService->IsOnCurrentThread( + &isOnPipeServiceThread)) && + isOnPipeServiceThread); + + LOG_NPIO_ERROR("[%s] error code=%d", __func__, aError); + mErrorCode = aError; + + // dispatch an empty event to trigger STS thread + gSocketTransportService->Dispatch( + NS_NewRunnableFunction("NamedPipeInfo::OnError", [] {}), + NS_DISPATCH_NORMAL); + + return NS_OK; +} + +// Named pipe operations + +nsresult NamedPipeInfo::Connect(const nsAString& aPath) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + HANDLE pipe = + CreateFileW(PromiseFlatString(aPath).get(), GENERIC_READ | GENERIC_WRITE, + FILE_SHARE_READ | FILE_SHARE_WRITE, nullptr, OPEN_EXISTING, + FILE_FLAG_OVERLAPPED, nullptr); + + if (pipe == INVALID_HANDLE_VALUE) { + LOG_NPIO_ERROR("[%p] CreateFile error (%lu)", this, GetLastError()); + return NS_ERROR_FAILURE; + } + + DWORD pipeMode = PIPE_READMODE_MESSAGE; + if (!SetNamedPipeHandleState(pipe, &pipeMode, nullptr, nullptr)) { + LOG_NPIO_ERROR("[%p] SetNamedPipeHandleState error (%lu)", this, + GetLastError()); + CloseHandle(pipe); + return NS_ERROR_FAILURE; + } + + nsresult rv = mNamedPipeService->AddDataObserver(pipe, this); + if (NS_WARN_IF(NS_FAILED(rv))) { + CloseHandle(pipe); + return rv; + } + + HANDLE readEvent = CreateEventA(nullptr, TRUE, TRUE, "NamedPipeRead"); + if (NS_WARN_IF(!readEvent || readEvent == INVALID_HANDLE_VALUE)) { + CloseHandle(pipe); + return NS_ERROR_FAILURE; + } + + HANDLE writeEvent = CreateEventA(nullptr, TRUE, TRUE, "NamedPipeWrite"); + if (NS_WARN_IF(!writeEvent || writeEvent == INVALID_HANDLE_VALUE)) { + CloseHandle(pipe); + CloseHandle(readEvent); + return NS_ERROR_FAILURE; + } + + mPipe = pipe; + mReadOverlapped.hEvent = readEvent; + mWriteOverlapped.hEvent = writeEvent; + return NS_OK; +} + +nsresult NamedPipeInfo::Disconnect() { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + nsresult rv = mNamedPipeService->RemoveDataObserver(mPipe, this); + Unused << NS_WARN_IF(NS_FAILED(rv)); + + mPipe = nullptr; + + if (mReadOverlapped.hEvent && + mReadOverlapped.hEvent != INVALID_HANDLE_VALUE) { + CloseHandle(mReadOverlapped.hEvent); + mReadOverlapped.hEvent = nullptr; + } + + if (mWriteOverlapped.hEvent && + mWriteOverlapped.hEvent != INVALID_HANDLE_VALUE) { + CloseHandle(mWriteOverlapped.hEvent); + mWriteOverlapped.hEvent = nullptr; + } + + return rv; +} + +int32_t NamedPipeInfo::Read(void* aBuffer, int32_t aSize) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + int32_t bytesRead = Peek(aBuffer, aSize); + + if (bytesRead > 0) { + mReadBegin += bytesRead; + } + + return bytesRead; +} + +int32_t NamedPipeInfo::Write(const void* aBuffer, int32_t aSize) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + MOZ_ASSERT(mWriteBegin <= mWriteEnd); + + if (!IsConnected()) { + // pipe unconnected + PR_SetError(PR_NOT_CONNECTED_ERROR, 0); + return -1; + } + + if (mWriteBegin == mWriteEnd) { + mWriteBegin = mWriteEnd = 0; + } + + int32_t bytesToWrite = + std::min<int32_t>(aSize, sizeof(mWriteBuffer) - mWriteEnd); + MOZ_ASSERT(bytesToWrite >= 0); + + if (bytesToWrite == 0) { + PR_SetError(IsNonblocking() ? PR_WOULD_BLOCK_ERROR : PR_IO_PENDING_ERROR, + 0); + return -1; + } + + memcpy(&mWriteBuffer[mWriteEnd], aBuffer, bytesToWrite); + mWriteEnd += bytesToWrite; + + /** + * Triggers internal write operation by calling |GetPollFlags|. + * This is required for callers that use blocking I/O because they don't call + * |GetPollFlags| to write data, but this also works for non-blocking I/O. + */ + int16_t outFlag; + GetPollFlags(PR_POLL_WRITE, &outFlag); + + return bytesToWrite; +} + +uint32_t NamedPipeInfo::Peek(void* aBuffer, int32_t aSize) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + MOZ_ASSERT(mReadBegin <= mReadEnd); + + if (!IsConnected()) { + // pipe unconnected + PR_SetError(PR_NOT_CONNECTED_ERROR, 0); + return -1; + } + + /** + * If there's nothing in the read buffer, try to trigger internal read + * operation by calling |GetPollFlags|. This is required for callers that + * use blocking I/O because they don't call |GetPollFlags| to read data, + * but this also works for non-blocking I/O. + */ + if (!Available()) { + int16_t outFlag; + GetPollFlags(PR_POLL_READ, &outFlag); + + if (!(outFlag & PR_POLL_READ)) { + PR_SetError(IsNonblocking() ? PR_WOULD_BLOCK_ERROR : PR_IO_PENDING_ERROR, + 0); + return -1; + } + } + + // Available() can't return more than what fits to the buffer at the read + // offset. + int32_t bytesRead = std::min<int32_t>(aSize, Available()); + MOZ_ASSERT(bytesRead >= 0); + MOZ_ASSERT(mReadBegin + bytesRead <= mReadEnd); + memcpy(aBuffer, &mReadBuffer[mReadBegin], bytesRead); + return bytesRead; +} + +int32_t NamedPipeInfo::Available() const { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + MOZ_ASSERT(mReadBegin <= mReadEnd); + MOZ_ASSERT(mReadEnd - mReadBegin <= 0x7FFFFFFF); // no more than int32_max + return mReadEnd - mReadBegin; +} + +bool NamedPipeInfo::Sync(uint32_t aTimeout) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + if (!mHasPendingWrite) { + return true; + } + return WaitForSingleObject(mWriteOverlapped.hEvent, aTimeout) == + WAIT_OBJECT_0; +} + +void NamedPipeInfo::SetNonblocking(bool nonblocking) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + mNonblocking = nonblocking; +} + +bool NamedPipeInfo::IsConnected() const { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + return mPipe && mPipe != INVALID_HANDLE_VALUE; +} + +bool NamedPipeInfo::IsNonblocking() const { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + return mNonblocking; +} + +HANDLE +NamedPipeInfo::GetHandle() const { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + return mPipe; +} + +int16_t NamedPipeInfo::GetPollFlags(int16_t aInFlags, int16_t* aOutFlags) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + *aOutFlags = 0; + + if (aInFlags & PR_POLL_READ) { + int32_t bytesToRead = 0; + if (mReadBegin < mReadEnd) { // data in buffer and is ready to be read + bytesToRead = Available(); + } else if (mHasPendingRead) { // nonblocking I/O and has pending task + bytesToRead = DoReadContinue(); + } else { // read bufer is empty. + bytesToRead = DoRead(); + } + + if (bytesToRead > 0) { + *aOutFlags |= PR_POLL_READ; + } else if (bytesToRead < 0) { + *aOutFlags |= PR_POLL_ERR; + } + } + + if (aInFlags & PR_POLL_WRITE) { + int32_t bytesWritten = 0; + if (mHasPendingWrite) { // nonblocking I/O and has pending task. + bytesWritten = DoWriteContinue(); + } else if (mWriteBegin < mWriteEnd) { // data in buffer, ready to write + bytesWritten = DoWrite(); + } else { // write buffer is empty. + *aOutFlags |= PR_POLL_WRITE; + } + + if (bytesWritten < 0) { + *aOutFlags |= PR_POLL_ERR; + } else if (bytesWritten && !mHasPendingWrite && mWriteBegin == mWriteEnd) { + *aOutFlags |= PR_POLL_WRITE; + } + } + + return *aOutFlags; +} + +// @return: data has been read and is available +int32_t NamedPipeInfo::DoRead() { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + MOZ_ASSERT(!mHasPendingRead); + MOZ_ASSERT(mReadBegin == mReadEnd); // the buffer should be empty + + mReadBegin = 0; + mReadEnd = 0; + + BOOL success = ReadFile(mPipe, mReadBuffer, sizeof(mReadBuffer), &mReadEnd, + IsNonblocking() ? &mReadOverlapped : nullptr); + + if (success) { + LOG_NPIO_DEBUG("[%s][%p] %lu bytes read", __func__, this, mReadEnd); + return mReadEnd; + } + + switch (GetLastError()) { + case ERROR_MORE_DATA: // has more data to read + mHasPendingRead = true; + return DoReadContinue(); + + case ERROR_IO_PENDING: // read is pending + mHasPendingRead = true; + break; + + default: + LOG_NPIO_ERROR("[%s] ReadFile failed (%lu)", __func__, GetLastError()); + Disconnect(); + PR_SetError(PR_IO_ERROR, 0); + return -1; + } + + return 0; +} + +int32_t NamedPipeInfo::DoReadContinue() { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + MOZ_ASSERT(mHasPendingRead); + MOZ_ASSERT(mReadBegin == 0 && mReadEnd == 0); + + BOOL success; + success = GetOverlappedResult(mPipe, &mReadOverlapped, &mReadEnd, FALSE); + if (success) { + mHasPendingRead = false; + if (mReadEnd == 0) { + Disconnect(); + PR_SetError(PR_NOT_CONNECTED_ERROR, 0); + return -1; + } + + LOG_NPIO_DEBUG("[%s][%p] %lu bytes read", __func__, this, mReadEnd); + return mReadEnd; + } + + switch (GetLastError()) { + case ERROR_MORE_DATA: + mHasPendingRead = false; + LOG_NPIO_DEBUG("[%s][%p] %lu bytes read", __func__, this, mReadEnd); + return mReadEnd; + case ERROR_IO_INCOMPLETE: // still in progress + break; + default: + LOG_NPIO_ERROR("[%s]: GetOverlappedResult failed (%lu)", __func__, + GetLastError()); + Disconnect(); + PR_SetError(PR_IO_ERROR, 0); + return -1; + } + + return 0; +} + +int32_t NamedPipeInfo::DoWrite() { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + MOZ_ASSERT(!mHasPendingWrite); + MOZ_ASSERT(mWriteBegin < mWriteEnd); + + DWORD bytesWritten = 0; + BOOL success = + WriteFile(mPipe, &mWriteBuffer[mWriteBegin], mWriteEnd - mWriteBegin, + &bytesWritten, IsNonblocking() ? &mWriteOverlapped : nullptr); + + if (success) { + mWriteBegin += bytesWritten; + LOG_NPIO_DEBUG("[%s][%p] %lu bytes written", __func__, this, bytesWritten); + return bytesWritten; + } + + if (GetLastError() != ERROR_IO_PENDING) { + LOG_NPIO_ERROR("[%s] WriteFile failed (%lu)", __func__, GetLastError()); + Disconnect(); + PR_SetError(PR_IO_ERROR, 0); + return -1; + } + + mHasPendingWrite = true; + + return 0; +} + +int32_t NamedPipeInfo::DoWriteContinue() { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + MOZ_ASSERT(mHasPendingWrite); + + DWORD bytesWritten = 0; + BOOL success = + GetOverlappedResult(mPipe, &mWriteOverlapped, &bytesWritten, FALSE); + + if (!success) { + if (GetLastError() == ERROR_IO_INCOMPLETE) { + // still in progress + return 0; + } + + LOG_NPIO_ERROR("[%s] GetOverlappedResult failed (%lu)", __func__, + GetLastError()); + Disconnect(); + PR_SetError(PR_IO_ERROR, 0); + return -1; + } + + mHasPendingWrite = false; + mWriteBegin += bytesWritten; + LOG_NPIO_DEBUG("[%s][%p] %lu bytes written", __func__, this, bytesWritten); + return bytesWritten; +} + +static inline NamedPipeInfo* GetNamedPipeInfo(PRFileDesc* aFd) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + MOZ_DIAGNOSTIC_ASSERT(aFd); + MOZ_DIAGNOSTIC_ASSERT(aFd->secret); + MOZ_DIAGNOSTIC_ASSERT(PR_GetLayersIdentity(aFd) == nsNamedPipeLayerIdentity); + + if (!aFd || !aFd->secret || + PR_GetLayersIdentity(aFd) != nsNamedPipeLayerIdentity) { + LOG_NPIO_ERROR("cannot get named pipe info"); + return nullptr; + } + + return reinterpret_cast<NamedPipeInfo*>(aFd->secret); +} + +static PRStatus nsNamedPipeConnect(PRFileDesc* aFd, const PRNetAddr* aAddr, + PRIntervalTime aTimeout) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + NamedPipeInfo* info = GetNamedPipeInfo(aFd); + if (!info) { + PR_SetError(PR_BAD_DESCRIPTOR_ERROR, 0); + return PR_FAILURE; + } + + nsAutoString path; + if (NS_FAILED(NS_CopyNativeToUnicode(nsDependentCString(aAddr->local.path), + path))) { + PR_SetError(PR_OUT_OF_MEMORY_ERROR, 0); + return PR_FAILURE; + } + if (NS_WARN_IF(NS_FAILED(info->Connect(path)))) { + return PR_FAILURE; + } + + return PR_SUCCESS; +} + +static PRStatus nsNamedPipeConnectContinue(PRFileDesc* aFd, PRInt16 aOutFlags) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + return PR_SUCCESS; +} + +static PRStatus nsNamedPipeClose(PRFileDesc* aFd) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + if (aFd->secret && PR_GetLayersIdentity(aFd) == nsNamedPipeLayerIdentity) { + RefPtr<NamedPipeInfo> info = dont_AddRef(GetNamedPipeInfo(aFd)); + info->Disconnect(); + aFd->secret = nullptr; + aFd->identity = PR_INVALID_IO_LAYER; + } + + MOZ_ASSERT(!aFd->lower); + PR_Free(aFd); // PRFileDescs are allocated with PR_Malloc(). + + return PR_SUCCESS; +} + +static PRInt32 nsNamedPipeSend(PRFileDesc* aFd, const void* aBuffer, + PRInt32 aAmount, PRIntn aFlags, + PRIntervalTime aTimeout) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + Unused << aFlags; + Unused << aTimeout; + + NamedPipeInfo* info = GetNamedPipeInfo(aFd); + if (!info) { + PR_SetError(PR_BAD_DESCRIPTOR_ERROR, 0); + return -1; + } + return info->Write(aBuffer, aAmount); +} + +static PRInt32 nsNamedPipeRecv(PRFileDesc* aFd, void* aBuffer, PRInt32 aAmount, + PRIntn aFlags, PRIntervalTime aTimeout) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + Unused << aTimeout; + + NamedPipeInfo* info = GetNamedPipeInfo(aFd); + if (!info) { + PR_SetError(PR_BAD_DESCRIPTOR_ERROR, 0); + return -1; + } + + if (aFlags) { + if (aFlags != PR_MSG_PEEK) { + PR_SetError(PR_UNKNOWN_ERROR, 0); + return -1; + } + return info->Peek(aBuffer, aAmount); + } + + return info->Read(aBuffer, aAmount); +} + +static inline PRInt32 nsNamedPipeRead(PRFileDesc* aFd, void* aBuffer, + PRInt32 aAmount) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + NamedPipeInfo* info = GetNamedPipeInfo(aFd); + if (!info) { + PR_SetError(PR_BAD_DESCRIPTOR_ERROR, 0); + return -1; + } + return info->Read(aBuffer, aAmount); +} + +static inline PRInt32 nsNamedPipeWrite(PRFileDesc* aFd, const void* aBuffer, + PRInt32 aAmount) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + NamedPipeInfo* info = GetNamedPipeInfo(aFd); + if (!info) { + PR_SetError(PR_BAD_DESCRIPTOR_ERROR, 0); + return -1; + } + return info->Write(aBuffer, aAmount); +} + +static PRInt32 nsNamedPipeAvailable(PRFileDesc* aFd) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + NamedPipeInfo* info = GetNamedPipeInfo(aFd); + if (!info) { + PR_SetError(PR_BAD_DESCRIPTOR_ERROR, 0); + return -1; + } + return static_cast<PRInt32>(info->Available()); +} + +static PRInt64 nsNamedPipeAvailable64(PRFileDesc* aFd) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + NamedPipeInfo* info = GetNamedPipeInfo(aFd); + if (!info) { + PR_SetError(PR_BAD_DESCRIPTOR_ERROR, 0); + return -1; + } + return static_cast<PRInt64>(info->Available()); +} + +static PRStatus nsNamedPipeSync(PRFileDesc* aFd) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + NamedPipeInfo* info = GetNamedPipeInfo(aFd); + if (!info) { + PR_SetError(PR_BAD_DESCRIPTOR_ERROR, 0); + return PR_FAILURE; + } + return info->Sync(0) ? PR_SUCCESS : PR_FAILURE; +} + +static PRInt16 nsNamedPipePoll(PRFileDesc* aFd, PRInt16 aInFlags, + PRInt16* aOutFlags) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + NamedPipeInfo* info = GetNamedPipeInfo(aFd); + if (!info) { + PR_SetError(PR_BAD_DESCRIPTOR_ERROR, 0); + return 0; + } + return info->GetPollFlags(aInFlags, aOutFlags); +} + +// FIXME: remove socket option functions? +static PRStatus nsNamedPipeGetSocketOption(PRFileDesc* aFd, + PRSocketOptionData* aData) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + MOZ_ASSERT(aFd); + MOZ_ASSERT(aData); + + switch (aData->option) { + case PR_SockOpt_Nonblocking: + aData->value.non_blocking = + GetNamedPipeInfo(aFd)->IsNonblocking() ? PR_TRUE : PR_FALSE; + break; + case PR_SockOpt_Keepalive: + aData->value.keep_alive = PR_TRUE; + break; + case PR_SockOpt_NoDelay: + aData->value.no_delay = PR_TRUE; + break; + default: + PR_SetError(PR_INVALID_METHOD_ERROR, 0); + return PR_FAILURE; + } + + return PR_SUCCESS; +} + +static PRStatus nsNamedPipeSetSocketOption(PRFileDesc* aFd, + const PRSocketOptionData* aData) { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + MOZ_ASSERT(aFd); + MOZ_ASSERT(aData); + + switch (aData->option) { + case PR_SockOpt_Nonblocking: + GetNamedPipeInfo(aFd)->SetNonblocking(aData->value.non_blocking); + break; + case PR_SockOpt_Keepalive: + case PR_SockOpt_NoDelay: + break; + default: + PR_SetError(PR_INVALID_METHOD_ERROR, 0); + return PR_FAILURE; + } + + return PR_SUCCESS; +} + +static void Initialize() { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + + static bool initialized = false; + if (initialized) { + return; + } + + nsNamedPipeLayerIdentity = PR_GetUniqueIdentity("Named Pipe layer"); + nsNamedPipeLayerMethods = *PR_GetDefaultIOMethods(); + nsNamedPipeLayerMethods.close = nsNamedPipeClose; + nsNamedPipeLayerMethods.read = nsNamedPipeRead; + nsNamedPipeLayerMethods.write = nsNamedPipeWrite; + nsNamedPipeLayerMethods.available = nsNamedPipeAvailable; + nsNamedPipeLayerMethods.available64 = nsNamedPipeAvailable64; + nsNamedPipeLayerMethods.fsync = nsNamedPipeSync; + nsNamedPipeLayerMethods.connect = nsNamedPipeConnect; + nsNamedPipeLayerMethods.recv = nsNamedPipeRecv; + nsNamedPipeLayerMethods.send = nsNamedPipeSend; + nsNamedPipeLayerMethods.poll = nsNamedPipePoll; + nsNamedPipeLayerMethods.getsocketoption = nsNamedPipeGetSocketOption; + nsNamedPipeLayerMethods.setsocketoption = nsNamedPipeSetSocketOption; + nsNamedPipeLayerMethods.connectcontinue = nsNamedPipeConnectContinue; + + initialized = true; +} + +bool IsNamedPipePath(const nsACString& aPath) { + return StringBeginsWith(aPath, "\\\\.\\pipe\\"_ns); +} + +PRFileDesc* CreateNamedPipeLayer() { + MOZ_ASSERT(OnSocketThread(), "not on socket thread"); + Initialize(); + + PRFileDesc* layer = + PR_CreateIOLayerStub(nsNamedPipeLayerIdentity, &nsNamedPipeLayerMethods); + if (NS_WARN_IF(!layer)) { + LOG_NPIO_ERROR("CreateNamedPipeLayer() failed."); + return nullptr; + } + + RefPtr<NamedPipeInfo> info = new NamedPipeInfo(); + layer->secret = reinterpret_cast<PRFilePrivate*>(info.forget().take()); + + return layer; +} + +} // namespace net +} // namespace mozilla diff --git a/netwerk/socket/nsNamedPipeIOLayer.h b/netwerk/socket/nsNamedPipeIOLayer.h new file mode 100644 index 0000000000..87598324b6 --- /dev/null +++ b/netwerk/socket/nsNamedPipeIOLayer.h @@ -0,0 +1,24 @@ +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_netwerk_socket_nsNamedPipeIOLayer_h +#define mozilla_netwerk_socket_nsNamedPipeIOLayer_h + +#include "nscore.h" +#include "nsStringFwd.h" +#include "prio.h" + +namespace mozilla { +namespace net { + +bool IsNamedPipePath(const nsACString& aPath); +PRFileDesc* CreateNamedPipeLayer(); + +extern PRDescIdentity nsNamedPipeLayerIdentity; + +} // namespace net +} // namespace mozilla + +#endif // mozilla_netwerk_socket_nsNamedPipeIOLayer_h diff --git a/netwerk/socket/nsNamedPipeService.cpp b/netwerk/socket/nsNamedPipeService.cpp new file mode 100644 index 0000000000..4cf41680d5 --- /dev/null +++ b/netwerk/socket/nsNamedPipeService.cpp @@ -0,0 +1,319 @@ +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "mozilla/Services.h" +#include "nsCOMPtr.h" +#include "nsIObserverService.h" +#include "nsNamedPipeService.h" +#include "nsNetCID.h" +#include "nsThreadUtils.h" +#include "mozilla/ClearOnShutdown.h" +#include "mozilla/Logging.h" + +namespace mozilla { +namespace net { + +static mozilla::LazyLogModule gNamedPipeServiceLog("NamedPipeWin"); +#define LOG_NPS_DEBUG(...) \ + MOZ_LOG(gNamedPipeServiceLog, mozilla::LogLevel::Debug, (__VA_ARGS__)) +#define LOG_NPS_ERROR(...) \ + MOZ_LOG(gNamedPipeServiceLog, mozilla::LogLevel::Error, (__VA_ARGS__)) + +StaticRefPtr<NamedPipeService> NamedPipeService::gSingleton; + +NS_IMPL_ISUPPORTS(NamedPipeService, nsINamedPipeService, nsIObserver, + nsIRunnable) + +NamedPipeService::NamedPipeService() + : mIocp(nullptr), mIsShutdown(false), mLock("NamedPipeServiceLock") {} + +nsresult NamedPipeService::Init() { + MOZ_ASSERT(!mIsShutdown); + + nsresult rv; + + // nsIObserverService must be accessed in main thread. + // register shutdown event to stop NamedPipeSrv thread. + nsCOMPtr<nsIObserver> self(this); + nsCOMPtr<nsIRunnable> r = NS_NewRunnableFunction( + "NamedPipeService::Init", [self = std::move(self)]() -> void { + MOZ_ASSERT(NS_IsMainThread()); + + nsCOMPtr<nsIObserverService> svc = + mozilla::services::GetObserverService(); + + if (NS_WARN_IF(!svc)) { + return; + } + + if (NS_WARN_IF(NS_FAILED(svc->AddObserver( + self, NS_XPCOM_SHUTDOWN_OBSERVER_ID, false)))) { + return; + } + }); + + if (NS_IsMainThread()) { + rv = r->Run(); + } else { + rv = NS_DispatchToMainThread(r); + } + if (NS_WARN_IF(NS_FAILED(rv))) { + return rv; + } + + mIocp = CreateIoCompletionPort(INVALID_HANDLE_VALUE, nullptr, 0, 1); + if (NS_WARN_IF(!mIocp || mIocp == INVALID_HANDLE_VALUE)) { + Shutdown(); + return NS_ERROR_FAILURE; + } + + rv = NS_NewNamedThread("NamedPipeSrv", getter_AddRefs(mThread)); + if (NS_WARN_IF(NS_FAILED(rv))) { + Shutdown(); + return rv; + } + + return NS_OK; +} + +// static +already_AddRefed<nsINamedPipeService> NamedPipeService::GetOrCreate() { + MOZ_ASSERT(NS_IsMainThread()); + + RefPtr<NamedPipeService> inst; + if (gSingleton) { + inst = gSingleton; + } else { + inst = new NamedPipeService(); + nsresult rv = inst->Init(); + NS_ENSURE_SUCCESS(rv, nullptr); + gSingleton = inst; + ClearOnShutdown(&gSingleton); + } + + return inst.forget(); +} + +void NamedPipeService::Shutdown() { + MOZ_ASSERT(NS_IsMainThread()); + + // remove observer + nsCOMPtr<nsIObserverService> obs = mozilla::services::GetObserverService(); + if (obs) { + obs->RemoveObserver(this, NS_XPCOM_SHUTDOWN_OBSERVER_ID); + } + + // stop thread + if (mThread && !mIsShutdown) { + mIsShutdown = true; + + // invoke ERROR_ABANDONED_WAIT_0 to |GetQueuedCompletionStatus| + CloseHandle(mIocp); + mIocp = nullptr; + + mThread->Shutdown(); + } + + // close I/O Completion Port + if (mIocp && mIocp != INVALID_HANDLE_VALUE) { + CloseHandle(mIocp); + mIocp = nullptr; + } +} + +void NamedPipeService::RemoveRetiredObjects() { + MOZ_ASSERT(NS_GetCurrentThread() == mThread); + mLock.AssertCurrentThreadOwns(); + + if (!mRetiredHandles.IsEmpty()) { + for (auto& handle : mRetiredHandles) { + CloseHandle(handle); + } + mRetiredHandles.Clear(); + } + + mRetiredObservers.Clear(); +} + +/** + * Implement nsINamedPipeService + */ + +NS_IMETHODIMP +NamedPipeService::AddDataObserver(void* aHandle, + nsINamedPipeDataObserver* aObserver) { + if (!aHandle || aHandle == INVALID_HANDLE_VALUE || !aObserver) { + return NS_ERROR_ILLEGAL_VALUE; + } + + nsresult rv; + + HANDLE h = CreateIoCompletionPort(aHandle, mIocp, + reinterpret_cast<ULONG_PTR>(aObserver), 1); + if (NS_WARN_IF(!h)) { + LOG_NPS_ERROR("CreateIoCompletionPort error (%lu)", GetLastError()); + return NS_ERROR_FAILURE; + } + if (NS_WARN_IF(h != mIocp)) { + LOG_NPS_ERROR( + "CreateIoCompletionPort got unexpected value %p (should be %p)", h, + mIocp); + CloseHandle(h); + return NS_ERROR_FAILURE; + } + + { + MutexAutoLock lock(mLock); + MOZ_ASSERT(!mObservers.Contains(aObserver)); + + mObservers.AppendElement(aObserver); + + // start event loop + if (mObservers.Length() == 1) { + rv = mThread->Dispatch(this, NS_DISPATCH_NORMAL); + if (NS_WARN_IF(NS_FAILED(rv))) { + LOG_NPS_ERROR("Dispatch to thread failed (%08x)", rv); + mObservers.Clear(); + return rv; + } + } + } + + return NS_OK; +} + +NS_IMETHODIMP +NamedPipeService::RemoveDataObserver(void* aHandle, + nsINamedPipeDataObserver* aObserver) { + MutexAutoLock lock(mLock); + mObservers.RemoveElement(aObserver); + + mRetiredHandles.AppendElement(aHandle); + mRetiredObservers.AppendElement(aObserver); + + return NS_OK; +} + +NS_IMETHODIMP +NamedPipeService::IsOnCurrentThread(bool* aRetVal) { + MOZ_ASSERT(mThread); + MOZ_ASSERT(aRetVal); + + if (!mThread) { + *aRetVal = false; + return NS_OK; + } + + return mThread->IsOnCurrentThread(aRetVal); +} + +/** + * Implement nsIObserver + */ + +NS_IMETHODIMP +NamedPipeService::Observe(nsISupports* aSubject, const char* aTopic, + const char16_t* aData) { + MOZ_ASSERT(NS_IsMainThread()); + + if (!strcmp(NS_XPCOM_SHUTDOWN_OBSERVER_ID, aTopic)) { + Shutdown(); + } + + return NS_OK; +} + +/** + * Implement nsIRunnable + */ + +NS_IMETHODIMP +NamedPipeService::Run() { + MOZ_ASSERT(NS_GetCurrentThread() == mThread); + MOZ_ASSERT(mIocp && mIocp != INVALID_HANDLE_VALUE); + + while (!mIsShutdown) { + { + MutexAutoLock lock(mLock); + if (mObservers.IsEmpty()) { + LOG_NPS_DEBUG("no observer, stop loop"); + break; + } + + RemoveRetiredObjects(); + } + + DWORD bytesTransferred = 0; + ULONG_PTR key = 0; + LPOVERLAPPED overlapped = nullptr; + BOOL success = + GetQueuedCompletionStatus(mIocp, &bytesTransferred, &key, &overlapped, + 1000); // timeout, 1s + auto err = GetLastError(); + if (!success) { + if (err == WAIT_TIMEOUT) { + continue; + } else if (err == ERROR_ABANDONED_WAIT_0) { // mIocp was closed + break; + } else if (!overlapped) { + /** + * Did not dequeue a completion packet from the completion port, and + * bytesTransferred/key are meaningless. + * See remarks of |GetQueuedCompletionStatus| API. + */ + + LOG_NPS_ERROR("invalid overlapped (%lu)", err); + continue; + } + + MOZ_ASSERT(key); + } + + /** + * Windows doesn't provide a method to remove created I/O Completion Port, + * all we can do is just close the handle we monitored before. + * In some cases, there's race condition that the monitored handle has an + * I/O status after the observer is being removed and destroyed. + * To avoid changing the ref-count of a dangling pointer, don't use nsCOMPtr + * here. + */ + nsINamedPipeDataObserver* target = + reinterpret_cast<nsINamedPipeDataObserver*>(key); + + nsCOMPtr<nsINamedPipeDataObserver> obs; + { + MutexAutoLock lock(mLock); + + auto idx = mObservers.IndexOf(target); + if (idx == decltype(mObservers)::NoIndex) { + LOG_NPS_ERROR("observer %p not found", target); + continue; + } + obs = target; + } + + MOZ_ASSERT(obs.get()); + + if (success) { + LOG_NPS_DEBUG("OnDataAvailable: obs=%p, bytes=%lu", obs.get(), + bytesTransferred); + obs->OnDataAvailable(bytesTransferred, overlapped); + } else { + LOG_NPS_ERROR("GetQueuedCompletionStatus %p failed, error=%lu", obs.get(), + err); + obs->OnError(err, overlapped); + } + } + + { + MutexAutoLock lock(mLock); + RemoveRetiredObjects(); + } + + return NS_OK; +} + +} // namespace net +} // namespace mozilla diff --git a/netwerk/socket/nsNamedPipeService.h b/netwerk/socket/nsNamedPipeService.h new file mode 100644 index 0000000000..cbe6ff9631 --- /dev/null +++ b/netwerk/socket/nsNamedPipeService.h @@ -0,0 +1,67 @@ +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef mozilla_netwerk_socket_nsNamedPipeService_h +#define mozilla_netwerk_socket_nsNamedPipeService_h + +#include <windows.h> +#include "mozilla/Atomics.h" +#include "mozilla/Mutex.h" +#include "nsINamedPipeService.h" +#include "nsIObserver.h" +#include "nsIRunnable.h" +#include "nsIThread.h" +#include "nsTArray.h" +#include "mozilla/StaticPtr.h" + +namespace mozilla { +namespace net { + +class NamedPipeService final : public nsINamedPipeService, + public nsIObserver, + public nsIRunnable { + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSINAMEDPIPESERVICE + NS_DECL_NSIOBSERVER + NS_DECL_NSIRUNNABLE + + static already_AddRefed<nsINamedPipeService> GetOrCreate(); + + private: + explicit NamedPipeService(); + virtual ~NamedPipeService() = default; + + nsresult Init(); + + void Shutdown(); + void RemoveRetiredObjects(); + + HANDLE mIocp; // native handle to the I/O completion port. + Atomic<bool> + mIsShutdown; // set to true to stop the event loop running by mThread. + nsCOMPtr<nsIThread> mThread; // worker thread to get I/O events. + + /** + * The observers is maintained in |mObservers| to ensure valid life-cycle. + * We don't remove the handle and corresponding observer directly, instead + * the handle and observer into a "retired" list and close/remove them in + * the worker thread to avoid a race condition that might happen between + * |CloseHandle()| and |GetQueuedCompletionStatus()|. + */ + Mutex mLock MOZ_UNANNOTATED; + nsTArray<nsCOMPtr<nsINamedPipeDataObserver>> + mObservers; // protected by mLock + nsTArray<nsCOMPtr<nsINamedPipeDataObserver>> + mRetiredObservers; // protected by mLock + nsTArray<HANDLE> mRetiredHandles; // protected by mLock + + static StaticRefPtr<NamedPipeService> gSingleton; +}; + +} // namespace net +} // namespace mozilla + +#endif // mozilla_netwerk_socket_nsNamedPipeService_h diff --git a/netwerk/socket/nsSOCKSIOLayer.cpp b/netwerk/socket/nsSOCKSIOLayer.cpp new file mode 100644 index 0000000000..b92de4e87a --- /dev/null +++ b/netwerk/socket/nsSOCKSIOLayer.cpp @@ -0,0 +1,1471 @@ +/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim:set expandtab ts=4 sw=2 sts=2 cin: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nspr.h" +#include "private/pprio.h" +#include "nsString.h" +#include "nsCRT.h" + +#include "nsIDNSService.h" +#include "nsIDNSRecord.h" +#include "nsISocketProvider.h" +#include "nsNamedPipeIOLayer.h" +#include "nsSOCKSIOLayer.h" +#include "nsNetCID.h" +#include "nsIDNSListener.h" +#include "nsICancelable.h" +#include "nsThreadUtils.h" +#include "nsIFile.h" +#include "nsIFileProtocolHandler.h" +#include "mozilla/Logging.h" +#include "mozilla/net/DNS.h" +#include "mozilla/Unused.h" + +using mozilla::LogLevel; +using namespace mozilla::net; + +static PRDescIdentity nsSOCKSIOLayerIdentity; +static PRIOMethods nsSOCKSIOLayerMethods; +static bool firstTime = true; +static bool ipv6Supported = true; + +static mozilla::LazyLogModule gSOCKSLog("SOCKS"); +#define LOGDEBUG(args) MOZ_LOG(gSOCKSLog, mozilla::LogLevel::Debug, args) +#define LOGERROR(args) MOZ_LOG(gSOCKSLog, mozilla::LogLevel::Error, args) + +class nsSOCKSSocketInfo : public nsIDNSListener { + enum State { + SOCKS_INITIAL, + SOCKS_DNS_IN_PROGRESS, + SOCKS_DNS_COMPLETE, + SOCKS_CONNECTING_TO_PROXY, + SOCKS4_WRITE_CONNECT_REQUEST, + SOCKS4_READ_CONNECT_RESPONSE, + SOCKS5_WRITE_AUTH_REQUEST, + SOCKS5_READ_AUTH_RESPONSE, + SOCKS5_WRITE_USERNAME_REQUEST, + SOCKS5_READ_USERNAME_RESPONSE, + SOCKS5_WRITE_CONNECT_REQUEST, + SOCKS5_READ_CONNECT_RESPONSE_TOP, + SOCKS5_READ_CONNECT_RESPONSE_BOTTOM, + SOCKS_CONNECTED, + SOCKS_FAILED + }; + + // A buffer of 520 bytes should be enough for any request and response + // in case of SOCKS4 as well as SOCKS5 + static const uint32_t BUFFER_SIZE = 520; + static const uint32_t MAX_HOSTNAME_LEN = 255; + static const uint32_t MAX_USERNAME_LEN = 255; + static const uint32_t MAX_PASSWORD_LEN = 255; + + public: + nsSOCKSSocketInfo(); + + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSIDNSLISTENER + + void Init(int32_t version, int32_t family, nsIProxyInfo* proxy, + const char* destinationHost, uint32_t flags, uint32_t tlsFlags); + + void SetConnectTimeout(PRIntervalTime to); + PRStatus DoHandshake(PRFileDesc* fd, int16_t oflags = -1); + int16_t GetPollFlags() const; + bool IsConnected() const { return mState == SOCKS_CONNECTED; } + void ForgetFD() { mFD = nullptr; } + void SetNamedPipeFD(PRFileDesc* fd) { mFD = fd; } + + void GetExternalProxyAddr(NetAddr& aExternalProxyAddr); + void GetDestinationAddr(NetAddr& aDestinationAddr); + void SetDestinationAddr(const NetAddr& aDestinationAddr); + + private: + virtual ~nsSOCKSSocketInfo() { + ForgetFD(); + HandshakeFinished(); + } + + void HandshakeFinished(PRErrorCode err = 0); + PRStatus StartDNS(PRFileDesc* fd); + PRStatus ConnectToProxy(PRFileDesc* fd); + void FixupAddressFamily(PRFileDesc* fd, NetAddr* proxy); + PRStatus ContinueConnectingToProxy(PRFileDesc* fd, int16_t oflags); + PRStatus WriteV4ConnectRequest(); + PRStatus ReadV4ConnectResponse(); + PRStatus WriteV5AuthRequest(); + PRStatus ReadV5AuthResponse(); + PRStatus WriteV5UsernameRequest(); + PRStatus ReadV5UsernameResponse(); + PRStatus WriteV5ConnectRequest(); + PRStatus ReadV5AddrTypeAndLength(uint8_t* type, uint32_t* len); + PRStatus ReadV5ConnectResponseTop(); + PRStatus ReadV5ConnectResponseBottom(); + + uint8_t ReadUint8(); + uint16_t ReadUint16(); + uint32_t ReadUint32(); + void ReadNetAddr(NetAddr* addr, uint16_t fam); + void ReadNetPort(NetAddr* addr); + + void WantRead(uint32_t sz); + PRStatus ReadFromSocket(PRFileDesc* fd); + PRStatus WriteToSocket(PRFileDesc* fd); + + bool IsLocalProxy() { + nsAutoCString proxyHost; + mProxy->GetHost(proxyHost); + return IsHostLocalTarget(proxyHost); + } + + nsresult SetLocalProxyPath(const nsACString& aLocalProxyPath, + NetAddr* aProxyAddr) { +#ifdef XP_UNIX + nsresult rv; + MOZ_ASSERT(aProxyAddr); + + nsCOMPtr<nsIProtocolHandler> protocolHandler( + do_GetService(NS_NETWORK_PROTOCOL_CONTRACTID_PREFIX "file", &rv)); + if (NS_WARN_IF(NS_FAILED(rv))) { + return rv; + } + + nsCOMPtr<nsIFileProtocolHandler> fileHandler( + do_QueryInterface(protocolHandler, &rv)); + if (NS_WARN_IF(NS_FAILED(rv))) { + return rv; + } + + nsCOMPtr<nsIFile> socketFile; + rv = fileHandler->GetFileFromURLSpec(aLocalProxyPath, + getter_AddRefs(socketFile)); + if (NS_WARN_IF(NS_FAILED(rv))) { + return rv; + } + + nsAutoCString path; + if (NS_WARN_IF(NS_FAILED(rv = socketFile->GetNativePath(path)))) { + return rv; + } + + if (sizeof(aProxyAddr->local.path) <= path.Length()) { + NS_WARNING("domain socket path too long."); + return NS_ERROR_FAILURE; + } + + aProxyAddr->raw.family = AF_UNIX; + strcpy(aProxyAddr->local.path, path.get()); + + return NS_OK; +#elif defined(XP_WIN) + MOZ_ASSERT(aProxyAddr); + + if (sizeof(aProxyAddr->local.path) <= aLocalProxyPath.Length()) { + NS_WARNING("pipe path too long."); + return NS_ERROR_FAILURE; + } + + aProxyAddr->raw.family = AF_LOCAL; + strcpy(aProxyAddr->local.path, PromiseFlatCString(aLocalProxyPath).get()); + return NS_OK; +#else + mozilla::Unused << aLocalProxyPath; + mozilla::Unused << aProxyAddr; + return NS_ERROR_NOT_IMPLEMENTED; +#endif + } + + bool SetupNamedPipeLayer(PRFileDesc* fd) { +#if defined(XP_WIN) + if (IsLocalProxy()) { + // nsSOCKSIOLayer handshaking only works under blocking mode + // unfortunately. Remember named pipe's FD to switch between modes. + SetNamedPipeFD(fd->lower); + return true; + } +#endif + return false; + } + + private: + State mState{SOCKS_INITIAL}; + uint8_t* mData{nullptr}; + uint8_t* mDataIoPtr{nullptr}; + uint32_t mDataLength{0}; + uint32_t mReadOffset{0}; + uint32_t mAmountToRead{0}; + nsCOMPtr<nsIDNSRecord> mDnsRec; + nsCOMPtr<nsICancelable> mLookup; + nsresult mLookupStatus{NS_ERROR_NOT_INITIALIZED}; + PRFileDesc* mFD{nullptr}; + + nsCString mDestinationHost; + nsCOMPtr<nsIProxyInfo> mProxy; + int32_t mVersion{-1}; // SOCKS version 4 or 5 + int32_t mDestinationFamily{AF_INET}; + uint32_t mFlags{0}; + uint32_t mTlsFlags{0}; + NetAddr mInternalProxyAddr; + NetAddr mExternalProxyAddr; + NetAddr mDestinationAddr; + PRIntervalTime mTimeout{PR_INTERVAL_NO_TIMEOUT}; + nsCString mProxyUsername; // Cache, from mProxy +}; + +nsSOCKSSocketInfo::nsSOCKSSocketInfo() { + mData = new uint8_t[BUFFER_SIZE]; + + mInternalProxyAddr.raw.family = AF_INET; + mInternalProxyAddr.inet.ip = htonl(INADDR_ANY); + mInternalProxyAddr.inet.port = htons(0); + + mExternalProxyAddr.raw.family = AF_INET; + mExternalProxyAddr.inet.ip = htonl(INADDR_ANY); + mExternalProxyAddr.inet.port = htons(0); + + mDestinationAddr.raw.family = AF_INET; + mDestinationAddr.inet.ip = htonl(INADDR_ANY); + mDestinationAddr.inet.port = htons(0); +} + +/* Helper template class to statically check that writes to a fixed-size + * buffer are not going to overflow. + * + * Example usage: + * uint8_t real_buf[TOTAL_SIZE]; + * Buffer<TOTAL_SIZE> buf(&real_buf); + * auto buf2 = buf.WriteUint16(1); + * auto buf3 = buf2.WriteUint8(2); + * + * It is possible to chain them, to limit the number of (error-prone) + * intermediate variables: + * auto buf = Buffer<TOTAL_SIZE>(&real_buf) + * .WriteUint16(1) + * .WriteUint8(2); + * + * Debug builds assert when intermediate variables are reused: + * Buffer<TOTAL_SIZE> buf(&real_buf); + * auto buf2 = buf.WriteUint16(1); + * auto buf3 = buf.WriteUint8(2); // Asserts + * + * Strings can be written, given an explicit maximum length. + * buf.WriteString<MAX_STRING_LENGTH>(str); + * + * The Written() method returns how many bytes have been written so far: + * Buffer<TOTAL_SIZE> buf(&real_buf); + * auto buf2 = buf.WriteUint16(1); + * auto buf3 = buf2.WriteUint8(2); + * buf3.Written(); // returns 3. + */ +template <size_t Size> +class Buffer { + public: + Buffer() = default; + + explicit Buffer(uint8_t* aBuf, size_t aLength = 0) + : mBuf(aBuf), mLength(aLength) {} + + template <size_t Size2> + MOZ_IMPLICIT Buffer(const Buffer<Size2>& aBuf) + : mBuf(aBuf.mBuf), mLength(aBuf.mLength) { + static_assert(Size2 > Size, "Cannot cast buffer"); + } + + Buffer<Size - sizeof(uint8_t)> WriteUint8(uint8_t aValue) { + return Write(aValue); + } + + Buffer<Size - sizeof(uint16_t)> WriteUint16(uint16_t aValue) { + return Write(aValue); + } + + Buffer<Size - sizeof(uint32_t)> WriteUint32(uint32_t aValue) { + return Write(aValue); + } + + Buffer<Size - sizeof(uint16_t)> WriteNetPort(const NetAddr* aAddr) { + return WriteUint16(aAddr->inet.port); + } + + Buffer<Size - sizeof(IPv6Addr)> WriteNetAddr(const NetAddr* aAddr) { + if (aAddr->raw.family == AF_INET) { + return Write(aAddr->inet.ip); + } + if (aAddr->raw.family == AF_INET6) { + return Write(aAddr->inet6.ip.u8); + } + MOZ_ASSERT_UNREACHABLE("Unknown address family"); + return *this; + } + + template <size_t MaxLength> + Buffer<Size - MaxLength> WriteString(const nsACString& aStr) { + if (aStr.Length() > MaxLength) { + return Buffer<Size - MaxLength>(nullptr); + } + return WritePtr<char, MaxLength>(aStr.Data(), aStr.Length()); + } + + size_t Written() { + MOZ_ASSERT(mBuf); + return mLength; + } + + explicit operator bool() { return !!mBuf; } + + private: + template <size_t Size2> + friend class Buffer; + + template <typename T> + Buffer<Size - sizeof(T)> Write(T& aValue) { + return WritePtr<T, sizeof(T)>(&aValue, sizeof(T)); + } + + template <typename T, size_t Length> + Buffer<Size - Length> WritePtr(const T* aValue, size_t aCopyLength) { + static_assert(Size >= Length, "Cannot write that much"); + MOZ_ASSERT(aCopyLength <= Length); + MOZ_ASSERT(mBuf); + memcpy(mBuf, aValue, aCopyLength); + Buffer<Size - Length> result(mBuf + aCopyLength, mLength + aCopyLength); + mBuf = nullptr; + mLength = 0; + return result; + } + + uint8_t* mBuf{nullptr}; + size_t mLength{0}; +}; + +void nsSOCKSSocketInfo::Init(int32_t version, int32_t family, + nsIProxyInfo* proxy, const char* host, + uint32_t flags, uint32_t tlsFlags) { + mVersion = version; + mDestinationFamily = family; + mProxy = proxy; + mDestinationHost = host; + mFlags = flags; + mTlsFlags = tlsFlags; + mProxy->GetUsername(mProxyUsername); // cache +} + +NS_IMPL_ISUPPORTS(nsSOCKSSocketInfo, nsIDNSListener) + +void nsSOCKSSocketInfo::GetExternalProxyAddr(NetAddr& aExternalProxyAddr) { + aExternalProxyAddr = mExternalProxyAddr; +} + +void nsSOCKSSocketInfo::GetDestinationAddr(NetAddr& aDestinationAddr) { + aDestinationAddr = mDestinationAddr; +} + +void nsSOCKSSocketInfo::SetDestinationAddr(const NetAddr& aDestinationAddr) { + mDestinationAddr = aDestinationAddr; +} + +// There needs to be a means of distinguishing between connection errors +// that the SOCKS server reports when it rejects a connection request, and +// connection errors that happen while attempting to connect to the SOCKS +// server. Otherwise, Firefox will report incorrectly that the proxy server +// is refusing connections when a SOCKS request is rejected by the proxy. +// When a SOCKS handshake failure occurs, the PR error is set to +// PR_UNKNOWN_ERROR, and the real error code is returned via the OS error. +void nsSOCKSSocketInfo::HandshakeFinished(PRErrorCode err) { + if (err == 0) { + mState = SOCKS_CONNECTED; +#if defined(XP_WIN) + // Switch back to nonblocking mode after finishing handshaking. + if (IsLocalProxy() && mFD) { + PRSocketOptionData opt_nonblock; + opt_nonblock.option = PR_SockOpt_Nonblocking; + opt_nonblock.value.non_blocking = PR_TRUE; + PR_SetSocketOption(mFD, &opt_nonblock); + mFD = nullptr; + } +#endif + } else { + mState = SOCKS_FAILED; + PR_SetError(PR_UNKNOWN_ERROR, err); + } + + // We don't need the buffer any longer, so free it. + delete[] mData; + mData = nullptr; + mDataIoPtr = nullptr; + mDataLength = 0; + mReadOffset = 0; + mAmountToRead = 0; + if (mLookup) { + mLookup->Cancel(NS_ERROR_FAILURE); + mLookup = nullptr; + } +} + +PRStatus nsSOCKSSocketInfo::StartDNS(PRFileDesc* fd) { + MOZ_ASSERT(!mDnsRec && mState == SOCKS_INITIAL, + "Must be in initial state to make DNS Lookup"); + + nsCOMPtr<nsIDNSService> dns = do_GetService(NS_DNSSERVICE_CONTRACTID); + if (!dns) return PR_FAILURE; + + nsCString proxyHost; + mProxy->GetHost(proxyHost); + + mozilla::OriginAttributes attrs; + + mFD = fd; + nsresult rv = dns->AsyncResolveNative( + proxyHost, nsIDNSService::RESOLVE_TYPE_DEFAULT, + nsIDNSService::RESOLVE_IGNORE_SOCKS_DNS, nullptr, this, + mozilla::GetCurrentSerialEventTarget(), attrs, getter_AddRefs(mLookup)); + + if (NS_FAILED(rv)) { + LOGERROR(("socks: DNS lookup for SOCKS proxy %s failed", proxyHost.get())); + return PR_FAILURE; + } + mState = SOCKS_DNS_IN_PROGRESS; + PR_SetError(PR_IN_PROGRESS_ERROR, 0); + return PR_FAILURE; +} + +NS_IMETHODIMP +nsSOCKSSocketInfo::OnLookupComplete(nsICancelable* aRequest, + nsIDNSRecord* aRecord, nsresult aStatus) { + MOZ_ASSERT(aRequest == mLookup, "wrong DNS query"); + mLookup = nullptr; + mLookupStatus = aStatus; + mDnsRec = aRecord; + mState = SOCKS_DNS_COMPLETE; + if (mFD) { + ConnectToProxy(mFD); + ForgetFD(); + } + return NS_OK; +} + +PRStatus nsSOCKSSocketInfo::ConnectToProxy(PRFileDesc* fd) { + PRStatus status; + nsresult rv; + + MOZ_ASSERT(mState == SOCKS_DNS_COMPLETE, "Must have DNS to make connection!"); + + if (NS_FAILED(mLookupStatus)) { + PR_SetError(PR_BAD_ADDRESS_ERROR, 0); + return PR_FAILURE; + } + + // Try socks5 if the destination addrress is IPv6 + if (mVersion == 4 && mDestinationAddr.raw.family == AF_INET6) { + mVersion = 5; + } + + nsAutoCString proxyHost; + mProxy->GetHost(proxyHost); + + int32_t proxyPort; + mProxy->GetPort(&proxyPort); + + int32_t addresses = 0; + do { + if (IsLocalProxy()) { + rv = SetLocalProxyPath(proxyHost, &mInternalProxyAddr); + if (NS_FAILED(rv)) { + LOGERROR( + ("socks: unable to connect to SOCKS proxy, %s", proxyHost.get())); + return PR_FAILURE; + } + } else { + nsCOMPtr<nsIDNSAddrRecord> record = do_QueryInterface(mDnsRec); + MOZ_ASSERT(record); + if (addresses++) { + record->ReportUnusable(proxyPort); + } + + rv = record->GetNextAddr(proxyPort, &mInternalProxyAddr); + // No more addresses to try? If so, we'll need to bail + if (NS_FAILED(rv)) { + LOGERROR( + ("socks: unable to connect to SOCKS proxy, %s", proxyHost.get())); + return PR_FAILURE; + } + + if (MOZ_LOG_TEST(gSOCKSLog, LogLevel::Debug)) { + char buf[kIPv6CStrBufSize]; + mInternalProxyAddr.ToStringBuffer(buf, sizeof(buf)); + LOGDEBUG(("socks: trying proxy server, %s:%hu", buf, + ntohs(mInternalProxyAddr.inet.port))); + } + } + + NetAddr proxy = mInternalProxyAddr; + FixupAddressFamily(fd, &proxy); + PRNetAddr prProxy; + NetAddrToPRNetAddr(&proxy, &prProxy); + status = fd->lower->methods->connect(fd->lower, &prProxy, mTimeout); + if (status != PR_SUCCESS) { + PRErrorCode c = PR_GetError(); + + // If EINPROGRESS, return now and check back later after polling + if (c == PR_WOULD_BLOCK_ERROR || c == PR_IN_PROGRESS_ERROR) { + mState = SOCKS_CONNECTING_TO_PROXY; + return status; + } + if (IsLocalProxy()) { + LOGERROR(("socks: connect to domain socket failed (%d)", c)); + PR_SetError(PR_CONNECT_REFUSED_ERROR, 0); + mState = SOCKS_FAILED; + return status; + } + } + } while (status != PR_SUCCESS); + +#if defined(XP_WIN) + // Switch to blocking mode during handshaking + if (IsLocalProxy() && mFD) { + PRSocketOptionData opt_nonblock; + opt_nonblock.option = PR_SockOpt_Nonblocking; + opt_nonblock.value.non_blocking = PR_FALSE; + PR_SetSocketOption(mFD, &opt_nonblock); + } +#endif + + // Connected now, start SOCKS + if (mVersion == 4) return WriteV4ConnectRequest(); + return WriteV5AuthRequest(); +} + +void nsSOCKSSocketInfo::FixupAddressFamily(PRFileDesc* fd, NetAddr* proxy) { + int32_t proxyFamily = mInternalProxyAddr.raw.family; + // Do nothing if the address family is already matched + if (proxyFamily == mDestinationFamily) { + return; + } + // If the system does not support IPv6 and the proxy address is IPv6, + // We can do nothing here. + if (proxyFamily == AF_INET6 && !ipv6Supported) { + return; + } + // If the system does not support IPv6 and the destination address is + // IPv6, convert IPv4 address to IPv4-mapped IPv6 address to satisfy + // the emulation layer + if (mDestinationFamily == AF_INET6 && !ipv6Supported) { + proxy->inet6.family = AF_INET6; + proxy->inet6.port = mInternalProxyAddr.inet.port; + uint8_t* proxyp = proxy->inet6.ip.u8; + memset(proxyp, 0, 10); + memset(proxyp + 10, 0xff, 2); + memcpy(proxyp + 12, (char*)&mInternalProxyAddr.inet.ip, 4); + // mDestinationFamily should not be updated + return; + } + // There's no PR_NSPR_IO_LAYER required when using named pipe, + // we simply ignore the TCP family here. + if (SetupNamedPipeLayer(fd)) { + return; + } + + // Get an OS native handle from a specified FileDesc + PROsfd osfd = PR_FileDesc2NativeHandle(fd); + if (osfd == -1) { + return; + } + + // Create a new FileDesc with a specified family + PRFileDesc* tmpfd = PR_OpenTCPSocket(proxyFamily); + if (!tmpfd) { + return; + } + PROsfd newsd = PR_FileDesc2NativeHandle(tmpfd); + if (newsd == -1) { + PR_Close(tmpfd); + return; + } + // Must succeed because PR_FileDesc2NativeHandle succeeded + fd = PR_GetIdentitiesLayer(fd, PR_NSPR_IO_LAYER); + MOZ_ASSERT(fd); + // Swap OS native handles + PR_ChangeFileDescNativeHandle(fd, newsd); + PR_ChangeFileDescNativeHandle(tmpfd, osfd); + // Close temporary FileDesc which is now associated with + // old OS native handle + PR_Close(tmpfd); + mDestinationFamily = proxyFamily; +} + +PRStatus nsSOCKSSocketInfo::ContinueConnectingToProxy(PRFileDesc* fd, + int16_t oflags) { + PRStatus status; + + MOZ_ASSERT(mState == SOCKS_CONNECTING_TO_PROXY, + "Continuing connection in wrong state!"); + + LOGDEBUG(("socks: continuing connection to proxy")); + + status = fd->lower->methods->connectcontinue(fd->lower, oflags); + if (status != PR_SUCCESS) { + PRErrorCode c = PR_GetError(); + if (c != PR_WOULD_BLOCK_ERROR && c != PR_IN_PROGRESS_ERROR) { + // A connection failure occured, try another address + mState = SOCKS_DNS_COMPLETE; + return ConnectToProxy(fd); + } + + // We're still connecting + return PR_FAILURE; + } + + // Connected now, start SOCKS + if (mVersion == 4) return WriteV4ConnectRequest(); + return WriteV5AuthRequest(); +} + +PRStatus nsSOCKSSocketInfo::WriteV4ConnectRequest() { + if (mProxyUsername.Length() > MAX_USERNAME_LEN) { + LOGERROR(("socks username is too long")); + HandshakeFinished(PR_UNKNOWN_ERROR); + return PR_FAILURE; + } + + NetAddr* addr = &mDestinationAddr; + int32_t proxy_resolve; + + MOZ_ASSERT(mState == SOCKS_CONNECTING_TO_PROXY, "Invalid state!"); + + proxy_resolve = mFlags & nsISocketProvider::PROXY_RESOLVES_HOST; + + mDataLength = 0; + mState = SOCKS4_WRITE_CONNECT_REQUEST; + + LOGDEBUG(("socks4: sending connection request (socks4a resolve? %s)", + proxy_resolve ? "yes" : "no")); + + // Send a SOCKS 4 connect request. + auto buf = Buffer<BUFFER_SIZE>(mData) + .WriteUint8(0x04) // version -- 4 + .WriteUint8(0x01) // command -- connect + .WriteNetPort(addr); + + // We don't have anything more to write after the if, so we can + // use a buffer with no further writes allowed. + Buffer<0> buf3; + if (proxy_resolve) { + // Add the full name, null-terminated, to the request + // according to SOCKS 4a. A fake IP address, with the first + // four bytes set to 0 and the last byte set to something other + // than 0, is used to notify the proxy that this is a SOCKS 4a + // request. This request type works for Tor and perhaps others. + // Passwords not supported by V4. + auto buf2 = + buf.WriteUint32(htonl(0x00000001)) // Fake IP + .WriteString<MAX_USERNAME_LEN>(mProxyUsername) + .WriteUint8(0x00) // Null-terminate username + .WriteString<MAX_HOSTNAME_LEN>(mDestinationHost); // Hostname + if (!buf2) { + LOGERROR(("socks4: destination host name is too long!")); + HandshakeFinished(PR_BAD_ADDRESS_ERROR); + return PR_FAILURE; + } + buf3 = buf2.WriteUint8(0x00); + } else if (addr->raw.family == AF_INET) { + // Passwords not supported by V4. + buf3 = buf.WriteNetAddr(addr) // Add the IPv4 address + .WriteString<MAX_USERNAME_LEN>(mProxyUsername) + .WriteUint8(0x00); // Null-terminate username + } else { + LOGERROR(("socks: SOCKS 4 can only handle IPv4 addresses!")); + HandshakeFinished(PR_BAD_ADDRESS_ERROR); + return PR_FAILURE; + } + + mDataLength = buf3.Written(); + return PR_SUCCESS; +} + +PRStatus nsSOCKSSocketInfo::ReadV4ConnectResponse() { + MOZ_ASSERT(mState == SOCKS4_READ_CONNECT_RESPONSE, + "Handling SOCKS 4 connection reply in wrong state!"); + MOZ_ASSERT(mDataLength == 8, "SOCKS 4 connection reply must be 8 bytes!"); + + LOGDEBUG(("socks4: checking connection reply")); + + if (ReadUint8() != 0x00) { + LOGERROR(("socks4: wrong connection reply")); + HandshakeFinished(PR_CONNECT_REFUSED_ERROR); + return PR_FAILURE; + } + + // See if our connection request was granted + if (ReadUint8() == 90) { + LOGDEBUG(("socks4: connection successful!")); + HandshakeFinished(); + return PR_SUCCESS; + } + + LOGERROR(("socks4: unable to connect")); + HandshakeFinished(PR_CONNECT_REFUSED_ERROR); + return PR_FAILURE; +} + +PRStatus nsSOCKSSocketInfo::WriteV5AuthRequest() { + MOZ_ASSERT(mVersion == 5, "SOCKS version must be 5!"); + + mDataLength = 0; + mState = SOCKS5_WRITE_AUTH_REQUEST; + + // Send an initial SOCKS 5 greeting + LOGDEBUG(("socks5: sending auth methods")); + mDataLength = Buffer<BUFFER_SIZE>(mData) + .WriteUint8(0x05) // version -- 5 + .WriteUint8(0x01) // # of auth methods -- 1 + // Use authenticate iff we have a proxy username. + .WriteUint8(mProxyUsername.IsEmpty() ? 0x00 : 0x02) + .Written(); + + return PR_SUCCESS; +} + +PRStatus nsSOCKSSocketInfo::ReadV5AuthResponse() { + MOZ_ASSERT(mState == SOCKS5_READ_AUTH_RESPONSE, + "Handling SOCKS 5 auth method reply in wrong state!"); + MOZ_ASSERT(mDataLength == 2, "SOCKS 5 auth method reply must be 2 bytes!"); + + LOGDEBUG(("socks5: checking auth method reply")); + + // Check version number + if (ReadUint8() != 0x05) { + LOGERROR(("socks5: unexpected version in the reply")); + HandshakeFinished(PR_CONNECT_REFUSED_ERROR); + return PR_FAILURE; + } + + // Make sure our authentication choice was accepted, + // and continue accordingly + uint8_t authMethod = ReadUint8(); + if (mProxyUsername.IsEmpty() && authMethod == 0x00) { // no auth + LOGDEBUG(("socks5: server allows connection without authentication")); + return WriteV5ConnectRequest(); + } + if (!mProxyUsername.IsEmpty() && authMethod == 0x02) { // username/pw + LOGDEBUG(("socks5: auth method accepted by server")); + return WriteV5UsernameRequest(); + } // 0xFF signals error + LOGERROR(("socks5: server did not accept our authentication method")); + HandshakeFinished(PR_CONNECT_REFUSED_ERROR); + return PR_FAILURE; +} + +PRStatus nsSOCKSSocketInfo::WriteV5UsernameRequest() { + MOZ_ASSERT(mVersion == 5, "SOCKS version must be 5!"); + + if (mProxyUsername.Length() > MAX_USERNAME_LEN) { + LOGERROR(("socks username is too long")); + HandshakeFinished(PR_UNKNOWN_ERROR); + return PR_FAILURE; + } + + nsCString password; + mProxy->GetPassword(password); + if (password.Length() > MAX_PASSWORD_LEN) { + LOGERROR(("socks password is too long")); + HandshakeFinished(PR_UNKNOWN_ERROR); + return PR_FAILURE; + } + + mDataLength = 0; + mState = SOCKS5_WRITE_USERNAME_REQUEST; + + // RFC 1929 Username/password auth for SOCKS 5 + LOGDEBUG(("socks5: sending username and password")); + mDataLength = Buffer<BUFFER_SIZE>(mData) + .WriteUint8(0x01) // version 1 (not 5) + .WriteUint8(mProxyUsername.Length()) // username length + .WriteString<MAX_USERNAME_LEN>(mProxyUsername) // username + .WriteUint8(password.Length()) // password length + .WriteString<MAX_PASSWORD_LEN>( + password) // password. WARNING: Sent unencrypted! + .Written(); + + return PR_SUCCESS; +} + +PRStatus nsSOCKSSocketInfo::ReadV5UsernameResponse() { + MOZ_ASSERT(mState == SOCKS5_READ_USERNAME_RESPONSE, + "Handling SOCKS 5 username/password reply in wrong state!"); + + MOZ_ASSERT(mDataLength == 2, "SOCKS 5 username reply must be 2 bytes"); + + // Check version number, must be 1 (not 5) + if (ReadUint8() != 0x01) { + LOGERROR(("socks5: unexpected version in the reply")); + HandshakeFinished(PR_CONNECT_REFUSED_ERROR); + return PR_FAILURE; + } + + // Check whether username/password were accepted + if (ReadUint8() != 0x00) { // 0 = success + LOGERROR(("socks5: username/password not accepted")); + HandshakeFinished(PR_CONNECT_REFUSED_ERROR); + return PR_FAILURE; + } + + LOGDEBUG(("socks5: username/password accepted by server")); + + return WriteV5ConnectRequest(); +} + +PRStatus nsSOCKSSocketInfo::WriteV5ConnectRequest() { + // Send SOCKS 5 connect request + NetAddr* addr = &mDestinationAddr; + int32_t proxy_resolve; + proxy_resolve = mFlags & nsISocketProvider::PROXY_RESOLVES_HOST; + + LOGDEBUG(("socks5: sending connection request (socks5 resolve? %s)", + proxy_resolve ? "yes" : "no")); + + mDataLength = 0; + mState = SOCKS5_WRITE_CONNECT_REQUEST; + + auto buf = Buffer<BUFFER_SIZE>(mData) + .WriteUint8(0x05) // version -- 5 + .WriteUint8(0x01) // command -- connect + .WriteUint8(0x00); // reserved + + // We're writing a net port after the if, so we need a buffer allowing + // to write that much. + Buffer<sizeof(uint16_t)> buf2; + // Add the address to the SOCKS 5 request. SOCKS 5 supports several + // address types, so we pick the one that works best for us. + if (proxy_resolve) { + // Add the host name. Only a single byte is used to store the length, + // so we must prevent long names from being used. + buf2 = buf.WriteUint8(0x03) // addr type -- domainname + .WriteUint8(mDestinationHost.Length()) // name length + .WriteString<MAX_HOSTNAME_LEN>(mDestinationHost); // Hostname + if (!buf2) { + LOGERROR(("socks5: destination host name is too long!")); + HandshakeFinished(PR_BAD_ADDRESS_ERROR); + return PR_FAILURE; + } + } else if (addr->raw.family == AF_INET) { + buf2 = buf.WriteUint8(0x01) // addr type -- IPv4 + .WriteNetAddr(addr); + } else if (addr->raw.family == AF_INET6) { + buf2 = buf.WriteUint8(0x04) // addr type -- IPv6 + .WriteNetAddr(addr); + } else { + LOGERROR(("socks5: destination address of unknown type!")); + HandshakeFinished(PR_BAD_ADDRESS_ERROR); + return PR_FAILURE; + } + + auto buf3 = buf2.WriteNetPort(addr); // port + mDataLength = buf3.Written(); + + return PR_SUCCESS; +} + +PRStatus nsSOCKSSocketInfo::ReadV5AddrTypeAndLength(uint8_t* type, + uint32_t* len) { + MOZ_ASSERT(mState == SOCKS5_READ_CONNECT_RESPONSE_TOP || + mState == SOCKS5_READ_CONNECT_RESPONSE_BOTTOM, + "Invalid state!"); + MOZ_ASSERT(mDataLength >= 5, + "SOCKS 5 connection reply must be at least 5 bytes!"); + + // Seek to the address location + mReadOffset = 3; + + *type = ReadUint8(); + + switch (*type) { + case 0x01: // ipv4 + *len = 4 - 1; + break; + case 0x04: // ipv6 + *len = 16 - 1; + break; + case 0x03: // fqdn + *len = ReadUint8(); + break; + default: // wrong address type + LOGERROR(("socks5: wrong address type in connection reply!")); + return PR_FAILURE; + } + + return PR_SUCCESS; +} + +PRStatus nsSOCKSSocketInfo::ReadV5ConnectResponseTop() { + uint8_t res; + uint32_t len; + + MOZ_ASSERT(mState == SOCKS5_READ_CONNECT_RESPONSE_TOP, "Invalid state!"); + MOZ_ASSERT(mDataLength == 5, + "SOCKS 5 connection reply must be exactly 5 bytes!"); + + LOGDEBUG(("socks5: checking connection reply")); + + // Check version number + if (ReadUint8() != 0x05) { + LOGERROR(("socks5: unexpected version in the reply")); + HandshakeFinished(PR_CONNECT_REFUSED_ERROR); + return PR_FAILURE; + } + + // Check response + res = ReadUint8(); + if (res != 0x00) { + PRErrorCode c = PR_CONNECT_REFUSED_ERROR; + + switch (res) { + case 0x01: + LOGERROR( + ("socks5: connect failed: " + "01, General SOCKS server failure.")); + break; + case 0x02: + LOGERROR( + ("socks5: connect failed: " + "02, Connection not allowed by ruleset.")); + break; + case 0x03: + LOGERROR(("socks5: connect failed: 03, Network unreachable.")); + c = PR_NETWORK_UNREACHABLE_ERROR; + break; + case 0x04: + LOGERROR(("socks5: connect failed: 04, Host unreachable.")); + c = PR_BAD_ADDRESS_ERROR; + break; + case 0x05: + LOGERROR(("socks5: connect failed: 05, Connection refused.")); + break; + case 0x06: + LOGERROR(("socks5: connect failed: 06, TTL expired.")); + c = PR_CONNECT_TIMEOUT_ERROR; + break; + case 0x07: + LOGERROR( + ("socks5: connect failed: " + "07, Command not supported.")); + break; + case 0x08: + LOGERROR( + ("socks5: connect failed: " + "08, Address type not supported.")); + c = PR_BAD_ADDRESS_ERROR; + break; + default: + LOGERROR(("socks5: connect failed.")); + break; + } + + HandshakeFinished(c); + return PR_FAILURE; + } + + if (ReadV5AddrTypeAndLength(&res, &len) != PR_SUCCESS) { + HandshakeFinished(PR_BAD_ADDRESS_ERROR); + return PR_FAILURE; + } + + mState = SOCKS5_READ_CONNECT_RESPONSE_BOTTOM; + WantRead(len + 2); + + return PR_SUCCESS; +} + +PRStatus nsSOCKSSocketInfo::ReadV5ConnectResponseBottom() { + uint8_t type; + uint32_t len; + + MOZ_ASSERT(mState == SOCKS5_READ_CONNECT_RESPONSE_BOTTOM, "Invalid state!"); + + if (ReadV5AddrTypeAndLength(&type, &len) != PR_SUCCESS) { + HandshakeFinished(PR_BAD_ADDRESS_ERROR); + return PR_FAILURE; + } + + MOZ_ASSERT(mDataLength == 7 + len, + "SOCKS 5 unexpected length of connection reply!"); + + LOGDEBUG(("socks5: loading source addr and port")); + // Read what the proxy says is our source address + switch (type) { + case 0x01: // ipv4 + ReadNetAddr(&mExternalProxyAddr, AF_INET); + break; + case 0x04: // ipv6 + ReadNetAddr(&mExternalProxyAddr, AF_INET6); + break; + case 0x03: // fqdn (skip) + mReadOffset += len; + mExternalProxyAddr.raw.family = AF_INET; + break; + } + + ReadNetPort(&mExternalProxyAddr); + + LOGDEBUG(("socks5: connected!")); + HandshakeFinished(); + + return PR_SUCCESS; +} + +void nsSOCKSSocketInfo::SetConnectTimeout(PRIntervalTime to) { mTimeout = to; } + +PRStatus nsSOCKSSocketInfo::DoHandshake(PRFileDesc* fd, int16_t oflags) { + LOGDEBUG(("socks: DoHandshake(), state = %d", mState)); + + switch (mState) { + case SOCKS_INITIAL: + if (IsLocalProxy()) { + mState = SOCKS_DNS_COMPLETE; + mLookupStatus = NS_OK; + return ConnectToProxy(fd); + } + + return StartDNS(fd); + case SOCKS_DNS_IN_PROGRESS: + PR_SetError(PR_IN_PROGRESS_ERROR, 0); + return PR_FAILURE; + case SOCKS_DNS_COMPLETE: + return ConnectToProxy(fd); + case SOCKS_CONNECTING_TO_PROXY: + return ContinueConnectingToProxy(fd, oflags); + case SOCKS4_WRITE_CONNECT_REQUEST: + if (WriteToSocket(fd) != PR_SUCCESS) return PR_FAILURE; + WantRead(8); + mState = SOCKS4_READ_CONNECT_RESPONSE; + return PR_SUCCESS; + case SOCKS4_READ_CONNECT_RESPONSE: + if (ReadFromSocket(fd) != PR_SUCCESS) return PR_FAILURE; + return ReadV4ConnectResponse(); + + case SOCKS5_WRITE_AUTH_REQUEST: + if (WriteToSocket(fd) != PR_SUCCESS) return PR_FAILURE; + WantRead(2); + mState = SOCKS5_READ_AUTH_RESPONSE; + return PR_SUCCESS; + case SOCKS5_READ_AUTH_RESPONSE: + if (ReadFromSocket(fd) != PR_SUCCESS) return PR_FAILURE; + return ReadV5AuthResponse(); + case SOCKS5_WRITE_USERNAME_REQUEST: + if (WriteToSocket(fd) != PR_SUCCESS) return PR_FAILURE; + WantRead(2); + mState = SOCKS5_READ_USERNAME_RESPONSE; + return PR_SUCCESS; + case SOCKS5_READ_USERNAME_RESPONSE: + if (ReadFromSocket(fd) != PR_SUCCESS) return PR_FAILURE; + return ReadV5UsernameResponse(); + case SOCKS5_WRITE_CONNECT_REQUEST: + if (WriteToSocket(fd) != PR_SUCCESS) return PR_FAILURE; + + // The SOCKS 5 response to the connection request is variable + // length. First, we'll read enough to tell how long the response + // is, and will read the rest later. + WantRead(5); + mState = SOCKS5_READ_CONNECT_RESPONSE_TOP; + return PR_SUCCESS; + case SOCKS5_READ_CONNECT_RESPONSE_TOP: + if (ReadFromSocket(fd) != PR_SUCCESS) return PR_FAILURE; + return ReadV5ConnectResponseTop(); + case SOCKS5_READ_CONNECT_RESPONSE_BOTTOM: + if (ReadFromSocket(fd) != PR_SUCCESS) return PR_FAILURE; + return ReadV5ConnectResponseBottom(); + + case SOCKS_CONNECTED: + LOGERROR(("socks: already connected")); + HandshakeFinished(PR_IS_CONNECTED_ERROR); + return PR_FAILURE; + case SOCKS_FAILED: + LOGERROR(("socks: already failed")); + return PR_FAILURE; + } + + LOGERROR(("socks: executing handshake in invalid state, %d", mState)); + HandshakeFinished(PR_INVALID_STATE_ERROR); + + return PR_FAILURE; +} + +int16_t nsSOCKSSocketInfo::GetPollFlags() const { + switch (mState) { + case SOCKS_DNS_IN_PROGRESS: + case SOCKS_DNS_COMPLETE: + case SOCKS_CONNECTING_TO_PROXY: + return PR_POLL_EXCEPT | PR_POLL_WRITE; + case SOCKS4_WRITE_CONNECT_REQUEST: + case SOCKS5_WRITE_AUTH_REQUEST: + case SOCKS5_WRITE_USERNAME_REQUEST: + case SOCKS5_WRITE_CONNECT_REQUEST: + return PR_POLL_WRITE; + case SOCKS4_READ_CONNECT_RESPONSE: + case SOCKS5_READ_AUTH_RESPONSE: + case SOCKS5_READ_USERNAME_RESPONSE: + case SOCKS5_READ_CONNECT_RESPONSE_TOP: + case SOCKS5_READ_CONNECT_RESPONSE_BOTTOM: + return PR_POLL_READ; + default: + break; + } + + return 0; +} + +inline uint8_t nsSOCKSSocketInfo::ReadUint8() { + uint8_t rv; + MOZ_ASSERT(mReadOffset + sizeof(rv) <= mDataLength, + "Not enough space to pop a uint8_t!"); + rv = mData[mReadOffset]; + mReadOffset += sizeof(rv); + return rv; +} + +inline uint16_t nsSOCKSSocketInfo::ReadUint16() { + uint16_t rv; + MOZ_ASSERT(mReadOffset + sizeof(rv) <= mDataLength, + "Not enough space to pop a uint16_t!"); + memcpy(&rv, mData + mReadOffset, sizeof(rv)); + mReadOffset += sizeof(rv); + return rv; +} + +inline uint32_t nsSOCKSSocketInfo::ReadUint32() { + uint32_t rv; + MOZ_ASSERT(mReadOffset + sizeof(rv) <= mDataLength, + "Not enough space to pop a uint32_t!"); + memcpy(&rv, mData + mReadOffset, sizeof(rv)); + mReadOffset += sizeof(rv); + return rv; +} + +void nsSOCKSSocketInfo::ReadNetAddr(NetAddr* addr, uint16_t fam) { + uint32_t amt = 0; + const uint8_t* ip = mData + mReadOffset; + + addr->raw.family = fam; + if (fam == AF_INET) { + amt = sizeof(addr->inet.ip); + MOZ_ASSERT(mReadOffset + amt <= mDataLength, + "Not enough space to pop an ipv4 addr!"); + memcpy(&addr->inet.ip, ip, amt); + } else if (fam == AF_INET6) { + amt = sizeof(addr->inet6.ip.u8); + MOZ_ASSERT(mReadOffset + amt <= mDataLength, + "Not enough space to pop an ipv6 addr!"); + memcpy(addr->inet6.ip.u8, ip, amt); + } + + mReadOffset += amt; +} + +void nsSOCKSSocketInfo::ReadNetPort(NetAddr* addr) { + addr->inet.port = ReadUint16(); +} + +void nsSOCKSSocketInfo::WantRead(uint32_t sz) { + MOZ_ASSERT(mDataIoPtr == nullptr, + "WantRead() called while I/O already in progress!"); + MOZ_ASSERT(mDataLength + sz <= BUFFER_SIZE, "Can't read that much data!"); + mAmountToRead = sz; +} + +PRStatus nsSOCKSSocketInfo::ReadFromSocket(PRFileDesc* fd) { + int32_t rc; + const uint8_t* end; + + if (!mAmountToRead) { + LOGDEBUG(("socks: ReadFromSocket(), nothing to do")); + return PR_SUCCESS; + } + + if (!mDataIoPtr) { + mDataIoPtr = mData + mDataLength; + mDataLength += mAmountToRead; + } + + end = mData + mDataLength; + + while (mDataIoPtr < end) { + rc = PR_Read(fd, mDataIoPtr, end - mDataIoPtr); + if (rc <= 0) { + if (rc == 0) { + LOGERROR(("socks: proxy server closed connection")); + HandshakeFinished(PR_CONNECT_REFUSED_ERROR); + return PR_FAILURE; + } + if (PR_GetError() == PR_WOULD_BLOCK_ERROR) { + LOGDEBUG(("socks: ReadFromSocket(), want read")); + } + break; + } + + mDataIoPtr += rc; + } + + LOGDEBUG(("socks: ReadFromSocket(), have %u bytes total", + unsigned(mDataIoPtr - mData))); + if (mDataIoPtr == end) { + mDataIoPtr = nullptr; + mAmountToRead = 0; + mReadOffset = 0; + return PR_SUCCESS; + } + + return PR_FAILURE; +} + +PRStatus nsSOCKSSocketInfo::WriteToSocket(PRFileDesc* fd) { + int32_t rc; + const uint8_t* end; + + if (!mDataLength) { + LOGDEBUG(("socks: WriteToSocket(), nothing to do")); + return PR_SUCCESS; + } + + if (!mDataIoPtr) mDataIoPtr = mData; + + end = mData + mDataLength; + + while (mDataIoPtr < end) { + rc = PR_Write(fd, mDataIoPtr, end - mDataIoPtr); + if (rc < 0) { + if (PR_GetError() == PR_WOULD_BLOCK_ERROR) { + LOGDEBUG(("socks: WriteToSocket(), want write")); + } + break; + } + + mDataIoPtr += rc; + } + + if (mDataIoPtr == end) { + mDataIoPtr = nullptr; + mDataLength = 0; + mReadOffset = 0; + return PR_SUCCESS; + } + + return PR_FAILURE; +} + +static PRStatus nsSOCKSIOLayerConnect(PRFileDesc* fd, const PRNetAddr* addr, + PRIntervalTime to) { + PRStatus status; + NetAddr dst; + + nsSOCKSSocketInfo* info = (nsSOCKSSocketInfo*)fd->secret; + if (info == nullptr) return PR_FAILURE; + + if (addr->raw.family == PR_AF_INET6 && + PR_IsNetAddrType(addr, PR_IpAddrV4Mapped)) { + const uint8_t* srcp; + + LOGDEBUG(("socks: converting ipv4-mapped ipv6 address to ipv4")); + + // copied from _PR_ConvertToIpv4NetAddr() + dst.raw.family = AF_INET; + dst.inet.ip = htonl(INADDR_ANY); + dst.inet.port = htons(0); + srcp = addr->ipv6.ip.pr_s6_addr; + memcpy(&dst.inet.ip, srcp + 12, 4); + dst.inet.family = AF_INET; + dst.inet.port = addr->ipv6.port; + } else { + memcpy(&dst, addr, sizeof(dst)); + } + + info->SetDestinationAddr(dst); + info->SetConnectTimeout(to); + + do { + status = info->DoHandshake(fd, -1); + } while (status == PR_SUCCESS && !info->IsConnected()); + + return status; +} + +static PRStatus nsSOCKSIOLayerConnectContinue(PRFileDesc* fd, int16_t oflags) { + PRStatus status; + + nsSOCKSSocketInfo* info = (nsSOCKSSocketInfo*)fd->secret; + if (info == nullptr) return PR_FAILURE; + + do { + status = info->DoHandshake(fd, oflags); + } while (status == PR_SUCCESS && !info->IsConnected()); + + return status; +} + +static int16_t nsSOCKSIOLayerPoll(PRFileDesc* fd, int16_t in_flags, + int16_t* out_flags) { + nsSOCKSSocketInfo* info = (nsSOCKSSocketInfo*)fd->secret; + if (info == nullptr) return PR_FAILURE; + + if (!info->IsConnected()) { + *out_flags = 0; + return info->GetPollFlags(); + } + + return fd->lower->methods->poll(fd->lower, in_flags, out_flags); +} + +static PRStatus nsSOCKSIOLayerClose(PRFileDesc* fd) { + nsSOCKSSocketInfo* info = (nsSOCKSSocketInfo*)fd->secret; + PRDescIdentity id = PR_GetLayersIdentity(fd); + + if (info && id == nsSOCKSIOLayerIdentity) { + info->ForgetFD(); + NS_RELEASE(info); + fd->identity = PR_INVALID_IO_LAYER; + } + + return fd->lower->methods->close(fd->lower); +} + +static PRFileDesc* nsSOCKSIOLayerAccept(PRFileDesc* fd, PRNetAddr* addr, + PRIntervalTime timeout) { + // TODO: implement SOCKS support for accept + return fd->lower->methods->accept(fd->lower, addr, timeout); +} + +static int32_t nsSOCKSIOLayerAcceptRead(PRFileDesc* sd, PRFileDesc** nd, + PRNetAddr** raddr, void* buf, + int32_t amount, + PRIntervalTime timeout) { + // TODO: implement SOCKS support for accept, then read from it + return sd->lower->methods->acceptread(sd->lower, nd, raddr, buf, amount, + timeout); +} + +static PRStatus nsSOCKSIOLayerBind(PRFileDesc* fd, const PRNetAddr* addr) { + // TODO: implement SOCKS support for bind (very similar to connect) + return fd->lower->methods->bind(fd->lower, addr); +} + +static PRStatus nsSOCKSIOLayerGetName(PRFileDesc* fd, PRNetAddr* addr) { + nsSOCKSSocketInfo* info = (nsSOCKSSocketInfo*)fd->secret; + + if (info != nullptr && addr != nullptr) { + NetAddr temp; + info->GetExternalProxyAddr(temp); + NetAddrToPRNetAddr(&temp, addr); + return PR_SUCCESS; + } + + return PR_FAILURE; +} + +static PRStatus nsSOCKSIOLayerGetPeerName(PRFileDesc* fd, PRNetAddr* addr) { + nsSOCKSSocketInfo* info = (nsSOCKSSocketInfo*)fd->secret; + + if (info != nullptr && addr != nullptr) { + NetAddr temp; + info->GetDestinationAddr(temp); + NetAddrToPRNetAddr(&temp, addr); + return PR_SUCCESS; + } + + return PR_FAILURE; +} + +static PRStatus nsSOCKSIOLayerListen(PRFileDesc* fd, int backlog) { + // TODO: implement SOCKS support for listen + return fd->lower->methods->listen(fd->lower, backlog); +} + +// add SOCKS IO layer to an existing socket +nsresult nsSOCKSIOLayerAddToSocket(int32_t family, const char* host, + int32_t port, nsIProxyInfo* proxy, + int32_t socksVersion, uint32_t flags, + uint32_t tlsFlags, PRFileDesc* fd) { + NS_ENSURE_TRUE((socksVersion == 4) || (socksVersion == 5), + NS_ERROR_NOT_INITIALIZED); + + if (firstTime) { + // XXX hack until NSPR provides an official way to detect system IPv6 + // support (bug 388519) + PRFileDesc* tmpfd = PR_OpenTCPSocket(PR_AF_INET6); + if (!tmpfd) { + ipv6Supported = false; + } else { + // If the system does not support IPv6, NSPR will push + // IPv6-to-IPv4 emulation layer onto the native layer + ipv6Supported = PR_GetIdentitiesLayer(tmpfd, PR_NSPR_IO_LAYER) == tmpfd; + PR_Close(tmpfd); + } + + nsSOCKSIOLayerIdentity = PR_GetUniqueIdentity("SOCKS layer"); + nsSOCKSIOLayerMethods = *PR_GetDefaultIOMethods(); + + nsSOCKSIOLayerMethods.connect = nsSOCKSIOLayerConnect; + nsSOCKSIOLayerMethods.connectcontinue = nsSOCKSIOLayerConnectContinue; + nsSOCKSIOLayerMethods.poll = nsSOCKSIOLayerPoll; + nsSOCKSIOLayerMethods.bind = nsSOCKSIOLayerBind; + nsSOCKSIOLayerMethods.acceptread = nsSOCKSIOLayerAcceptRead; + nsSOCKSIOLayerMethods.getsockname = nsSOCKSIOLayerGetName; + nsSOCKSIOLayerMethods.getpeername = nsSOCKSIOLayerGetPeerName; + nsSOCKSIOLayerMethods.accept = nsSOCKSIOLayerAccept; + nsSOCKSIOLayerMethods.listen = nsSOCKSIOLayerListen; + nsSOCKSIOLayerMethods.close = nsSOCKSIOLayerClose; + + firstTime = false; + } + + LOGDEBUG(("Entering nsSOCKSIOLayerAddToSocket().")); + + PRFileDesc* layer; + PRStatus rv; + + layer = PR_CreateIOLayerStub(nsSOCKSIOLayerIdentity, &nsSOCKSIOLayerMethods); + if (!layer) { + LOGERROR(("PR_CreateIOLayerStub() failed.")); + return NS_ERROR_FAILURE; + } + + nsSOCKSSocketInfo* infoObject = new nsSOCKSSocketInfo(); + if (!infoObject) { + // clean up IOLayerStub + LOGERROR(("Failed to create nsSOCKSSocketInfo().")); + PR_Free(layer); // PR_CreateIOLayerStub() uses PR_Malloc(). + return NS_ERROR_FAILURE; + } + + NS_ADDREF(infoObject); + infoObject->Init(socksVersion, family, proxy, host, flags, tlsFlags); + layer->secret = (PRFilePrivate*)infoObject; + + PRDescIdentity fdIdentity = PR_GetLayersIdentity(fd); +#if defined(XP_WIN) + if (fdIdentity == mozilla::net::nsNamedPipeLayerIdentity) { + // remember named pipe fd on the info object so that we can switch + // blocking and non-blocking mode on the pipe later. + infoObject->SetNamedPipeFD(fd); + } +#endif + rv = PR_PushIOLayer(fd, fdIdentity, layer); + + if (rv == PR_FAILURE) { + LOGERROR(("PR_PushIOLayer() failed. rv = %x.", rv)); + NS_RELEASE(infoObject); + PR_Free(layer); // PR_CreateIOLayerStub() uses PR_Malloc(). + return NS_ERROR_FAILURE; + } + + return NS_OK; +} + +bool IsHostLocalTarget(const nsACString& aHost) { +#if defined(XP_UNIX) + return StringBeginsWith(aHost, "file:"_ns); +#elif defined(XP_WIN) + return IsNamedPipePath(aHost); +#else + return false; +#endif // XP_UNIX +} diff --git a/netwerk/socket/nsSOCKSIOLayer.h b/netwerk/socket/nsSOCKSIOLayer.h new file mode 100644 index 0000000000..576c96ea7f --- /dev/null +++ b/netwerk/socket/nsSOCKSIOLayer.h @@ -0,0 +1,21 @@ +/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*- + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef nsSOCKSIOLayer_h__ +#define nsSOCKSIOLayer_h__ + +#include "prio.h" +#include "nscore.h" +#include "nsIProxyInfo.h" + +nsresult nsSOCKSIOLayerAddToSocket(int32_t family, const char* host, + int32_t port, nsIProxyInfo* proxyInfo, + int32_t socksVersion, uint32_t flags, + uint32_t tlsFlags, PRFileDesc* fd); + +bool IsHostLocalTarget(const nsACString& aHost); + +#endif /* nsSOCKSIOLayer_h__ */ diff --git a/netwerk/socket/nsSOCKSSocketProvider.cpp b/netwerk/socket/nsSOCKSSocketProvider.cpp new file mode 100644 index 0000000000..fc18e8e788 --- /dev/null +++ b/netwerk/socket/nsSOCKSSocketProvider.cpp @@ -0,0 +1,98 @@ +/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*- + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nsNamedPipeIOLayer.h" +#include "nsSOCKSSocketProvider.h" +#include "nsSOCKSIOLayer.h" +#include "nsCOMPtr.h" +#include "nsError.h" + +using mozilla::OriginAttributes; +using namespace mozilla::net; + +////////////////////////////////////////////////////////////////////////// + +NS_IMPL_ISUPPORTS(nsSOCKSSocketProvider, nsISocketProvider) + +// Per-platform implemenation of OpenTCPSocket helper function +// Different platforms have special cases to handle + +#if defined(XP_WIN) +// The proxy host on Windows may be a named pipe uri, in which +// case a named-pipe (rather than a socket) should be returned +static PRFileDesc* OpenTCPSocket(int32_t family, nsIProxyInfo* proxy) { + PRFileDesc* sock = nullptr; + + nsAutoCString proxyHost; + proxy->GetHost(proxyHost); + if (IsNamedPipePath(proxyHost)) { + sock = CreateNamedPipeLayer(); + } else { + sock = PR_OpenTCPSocket(family); + } + + return sock; +} +#elif defined(XP_UNIX) +// The proxy host on UNIX systems may point to a local file uri +// in which case we should create an AF_LOCAL (UNIX Domain) socket +// instead of the requested AF_INET or AF_INET6 socket. + +// Normally,this socket would get thrown out and recreated later on +// with the proper family, but we want to do it early here so that +// we can enforce seccomp policy to blacklist socket(AF_INET) calls +// to prevent the content sandbox from creating network requests +static PRFileDesc* OpenTCPSocket(int32_t family, nsIProxyInfo* proxy) { + nsAutoCString proxyHost; + proxy->GetHost(proxyHost); + if (StringBeginsWith(proxyHost, "file://"_ns)) { + family = AF_LOCAL; + } + + return PR_OpenTCPSocket(family); +} +#else +// Default, pass-through to PR_OpenTCPSocket +static PRFileDesc* OpenTCPSocket(int32_t family, nsIProxyInfo*) { + return PR_OpenTCPSocket(family); +} +#endif + +NS_IMETHODIMP +nsSOCKSSocketProvider::NewSocket(int32_t family, const char* host, int32_t port, + nsIProxyInfo* proxy, + const OriginAttributes& originAttributes, + uint32_t flags, uint32_t tlsFlags, + PRFileDesc** result, + nsITLSSocketControl** tlsSocketControl) { + PRFileDesc* sock = OpenTCPSocket(family, proxy); + if (!sock) { + return NS_ERROR_OUT_OF_MEMORY; + } + + nsresult rv = nsSOCKSIOLayerAddToSocket(family, host, port, proxy, mVersion, + flags, tlsFlags, sock); + if (NS_SUCCEEDED(rv)) { + *result = sock; + return NS_OK; + } + + return NS_ERROR_SOCKET_CREATE_FAILED; +} + +NS_IMETHODIMP +nsSOCKSSocketProvider::AddToSocket(int32_t family, const char* host, + int32_t port, nsIProxyInfo* proxy, + const OriginAttributes& originAttributes, + uint32_t flags, uint32_t tlsFlags, + PRFileDesc* sock, + nsITLSSocketControl** tlsSocketControl) { + nsresult rv = nsSOCKSIOLayerAddToSocket(family, host, port, proxy, mVersion, + flags, tlsFlags, sock); + + if (NS_FAILED(rv)) rv = NS_ERROR_SOCKET_CREATE_FAILED; + return rv; +} diff --git a/netwerk/socket/nsSOCKSSocketProvider.h b/netwerk/socket/nsSOCKSSocketProvider.h new file mode 100644 index 0000000000..2733856733 --- /dev/null +++ b/netwerk/socket/nsSOCKSSocketProvider.h @@ -0,0 +1,28 @@ +/* -*- Mode: C++; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 2 -*- + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef nsSOCKSSocketProvider_h__ +#define nsSOCKSSocketProvider_h__ + +#include "nsISocketProvider.h" + +// values for ctor's |version| argument +enum { NS_SOCKS_VERSION_4 = 4, NS_SOCKS_VERSION_5 = 5 }; + +class nsSOCKSSocketProvider : public nsISocketProvider { + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSISOCKETPROVIDER + + explicit nsSOCKSSocketProvider(uint32_t version) : mVersion(version) {} + + private: + virtual ~nsSOCKSSocketProvider() = default; + + uint32_t mVersion; // NS_SOCKS_VERSION_4 or 5 +}; + +#endif /* nsSOCKSSocketProvider_h__ */ diff --git a/netwerk/socket/nsSocketProviderService.cpp b/netwerk/socket/nsSocketProviderService.cpp new file mode 100644 index 0000000000..737820890e --- /dev/null +++ b/netwerk/socket/nsSocketProviderService.cpp @@ -0,0 +1,72 @@ +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nsString.h" +#include "nsISocketProvider.h" +#include "nsError.h" +#include "nsNSSComponent.h" +#include "nsSOCKSSocketProvider.h" +#include "nsSocketProviderService.h" +#include "nsSSLSocketProvider.h" +#include "nsTLSSocketProvider.h" +#include "nsUDPSocketProvider.h" +#include "mozilla/ClearOnShutdown.h" +#include "nsThreadUtils.h" +#include "nsCRT.h" + +mozilla::StaticRefPtr<nsSocketProviderService> + nsSocketProviderService::gSingleton; + +//////////////////////////////////////////////////////////////////////////////// + +already_AddRefed<nsISocketProviderService> +nsSocketProviderService::GetOrCreate() { + RefPtr<nsSocketProviderService> inst; + if (gSingleton) { + inst = gSingleton; + } else { + inst = new nsSocketProviderService(); + gSingleton = inst; + if (NS_IsMainThread()) { + mozilla::ClearOnShutdown(&gSingleton); + } else { + NS_DispatchToMainThread(NS_NewRunnableFunction( + "net::nsSocketProviderService::GetOrCreate", + []() -> void { mozilla::ClearOnShutdown(&gSingleton); })); + } + } + return inst.forget(); +} + +NS_IMPL_ISUPPORTS(nsSocketProviderService, nsISocketProviderService) + +//////////////////////////////////////////////////////////////////////////////// + +NS_IMETHODIMP +nsSocketProviderService::GetSocketProvider(const char* type, + nsISocketProvider** result) { + nsCOMPtr<nsISocketProvider> inst; + if (!nsCRT::strcmp(type, "ssl") && + (XRE_IsParentProcess() || XRE_IsSocketProcess()) && + EnsureNSSInitializedChromeOrContent()) { + inst = new nsSSLSocketProvider(); + } else if (!nsCRT::strcmp(type, "starttls") && + (XRE_IsParentProcess() || XRE_IsSocketProcess()) && + EnsureNSSInitializedChromeOrContent()) { + inst = new nsTLSSocketProvider(); + } else if (!nsCRT::strcmp(type, "socks")) { + inst = new nsSOCKSSocketProvider(NS_SOCKS_VERSION_5); + } else if (!nsCRT::strcmp(type, "socks4")) { + inst = new nsSOCKSSocketProvider(NS_SOCKS_VERSION_4); + } else if (!nsCRT::strcmp(type, "udp")) { + inst = new nsUDPSocketProvider(); + } else { + return NS_ERROR_UNKNOWN_SOCKET_TYPE; + } + inst.forget(result); + return NS_OK; +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/netwerk/socket/nsSocketProviderService.h b/netwerk/socket/nsSocketProviderService.h new file mode 100644 index 0000000000..85b2ba04a4 --- /dev/null +++ b/netwerk/socket/nsSocketProviderService.h @@ -0,0 +1,26 @@ +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef nsSocketProviderService_h__ +#define nsSocketProviderService_h__ + +#include "nsISocketProviderService.h" +#include "mozilla/StaticPtr.h" + +class nsSocketProviderService : public nsISocketProviderService { + nsSocketProviderService() = default; + virtual ~nsSocketProviderService() = default; + + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSISOCKETPROVIDERSERVICE + + static already_AddRefed<nsISocketProviderService> GetOrCreate(); + + private: + static mozilla::StaticRefPtr<nsSocketProviderService> gSingleton; +}; + +#endif /* nsSocketProviderService_h__ */ diff --git a/netwerk/socket/nsUDPSocketProvider.cpp b/netwerk/socket/nsUDPSocketProvider.cpp new file mode 100644 index 0000000000..e99ef36a6d --- /dev/null +++ b/netwerk/socket/nsUDPSocketProvider.cpp @@ -0,0 +1,39 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nsUDPSocketProvider.h" + +#include "nspr.h" + +using mozilla::OriginAttributes; + +NS_IMPL_ISUPPORTS(nsUDPSocketProvider, nsISocketProvider) + +NS_IMETHODIMP +nsUDPSocketProvider::NewSocket(int32_t aFamily, const char* aHost, + int32_t aPort, nsIProxyInfo* aProxy, + const OriginAttributes& originAttributes, + uint32_t aFlags, uint32_t aTlsFlags, + PRFileDesc** aFileDesc, + nsITLSSocketControl** aTLSSocketControl) { + NS_ENSURE_ARG_POINTER(aFileDesc); + + PRFileDesc* udpFD = PR_OpenUDPSocket(aFamily); + if (!udpFD) return NS_ERROR_FAILURE; + + *aFileDesc = udpFD; + return NS_OK; +} + +NS_IMETHODIMP +nsUDPSocketProvider::AddToSocket(int32_t aFamily, const char* aHost, + int32_t aPort, nsIProxyInfo* aProxy, + const OriginAttributes& originAttributes, + uint32_t aFlags, uint32_t aTlsFlags, + struct PRFileDesc* aFileDesc, + nsITLSSocketControl** aTLSSocketControl) { + // does not make sense to strap a UDP socket onto an existing socket + MOZ_ASSERT_UNREACHABLE("Cannot layer UDP socket on an existing socket"); + return NS_ERROR_UNEXPECTED; +} diff --git a/netwerk/socket/nsUDPSocketProvider.h b/netwerk/socket/nsUDPSocketProvider.h new file mode 100644 index 0000000000..fb1fe749d9 --- /dev/null +++ b/netwerk/socket/nsUDPSocketProvider.h @@ -0,0 +1,20 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef nsUDPSocketProvider_h__ +#define nsUDPSocketProvider_h__ + +#include "nsISocketProvider.h" +#include "mozilla/Attributes.h" + +class nsUDPSocketProvider final : public nsISocketProvider { + public: + NS_DECL_THREADSAFE_ISUPPORTS + NS_DECL_NSISOCKETPROVIDER + + private: + ~nsUDPSocketProvider() = default; +}; + +#endif /* nsUDPSocketProvider_h__ */ |