diff --git a/tokio/src/io/util/read_to_end.rs b/tokio/src/io/util/read_to_end.rs index f4a564d7dd4..8edba2a1711 100644 --- a/tokio/src/io/util/read_to_end.rs +++ b/tokio/src/io/util/read_to_end.rs @@ -1,11 +1,11 @@ use crate::io::util::vec_with_initialized::{into_read_buf_parts, VecU8, VecWithInitialized}; -use crate::io::AsyncRead; +use crate::io::{AsyncRead, ReadBuf}; use pin_project_lite::pin_project; use std::future::Future; use std::io; use std::marker::PhantomPinned; -use std::mem; +use std::mem::{self, MaybeUninit}; use std::pin::Pin; use std::task::{Context, Poll}; @@ -67,16 +67,47 @@ fn poll_read_to_end( // has 4 bytes while still making large reads if the reader does have a ton // of data to return. Simply tacking on an extra DEFAULT_BUF_SIZE space every // time is 4,500 times (!) slower than this if the reader has a very small - // amount of data to return. - buf.reserve(32); + // amount of data to return. When the vector is full with its starting + // capacity, we first try to read into a small buffer to see if we reached + // an EOF. This only happens when the starting capacity is >= NUM_BYTES, since + // we allocate at least NUM_BYTES each time. This avoids the unnecessary + // allocation that we attempt before reading into the vector. + + const NUM_BYTES: usize = 32; + let try_small_read = buf.try_small_read_first(NUM_BYTES); // Get a ReadBuf into the vector. - let mut read_buf = buf.get_read_buf(); + let mut read_buf; + let poll_result; + + let n = if try_small_read { + // Read some bytes using a small read. + let mut small_buf: [MaybeUninit; NUM_BYTES] = [MaybeUninit::uninit(); NUM_BYTES]; + let mut small_read_buf = ReadBuf::uninit(&mut small_buf); + poll_result = read.poll_read(cx, &mut small_read_buf); + let to_write = small_read_buf.filled(); + + // Ensure we have enough space to fill our vector with what we read. + read_buf = buf.get_read_buf(); + if to_write.len() > read_buf.remaining() { + buf.reserve(NUM_BYTES); + read_buf = buf.get_read_buf(); + } + read_buf.put_slice(to_write); + + to_write.len() + } else { + // Ensure we have enough space for reading. + buf.reserve(NUM_BYTES); + read_buf = buf.get_read_buf(); + + // Read data directly into vector. + let filled_before = read_buf.filled().len(); + poll_result = read.poll_read(cx, &mut read_buf); - let filled_before = read_buf.filled().len(); - let poll_result = read.poll_read(cx, &mut read_buf); - let filled_after = read_buf.filled().len(); - let n = filled_after - filled_before; + // Compute the number of bytes read. + read_buf.filled().len() - filled_before + }; // Update the length of the vector using the result of poll_read. let read_buf_parts = into_read_buf_parts(read_buf); @@ -87,11 +118,11 @@ fn poll_read_to_end( // In this case, nothing should have been read. However we still // update the vector in case the poll_read call initialized parts of // the vector's unused capacity. - debug_assert_eq!(filled_before, filled_after); + debug_assert_eq!(n, 0); Poll::Pending } Poll::Ready(Err(err)) => { - debug_assert_eq!(filled_before, filled_after); + debug_assert_eq!(n, 0); Poll::Ready(Err(err)) } Poll::Ready(Ok(())) => Poll::Ready(Ok(n)), diff --git a/tokio/src/io/util/vec_with_initialized.rs b/tokio/src/io/util/vec_with_initialized.rs index a9b94e39d57..f527492dcd9 100644 --- a/tokio/src/io/util/vec_with_initialized.rs +++ b/tokio/src/io/util/vec_with_initialized.rs @@ -28,6 +28,7 @@ pub(crate) struct VecWithInitialized { // The number of initialized bytes in the vector. // Always between `vec.len()` and `vec.capacity()`. num_initialized: usize, + starting_capacity: usize, } impl VecWithInitialized> { @@ -47,6 +48,7 @@ where // to its length are initialized. Self { num_initialized: vec.as_mut().len(), + starting_capacity: vec.as_ref().capacity(), vec, } } @@ -111,6 +113,15 @@ where vec.set_len(parts.len); } } + + // Returns a boolean telling the caller to try reading into a small local buffer first if true. + // Doing so would avoid overallocating when vec is filled to capacity and we reached EOF. + pub(crate) fn try_small_read_first(&self, num_bytes: usize) -> bool { + let vec = self.vec.as_ref(); + vec.capacity() - vec.len() < num_bytes + && self.starting_capacity == vec.capacity() + && self.starting_capacity >= num_bytes + } } pub(crate) struct ReadBufParts { diff --git a/tokio/tests/io_read_to_end.rs b/tokio/tests/io_read_to_end.rs index 171e6d6480e..29b60becec4 100644 --- a/tokio/tests/io_read_to_end.rs +++ b/tokio/tests/io_read_to_end.rs @@ -5,6 +5,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; use tokio_test::assert_ok; +use tokio_test::io::Builder; #[tokio::test] async fn read_to_end() { @@ -76,3 +77,48 @@ async fn read_to_end_uninit() { test.read_to_end(&mut buf).await.unwrap(); assert_eq!(buf.len(), 33); } + +#[tokio::test] +async fn read_to_end_doesnt_grow_with_capacity() { + let arr: Vec = (0..100).collect(); + + // We only test from 32 since we allocate at least 32 bytes each time + for len in 32..100 { + let bytes = &arr[..len]; + for split in 0..len { + for cap in 0..101 { + let mut mock = if split == 0 { + Builder::new().read(bytes).build() + } else { + Builder::new() + .read(&bytes[..split]) + .read(&bytes[split..]) + .build() + }; + let mut buf = Vec::with_capacity(cap); + AsyncReadExt::read_to_end(&mut mock, &mut buf) + .await + .unwrap(); + // It has the right data. + assert_eq!(buf.as_slice(), bytes); + // Unless cap was smaller than length, then we did not reallocate. + if cap >= len { + assert_eq!(buf.capacity(), cap); + } + } + } + } +} + +#[tokio::test] +async fn read_to_end_grows_capacity_if_unfit() { + let bytes = b"the_vector_startingcap_will_be_smaller"; + let mut mock = Builder::new().read(bytes).build(); + let initial_capacity = bytes.len() - 4; + let mut buf = Vec::with_capacity(initial_capacity); + AsyncReadExt::read_to_end(&mut mock, &mut buf) + .await + .unwrap(); + // *4 since it doubles when it doesn't fit and again when reaching EOF + assert_eq!(buf.capacity(), initial_capacity * 4); +}