summaryrefslogtreecommitdiffstats
path: root/third_party/rust/neqo-crypto/src/agentio.rs
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/rust/neqo-crypto/src/agentio.rs')
-rw-r--r--third_party/rust/neqo-crypto/src/agentio.rs393
1 files changed, 393 insertions, 0 deletions
diff --git a/third_party/rust/neqo-crypto/src/agentio.rs b/third_party/rust/neqo-crypto/src/agentio.rs
new file mode 100644
index 0000000000..1d39b2398a
--- /dev/null
+++ b/third_party/rust/neqo-crypto/src/agentio.rs
@@ -0,0 +1,393 @@
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+use crate::constants::{ContentType, Epoch};
+use crate::err::{nspr, Error, PR_SetError, Res};
+use crate::prio;
+use crate::ssl;
+
+use neqo_common::{hex, hex_with_len, qtrace};
+use std::cmp::min;
+use std::convert::{TryFrom, TryInto};
+use std::fmt;
+use std::mem;
+use std::ops::Deref;
+use std::os::raw::{c_uint, c_void};
+use std::pin::Pin;
+use std::ptr::{null, null_mut};
+use std::vec::Vec;
+
+// Alias common types.
+type PrFd = *mut prio::PRFileDesc;
+type PrStatus = prio::PRStatus::Type;
+const PR_SUCCESS: PrStatus = prio::PRStatus::PR_SUCCESS;
+const PR_FAILURE: PrStatus = prio::PRStatus::PR_FAILURE;
+
+/// Convert a pinned, boxed object into a void pointer.
+pub fn as_c_void<T: Unpin>(pin: &mut Pin<Box<T>>) -> *mut c_void {
+ (Pin::into_inner(pin.as_mut()) as *mut T).cast()
+}
+
+/// A slice of the output.
+#[derive(Default)]
+pub struct Record {
+ pub epoch: Epoch,
+ pub ct: ContentType,
+ pub data: Vec<u8>,
+}
+
+impl Record {
+ #[must_use]
+ pub fn new(epoch: Epoch, ct: ContentType, data: &[u8]) -> Self {
+ Self {
+ epoch,
+ ct,
+ data: data.to_vec(),
+ }
+ }
+
+ // Shoves this record into the socket, returns true if blocked.
+ pub(crate) fn write(self, fd: *mut ssl::PRFileDesc) -> Res<()> {
+ qtrace!("write {:?}", self);
+ unsafe {
+ ssl::SSL_RecordLayerData(
+ fd,
+ self.epoch,
+ ssl::SSLContentType::Type::from(self.ct),
+ self.data.as_ptr(),
+ c_uint::try_from(self.data.len())?,
+ )
+ }
+ }
+}
+
+impl fmt::Debug for Record {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(
+ f,
+ "Record {:?}:{:?} {}",
+ self.epoch,
+ self.ct,
+ hex_with_len(&self.data[..])
+ )
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct RecordList {
+ records: Vec<Record>,
+}
+
+impl RecordList {
+ fn append(&mut self, epoch: Epoch, ct: ContentType, data: &[u8]) {
+ self.records.push(Record::new(epoch, ct, data));
+ }
+
+ #[allow(clippy::unused_self)]
+ unsafe extern "C" fn ingest(
+ _fd: *mut ssl::PRFileDesc,
+ epoch: ssl::PRUint16,
+ ct: ssl::SSLContentType::Type,
+ data: *const ssl::PRUint8,
+ len: c_uint,
+ arg: *mut c_void,
+ ) -> ssl::SECStatus {
+ let records = arg.cast::<Self>().as_mut().unwrap();
+
+ let slice = std::slice::from_raw_parts(data, len as usize);
+ records.append(epoch, ContentType::try_from(ct).unwrap(), slice);
+ ssl::SECSuccess
+ }
+
+ /// Create a new record list.
+ pub(crate) fn setup(fd: *mut ssl::PRFileDesc) -> Res<Pin<Box<Self>>> {
+ let mut records = Box::pin(Self::default());
+ unsafe {
+ ssl::SSL_RecordLayerWriteCallback(fd, Some(Self::ingest), as_c_void(&mut records))
+ }?;
+ Ok(records)
+ }
+}
+
+impl Deref for RecordList {
+ type Target = Vec<Record>;
+ #[must_use]
+ fn deref(&self) -> &Vec<Record> {
+ &self.records
+ }
+}
+
+pub struct RecordListIter(std::vec::IntoIter<Record>);
+
+impl Iterator for RecordListIter {
+ type Item = Record;
+ fn next(&mut self) -> Option<Self::Item> {
+ self.0.next()
+ }
+}
+
+impl IntoIterator for RecordList {
+ type Item = Record;
+ type IntoIter = RecordListIter;
+ #[must_use]
+ fn into_iter(self) -> Self::IntoIter {
+ RecordListIter(self.records.into_iter())
+ }
+}
+
+pub struct AgentIoInputContext<'a> {
+ input: &'a mut AgentIoInput,
+}
+
+impl<'a> Drop for AgentIoInputContext<'a> {
+ fn drop(&mut self) {
+ self.input.reset();
+ }
+}
+
+#[derive(Debug)]
+struct AgentIoInput {
+ // input is data that is read by TLS.
+ input: *const u8,
+ // input_available is how much data is left for reading.
+ available: usize,
+}
+
+impl AgentIoInput {
+ fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> {
+ assert!(self.input.is_null());
+ self.input = input.as_ptr();
+ self.available = input.len();
+ qtrace!("AgentIoInput wrap {:p}", self.input);
+ AgentIoInputContext { input: self }
+ }
+
+ // Take the data provided as input and provide it to the TLS stack.
+ fn read_input(&mut self, buf: *mut u8, count: usize) -> Res<usize> {
+ let amount = min(self.available, count);
+ if amount == 0 {
+ unsafe {
+ PR_SetError(nspr::PR_WOULD_BLOCK_ERROR, 0);
+ }
+ return Err(Error::NoDataAvailable);
+ }
+
+ let src = unsafe { std::slice::from_raw_parts(self.input, amount) };
+ qtrace!([self], "read {}", hex(src));
+ let dst = unsafe { std::slice::from_raw_parts_mut(buf, amount) };
+ dst.copy_from_slice(src);
+ self.input = self.input.wrapping_add(amount);
+ self.available -= amount;
+ Ok(amount)
+ }
+
+ fn reset(&mut self) {
+ qtrace!([self], "reset");
+ self.input = null();
+ self.available = 0;
+ }
+}
+
+impl ::std::fmt::Display for AgentIoInput {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "AgentIoInput {:p}", self.input)
+ }
+}
+
+#[derive(Debug)]
+pub struct AgentIo {
+ // input collects the input we might provide to TLS.
+ input: AgentIoInput,
+
+ // output contains data that is written by TLS.
+ output: Vec<u8>,
+}
+
+impl AgentIo {
+ pub fn new() -> Self {
+ Self {
+ input: AgentIoInput {
+ input: null(),
+ available: 0,
+ },
+ output: Vec::new(),
+ }
+ }
+
+ unsafe fn borrow(fd: &mut PrFd) -> &mut Self {
+ #[allow(clippy::cast_ptr_alignment)]
+ (**fd).secret.cast::<Self>().as_mut().unwrap()
+ }
+
+ pub fn wrap<'a: 'c, 'b: 'c, 'c>(&'a mut self, input: &'b [u8]) -> AgentIoInputContext<'c> {
+ assert_eq!(self.output.len(), 0);
+ self.input.wrap(input)
+ }
+
+ // Stage output from TLS into the output buffer.
+ fn save_output(&mut self, buf: *const u8, count: usize) {
+ let slice = unsafe { std::slice::from_raw_parts(buf, count) };
+ qtrace!([self], "save output {}", hex(slice));
+ self.output.extend_from_slice(slice);
+ }
+
+ pub fn take_output(&mut self) -> Vec<u8> {
+ qtrace!([self], "take output");
+ mem::take(&mut self.output)
+ }
+}
+
+impl ::std::fmt::Display for AgentIo {
+ fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ write!(f, "AgentIo")
+ }
+}
+
+unsafe extern "C" fn agent_close(fd: PrFd) -> PrStatus {
+ (*fd).secret = null_mut();
+ if let Some(dtor) = (*fd).dtor {
+ dtor(fd);
+ }
+ PR_SUCCESS
+}
+
+unsafe extern "C" fn agent_read(mut fd: PrFd, buf: *mut c_void, amount: prio::PRInt32) -> PrStatus {
+ let io = AgentIo::borrow(&mut fd);
+ if let Ok(a) = usize::try_from(amount) {
+ match io.input.read_input(buf.cast(), a) {
+ Ok(_) => PR_SUCCESS,
+ Err(_) => PR_FAILURE,
+ }
+ } else {
+ PR_FAILURE
+ }
+}
+
+unsafe extern "C" fn agent_recv(
+ mut fd: PrFd,
+ buf: *mut c_void,
+ amount: prio::PRInt32,
+ flags: prio::PRIntn,
+ _timeout: prio::PRIntervalTime,
+) -> prio::PRInt32 {
+ let io = AgentIo::borrow(&mut fd);
+ if flags != 0 {
+ return PR_FAILURE;
+ }
+ if let Ok(a) = usize::try_from(amount) {
+ match io.input.read_input(buf.cast(), a) {
+ Ok(v) => prio::PRInt32::try_from(v).unwrap_or(PR_FAILURE),
+ Err(_) => PR_FAILURE,
+ }
+ } else {
+ PR_FAILURE
+ }
+}
+
+unsafe extern "C" fn agent_write(
+ mut fd: PrFd,
+ buf: *const c_void,
+ amount: prio::PRInt32,
+) -> PrStatus {
+ let io = AgentIo::borrow(&mut fd);
+ if let Ok(a) = usize::try_from(amount) {
+ io.save_output(buf.cast(), a);
+ amount
+ } else {
+ PR_FAILURE
+ }
+}
+
+unsafe extern "C" fn agent_send(
+ mut fd: PrFd,
+ buf: *const c_void,
+ amount: prio::PRInt32,
+ flags: prio::PRIntn,
+ _timeout: prio::PRIntervalTime,
+) -> prio::PRInt32 {
+ let io = AgentIo::borrow(&mut fd);
+
+ if flags != 0 {
+ return PR_FAILURE;
+ }
+ if let Ok(a) = usize::try_from(amount) {
+ io.save_output(buf.cast(), a);
+ amount
+ } else {
+ PR_FAILURE
+ }
+}
+
+unsafe extern "C" fn agent_available(mut fd: PrFd) -> prio::PRInt32 {
+ let io = AgentIo::borrow(&mut fd);
+ io.input.available.try_into().unwrap_or(PR_FAILURE)
+}
+
+unsafe extern "C" fn agent_available64(mut fd: PrFd) -> prio::PRInt64 {
+ let io = AgentIo::borrow(&mut fd);
+ io.input
+ .available
+ .try_into()
+ .unwrap_or_else(|_| PR_FAILURE.into())
+}
+
+#[allow(clippy::cast_possible_truncation)]
+unsafe extern "C" fn agent_getname(_fd: PrFd, addr: *mut prio::PRNetAddr) -> PrStatus {
+ let a = addr.as_mut().unwrap();
+ // Cast is safe because prio::PR_AF_INET is 2
+ a.inet.family = prio::PR_AF_INET as prio::PRUint16;
+ a.inet.port = 0;
+ a.inet.ip = 0;
+ PR_SUCCESS
+}
+
+unsafe extern "C" fn agent_getsockopt(_fd: PrFd, opt: *mut prio::PRSocketOptionData) -> PrStatus {
+ let o = opt.as_mut().unwrap();
+ if o.option == prio::PRSockOption::PR_SockOpt_Nonblocking {
+ o.value.non_blocking = 1;
+ return PR_SUCCESS;
+ }
+ PR_FAILURE
+}
+
+pub const METHODS: &prio::PRIOMethods = &prio::PRIOMethods {
+ file_type: prio::PRDescType::PR_DESC_LAYERED,
+ close: Some(agent_close),
+ read: Some(agent_read),
+ write: Some(agent_write),
+ available: Some(agent_available),
+ available64: Some(agent_available64),
+ fsync: None,
+ seek: None,
+ seek64: None,
+ fileInfo: None,
+ fileInfo64: None,
+ writev: None,
+ connect: None,
+ accept: None,
+ bind: None,
+ listen: None,
+ shutdown: None,
+ recv: Some(agent_recv),
+ send: Some(agent_send),
+ recvfrom: None,
+ sendto: None,
+ poll: None,
+ acceptread: None,
+ transmitfile: None,
+ getsockname: Some(agent_getname),
+ getpeername: Some(agent_getname),
+ reserved_fn_6: None,
+ reserved_fn_5: None,
+ getsocketoption: Some(agent_getsockopt),
+ setsocketoption: None,
+ sendfile: None,
+ connectcontinue: None,
+ reserved_fn_3: None,
+ reserved_fn_2: None,
+ reserved_fn_1: None,
+ reserved_fn_0: None,
+};