From 881b510a072f5acd773a10d3be0debf74113404e Mon Sep 17 00:00:00 2001 From: Aaron Schweiger Date: Tue, 17 Oct 2023 05:01:41 -0400 Subject: [PATCH] sync: add `mpsc::Receiver::recv_many` (#6010) --- benches/sync_mpsc.rs | 130 +++++++++++++++++++++++++ tokio/src/sync/mpsc/bounded.rs | 82 +++++++++++++++- tokio/src/sync/mpsc/chan.rs | 100 +++++++++++++++++++ tokio/src/sync/mpsc/unbounded.rs | 77 ++++++++++++++- tokio/tests/sync_mpsc.rs | 161 +++++++++++++++++++++++++++++++ 5 files changed, 545 insertions(+), 5 deletions(-) diff --git a/benches/sync_mpsc.rs b/benches/sync_mpsc.rs index d6545e8047f..117b3babdde 100644 --- a/benches/sync_mpsc.rs +++ b/benches/sync_mpsc.rs @@ -73,6 +73,33 @@ fn contention_bounded(g: &mut BenchmarkGroup) { }); } +fn contention_bounded_recv_many(g: &mut BenchmarkGroup) { + let rt = rt(); + + g.bench_function("bounded_recv_many", |b| { + b.iter(|| { + rt.block_on(async move { + let (tx, mut rx) = mpsc::channel::(1_000_000); + + for _ in 0..5 { + let tx = tx.clone(); + tokio::spawn(async move { + for i in 0..1000 { + tx.send(i).await.unwrap(); + } + }); + } + + let mut buffer = Vec::::with_capacity(5_000); + let mut total = 0; + while total < 1_000 * 5 { + total += rx.recv_many(&mut buffer, 5_000).await; + } + }) + }) + }); +} + fn contention_bounded_full(g: &mut BenchmarkGroup) { let rt = rt(); @@ -98,6 +125,33 @@ fn contention_bounded_full(g: &mut BenchmarkGroup) { }); } +fn contention_bounded_full_recv_many(g: &mut BenchmarkGroup) { + let rt = rt(); + + g.bench_function("bounded_full_recv_many", |b| { + b.iter(|| { + rt.block_on(async move { + let (tx, mut rx) = mpsc::channel::(100); + + for _ in 0..5 { + let tx = tx.clone(); + tokio::spawn(async move { + for i in 0..1000 { + tx.send(i).await.unwrap(); + } + }); + } + + let mut buffer = Vec::::with_capacity(5_000); + let mut total = 0; + while total < 1_000 * 5 { + total += rx.recv_many(&mut buffer, 5_000).await; + } + }) + }) + }); +} + fn contention_unbounded(g: &mut BenchmarkGroup) { let rt = rt(); @@ -123,6 +177,33 @@ fn contention_unbounded(g: &mut BenchmarkGroup) { }); } +fn contention_unbounded_recv_many(g: &mut BenchmarkGroup) { + let rt = rt(); + + g.bench_function("unbounded_recv_many", |b| { + b.iter(|| { + rt.block_on(async move { + let (tx, mut rx) = mpsc::unbounded_channel::(); + + for _ in 0..5 { + let tx = tx.clone(); + tokio::spawn(async move { + for i in 0..1000 { + tx.send(i).unwrap(); + } + }); + } + + let mut buffer = Vec::::with_capacity(5_000); + let mut total = 0; + while total < 1_000 * 5 { + total += rx.recv_many(&mut buffer, 5_000).await; + } + }) + }) + }); +} + fn uncontented_bounded(g: &mut BenchmarkGroup) { let rt = rt(); @@ -143,6 +224,28 @@ fn uncontented_bounded(g: &mut BenchmarkGroup) { }); } +fn uncontented_bounded_recv_many(g: &mut BenchmarkGroup) { + let rt = rt(); + + g.bench_function("bounded_recv_many", |b| { + b.iter(|| { + rt.block_on(async move { + let (tx, mut rx) = mpsc::channel::(1_000_000); + + for i in 0..5000 { + tx.send(i).await.unwrap(); + } + + let mut buffer = Vec::::with_capacity(5_000); + let mut total = 0; + while total < 1_000 * 5 { + total += rx.recv_many(&mut buffer, 5_000).await; + } + }) + }) + }); +} + fn uncontented_unbounded(g: &mut BenchmarkGroup) { let rt = rt(); @@ -163,6 +266,28 @@ fn uncontented_unbounded(g: &mut BenchmarkGroup) { }); } +fn uncontented_unbounded_recv_many(g: &mut BenchmarkGroup) { + let rt = rt(); + + g.bench_function("unbounded_recv_many", |b| { + b.iter(|| { + rt.block_on(async move { + let (tx, mut rx) = mpsc::unbounded_channel::(); + + for i in 0..5000 { + tx.send(i).unwrap(); + } + + let mut buffer = Vec::::with_capacity(5_000); + let mut total = 0; + while total < 1_000 * 5 { + total += rx.recv_many(&mut buffer, 5_000).await; + } + }) + }) + }); +} + fn bench_create_medium(c: &mut Criterion) { let mut group = c.benchmark_group("create_medium"); create_medium::<1>(&mut group); @@ -181,15 +306,20 @@ fn bench_send(c: &mut Criterion) { fn bench_contention(c: &mut Criterion) { let mut group = c.benchmark_group("contention"); contention_bounded(&mut group); + contention_bounded_recv_many(&mut group); contention_bounded_full(&mut group); + contention_bounded_full_recv_many(&mut group); contention_unbounded(&mut group); + contention_unbounded_recv_many(&mut group); group.finish(); } fn bench_uncontented(c: &mut Criterion) { let mut group = c.benchmark_group("uncontented"); uncontented_bounded(&mut group); + uncontented_bounded_recv_many(&mut group); uncontented_unbounded(&mut group); + uncontented_unbounded_recv_many(&mut group); group.finish(); } diff --git a/tokio/src/sync/mpsc/bounded.rs b/tokio/src/sync/mpsc/bounded.rs index a9cd73ee3fc..5024e839107 100644 --- a/tokio/src/sync/mpsc/bounded.rs +++ b/tokio/src/sync/mpsc/bounded.rs @@ -44,10 +44,10 @@ pub struct Sender { /// async fn main() { /// let (tx, _rx) = channel::(15); /// let tx_weak = tx.downgrade(); -/// +/// /// // Upgrading will succeed because `tx` still exists. /// assert!(tx_weak.upgrade().is_some()); -/// +/// /// // If we drop `tx`, then it will fail. /// drop(tx); /// assert!(tx_weak.clone().upgrade().is_none()); @@ -230,6 +230,82 @@ impl Receiver { poll_fn(|cx| self.chan.recv(cx)).await } + /// Receives the next values for this receiver and extends `buffer`. + /// + /// This method extends `buffer` by no more than a fixed number of values + /// as specified by `limit`. If `limit` is zero, the function immediately + /// returns `0`. The return value is the number of values added to `buffer`. + /// + /// For `limit > 0`, if there are no messages in the channel's queue, but + /// the channel has not yet been closed, this method will sleep until a + /// message is sent or the channel is closed. Note that if [`close`] is + /// called, but there are still outstanding [`Permits`] from before it was + /// closed, the channel is not considered closed by `recv_many` until the + /// permits are released. + /// + /// For non-zero values of `limit`, this method will never return `0` unless + /// the channel has been closed and there are no remaining messages in the + /// channel's queue. This indicates that no further values can ever be + /// received from this `Receiver`. The channel is closed when all senders + /// have been dropped, or when [`close`] is called. + /// + /// The capacity of `buffer` is increased as needed. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv_many` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. + /// + /// [`close`]: Self::close + /// [`Permits`]: struct@crate::sync::mpsc::Permit + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut buffer: Vec<&str> = Vec::with_capacity(2); + /// let limit = 2; + /// let (tx, mut rx) = mpsc::channel(100); + /// let tx2 = tx.clone(); + /// tx2.send("first").await.unwrap(); + /// tx2.send("second").await.unwrap(); + /// tx2.send("third").await.unwrap(); + /// + /// // Call `recv_many` to receive up to `limit` (2) values. + /// assert_eq!(2, rx.recv_many(&mut buffer, limit).await); + /// assert_eq!(vec!["first", "second"], buffer); + /// + /// // If the buffer is full, the next call to `recv_many` + /// // reserves additional capacity. + /// assert_eq!(1, rx.recv_many(&mut buffer, 1).await); + /// + /// tokio::spawn(async move { + /// tx.send("fourth").await.unwrap(); + /// }); + /// + /// // 'tx' is dropped, but `recv_many` + /// // is guaranteed not to return 0 as the channel + /// // is not yet closed. + /// assert_eq!(1, rx.recv_many(&mut buffer, 1).await); + /// assert_eq!(vec!["first", "second", "third", "fourth"], buffer); + /// + /// // Once the last sender is dropped, the channel is + /// // closed and `recv_many` returns 0, capacity unchanged. + /// drop(tx2); + /// assert_eq!(0, rx.recv_many(&mut buffer, limit).await); + /// assert_eq!(vec!["first", "second", "third", "fourth"], buffer); + /// } + /// ``` + pub async fn recv_many(&mut self, buffer: &mut Vec, limit: usize) -> usize { + use crate::future::poll_fn; + poll_fn(|cx| self.chan.recv_many(cx, buffer, limit)).await + } + /// Tries to receive the next value for this receiver. /// /// This method returns the [`Empty`] error if the channel is currently @@ -1072,7 +1148,7 @@ impl Sender { /// #[tokio::main] /// async fn main() { /// let (tx, _rx) = mpsc::channel::<()>(5); - /// + /// /// // both max capacity and capacity are the same at first /// assert_eq!(tx.max_capacity(), 5); /// assert_eq!(tx.capacity(), 5); diff --git a/tokio/src/sync/mpsc/chan.rs b/tokio/src/sync/mpsc/chan.rs index 2540e3c2ffd..c05a4abb7c0 100644 --- a/tokio/src/sync/mpsc/chan.rs +++ b/tokio/src/sync/mpsc/chan.rs @@ -41,6 +41,8 @@ pub(crate) trait Semaphore { fn add_permit(&self); + fn add_permits(&self, n: usize); + fn close(&self); fn is_closed(&self) -> bool; @@ -293,6 +295,91 @@ impl Rx { }) } + /// Receives up to `limit` values into `buffer` + /// + /// For `limit > 0`, receives up to limit values into `buffer`. + /// For `limit == 0`, immediately returns Ready(0). + pub(crate) fn recv_many( + &mut self, + cx: &mut Context<'_>, + buffer: &mut Vec, + limit: usize, + ) -> Poll { + use super::block::Read; + + ready!(crate::trace::trace_leaf(cx)); + + // Keep track of task budget + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); + + if limit == 0 { + coop.made_progress(); + return Ready(0usize); + } + + let mut remaining = limit; + let initial_length = buffer.len(); + + self.inner.rx_fields.with_mut(|rx_fields_ptr| { + let rx_fields = unsafe { &mut *rx_fields_ptr }; + macro_rules! try_recv { + () => { + while remaining > 0 { + match rx_fields.list.pop(&self.inner.tx) { + Some(Read::Value(value)) => { + remaining -= 1; + buffer.push(value); + } + + Some(Read::Closed) => { + let number_added = buffer.len() - initial_length; + if number_added > 0 { + self.inner.semaphore.add_permits(number_added); + } + // TODO: This check may not be required as it most + // likely can only return `true` at this point. A + // channel is closed when all tx handles are + // dropped. Dropping a tx handle releases memory, + // which ensures that if dropping the tx handle is + // visible, then all messages sent are also visible. + assert!(self.inner.semaphore.is_idle()); + coop.made_progress(); + return Ready(number_added); + } + + None => { + break; // fall through + } + } + } + let number_added = buffer.len() - initial_length; + if number_added > 0 { + self.inner.semaphore.add_permits(number_added); + coop.made_progress(); + return Ready(number_added); + } + }; + } + + try_recv!(); + + self.inner.rx_waker.register_by_ref(cx.waker()); + + // It is possible that a value was pushed between attempting to read + // and registering the task, so we have to check the channel a + // second time here. + try_recv!(); + + if rx_fields.rx_closed && self.inner.semaphore.is_idle() { + assert!(buffer.is_empty()); + coop.made_progress(); + Ready(0usize) + } else { + Pending + } + }) + } + /// Try to receive the next value. pub(crate) fn try_recv(&mut self) -> Result { use super::list::TryPopResult; @@ -389,6 +476,10 @@ impl Semaphore for bounded::Semaphore { self.semaphore.release(1); } + fn add_permits(&self, n: usize) { + self.semaphore.release(n) + } + fn is_idle(&self) -> bool { self.semaphore.available_permits() == self.bound } @@ -414,6 +505,15 @@ impl Semaphore for unbounded::Semaphore { } } + fn add_permits(&self, n: usize) { + let prev = self.0.fetch_sub(n << 1, Release); + + if (prev >> 1) < n { + // Something went wrong + process::abort(); + } + } + fn is_idle(&self) -> bool { self.0.load(Acquire) >> 1 == 0 } diff --git a/tokio/src/sync/mpsc/unbounded.rs b/tokio/src/sync/mpsc/unbounded.rs index 7ec5faf5b05..81d3f701b3c 100644 --- a/tokio/src/sync/mpsc/unbounded.rs +++ b/tokio/src/sync/mpsc/unbounded.rs @@ -34,10 +34,10 @@ pub struct UnboundedSender { /// async fn main() { /// let (tx, _rx) = unbounded_channel::(); /// let tx_weak = tx.downgrade(); -/// +/// /// // Upgrading will succeed because `tx` still exists. /// assert!(tx_weak.upgrade().is_some()); -/// +/// /// // If we drop `tx`, then it will fail. /// drop(tx); /// assert!(tx_weak.clone().upgrade().is_none()); @@ -172,6 +172,79 @@ impl UnboundedReceiver { poll_fn(|cx| self.poll_recv(cx)).await } + /// Receives the next values for this receiver and extends `buffer`. + /// + /// This method extends `buffer` by no more than a fixed number of values + /// as specified by `limit`. If `limit` is zero, the function returns + /// immediately with `0`. The return value is the number of values added to + /// `buffer`. + /// + /// For `limit > 0`, if there are no messages in the channel's queue, + /// but the channel has not yet been closed, this method will sleep + /// until a message is sent or the channel is closed. + /// + /// For non-zero values of `limit`, this method will never return `0` unless + /// the channel has been closed and there are no remaining messages in the + /// channel's queue. This indicates that no further values can ever be + /// received from this `Receiver`. The channel is closed when all senders + /// have been dropped, or when [`close`] is called. + /// + /// The capacity of `buffer` is increased as needed. + /// + /// # Cancel safety + /// + /// This method is cancel safe. If `recv_many` is used as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, it is guaranteed that no messages were received on this + /// channel. + /// + /// [`close`]: Self::close + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut buffer: Vec<&str> = Vec::with_capacity(2); + /// let limit = 2; + /// let (tx, mut rx) = mpsc::unbounded_channel(); + /// let tx2 = tx.clone(); + /// tx2.send("first").unwrap(); + /// tx2.send("second").unwrap(); + /// tx2.send("third").unwrap(); + /// + /// // Call `recv_many` to receive up to `limit` (2) values. + /// assert_eq!(2, rx.recv_many(&mut buffer, limit).await); + /// assert_eq!(vec!["first", "second"], buffer); + /// + /// // If the buffer is full, the next call to `recv_many` + /// // reserves additional capacity. + /// assert_eq!(1, rx.recv_many(&mut buffer, limit).await); + /// + /// tokio::spawn(async move { + /// tx.send("fourth").unwrap(); + /// }); + /// + /// // 'tx' is dropped, but `recv_many` + /// // is guaranteed not to return 0 as the channel + /// // is not yet closed. + /// assert_eq!(1, rx.recv_many(&mut buffer, limit).await); + /// assert_eq!(vec!["first", "second", "third", "fourth"], buffer); + /// + /// // Once the last sender is dropped, the channel is + /// // closed and `recv_many` returns 0, capacity unchanged. + /// drop(tx2); + /// assert_eq!(0, rx.recv_many(&mut buffer, limit).await); + /// assert_eq!(vec!["first", "second", "third", "fourth"], buffer); + /// } + /// ``` + pub async fn recv_many(&mut self, buffer: &mut Vec, limit: usize) -> usize { + use crate::future::poll_fn; + poll_fn(|cx| self.chan.recv_many(cx, buffer, limit)).await + } + /// Tries to receive the next value for this receiver. /// /// This method returns the [`Empty`] error if the channel is currently diff --git a/tokio/tests/sync_mpsc.rs b/tokio/tests/sync_mpsc.rs index d2b7078b4ea..a5c15a4cfc6 100644 --- a/tokio/tests/sync_mpsc.rs +++ b/tokio/tests/sync_mpsc.rs @@ -120,6 +120,34 @@ async fn async_send_recv_with_buffer() { assert_eq!(None, rx.recv().await); } +#[tokio::test] +#[cfg(feature = "full")] +async fn async_send_recv_many_with_buffer() { + let (tx, mut rx) = mpsc::channel(2); + let mut buffer = Vec::::with_capacity(3); + + // With `limit=0` does not sleep, returns immediately + assert_eq!(0, rx.recv_many(&mut buffer, 0).await); + + let handle = tokio::spawn(async move { + assert_ok!(tx.send(1).await); + assert_ok!(tx.send(2).await); + assert_ok!(tx.send(7).await); + assert_ok!(tx.send(0).await); + }); + + let limit = 3; + let mut recv_count = 0usize; + while recv_count < 4 { + recv_count += rx.recv_many(&mut buffer, limit).await; + assert_eq!(buffer.len(), recv_count); + } + + assert_eq!(vec![1, 2, 7, 0], buffer); + assert_eq!(0, rx.recv_many(&mut buffer, limit).await); + handle.await.unwrap(); +} + #[tokio::test] #[cfg(feature = "full")] async fn start_send_past_cap() { @@ -176,6 +204,139 @@ async fn send_recv_unbounded() { assert!(rx.recv().await.is_none()); } +#[maybe_tokio_test] +async fn send_recv_many_unbounded() { + let (tx, mut rx) = mpsc::unbounded_channel::(); + + let mut buffer: Vec = Vec::new(); + + // With `limit=0` does not sleep, returns immediately + rx.recv_many(&mut buffer, 0).await; + assert_eq!(0, buffer.len()); + + assert_ok!(tx.send(7)); + assert_ok!(tx.send(13)); + assert_ok!(tx.send(100)); + assert_ok!(tx.send(1002)); + + rx.recv_many(&mut buffer, 0).await; + assert_eq!(0, buffer.len()); + + let mut count = 0; + while count < 4 { + count += rx.recv_many(&mut buffer, 1).await; + } + assert_eq!(count, 4); + assert_eq!(vec![7, 13, 100, 1002], buffer); + let final_capacity = buffer.capacity(); + assert!(final_capacity > 0); + + buffer.clear(); + + assert_ok!(tx.send(5)); + assert_ok!(tx.send(6)); + assert_ok!(tx.send(7)); + assert_ok!(tx.send(2)); + + // Re-use existing capacity + count = rx.recv_many(&mut buffer, 32).await; + + assert_eq!(final_capacity, buffer.capacity()); + assert_eq!(count, 4); + assert_eq!(vec![5, 6, 7, 2], buffer); + + drop(tx); + + // recv_many will immediately return zero if the channel + // is closed and no more messages are waiting + assert_eq!(0, rx.recv_many(&mut buffer, 4).await); + assert!(rx.recv().await.is_none()); +} + +#[tokio::test] +#[cfg(feature = "full")] +async fn send_recv_many_bounded_capacity() { + let mut buffer: Vec = Vec::with_capacity(9); + let limit = buffer.capacity(); + let (tx, mut rx) = mpsc::channel(100); + + let mut expected: Vec = (0..limit) + .map(|x: usize| format!("{x}")) + .collect::>(); + for x in expected.clone() { + tx.send(x).await.unwrap() + } + tx.send("one more".to_string()).await.unwrap(); + + // Here `recv_many` receives all but the last value; + // the initial capacity is adequate, so the buffer does + // not increase in side. + assert_eq!(buffer.capacity(), rx.recv_many(&mut buffer, limit).await); + assert_eq!(expected, buffer); + assert_eq!(limit, buffer.capacity()); + + // Receive up more values: + assert_eq!(1, rx.recv_many(&mut buffer, limit).await); + assert!(buffer.capacity() > limit); + expected.push("one more".to_string()); + assert_eq!(expected, buffer); + + tokio::spawn(async move { + tx.send("final".to_string()).await.unwrap(); + }); + + // 'tx' is dropped, but `recv_many` is guaranteed not + // to return 0 as the channel has outstanding permits + assert_eq!(1, rx.recv_many(&mut buffer, limit).await); + expected.push("final".to_string()); + assert_eq!(expected, buffer); + // The channel is now closed and `recv_many` returns 0. + assert_eq!(0, rx.recv_many(&mut buffer, limit).await); + assert_eq!(expected, buffer); +} + +#[tokio::test] +#[cfg(feature = "full")] +async fn send_recv_many_unbounded_capacity() { + let mut buffer: Vec = Vec::with_capacity(9); // capacity >= 9 + let limit = buffer.capacity(); + let (tx, mut rx) = mpsc::unbounded_channel(); + + let mut expected: Vec = (0..limit) + .map(|x: usize| format!("{x}")) + .collect::>(); + for x in expected.clone() { + tx.send(x).unwrap() + } + tx.send("one more".to_string()).unwrap(); + + // Here `recv_many` receives all but the last value; + // the initial capacity is adequate, so the buffer does + // not increase in side. + assert_eq!(buffer.capacity(), rx.recv_many(&mut buffer, limit).await); + assert_eq!(expected, buffer); + assert_eq!(limit, buffer.capacity()); + + // Receive up more values: + assert_eq!(1, rx.recv_many(&mut buffer, limit).await); + assert!(buffer.capacity() > limit); + expected.push("one more".to_string()); + assert_eq!(expected, buffer); + + tokio::spawn(async move { + tx.send("final".to_string()).unwrap(); + }); + + // 'tx' is dropped, but `recv_many` is guaranteed not + // to return 0 as the channel has outstanding permits + assert_eq!(1, rx.recv_many(&mut buffer, limit).await); + expected.push("final".to_string()); + assert_eq!(expected, buffer); + // The channel is now closed and `recv_many` returns 0. + assert_eq!(0, rx.recv_many(&mut buffer, limit).await); + assert_eq!(expected, buffer); +} + #[tokio::test] #[cfg(feature = "full")] async fn async_send_recv_unbounded() {