diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 65250e68ecb..35ec33d6093 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -1026,11 +1026,19 @@ impl Clone for Receiver { fn clone(&self) -> Self { let next = self.next; let shared = self.shared.clone(); + let mut tail = shared.tail.lock(); - // register interest in the slot that next points to - // let this be lock-free since we're not yet operating on the tail. - let tail_pos = shared.tail.lock().pos; - for n in next..tail_pos { + // register the new receiver with `Tail` + if tail.rx_cnt == MAX_RECEIVERS { + panic!("max receivers"); + } + tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow"); + + // Register interest in the slots from next to tail.pos. + + // We need to hold the lock here to prevent a race with send2 where send2 overwrites + // next or moves past tail before we register interest in the slot. + for n in next..tail.pos { let idx = (n & shared.mask as u64) as usize; let slot = shared.buffer[idx].read().unwrap(); @@ -1040,21 +1048,6 @@ impl Clone for Receiver { // called concurrently. slot.rem.fetch_add(1, SeqCst); } - // tail pos may have changed, we need a locked section here to prevent a race with `Sender::send2` - let mut n = tail_pos.wrapping_sub(1); - let mut tail = shared.tail.lock(); - while n <= tail.pos { - let idx = (n & shared.mask as u64) as usize; - let slot = self.shared.buffer[idx].read().unwrap(); - slot.rem.fetch_add(1, SeqCst); - n = n.wrapping_add(1); - } - - // register the new receiver with `Tail` - if tail.rx_cnt == MAX_RECEIVERS { - panic!("max receivers"); - } - tail.rx_cnt = tail.rx_cnt.checked_add(1).expect("overflow"); drop(tail); diff --git a/tokio/src/sync/tests/loom_broadcast.rs b/tokio/src/sync/tests/loom_broadcast.rs index 088b151c6ce..0aa9db65fc0 100644 --- a/tokio/src/sync/tests/loom_broadcast.rs +++ b/tokio/src/sync/tests/loom_broadcast.rs @@ -92,6 +92,52 @@ fn broadcast_two() { }); } +// An `Arc` is used as the value in order to detect memory leaks. +#[test] +fn broadcast_two_cloned() { + loom::model(|| { + let (tx, mut rx1) = broadcast::channel::>(16); + let mut rx2 = rx1.clone(); + + let th1 = thread::spawn(move || { + block_on(async { + let v = assert_ok!(rx1.recv().await); + assert_eq!(*v, "hello"); + + let v = assert_ok!(rx1.recv().await); + assert_eq!(*v, "world"); + + match assert_err!(rx1.recv().await) { + Closed => {} + _ => panic!(), + } + }); + }); + + let th2 = thread::spawn(move || { + block_on(async { + let v = assert_ok!(rx2.recv().await); + assert_eq!(*v, "hello"); + + let v = assert_ok!(rx2.recv().await); + assert_eq!(*v, "world"); + + match assert_err!(rx2.recv().await) { + Closed => {} + _ => panic!(), + } + }); + }); + + assert_ok!(tx.send(Arc::new("hello"))); + assert_ok!(tx.send(Arc::new("world"))); + drop(tx); + + assert_ok!(th1.join()); + assert_ok!(th2.join()); + }); +} + #[test] fn broadcast_wrap() { loom::model(|| { @@ -274,7 +320,7 @@ fn drop_multiple_cloned_rx_with_overflow() { #[test] fn send_and_rx_clone() { - // test the interraction of Sender::send and Rx::clone + // test the interaction of Sender::send and Rx::clone loom::model(move || { let (tx, mut rx) = broadcast::channel(2);