diff --git a/src/helpers/buffers/ordering_sender.rs b/src/helpers/buffers/ordering_sender.rs index 5ecc21c7f..ee73b3cb4 100644 --- a/src/helpers/buffers/ordering_sender.rs +++ b/src/helpers/buffers/ordering_sender.rs @@ -78,12 +78,13 @@ impl State { M::Size::USIZE, self.spare.get() ); + let open = self.accept_writes(); let b = &mut self.buf[self.written..]; - if M::Size::USIZE <= b.len() { + if open && M::Size::USIZE <= b.len() { self.written += M::Size::USIZE; m.serialize(GenericArray::from_mut_slice(&mut b[..M::Size::USIZE])); - if self.written + self.spare.get() >= self.buf.len() { + if !self.accept_writes() { Self::wake(&mut self.stream_ready); } Poll::Ready(()) @@ -111,6 +112,15 @@ impl State { self.closed = true; Self::wake(&mut self.stream_ready); } + + /// Returns `true` if more writes can be accepted by this sender. + /// If message size exceeds the remaining capacity, [`write`] may + /// still return `Poll::Pending` even if sender is open for writes. + /// + /// [`write`]: Self::write + fn accept_writes(&self) -> bool { + self.written + self.spare.get() < self.buf.len() + } } /// An saved waker for a given index. @@ -259,6 +269,11 @@ impl Waiting { /// Data less than the `write_size` threshold only becomes available to /// the stream when the sender is closed (with [`close`]). /// +/// Once `write_size` threshold has been reached, no subsequent writes +/// are allowed, until stream is polled. `OrderingSender` guarantees equal +/// size chunks will be sent to the stream when it is used to buffer +/// same-sized messages. +/// /// The `spare` capacity determines the size of messages that can be sent; /// see [`send`] for details. /// @@ -465,14 +480,14 @@ impl + Unpin> Stream for OrderedStream { #[cfg(all(test, any(unit_test, feature = "shuttle")))] mod test { - use std::{iter::zip, num::NonZeroUsize}; + use std::{future::poll_fn, iter::zip, num::NonZeroUsize, pin::pin}; use futures::{ future::{join, join3, join_all}, stream::StreamExt, FutureExt, }; - use futures_util::future::try_join_all; + use futures_util::future::{poll_immediate, try_join_all}; use generic_array::GenericArray; use rand::Rng; #[cfg(feature = "shuttle")] @@ -677,4 +692,41 @@ mod test { .unwrap(); }); } + + /// If sender is at capacity, but still have some bytes inside spare, we block the sends + /// until the stream is flushed. That ensures `OrderingSender` yields the equal-sized + /// chunks. + /// + /// This behavior is important for channels working in parallel `[parallel_join]` and wrapped + /// inside a windowed execution [`seq_join`]. Not enforcing this leads to some channels moving + /// forward faster and eventually getting outside of active work window. See [`issue`] for + /// more details. + /// + /// [`seq_join`]: crate::seq_join::SeqJoin::try_join + /// [`parallel_join`]: crate::seq_join::SeqJoin::parallel_join + /// [`issue`]: https://github.com/private-attribution/ipa/issues/843 + #[test] + fn reader_blocks_writers() { + const SZ: usize = <::Size as Unsigned>::USIZE; + run(|| async { + const CAPACITY: usize = SZ + 1; + const SPARE: usize = 2 * SZ; + let sender = + OrderingSender::new(CAPACITY.try_into().unwrap(), SPARE.try_into().unwrap()); + + // enough bytes in the buffer to hold 2 items + for i in 0..2 { + sender + .send(i, Fp32BitPrime::truncate_from(u128::try_from(i).unwrap())) + .await; + } + + // spare has enough capacity, but buffer is considered full. + let mut f = pin!(sender.send(2, Fp32BitPrime::truncate_from(2_u128))); + assert_eq!(None, poll_immediate(&mut f).await); + + drop(poll_fn(|ctx| sender.take_next(ctx)).await); + assert_eq!(Some(()), poll_immediate(f).await); + }); + } }