future_queue/
future_queue_grouped.rs

1// Copyright (c) The buffer-unordered-weighted Contributors
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{
5    global_weight::GlobalWeight, peekable_fused::PeekableFused, slots::SlotReservations,
6    FutureQueueContext,
7};
8use debug_ignore::DebugIgnore;
9use fnv::FnvHashMap;
10use futures_util::{
11    ready,
12    stream::{Fuse, FusedStream, FuturesUnordered},
13    Future, Stream, StreamExt,
14};
15use pin_project_lite::pin_project;
16use std::{
17    borrow::Borrow,
18    collections::VecDeque,
19    fmt,
20    hash::Hash,
21    pin::Pin,
22    task::{Context, Poll},
23};
24
25pin_project! {
26    /// Stream for the [`future_queue_grouped`](crate::StreamExt::future_queue_grouped) method.
27    #[must_use = "streams do nothing unless polled"]
28    pub struct FutureQueueGrouped<St, K>
29    where
30        St: Stream,
31        St::Item: GroupedWeightedFuture,
32     {
33        #[pin]
34        stream: PeekableFused<Fuse<St>>,
35        #[pin]
36        in_progress_queue: PeekableFused<InProgressQueue<St>>,
37        global_weight: GlobalWeight,
38        slots: SlotReservations,
39        group_store: GroupStore<<St::Item as GroupedWeightedFuture>::Q, K, <St::Item as GroupedWeightedFuture>::F>,
40    }
41}
42
43type InProgressQueue<St> = FuturesUnordered<
44    FutureWithGW<
45        <<St as Stream>::Item as GroupedWeightedFuture>::Future,
46        <<St as Stream>::Item as GroupedWeightedFuture>::Q,
47    >,
48>;
49
50impl<St, K> fmt::Debug for FutureQueueGrouped<St, K>
51where
52    St: Stream + fmt::Debug,
53    St::Item: GroupedWeightedFuture,
54    <St::Item as GroupedWeightedFuture>::Future: fmt::Debug,
55    <<St::Item as GroupedWeightedFuture>::Future as Future>::Output: fmt::Debug,
56    K: fmt::Debug,
57    <St::Item as GroupedWeightedFuture>::Q: fmt::Debug,
58{
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.debug_struct("FutureQueueGrouped")
61            .field("stream", &self.stream)
62            .field("in_progress_queue", &self.in_progress_queue)
63            .field("global_weight", &self.global_weight)
64            .field("slots", &self.slots)
65            .field("group_store", &self.group_store)
66            .finish()
67    }
68}
69
70impl<St, K> FutureQueueGrouped<St, K>
71where
72    St: Stream,
73    St::Item: GroupedWeightedFuture,
74    <St::Item as GroupedWeightedFuture>::Q: Eq + Hash + fmt::Debug,
75    K: Eq + Hash + fmt::Debug + Borrow<<St::Item as GroupedWeightedFuture>::Q>,
76{
77    pub(super) fn new(
78        stream: St,
79        max_global_weight: usize,
80        id_data: impl IntoIterator<Item = (K, usize)>,
81    ) -> Self {
82        let id_data_store = GroupStore::new(id_data);
83        Self {
84            stream: PeekableFused::new(stream.fuse()),
85            in_progress_queue: PeekableFused::new(FuturesUnordered::new()),
86            global_weight: GlobalWeight::new(max_global_weight),
87            slots: SlotReservations::with_capacity(max_global_weight),
88            group_store: id_data_store,
89        }
90    }
91
92    /// Sets a mode where extra internal verifications are performed.
93    #[doc(hidden)]
94    pub fn set_extra_verify(&mut self, verify: bool) {
95        self.slots.set_check_reserved(verify);
96        for data in self.group_store.group_data.values_mut() {
97            data.slots.set_check_reserved(verify);
98        }
99    }
100
101    /// Returns the maximum weight of futures allowed to be run by this adaptor.
102    pub fn max_global_weight(&self) -> usize {
103        self.global_weight.max()
104    }
105
106    /// Returns the current global weight of futures.
107    pub fn current_global_weight(&self) -> usize {
108        self.global_weight.current()
109    }
110
111    /// Returns the maximum weight of futures allowed to be run within this group.
112    pub fn max_group_weight<Q>(&self, id: &Q) -> Option<usize>
113    where
114        Q: Eq + Hash + fmt::Debug + ?Sized,
115        K: Borrow<Q>,
116    {
117        self.group_store
118            .group_data
119            .get(id)
120            .map(|id_data| id_data.max_weight)
121    }
122
123    /// Returns the current weight of futures being run within this group.
124    pub fn current_group_weight<Q>(&self, id: &Q) -> Option<usize>
125    where
126        Q: Eq + Hash + fmt::Debug + ?Sized,
127        K: Borrow<Q>,
128    {
129        self.group_store
130            .group_data
131            .get(id)
132            .map(|id_data| id_data.max_weight)
133    }
134
135    /// Acquires a reference to the underlying sink or stream that this combinator is
136    /// pulling from.
137    pub fn get_ref(&self) -> &St {
138        self.stream.get_ref().get_ref()
139    }
140
141    /// Acquires a mutable reference to the underlying sink or stream that this
142    /// combinator is pulling from.
143    ///
144    /// Note that care must be taken to avoid tampering with the state of the
145    /// sink or stream which may otherwise confuse this combinator.
146    pub fn get_mut(&mut self) -> &mut St {
147        self.stream.get_mut().get_mut()
148    }
149
150    /// Acquires a pinned mutable reference to the underlying sink or stream that this
151    /// combinator is pulling from.
152    ///
153    /// Note that care must be taken to avoid tampering with the state of the
154    /// sink or stream which may otherwise confuse this combinator.
155    pub fn get_pin_mut(self: Pin<&mut Self>) -> core::pin::Pin<&mut St> {
156        self.project().stream.get_pin_mut().get_pin_mut()
157    }
158
159    /// Consumes this combinator, returning the underlying sink or stream.
160    ///
161    /// Note that this may discard intermediate state of this combinator, so
162    /// care should be taken to avoid losing resources when this is called.
163    pub fn into_inner(self) -> St {
164        self.stream.into_inner().into_inner()
165    }
166
167    // ---
168    // Helper methods
169    // ---
170
171    // This returns true if any new futures were added to the in_progress_queue.
172    #[allow(clippy::type_complexity)]
173    fn poll_pop_in_progress(
174        self: Pin<&mut Self>,
175        cx: &mut Context<'_>,
176    ) -> Poll<(
177        Option<<<St::Item as GroupedWeightedFuture>::Future as Future>::Output>,
178        bool,
179    )> {
180        let mut this = self.project();
181
182        match ready!(this.in_progress_queue.poll_next_unpin(cx)) {
183            Some((weight, global_slot, id_and_group_slot, output)) => {
184                this.global_weight.sub_weight(weight);
185                this.slots.release(global_slot);
186
187                let mut any_queued = false;
188
189                if let Some((id, group_slot)) = id_and_group_slot {
190                    let data = this.group_store.get_id_mut_or_unwrap(&id);
191                    data.sub_weight(&id, weight);
192                    data.slots.release(group_slot);
193
194                    // Can we queue up additional futures from the queued ones for this ID?
195                    while let Some(&(weight, _, _)) = data.queued.front() {
196                        if this.global_weight.has_space_for(weight) && data.has_space_for(weight) {
197                            // The future can be queued up.
198                            let (weight, id, future_fn) = data.queued.pop_front().unwrap();
199                            this.global_weight.add_weight(weight);
200                            data.add_weight(&id, weight);
201
202                            let global_slot = this.slots.reserve();
203                            let group_slot = data.slots.reserve();
204
205                            let cx = FutureQueueContext {
206                                global_slot,
207                                group_slot: Some(group_slot),
208                            };
209                            let future = future_fn.0(cx);
210
211                            this.in_progress_queue.get_ref().push(FutureWithGW::new(
212                                weight,
213                                global_slot,
214                                Some((id, group_slot)),
215                                future,
216                            ));
217                            any_queued = true;
218                        } else {
219                            // Further futures cannot be queued up since doing so would cause one or
220                            // both of the overall weights to be exceeded -- leave them alone and
221                            // exit the loop.
222                            break;
223                        }
224                    }
225                }
226
227                Poll::Ready((Some(output), any_queued))
228            }
229            None => Poll::Ready((None, false)),
230        }
231    }
232}
233
234impl<St, K> Stream for FutureQueueGrouped<St, K>
235where
236    St: Stream,
237    St::Item: GroupedWeightedFuture,
238    <St::Item as GroupedWeightedFuture>::Q: Eq + Hash + fmt::Debug,
239    K: Eq + Hash + fmt::Debug + Borrow<<St::Item as GroupedWeightedFuture>::Q>,
240{
241    type Item = <<St::Item as GroupedWeightedFuture>::Future as Future>::Output;
242
243    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
244        // First, attempt to pull the next value from the in progress queue.
245        let (return_output, mut any_queued) = ready!(self.as_mut().poll_pop_in_progress(cx));
246
247        let mut this = self.as_mut().project();
248
249        // Next, let's try to spawn off as many futures as possible by filling up our queue of
250        // futures.
251
252        while let Poll::Ready(Some(weighted_future)) = this.stream.as_mut().poll_peek(cx) {
253            let weight = weighted_future.weight();
254            if !this.global_weight.has_space_for(weight) {
255                // Global limits would be exceeded, break out of the loop. Consider this
256                // item next time.
257                break;
258            }
259            // We *do not* care about the group limit before pulling this item out. That's because
260            // if the group is full, it will be queued up in the group queue.
261
262            // Grab the next element from the queue.
263            let (weight, id, future_fn) = match this.stream.as_mut().poll_next(cx) {
264                Poll::Ready(Some(weighted_future)) => weighted_future.into_components(),
265                _ => unreachable!("we just peeked at this item"),
266            };
267
268            if let Some(id) = id {
269                // Is this group full?
270                let data = this.group_store.get_id_mut_or_unwrap(&id);
271                if data.has_space_for(weight) {
272                    this.global_weight.add_weight(weight);
273                    data.add_weight(&id, weight);
274
275                    let global_slot = this.slots.reserve();
276                    let group_slot = data.slots.reserve();
277
278                    let cx = FutureQueueContext {
279                        global_slot,
280                        group_slot: Some(group_slot),
281                    };
282                    let future = future_fn(cx);
283                    this.in_progress_queue.get_ref().push(FutureWithGW::new(
284                        weight,
285                        global_slot,
286                        Some((id, group_slot)),
287                        future,
288                    ));
289                    any_queued = true;
290                } else {
291                    data.queued.push_back((weight, id, DebugIgnore(future_fn)));
292                }
293            } else {
294                // No ID associated with this future.
295                this.global_weight.add_weight(weight);
296
297                let global_slot = this.slots.reserve();
298                let cx = FutureQueueContext {
299                    global_slot,
300                    group_slot: None,
301                };
302                let future = future_fn(cx);
303
304                this.in_progress_queue.get_ref().push(FutureWithGW::new(
305                    weight,
306                    global_slot,
307                    None,
308                    future,
309                ));
310                any_queued = true;
311            }
312        }
313
314        if any_queued {
315            // Start any futures that were just queued up. If this returns Pending, then that's fine --
316            // the task will be scheduled on the waker.
317            let _ = this.in_progress_queue.as_mut().poll_peek(cx);
318        }
319
320        if let Some(output) = return_output {
321            // A value was returned from the in-progress queue.
322            Poll::Ready(Some(output))
323        } else {
324            match (
325                self.stream.is_done(),
326                self.in_progress_queue.is_terminated(),
327            ) {
328                (true, true) => {
329                    // No more futures left to schedule. (Note that poll_pop_in_progress would have
330                    // drained all futures in any queue.)
331                    debug_assert_eq!(
332                        self.group_store.num_queued_futures(),
333                        0,
334                        "no futures should be left in the queue"
335                    );
336                    Poll::Ready(None)
337                }
338                (false, true) => {
339                    // The in-progress queue is empty, but the stream is still pending.
340                    // (Note that Poll::Pending is OK to return here because this can only happen in
341                    // the Poll::Pending case above.)
342                    Poll::Pending
343                }
344                (_, false) => {
345                    // There are still futures in the in-progress queue. We need to poll the
346                    // in-progress queue to start any futures in it.
347                    let (output, any_queued) = ready!(self.as_mut().poll_pop_in_progress(cx));
348                    if any_queued {
349                        // It's possible that poll_pop_in_progress might have added more futures to the queue.
350                        let this = self.project();
351                        let _ = this.in_progress_queue.poll_peek(cx);
352                    }
353                    Poll::Ready(output)
354                }
355            }
356        }
357    }
358
359    fn size_hint(&self) -> (usize, Option<usize>) {
360        // The minimum size is the in progress queue + any queued futures.
361        let queue_len =
362            self.in_progress_queue.size_hint().0 + self.group_store.num_queued_futures();
363        let (lower, upper) = self.stream.size_hint();
364        let lower = lower.saturating_add(queue_len);
365        let upper = match upper {
366            Some(x) => x.checked_add(queue_len),
367            None => None,
368        };
369        (lower, upper)
370    }
371}
372
373struct GroupStore<Q, K, F> {
374    group_data: FnvHashMap<K, GroupData<Q, F>>,
375}
376
377impl<Q: fmt::Debug, K: fmt::Debug, F> fmt::Debug for GroupStore<Q, K, F> {
378    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
379        f.debug_struct("GroupStore")
380            .field("group_data", &self.group_data)
381            .finish()
382    }
383}
384
385impl<Q, K, F> GroupStore<Q, K, F>
386where
387    Q: Hash + Eq + fmt::Debug,
388    K: Eq + Hash + fmt::Debug + Borrow<Q>,
389{
390    fn new(ids: impl IntoIterator<Item = (K, usize)>) -> Self {
391        let id_data = ids
392            .into_iter()
393            .map(|(id, weight)| {
394                let data = GroupData {
395                    current_weight: 0,
396                    max_weight: weight,
397                    slots: SlotReservations::with_capacity(weight),
398                    queued: VecDeque::new(),
399                };
400                (id, data)
401            })
402            .collect();
403
404        Self {
405            group_data: id_data,
406        }
407    }
408
409    fn get_id_mut_or_unwrap(&mut self, id: &Q) -> &mut GroupData<Q, F> {
410        if self.group_data.contains_key(id) {
411            // Can't just use get_mut above because we're going to run into
412            // https://doc.rust-lang.org/nomicon/lifetime-mismatch.html#improperly-reduced-borrows
413            // with the else branch.
414            self.group_data.get_mut(id).unwrap()
415        } else {
416            panic!(
417                "unknown semaphore ID: {:?} (known IDs: {:?})",
418                id,
419                self.group_data.keys()
420            );
421        }
422    }
423
424    fn num_queued_futures(&self) -> usize {
425        self.group_data.values().map(|data| data.queued.len()).sum()
426    }
427}
428
429struct GroupData<Q, F> {
430    current_weight: usize,
431    max_weight: usize,
432    slots: SlotReservations,
433    queued: VecDeque<(usize, Q, DebugIgnore<F>)>,
434}
435
436impl<Q: fmt::Debug, F> fmt::Debug for GroupData<Q, F> {
437    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
438        f.debug_struct("GroupData")
439            .field("current_weight", &self.current_weight)
440            .field("max_weight", &self.max_weight)
441            .field("slots", &self.slots)
442            .field("queued", &self.queued)
443            .finish()
444    }
445}
446
447impl<Q: fmt::Debug, Fut> GroupData<Q, Fut> {
448    fn has_space_for(&self, weight: usize) -> bool {
449        let weight = weight.min(self.max_weight);
450        self.current_weight <= self.max_weight - weight
451    }
452
453    // The ID is passed in only for its Debug impl.
454    fn add_weight(&mut self, id: &Q, weight: usize) {
455        let weight = weight.min(self.max_weight);
456        self.current_weight = self.current_weight.checked_add(weight).unwrap_or_else(|| {
457            panic!(
458                "future_queue_grouped: for id `{:?}`, added weight {} to current {}, overflowed",
459                id, weight, self.current_weight,
460            )
461        });
462    }
463
464    fn sub_weight(&mut self, id: &Q, weight: usize) {
465        let weight = weight.min(self.max_weight);
466        self.current_weight = self.current_weight.checked_sub(weight).unwrap_or_else(|| {
467            panic!(
468                "future_queue_grouped: for id `{:?}`, sub weight {} from current {}, underflowed",
469                id, weight, self.current_weight,
470            )
471        });
472    }
473}
474
475pin_project! {
476    #[must_use = "futures do nothing unless polled"]
477    struct FutureWithGW<Fut, Q> {
478        #[pin]
479        future: Fut,
480        weight: usize,
481        global_slot: u64,
482        // The second parameter is the group slot.
483        id_and_group_slot: Option<(Q, u64)>,
484    }
485}
486
487impl<Fut, Q> FutureWithGW<Fut, Q> {
488    pub fn new(
489        weight: usize,
490        global_slot: u64,
491        id_and_group_slot: Option<(Q, u64)>,
492        future: Fut,
493    ) -> Self {
494        Self {
495            future,
496            weight,
497            global_slot,
498            id_and_group_slot,
499        }
500    }
501}
502
503impl<Fut, Q> Future for FutureWithGW<Fut, Q>
504where
505    Fut: Future,
506{
507    type Output = (usize, u64, Option<(Q, u64)>, Fut::Output);
508    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
509        let this = self.project();
510
511        match this.future.poll(cx) {
512            Poll::Pending => Poll::Pending,
513            Poll::Ready(output) => Poll::Ready((
514                *this.weight,
515                *this.global_slot,
516                this.id_and_group_slot.take(),
517                output,
518            )),
519        }
520    }
521}
522
523/// A trait for types which can be converted into functions that return a
524/// `Future`, an optional group, and a weight.
525///
526/// Provided in case it's necessary. This trait is only implemented for `(usize, Option<Q>, impl Future)`.
527pub trait GroupedWeightedFuture: private::Sealed {
528    /// The function to obtain the future from.
529    type F: FnOnce(FutureQueueContext) -> Self::Future;
530
531    /// The associated `Future` type.
532    type Future: Future;
533
534    /// The associated key lookup type.
535    type Q;
536
537    /// Returns the weight.
538    fn weight(&self) -> usize;
539
540    /// Turns self into its components.
541    fn into_components(self) -> (usize, Option<Self::Q>, Self::F);
542}
543
544impl<F, Fut, Q> private::Sealed for (usize, Option<Q>, F)
545where
546    F: FnOnce(FutureQueueContext) -> Fut,
547    Fut: Future,
548{
549}
550
551impl<F, Fut, Q> GroupedWeightedFuture for (usize, Option<Q>, F)
552where
553    F: FnOnce(FutureQueueContext) -> Fut,
554    Fut: Future,
555{
556    type F = F;
557    type Future = Fut;
558    type Q = Q;
559
560    #[inline]
561    fn weight(&self) -> usize {
562        self.0
563    }
564
565    #[inline]
566    fn into_components(self) -> (usize, Option<Self::Q>, Self::F) {
567        self
568    }
569}
570
571pub(crate) mod private {
572    pub trait Sealed {}
573}