Skip to content

Commit

Permalink
sync: Added WeakSender to sync::broadcast::channel
Browse files Browse the repository at this point in the history
  • Loading branch information
tglane committed Jan 14, 2025
1 parent a82bdee commit 2f7108e
Show file tree
Hide file tree
Showing 3 changed files with 342 additions and 1 deletion.
161 changes: 160 additions & 1 deletion tokio/src/sync/broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::ptr::NonNull;
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst};
use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release, SeqCst};
use std::task::{ready, Context, Poll, Waker};

/// Sending-half of the [`broadcast`] channel.
Expand Down Expand Up @@ -166,6 +166,40 @@ pub struct Sender<T> {
shared: Arc<Shared<T>>,
}

/// A sender that does not prevent the channel from being closed.
///
/// If all [`Sender`] instances of a channel were dropped and only `WeakSender`
/// instances remain, the channel is closed.
///
/// In order to send messages, the `WeakSender` needs to be upgraded using
/// [`WeakSender::upgrade`], which returns `Option<Sender>`. It returns `None`
/// if all `Sender`s have been dropped, and otherwise it returns a `Sender`.
///
/// [`Sender`]: Sender
/// [`WeakSender::upgrade`]: WeakSender::upgrade
///
/// # Examples
///
/// ```
/// use tokio::sync::broadcast::channel;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, _rx) = channel::<i32>(15);
/// let tx_weak = tx.downgrade();
///
/// // Upgrading will succeed because `tx` still exists.
/// assert!(tx_weak.upgrade().is_some());
///
/// // If we drop `tx`, then it will fail.
/// drop(tx);
/// assert!(tx_weak.clone().upgrade().is_none());
/// }
/// ```
pub struct WeakSender<T> {
shared: Arc<Shared<T>>,
}

/// Receiving-half of the [`broadcast`] channel.
///
/// Must not be used concurrently. Messages may be retrieved using
Expand Down Expand Up @@ -317,6 +351,9 @@ struct Shared<T> {
/// Number of outstanding Sender handles.
num_tx: AtomicUsize,

/// Number of outstanding weak Sender handles.
num_weak_tx: AtomicUsize,

/// Notify when the last subscribed [`Receiver`] drops.
notify_last_rx_drop: Notify,
}
Expand Down Expand Up @@ -475,6 +512,9 @@ pub fn channel<T: Clone>(capacity: usize) -> (Sender<T>, Receiver<T>) {
unsafe impl<T: Send> Send for Sender<T> {}
unsafe impl<T: Send> Sync for Sender<T> {}

unsafe impl<T: Send> Send for WeakSender<T> {}
unsafe impl<T: Send> Sync for WeakSender<T> {}

unsafe impl<T: Send> Send for Receiver<T> {}
unsafe impl<T: Send> Sync for Receiver<T> {}

Expand Down Expand Up @@ -533,6 +573,7 @@ impl<T> Sender<T> {
waiters: LinkedList::new(),
}),
num_tx: AtomicUsize::new(1),
num_weak_tx: AtomicUsize::new(0),
notify_last_rx_drop: Notify::new(),
});

Expand Down Expand Up @@ -656,6 +697,18 @@ impl<T> Sender<T> {
new_receiver(shared)
}

/// Converts the `Sender` to a [`WeakSender`] that does not count
/// towards RAII semantics, i.e. if all `Sender` instances of the
/// channel were dropped and only `WeakSender` instances remain,
/// the channel is closed.
#[must_use = "Downgrade creates a WeakSender without destroying the original non-weak sender."]
pub fn downgrade(&self) -> WeakSender<T> {
self.shared.num_weak_tx.fetch_add(1, SeqCst);
WeakSender {
shared: self.shared.clone(),
}
}

/// Returns the number of queued values.
///
/// A value is queued until it has either been seen by all receivers that were alive at the time
Expand Down Expand Up @@ -858,6 +911,16 @@ impl<T> Sender<T> {

self.shared.notify_rx(tail);
}

/// Returns the number of [`Sender`] handles.
pub fn strong_count(&self) -> usize {
self.shared.num_tx.load(SeqCst)
}

/// Returns the number of [`WeakSender`] handles.
pub fn weak_count(&self) -> usize {
self.shared.num_weak_tx.load(SeqCst)
}
}

/// Create a new `Receiver` which reads starting from the tail.
Expand Down Expand Up @@ -1012,6 +1075,60 @@ impl<T> Drop for Sender<T> {
}
}

impl<T> WeakSender<T> {
/// Tries to convert a `WeakSender` into a [`Sender`]. This will return `Some`
/// if there are other `Sender` instances alive and the channel wasn't
/// previously dropped, otherwise `None` is returned.
pub fn upgrade(&self) -> Option<Sender<T>> {
let mut tx_count = self.shared.num_tx.load(Acquire);

loop {
if tx_count == 0 {
// channel is closed so this WeakSender can not be upgraded
return None;
}

match self
.shared
.num_tx
.compare_exchange_weak(tx_count, tx_count + 1, AcqRel, Acquire)
{
Ok(_) => {
return Some(Sender {
shared: self.shared.clone(),
})
}
Err(prev_count) => tx_count = prev_count,
}
}
}

/// Returns the number of [`Sender`] handles.
pub fn strong_count(&self) -> usize {
self.shared.num_tx.load(SeqCst)
}

/// Returns the number of [`WeakSender`] handles.
pub fn weak_count(&self) -> usize {
self.shared.num_weak_tx.load(SeqCst)
}
}

impl<T> Clone for WeakSender<T> {
fn clone(&self) -> WeakSender<T> {
let shared = self.shared.clone();
shared.num_weak_tx.fetch_add(1, SeqCst);

Self { shared }
}
}

impl<T> Drop for WeakSender<T> {
fn drop(&mut self) {
self.shared.num_weak_tx.fetch_sub(1, SeqCst);
}
}

impl<T> Receiver<T> {
/// Returns the number of messages that were sent into the channel and that
/// this [`Receiver`] has yet to receive.
Expand Down Expand Up @@ -1213,6 +1330,42 @@ impl<T> Receiver<T> {

Ok(RecvGuard { slot })
}

/// Returns the number of [`Sender`] handles.
pub fn sender_strong_count(&self) -> usize {
self.shared.num_tx.load(SeqCst)
}

/// Returns the number of [`WeakSender`] handles.
pub fn sender_weak_count(&self) -> usize {
self.shared.num_weak_tx.load(SeqCst)
}

/// Checks if a channel is closed.
///
/// This method returns `true` if the channel has been closed. The channel is closed
/// when all [`Sender`] have been dropped.
///
/// [`Sender`]: crate::sync::broadcast::Sender
///
/// # Examples
/// ```
/// use tokio::sync::broadcast;
///
/// #[tokio::main]
/// async fn main() {
/// let (tx, rx) = broadcast::channel::<()>(10);
/// assert!(!rx.is_closed());
///
/// drop(tx);
///
/// assert!(rx.is_closed());
/// }
/// ```
pub fn is_closed(&self) -> bool {
// Channel is closed when there are no strong senders left active
self.shared.num_tx.load(Acquire) == 0
}
}

impl<T: Clone> Receiver<T> {
Expand Down Expand Up @@ -1534,6 +1687,12 @@ impl<T> fmt::Debug for Sender<T> {
}
}

impl<T> fmt::Debug for WeakSender<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "broadcast::WeakSender")
}
}

impl<T> fmt::Debug for Receiver<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "broadcast::Receiver")
Expand Down
1 change: 1 addition & 0 deletions tokio/tests/sync_broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ macro_rules! assert_closed {
trait AssertSend: Send + Sync {}
impl AssertSend for broadcast::Sender<i32> {}
impl AssertSend for broadcast::Receiver<i32> {}
impl AssertSend for broadcast::WeakSender<i32> {}

#[test]
fn send_try_recv_bounded() {
Expand Down
Loading

0 comments on commit 2f7108e

Please sign in to comment.