diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index 2347f6f0..c2904aca 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -51,6 +51,9 @@ pub(super) struct Prioritize { /// What `DATA` frame is currently being sent in the codec. in_flight_data_frame: InFlightData, + + /// The maximum amount of bytes a stream should buffer. + max_buffer_size: usize, } #[derive(Debug, Eq, PartialEq)] @@ -93,9 +96,14 @@ impl Prioritize { flow, last_opened_id: StreamId::ZERO, in_flight_data_frame: InFlightData::Nothing, + max_buffer_size: config.local_max_buffer_size, } } + pub(crate) fn max_buffer_size(&self) -> usize { + self.max_buffer_size + } + /// Queue a frame to be sent to the remote pub fn queue_frame( &mut self, @@ -424,7 +432,7 @@ impl Prioritize { tracing::trace!(capacity = assign, "assigning"); // Assign the capacity to the stream - stream.assign_capacity(assign); + stream.assign_capacity(assign, self.max_buffer_size); // Claim the capacity from the connection self.flow.claim_capacity(assign); @@ -744,7 +752,7 @@ impl Prioritize { // If the capacity was limited because of the // max_send_buffer_size, then consider waking // the send task again... - stream.notify_if_can_buffer_more(); + stream.notify_if_can_buffer_more(self.max_buffer_size); // Assign the capacity back to the connection that // was just consumed from the stream in the previous diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index b7230030..2c5a38c8 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -28,9 +28,6 @@ pub(super) struct Send { /// > the identified last stream. max_stream_id: StreamId, - /// The maximum amount of bytes a stream should buffer. - max_buffer_size: usize, - /// Initial window size of locally initiated streams init_window_sz: WindowSize, @@ -55,7 +52,6 @@ impl Send { pub fn new(config: &Config) -> Self { Send { init_window_sz: config.remote_init_window_sz, - max_buffer_size: config.local_max_buffer_size, max_stream_id: StreamId::MAX, next_stream_id: Ok(config.local_next_stream_id), prioritize: Prioritize::new(config), @@ -340,7 +336,9 @@ impl Send { let available = stream.send_flow.available().as_size() as usize; let buffered = stream.buffered_send_data; - available.min(self.max_buffer_size).saturating_sub(buffered) as WindowSize + available + .min(self.prioritize.max_buffer_size()) + .saturating_sub(buffered) as WindowSize } pub fn poll_reset( diff --git a/src/proto/streams/stream.rs b/src/proto/streams/stream.rs index f67cc364..36d515ba 100644 --- a/src/proto/streams/stream.rs +++ b/src/proto/streams/stream.rs @@ -260,30 +260,29 @@ impl Stream { self.ref_count == 0 && !self.state.is_closed() } - pub fn assign_capacity(&mut self, capacity: WindowSize) { + pub fn assign_capacity(&mut self, capacity: WindowSize, max_buffer_size: usize) { debug_assert!(capacity > 0); - self.send_capacity_inc = true; self.send_flow.assign_capacity(capacity); tracing::trace!( - " assigned capacity to stream; available={}; buffered={}; id={:?}", + " assigned capacity to stream; available={}; buffered={}; id={:?}; max_buffer_size={}", self.send_flow.available(), self.buffered_send_data, - self.id + self.id, + max_buffer_size ); - // Only notify if the capacity exceeds the amount of buffered data - if self.send_flow.available() > self.buffered_send_data { - tracing::trace!(" notifying task"); - self.notify_send(); - } + self.notify_if_can_buffer_more(max_buffer_size); } /// If the capacity was limited because of the max_send_buffer_size, /// then consider waking the send task again... - pub fn notify_if_can_buffer_more(&mut self) { + pub fn notify_if_can_buffer_more(&mut self, max_buffer_size: usize) { + let available = self.send_flow.available().as_size() as usize; + let buffered = self.buffered_send_data; + // Only notify if the capacity exceeds the amount of buffered data - if self.send_flow.available() > self.buffered_send_data { + if available.min(max_buffer_size) > buffered { self.send_capacity_inc = true; tracing::trace!(" notifying task"); self.notify_send(); diff --git a/tests/h2-support/src/util.rs b/tests/h2-support/src/util.rs index b3322c4d..1150d592 100644 --- a/tests/h2-support/src/util.rs +++ b/tests/h2-support/src/util.rs @@ -32,6 +32,7 @@ pub async fn yield_once() { .await; } +/// Should only be called after a non-0 capacity was requested for the stream. pub fn wait_for_capacity(stream: h2::SendStream, target: usize) -> WaitForCapacity { WaitForCapacity { stream: Some(stream), @@ -59,6 +60,11 @@ impl Future for WaitForCapacity { let act = self.stream().capacity(); + // If a non-0 capacity was requested for the stream before calling + // wait_for_capacity, then poll_capacity should return Pending + // until there is a non-0 capacity. + assert_ne!(act, 0); + if act >= self.target { return Poll::Ready(self.stream.take().unwrap().into()); } diff --git a/tests/h2-tests/tests/flow_control.rs b/tests/h2-tests/tests/flow_control.rs index 7adb3d73..92e7a532 100644 --- a/tests/h2-tests/tests/flow_control.rs +++ b/tests/h2-tests/tests/flow_control.rs @@ -1600,7 +1600,62 @@ async fn poll_capacity_after_send_data_and_reserve() { // Initial window size was 5 so current capacity is 0 even if we just reserved. assert_eq!(stream.capacity(), 0); - // The first call to `poll_capacity` in `wait_for_capacity` will return 0. + // This will panic if there is a bug causing h2 to return Ok(0) from poll_capacity. + let mut stream = h2.drive(util::wait_for_capacity(stream, 5)).await; + + stream.send_data("".into(), true).unwrap(); + + // Wait for the connection to close + h2.await.unwrap(); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn poll_capacity_after_send_data_and_reserve_with_max_send_buffer_size() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().initial_window_size(10)) + .await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.recv_frame(frames::data(1, &b"abcde"[..])).await; + srv.send_frame(frames::window_update(1, 10)).await; + srv.recv_frame(frames::data(1, &b""[..]).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::Builder::new() + .max_send_buffer_size(5) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (response, mut stream) = client.send_request(request, false).unwrap(); + + let response = h2.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + stream.send_data("abcde".into(), false).unwrap(); + + stream.reserve_capacity(5); + + // Initial window size was 10 but with a max send buffer size of 10 in the client, + // so current capacity is 0 even if we just reserved. + assert_eq!(stream.capacity(), 0); + + // This will panic if there is a bug causing h2 to return Ok(0) from poll_capacity. let mut stream = h2.drive(util::wait_for_capacity(stream, 5)).await; stream.send_data("".into(), true).unwrap();