Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

io: make duplex stream cooperative (#4470) #4478

Merged
merged 6 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions tokio/src/io/util/mem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ impl AsyncRead for Pipe {
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
ready!(poll_proceed_and_make_progress(cx));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't make progress if it returns Poll::Pending.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, looks like instead of calling make_progress always, we need to poll proceed at the start, and only call make_progress in the two branches below that would return Poll::Ready.

if self.buffer.has_remaining() {
let max = self.buffer.remaining().min(buf.remaining());
buf.put_slice(&self.buffer[..max]);
Expand Down Expand Up @@ -212,6 +213,7 @@ impl AsyncWrite for Pipe {
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
ready!(poll_proceed_and_make_progress(cx));
if self.is_closed {
return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
}
Expand Down Expand Up @@ -241,3 +243,17 @@ impl AsyncWrite for Pipe {
Poll::Ready(Ok(()))
}
}

cfg_coop! {
fn poll_proceed_and_make_progress(cx: &mut task::Context<'_>) -> Poll<()> {
let coop = ready!(crate::coop::poll_proceed(cx));
coop.made_progress();
Poll::Ready(())
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We aren't doing this conditional thing in src/io/util/empty.rs. Why is it needed here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean. It looks like it's in there:

cfg_coop! {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! Thanks for the review!

Maybe I read it wrong but I thought this is added in https://github.com/tokio-rs/tokio/blob/master/tokio/src/io/util/empty.rs#L78-L90 so I thought we only want this under coop feature. I can change it real quick if this is not the plan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, seems like I'm just blind.


cfg_not_coop! {
fn poll_proceed_and_make_progress(_: &mut Context<'_>) -> Poll<()> {
Poll::Ready(())
}
}
19 changes: 19 additions & 0 deletions tokio/tests/io_mem_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,22 @@ async fn max_write_size() {
// drop b only after task t1 finishes writing
drop(b);
}

#[tokio::test]
async fn duplex_is_cooperative() {
let (mut tx, mut rx) = tokio::io::duplex(1024 * 8);

tokio::select! {
biased;

_ = async {
loop {
let buf = [3u8; 4096];
let _ = tx.write_all(&buf).await;
let mut buf = [0u8; 4096];
let _ = rx.read(&mut buf).await;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let buf = [3u8; 4096];
let _ = tx.write_all(&buf).await;
let mut buf = [0u8; 4096];
let _ = rx.read(&mut buf).await;
let buf = [3u8; 4096];
tx.write_all(&buf).await.unwrap();
let mut buf = [0u8; 4096];
rx.read(&mut buf).await.unwrap();

}
} => {},
_ = tokio::task::yield_now() => {}
}
}