use crate::codec::{Decoder, Encoder}; use futures_core::Stream; use tokio::{io::ReadBuf, net::UdpSocket}; use bytes::{BufMut, BytesMut}; use futures_core::ready; use futures_sink::Sink; use std::pin::Pin; use std::task::{Context, Poll}; use std::{ borrow::Borrow, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, }; use std::{io, mem::MaybeUninit}; /// A unified [`Stream`] and [`Sink`] interface to an underlying `UdpSocket`, using /// the `Encoder` and `Decoder` traits to encode and decode frames. /// /// Raw UDP sockets work with datagrams, but higher-level code usually wants to /// batch these into meaningful chunks, called "frames". This method layers /// framing on top of this socket by using the `Encoder` and `Decoder` traits to /// handle encoding and decoding of messages frames. Note that the incoming and /// outgoing frame types may be distinct. /// /// This function returns a *single* object that is both [`Stream`] and [`Sink`]; /// grouping this into a single object is often useful for layering things which /// require both read and write access to the underlying object. /// /// If you want to work more directly with the streams and sink, consider /// calling [`split`] on the `UdpFramed` returned by this method, which will break /// them into separate objects, allowing them to interact more easily. /// /// [`Stream`]: futures_core::Stream /// [`Sink`]: futures_sink::Sink /// [`split`]: https://docs.rs/futures/0.3/futures/stream/trait.StreamExt.html#method.split #[must_use = "sinks do nothing unless polled"] #[derive(Debug)] pub struct UdpFramed { socket: T, codec: C, rd: BytesMut, wr: BytesMut, out_addr: SocketAddr, flushed: bool, is_readable: bool, current_addr: Option, } const INITIAL_RD_CAPACITY: usize = 64 * 1024; const INITIAL_WR_CAPACITY: usize = 8 * 1024; impl Unpin for UdpFramed {} impl Stream for UdpFramed where T: Borrow, C: Decoder, { type Item = Result<(C::Item, SocketAddr), C::Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let pin = self.get_mut(); pin.rd.reserve(INITIAL_RD_CAPACITY); loop { // Are there still bytes left in the read buffer to decode? if pin.is_readable { if let Some(frame) = pin.codec.decode_eof(&mut pin.rd)? { let current_addr = pin .current_addr .expect("will always be set before this line is called"); return Poll::Ready(Some(Ok((frame, current_addr)))); } // if this line has been reached then decode has returned `None`. pin.is_readable = false; pin.rd.clear(); } // We're out of data. Try and fetch more data to decode let addr = unsafe { // Convert `&mut [MaybeUnit]` to `&mut [u8]` because we will be // writing to it via `poll_recv_from` and therefore initializing the memory. let buf = &mut *(pin.rd.chunk_mut() as *mut _ as *mut [MaybeUninit]); let mut read = ReadBuf::uninit(buf); let ptr = read.filled().as_ptr(); let res = ready!(pin.socket.borrow().poll_recv_from(cx, &mut read)); assert_eq!(ptr, read.filled().as_ptr()); let addr = res?; pin.rd.advance_mut(read.filled().len()); addr }; pin.current_addr = Some(addr); pin.is_readable = true; } } } impl Sink<(I, SocketAddr)> for UdpFramed where T: Borrow, C: Encoder, { type Error = C::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if !self.flushed { match self.poll_flush(cx)? { Poll::Ready(()) => {} Poll::Pending => return Poll::Pending, } } Poll::Ready(Ok(())) } fn start_send(self: Pin<&mut Self>, item: (I, SocketAddr)) -> Result<(), Self::Error> { let (frame, out_addr) = item; let pin = self.get_mut(); pin.codec.encode(frame, &mut pin.wr)?; pin.out_addr = out_addr; pin.flushed = false; Ok(()) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.flushed { return Poll::Ready(Ok(())); } let Self { ref socket, ref mut out_addr, ref mut wr, .. } = *self; let n = ready!(socket.borrow().poll_send_to(cx, wr, *out_addr))?; let wrote_all = n == self.wr.len(); self.wr.clear(); self.flushed = true; let res = if wrote_all { Ok(()) } else { Err(io::Error::new( io::ErrorKind::Other, "failed to write entire datagram to socket", ) .into()) }; Poll::Ready(res) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { ready!(self.poll_flush(cx))?; Poll::Ready(Ok(())) } } impl UdpFramed where T: Borrow, { /// Create a new `UdpFramed` backed by the given socket and codec. /// /// See struct level documentation for more details. pub fn new(socket: T, codec: C) -> UdpFramed { Self { socket, codec, out_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)), rd: BytesMut::with_capacity(INITIAL_RD_CAPACITY), wr: BytesMut::with_capacity(INITIAL_WR_CAPACITY), flushed: true, is_readable: false, current_addr: None, } } /// Returns a reference to the underlying I/O stream wrapped by `Framed`. /// /// # Note /// /// Care should be taken to not tamper with the underlying stream of data /// coming in as it may corrupt the stream of frames otherwise being worked /// with. pub fn get_ref(&self) -> &T { &self.socket } /// Returns a mutable reference to the underlying I/O stream wrapped by `Framed`. /// /// # Note /// /// Care should be taken to not tamper with the underlying stream of data /// coming in as it may corrupt the stream of frames otherwise being worked /// with. pub fn get_mut(&mut self) -> &mut T { &mut self.socket } /// Returns a reference to the underlying codec wrapped by /// `Framed`. /// /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec(&self) -> &C { &self.codec } /// Returns a mutable reference to the underlying codec wrapped by /// `UdpFramed`. /// /// Note that care should be taken to not tamper with the underlying codec /// as it may corrupt the stream of frames otherwise being worked with. pub fn codec_mut(&mut self) -> &mut C { &mut self.codec } /// Returns a reference to the read buffer. pub fn read_buffer(&self) -> &BytesMut { &self.rd } /// Returns a mutable reference to the read buffer. pub fn read_buffer_mut(&mut self) -> &mut BytesMut { &mut self.rd } /// Consumes the `Framed`, returning its underlying I/O stream. pub fn into_inner(self) -> T { self.socket } }