Skip to content

Commit

Permalink
io: implement StreamedFd
Browse files Browse the repository at this point in the history
StreamedFd is a higher-level wrapper around an AsyncFd which provides the PollEvented optimization to users.

Fixes #5324.
  • Loading branch information
= committed Jan 6, 2023
1 parent 31c7e82 commit 83a398c
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tokio-util/src/io/mod.rs
Expand Up @@ -20,6 +20,11 @@ mod stream_reader;
cfg_io_util! {
mod sync_bridge;
pub use self::sync_bridge::SyncIoBridge;

#[cfg(unix)]
mod stream_fd;
#[cfg(unix)]
pub use stream_fd::*;
}

pub use self::copy_to_bytes::CopyToBytes;
Expand Down
132 changes: 132 additions & 0 deletions tokio-util/src/io/stream_fd.rs
@@ -0,0 +1,132 @@
use std::io;
use std::io::{Error, Read, Write};
use std::os::fd::AsRawFd;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use tokio::io::unix::AsyncFd;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

/// Provides async reading and writing semantics to a pollable file descriptor that is a byte
/// stream.
///
/// [`AsyncFd`] provides a way to poll file descriptors for IO readiness, but leaves reading and
/// writing to the user. This is a higher-level utility which handles this for users.
///
/// # Warning
/// The underlying IO source this is constructed from must not be capable of nonblocking reads and
/// writes, and must be pollable.
///
/// The underlying IO source must also be a continuous stream of bytes in either direction. It must
/// be guaranteed that a partial read or write signals a loss of readiness.
///
/// The underlying IO source must also be self-flushing. This will assume that flushing is a no-op.
///
///
/// [`AsyncFd`]: struct@tokio::io::unix::AsyncFd
#[derive(Debug)]
pub struct StreamFd<T>
where
T: AsRawFd,
{
inner: AsyncFd<T>,
}

impl<T> StreamFd<T>
where
T: AsRawFd,
{
/// Construct a new StreamFd from an IO source.
///
/// # Panics
/// Panics if called from outside a tokio runtime context.
///
/// [`RawFd`]: struct@std::os::fd::RawFd
pub fn new(fd: T) -> io::Result<Self> {
let inner = AsyncFd::new(fd)?;

Ok(Self { inner })
}
}

// note: taken from PollEvented
impl<T> AsyncRead for StreamFd<T>
where
T: AsRawFd + Read + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();

loop {
let mut guard = ready!(this.inner.poll_read_ready_mut(cx))?;

// safety: we will not be reading these bytes
let b = unsafe {
&mut *(buf.unfilled_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8])
};
let len = b.len();

match guard.get_inner_mut().read(b) {
Ok(n) => {
if n > 0 && n < len {
guard.clear_ready();
}

// Safety: We trust `File::read` to have filled up `n` bytes in the
// buffer.
unsafe { buf.assume_init(n) };
buf.advance(n);
return Poll::Ready(Ok(()));
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
guard.clear_ready();
}
Err(e) => return Poll::Ready(Err(e)),
}
}
}
}

impl<T> AsyncWrite for StreamFd<T>
where
T: AsRawFd + Write + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
let this = self.get_mut();

loop {
let mut guard = ready!(this.inner.poll_write_ready_mut(cx))?;

match guard.get_inner_mut().write(buf) {
Ok(n) => {
// if we write only part of our buffer, this is sufficient on unix to show
// that the socket buffer is full
if n > 0 && n < buf.len() {
guard.clear_ready();
}

return Poll::Ready(Ok(n));
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
guard.clear_ready();
}
Err(e) => return Poll::Ready(Err(e)),
}
}
}

fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Poll::Ready(Ok(()))
}

fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
unimplemented!("Shutdown is not implemented for this type")
}
}

0 comments on commit 83a398c

Please sign in to comment.