Skip to content

Commit

Permalink
take self as receiver types on downgrade and upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
b-naber committed Jul 26, 2022
1 parent ef6fabb commit fa8bdf9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 25 deletions.
9 changes: 5 additions & 4 deletions tokio/src/sync/mpsc/bounded.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::loom::sync::Arc;
use crate::sync::batch_semaphore::{self as semaphore, TryAcquireError};
use crate::sync::mpsc::chan;
use crate::sync::mpsc::error::{SendError, TryRecvError, TrySendError};
Expand Down Expand Up @@ -54,7 +55,7 @@ pub struct Sender<T> {
///
/// ```
pub struct WeakSender<T> {
chan: chan::Tx<T, Semaphore>,
chan: Arc<chan::Chan<T, Semaphore>>,
}

/// Permits to send one value into the channel.
Expand Down Expand Up @@ -1031,7 +1032,7 @@ impl<T> Sender<T> {
/// towards RAII semantics, i.e. if all `Sender` instances of the
/// channel were dropped and only `WeakSender` instances remain,
/// the channel is closed.
pub fn downgrade(self) -> WeakSender<T> {
pub fn downgrade(&self) -> WeakSender<T> {
// Note: If this is the last `Sender` instance we want to close the
// channel when downgrading, so it's important to move into `self` here.

Expand Down Expand Up @@ -1069,8 +1070,8 @@ impl<T> WeakSender<T> {
/// Tries to convert a WeakSender into a [`Sender`]. This will return `Some`
/// if there are other `Sender` instances alive and the channel wasn't
/// previously dropped, otherwise `None` is returned.
pub fn upgrade(self) -> Option<Sender<T>> {
self.chan.upgrade().map(Sender::new)
pub fn upgrade(&self) -> Option<Sender<T>> {
chan::Tx::upgrade(self.chan.clone()).map(Sender::new)
}
}

Expand Down
23 changes: 7 additions & 16 deletions tokio/src/sync/mpsc/chan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ pub(crate) trait Semaphore {
fn is_closed(&self) -> bool;
}

struct Chan<T, S> {
pub(crate) struct Chan<T, S> {
/// Notifies all tasks listening for the receiver being dropped.
notify_rx_closed: Notify,

Expand Down Expand Up @@ -130,34 +130,25 @@ impl<T, S> Tx<T, S> {
Tx { inner: chan }
}

pub(super) fn downgrade(self) -> Self {
if self.inner.tx_count.fetch_sub(1, AcqRel) == 1 {
// Close the list, which sends a `Close` message
self.inner.tx.close();

// Notify the receiver
self.wake_rx();
}

self
pub(super) fn downgrade(&self) -> Arc<Chan<T, S>> {
self.inner.clone()
}

// Returns the upgraded channel or None if the upgrade failed.
pub(super) fn upgrade(self) -> Option<Self> {
let mut tx_count = self.inner.tx_count.load(Acquire);
pub(super) fn upgrade(chan: Arc<Chan<T, S>>) -> Option<Self> {
let mut tx_count = chan.tx_count.load(Acquire);

loop {
if tx_count == 0 {
// channel is closed
return None;
}

match self
.inner
match chan
.tx_count
.compare_exchange_weak(tx_count, tx_count + 1, AcqRel, Acquire)
{
Ok(_) => return Some(self),
Ok(_) => return Some(Tx { inner: chan }),
Err(prev_count) => tx_count = prev_count,
}
}
Expand Down
22 changes: 17 additions & 5 deletions tokio/tests/sync_mpsc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,10 @@ fn recv_timeout_panic() {
#[tokio::test]
async fn weak_sender() {
let (tx, mut rx) = channel(11);
let tx_weak = tx.clone().downgrade();

let tx_weak = tokio::spawn(async move {
let tx_weak = tx.clone().downgrade();

for i in 0..10 {
if tx.send(i).await.is_err() {
return None;
Expand Down Expand Up @@ -703,10 +704,9 @@ async fn weak_sender() {
}
}

if let Some(tx_weak) = tx_weak {
let upgraded = tx_weak.upgrade();
assert!(upgraded.is_none());
}
let tx_weak = tx_weak.unwrap();
let upgraded = tx_weak.upgrade();
assert!(upgraded.is_none());
}

#[tokio::test]
Expand Down Expand Up @@ -887,6 +887,7 @@ async fn downgrade_upgrade_sender_success() {
async fn downgrade_upgrade_sender_failure() {
let (tx, _rx) = mpsc::channel::<i32>(1);
let weak_tx = tx.downgrade();
drop(tx);
assert!(weak_tx.upgrade().is_none());
}

Expand Down Expand Up @@ -921,3 +922,14 @@ async fn downgrade_upgrade_get_permit_no_senders() {
let weak_tx = tx2.downgrade();
assert!(weak_tx.upgrade().is_some());
}

// Tests that `Clone` of `WeakSender` doesn't decrement `tx_count`.
#[tokio::test]
async fn test_weak_sender_clone() {
let (tx, _rx) = mpsc::channel::<i32>(1);
let tx_weak = tx.downgrade();
let tx_weak2 = tx.downgrade();
drop(tx);

assert!(tx_weak.upgrade().is_none() && tx_weak2.upgrade().is_none());
}

0 comments on commit fa8bdf9

Please sign in to comment.