Skip to content

Commit

Permalink
fix: avoid race condition between pending frames and closing stream (#…
Browse files Browse the repository at this point in the history
…156)

Currently, we have a `garbage_collect` function that checks whether any of our
streams have been dropped. This can cause a race condition where the channel
between a `Stream` and the `Connection` still has pending frames for a stream
but dropping a stream causes us to already send a `FIN` flag for the stream.

We fix this by maintaining a single channel for each stream. When a stream gets
dropped, the `Receiver` becomes disconnected. We use this information to queue
the correct frame (`FIN` vs `RST`) into the buffer. At this point, all previous
frames have already been processed and the race condition is thus not present.

Additionally, this also allows us to implement `Stream::poll_flush` by
forwarding to the underlying `Sender`. Note that at present day, this only
checks whether there is _space_ in the channel, not whether the items have been
emitted by the `Receiver`.

We have a PR upstream that might fix this:
rust-lang/futures-rs#2746

Fixes: #117.
  • Loading branch information
thomaseizinger committed May 24, 2023
1 parent 88ed4df commit 52c725b
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 167 deletions.
1 change: 0 additions & 1 deletion test-harness/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ log = "0.4.17"
[dev-dependencies]
env_logger = "0.10"
constrained-connection = "0.1"

2 changes: 2 additions & 0 deletions yamux/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ nohash-hasher = "0.2"
parking_lot = "0.12"
rand = "0.8.3"
static_assertions = "1"
pin-project = "1.1.0"

[dev-dependencies]
anyhow = "1"
Expand All @@ -26,6 +27,7 @@ quickcheck = "1.0"
tokio = { version = "1.0", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.7", features = ["compat"] }
constrained-connection = "0.1"
futures_ringbuf = "0.3.1"

[[bench]]
name = "concurrent"
Expand Down
225 changes: 108 additions & 117 deletions yamux/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,18 @@ use crate::{
error::ConnectionError,
frame::header::{self, Data, GoAway, Header, Ping, StreamId, Tag, WindowUpdate, CONNECTION_ID},
frame::{self, Frame},
Config, WindowUpdateMode, DEFAULT_CREDIT, MAX_COMMAND_BACKLOG,
Config, WindowUpdateMode, DEFAULT_CREDIT,
};
use cleanup::Cleanup;
use closing::Closing;
use futures::stream::SelectAll;
use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse};
use nohash_hasher::IntMap;
use std::collections::VecDeque;
use std::task::Context;
use std::task::{Context, Waker};
use std::{fmt, sync::Arc, task::Poll};

use crate::tagged_stream::TaggedStream;
pub use stream::{Packet, State, Stream};

/// How the connection is used.
Expand Down Expand Up @@ -347,10 +349,11 @@ struct Active<T> {
config: Arc<Config>,
socket: Fuse<frame::Io<T>>,
next_id: u32,

streams: IntMap<StreamId, Stream>,
stream_sender: mpsc::Sender<StreamCommand>,
stream_receiver: mpsc::Receiver<StreamCommand>,
dropped_streams: Vec<StreamId>,
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
no_streams_waker: Option<Waker>,

pending_frames: VecDeque<Frame<()>>,
}

Expand All @@ -360,7 +363,7 @@ pub(crate) enum StreamCommand {
/// A new frame should be sent to the remote.
SendFrame(Frame<Either<Data, WindowUpdate>>),
/// Close a stream.
CloseStream { id: StreamId, ack: bool },
CloseStream { ack: bool },
}

/// Possible actions as a result of incoming frame handling.
Expand Down Expand Up @@ -408,28 +411,26 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
fn new(socket: T, cfg: Config, mode: Mode) -> Self {
let id = Id::random();
log::debug!("new connection: {} ({:?})", id, mode);
let (stream_sender, stream_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse();
Active {
id,
mode,
config: Arc::new(cfg),
socket,
streams: IntMap::default(),
stream_sender,
stream_receiver,
stream_receivers: SelectAll::default(),
no_streams_waker: None,
next_id: match mode {
Mode::Client => 1,
Mode::Server => 2,
},
dropped_streams: Vec::new(),
pending_frames: VecDeque::default(),
}
}

/// Gracefully close the connection to the remote.
fn close(self) -> Closing<T> {
Closing::new(self.stream_receiver, self.pending_frames, self.socket)
Closing::new(self.stream_receivers, self.pending_frames, self.socket)
}

/// Cleanup all our resources.
Expand All @@ -438,13 +439,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
fn cleanup(mut self, error: ConnectionError) -> Cleanup {
self.drop_all_streams();

Cleanup::new(self.stream_receiver, error)
Cleanup::new(self.stream_receivers, error)
}

fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
loop {
self.garbage_collect();

if self.socket.poll_ready_unpin(cx).is_ready() {
if let Some(frame) = self.pending_frames.pop_front() {
self.socket.start_send_unpin(frame)?;
Expand All @@ -457,17 +456,21 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Poll::Pending => {}
}

match self.stream_receiver.poll_next_unpin(cx) {
Poll::Ready(Some(StreamCommand::SendFrame(frame))) => {
self.on_send_frame(frame);
match self.stream_receivers.poll_next_unpin(cx) {
Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
self.on_send_frame(frame.into());
continue;
}
Poll::Ready(Some(StreamCommand::CloseStream { id, ack })) => {
Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
self.on_close_stream(id, ack);
continue;
}
Poll::Ready(Some((id, None))) => {
self.on_drop_stream(id);
continue;
}
Poll::Ready(None) => {
debug_assert!(false, "Only closed during shutdown")
self.no_streams_waker = Some(cx.waker().clone());
}
Poll::Pending => {}
}
Expand Down Expand Up @@ -508,16 +511,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
self.pending_frames.push_back(frame.into());
}

let stream = {
let config = self.config.clone();
let sender = self.stream_sender.clone();
let window = self.config.receive_window;
let mut stream = Stream::new(id, self.id, config, window, DEFAULT_CREDIT, sender);
if extra_credit == 0 {
stream.set_flag(stream::Flag::Syn)
}
stream
};
let mut stream = self.make_new_stream(id, self.config.receive_window, DEFAULT_CREDIT);

if extra_credit == 0 {
stream.set_flag(stream::Flag::Syn)
}

log::debug!("{}: new outbound {} of {}", self.id, stream, self);
self.streams.insert(id, stream.clone());
Expand All @@ -541,6 +539,69 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
.push_back(Frame::close_stream(id, ack).into());
}

fn on_drop_stream(&mut self, id: StreamId) {
let stream = self.streams.remove(&id).expect("stream not found");

log::trace!("{}: removing dropped {}", self.id, stream);
let stream_id = stream.id();
let frame = {
let mut shared = stream.shared();
let frame = match shared.update_state(self.id, stream_id, State::Closed) {
// The stream was dropped without calling `poll_close`.
// We reset the stream to inform the remote of the closure.
State::Open => {
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
}
// The stream was dropped without calling `poll_close`.
// We have already received a FIN from remote and send one
// back which closes the stream for good.
State::RecvClosed => {
let mut header = Header::data(stream_id, 0);
header.fin();
Some(Frame::new(header))
}
// The stream was properly closed. We already sent our FIN frame.
// The remote may be out of credit though and blocked on
// writing more data. We may need to reset the stream.
State::SendClosed => {
if self.config.window_update_mode == WindowUpdateMode::OnRead
&& shared.window == 0
{
// The remote may be waiting for a window update
// which we will never send, so reset the stream now.
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
} else {
// The remote has either still credit or will be given more
// (due to an enqueued window update or because the update
// mode is `OnReceive`) or we already have inbound frames in
// the socket buffer which will be processed later. In any
// case we will reply with an RST in `Connection::on_data`
// because the stream will no longer be known.
None
}
}
// The stream was properly closed. We already have sent our FIN frame. The
// remote end has already done so in the past.
State::Closed => None,
};
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
frame
};
if let Some(f) = frame {
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
self.pending_frames.push_back(f.into());
}
}

/// Process the result of reading from the socket.
///
/// Unless `frame` is `Ok(Some(_))` we will assume the connection got closed
Expand Down Expand Up @@ -628,12 +689,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
log::error!("{}: maximum number of streams reached", self.id);
return Action::Terminate(Frame::internal_error());
}
let mut stream = {
let config = self.config.clone();
let credit = DEFAULT_CREDIT;
let sender = self.stream_sender.clone();
Stream::new(stream_id, self.id, config, credit, credit, sender)
};
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, DEFAULT_CREDIT);
let mut window_update = None;
{
let mut shared = stream.shared();
Expand Down Expand Up @@ -748,15 +804,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
log::error!("{}: maximum number of streams reached", self.id);
return Action::Terminate(Frame::protocol_error());
}
let stream = {
let credit = frame.header().credit() + DEFAULT_CREDIT;
let config = self.config.clone();
let sender = self.stream_sender.clone();
let mut stream =
Stream::new(stream_id, self.id, config, DEFAULT_CREDIT, credit, sender);
stream.set_flag(stream::Flag::Ack);
stream
};

let credit = frame.header().credit() + DEFAULT_CREDIT;
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, credit);
stream.set_flag(stream::Flag::Ack);

if is_finish {
stream
.shared()
Expand Down Expand Up @@ -821,6 +873,18 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Action::None
}

fn make_new_stream(&mut self, id: StreamId, window: u32, credit: u32) -> Stream {
let config = self.config.clone();

let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number.
self.stream_receivers.push(TaggedStream::new(id, receiver));
if let Some(waker) = self.no_streams_waker.take() {
waker.wake();
}

Stream::new(id, self.id, config, window, credit, sender)
}

fn next_stream_id(&mut self) -> Result<StreamId> {
let proposed = StreamId::new(self.next_id);
self.next_id = self
Expand All @@ -844,79 +908,6 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Mode::Server => id.is_client(),
}
}

/// Remove stale streams and create necessary messages to be sent to the remote.
fn garbage_collect(&mut self) {
let conn_id = self.id;
let win_update_mode = self.config.window_update_mode;
for stream in self.streams.values_mut() {
if stream.strong_count() > 1 {
continue;
}
log::trace!("{}: removing dropped {}", conn_id, stream);
let stream_id = stream.id();
let frame = {
let mut shared = stream.shared();
let frame = match shared.update_state(conn_id, stream_id, State::Closed) {
// The stream was dropped without calling `poll_close`.
// We reset the stream to inform the remote of the closure.
State::Open => {
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
}
// The stream was dropped without calling `poll_close`.
// We have already received a FIN from remote and send one
// back which closes the stream for good.
State::RecvClosed => {
let mut header = Header::data(stream_id, 0);
header.fin();
Some(Frame::new(header))
}
// The stream was properly closed. We either already have
// or will at some later point send our FIN frame.
// The remote may be out of credit though and blocked on
// writing more data. We may need to reset the stream.
State::SendClosed => {
if win_update_mode == WindowUpdateMode::OnRead && shared.window == 0 {
// The remote may be waiting for a window update
// which we will never send, so reset the stream now.
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
} else {
// The remote has either still credit or will be given more
// (due to an enqueued window update or because the update
// mode is `OnReceive`) or we already have inbound frames in
// the socket buffer which will be processed later. In any
// case we will reply with an RST in `Connection::on_data`
// because the stream will no longer be known.
None
}
}
// The stream was properly closed. We either already have
// or will at some later point send our FIN frame. The
// remote end has already done so in the past.
State::Closed => None,
};
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
frame
};
if let Some(f) = frame {
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
self.pending_frames.push_back(f.into());
}
self.dropped_streams.push(stream_id)
}
for id in self.dropped_streams.drain(..) {
self.streams.remove(&id);
}
}
}

impl<T> Active<T> {
Expand Down

0 comments on commit 52c725b

Please sign in to comment.