use crate::{global_weight::GlobalWeight, peekable_fused::PeekableFused};
use fnv::FnvHashMap;
use futures_util::{
ready,
stream::{Fuse, FusedStream, FuturesUnordered},
Future, Stream, StreamExt,
};
use pin_project_lite::pin_project;
use std::{
borrow::Borrow,
collections::VecDeque,
fmt,
hash::Hash,
pin::Pin,
task::{Context, Poll},
};
pin_project! {
#[must_use = "streams do nothing unless polled"]
pub struct FutureQueueGrouped<St, K>
where
St: Stream,
St::Item: GroupedWeightedFuture,
{
#[pin]
stream: PeekableFused<Fuse<St>>,
#[pin]
in_progress_queue: PeekableFused<InProgressQueue<St>>,
global_weight: GlobalWeight,
group_store: GroupStore<<St::Item as GroupedWeightedFuture>::Q, K, <St::Item as GroupedWeightedFuture>::Future>,
}
}
type InProgressQueue<St> = FuturesUnordered<
FutureWithGW<
<<St as Stream>::Item as GroupedWeightedFuture>::Future,
<<St as Stream>::Item as GroupedWeightedFuture>::Q,
>,
>;
impl<St, K> fmt::Debug for FutureQueueGrouped<St, K>
where
St: Stream + fmt::Debug,
St::Item: GroupedWeightedFuture,
<St::Item as GroupedWeightedFuture>::Future: fmt::Debug,
<<St::Item as GroupedWeightedFuture>::Future as Future>::Output: fmt::Debug,
K: fmt::Debug,
<St::Item as GroupedWeightedFuture>::Q: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FutureQueueGrouped")
.field("stream", &self.stream)
.field("in_progress_queue", &self.in_progress_queue)
.field("global_weight", &self.global_weight)
.field("group_store", &self.group_store)
.finish()
}
}
impl<St, K> FutureQueueGrouped<St, K>
where
St: Stream,
St::Item: GroupedWeightedFuture,
<St::Item as GroupedWeightedFuture>::Q: Eq + Hash + fmt::Debug,
K: Eq + Hash + fmt::Debug + Borrow<<St::Item as GroupedWeightedFuture>::Q>,
{
pub(super) fn new(
stream: St,
max_global_weight: usize,
id_data: impl IntoIterator<Item = (K, usize)>,
) -> Self {
let id_data_store = GroupStore::new(id_data);
Self {
stream: PeekableFused::new(stream.fuse()),
in_progress_queue: PeekableFused::new(FuturesUnordered::new()),
global_weight: GlobalWeight::new(max_global_weight),
group_store: id_data_store,
}
}
pub fn max_global_weight(&self) -> usize {
self.global_weight.max()
}
pub fn current_global_weight(&self) -> usize {
self.global_weight.current()
}
pub fn max_group_weight<Q>(&self, id: &Q) -> Option<usize>
where
Q: Eq + Hash + fmt::Debug + ?Sized,
K: Borrow<Q>,
{
self.group_store
.group_data
.get(id)
.map(|id_data| id_data.max_weight)
}
pub fn current_group_weight<Q>(&self, id: &Q) -> Option<usize>
where
Q: Eq + Hash + fmt::Debug + ?Sized,
K: Borrow<Q>,
{
self.group_store
.group_data
.get(id)
.map(|id_data| id_data.max_weight)
}
pub fn get_ref(&self) -> &St {
self.stream.get_ref().get_ref()
}
pub fn get_mut(&mut self) -> &mut St {
self.stream.get_mut().get_mut()
}
pub fn get_pin_mut(self: Pin<&mut Self>) -> core::pin::Pin<&mut St> {
self.project().stream.get_pin_mut().get_pin_mut()
}
pub fn into_inner(self) -> St {
self.stream.into_inner().into_inner()
}
#[allow(clippy::type_complexity)]
fn poll_pop_in_progress(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<(
Option<<<St::Item as GroupedWeightedFuture>::Future as Future>::Output>,
bool,
)> {
let mut this = self.project();
match ready!(this.in_progress_queue.poll_next_unpin(cx)) {
Some((weight, id, output)) => {
this.global_weight.sub_weight(weight);
let mut any_queued = false;
if let Some(id) = id {
let data = this.group_store.get_id_mut_or_unwrap(&id);
data.sub_weight(&id, weight);
while let Some(&(weight, _, _)) = data.queued.front() {
if this.global_weight.has_space_for(weight) && data.has_space_for(weight) {
let (weight, id, future) = data.queued.pop_front().unwrap();
this.global_weight.add_weight(weight);
data.add_weight(&id, weight);
this.in_progress_queue
.as_mut()
.get_pin_mut()
.push(FutureWithGW::new(weight, Some(id), future));
any_queued = true;
} else {
break;
}
}
}
Poll::Ready((Some(output), any_queued))
}
None => Poll::Ready((None, false)),
}
}
}
impl<St, K> Stream for FutureQueueGrouped<St, K>
where
St: Stream,
St::Item: GroupedWeightedFuture,
<St::Item as GroupedWeightedFuture>::Q: Eq + Hash + fmt::Debug,
K: Eq + Hash + fmt::Debug + Borrow<<St::Item as GroupedWeightedFuture>::Q>,
{
type Item = <<St::Item as GroupedWeightedFuture>::Future as Future>::Output;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let (return_output, mut any_queued) = ready!(self.as_mut().poll_pop_in_progress(cx));
let mut this = self.as_mut().project();
while let Poll::Ready(Some(weighted_future)) = this.stream.as_mut().poll_peek(cx) {
let weight = weighted_future.weight();
if !this.global_weight.has_space_for(weight) {
break;
}
let (weight, id, future) = match this.stream.as_mut().poll_next(cx) {
Poll::Ready(Some(weighted_future)) => weighted_future.into_components(),
_ => unreachable!("we just peeked at this item"),
};
if let Some(id) = id {
let data = this.group_store.get_id_mut_or_unwrap(&id);
if data.has_space_for(weight) {
this.global_weight.add_weight(weight);
data.add_weight(&id, weight);
this.in_progress_queue
.as_mut()
.get_pin_mut()
.push(FutureWithGW::new(weight, Some(id), future));
any_queued = true;
} else {
data.queued.push_back((weight, id, future));
}
} else {
this.global_weight.add_weight(weight);
this.in_progress_queue
.as_mut()
.get_pin_mut()
.push(FutureWithGW::new(weight, None, future));
any_queued = true;
}
}
if any_queued {
let _ = this.in_progress_queue.as_mut().poll_peek(cx);
}
if let Some(output) = return_output {
Poll::Ready(Some(output))
} else {
match (
self.stream.is_done(),
self.in_progress_queue.is_terminated(),
) {
(true, true) => {
debug_assert_eq!(
self.group_store.num_queued_futures(),
0,
"no futures should be left in the queue"
);
Poll::Ready(None)
}
(false, true) => {
Poll::Pending
}
(_, false) => {
let (output, any_queued) = ready!(self.as_mut().poll_pop_in_progress(cx));
if any_queued {
let this = self.project();
let _ = this.in_progress_queue.poll_peek(cx);
}
Poll::Ready(output)
}
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let queue_len =
self.in_progress_queue.size_hint().0 + self.group_store.num_queued_futures();
let (lower, upper) = self.stream.size_hint();
let lower = lower.saturating_add(queue_len);
let upper = match upper {
Some(x) => x.checked_add(queue_len),
None => None,
};
(lower, upper)
}
}
#[derive(Debug)]
struct GroupStore<Q, K, Fut> {
group_data: FnvHashMap<K, GroupData<Q, Fut>>,
}
impl<Q, K, Fut> GroupStore<Q, K, Fut>
where
Q: Hash + Eq + fmt::Debug,
K: Eq + Hash + fmt::Debug + Borrow<Q>,
{
fn new(ids: impl IntoIterator<Item = (K, usize)>) -> Self {
let id_data = ids
.into_iter()
.map(|(id, weight)| {
let data = GroupData {
current_weight: 0,
max_weight: weight,
queued: VecDeque::new(),
};
(id, data)
})
.collect();
Self {
group_data: id_data,
}
}
fn get_id_mut_or_unwrap(&mut self, id: &Q) -> &mut GroupData<Q, Fut> {
if self.group_data.contains_key(id) {
self.group_data.get_mut(id).unwrap()
} else {
panic!(
"unknown semaphore ID: {:?} (known IDs: {:?})",
id,
self.group_data.keys()
);
}
}
fn num_queued_futures(&self) -> usize {
self.group_data.values().map(|data| data.queued.len()).sum()
}
}
#[derive(Debug)]
struct GroupData<Q, Fut> {
current_weight: usize,
max_weight: usize,
queued: VecDeque<(usize, Q, Fut)>,
}
impl<Q: fmt::Debug, Fut> GroupData<Q, Fut> {
fn has_space_for(&self, weight: usize) -> bool {
let weight = weight.min(self.max_weight);
self.current_weight <= self.max_weight - weight
}
fn add_weight(&mut self, id: &Q, weight: usize) {
let weight = weight.min(self.max_weight);
self.current_weight = self.current_weight.checked_add(weight).unwrap_or_else(|| {
panic!(
"future_queue_grouped: for id `{:?}`, added weight {} to current {}, overflowed",
id, weight, self.current_weight,
)
});
}
fn sub_weight(&mut self, id: &Q, weight: usize) {
let weight = weight.min(self.max_weight);
self.current_weight = self.current_weight.checked_sub(weight).unwrap_or_else(|| {
panic!(
"future_queue_grouped: for id `{:?}`, sub weight {} from current {}, underflowed",
id, weight, self.current_weight,
)
});
}
}
pin_project! {
#[must_use = "futures do nothing unless polled"]
struct FutureWithGW<Fut, Q> {
#[pin]
future: Fut,
weight: usize,
id: Option<Q>,
}
}
impl<Fut, Q> FutureWithGW<Fut, Q> {
pub fn new(weight: usize, id: Option<Q>, future: Fut) -> Self {
Self { future, weight, id }
}
}
impl<Fut, Q> Future for FutureWithGW<Fut, Q>
where
Fut: Future,
{
type Output = (usize, Option<Q>, Fut::Output);
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.future.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(output) => Poll::Ready((*this.weight, this.id.take(), output)),
}
}
}
pub trait GroupedWeightedFuture: private::Sealed {
type Future: Future;
type Q;
fn weight(&self) -> usize;
fn into_components(self) -> (usize, Option<Self::Q>, Self::Future);
}
impl<Fut, Q> private::Sealed for (usize, Option<Q>, Fut) where Fut: Future {}
impl<Fut, Q> GroupedWeightedFuture for (usize, Option<Q>, Fut)
where
Fut: Future,
{
type Future = Fut;
type Q = Q;
#[inline]
fn weight(&self) -> usize {
self.0
}
#[inline]
fn into_components(self) -> (usize, Option<Self::Q>, Self::Future) {
self
}
}
pub(crate) mod private {
pub trait Sealed {}
}