use crate::{global_weight::GlobalWeight, peekable_fused::PeekableFused};
use futures_util::{
stream::{Fuse, FuturesUnordered},
Future, Stream, StreamExt as _,
};
use pin_project_lite::pin_project;
use std::{
fmt,
pin::Pin,
task::{Context, Poll},
};
pin_project! {
#[must_use = "streams do nothing unless polled"]
pub struct FutureQueue<St>
where
St: Stream,
St::Item: WeightedFuture,
{
#[pin]
stream: PeekableFused<Fuse<St>>,
in_progress_queue: FuturesUnordered<FutureWithWeight<<St::Item as WeightedFuture>::Future>>,
global_weight: GlobalWeight,
}
}
impl<St> fmt::Debug for FutureQueue<St>
where
St: Stream + fmt::Debug,
St::Item: WeightedFuture,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FutureQueue")
.field("stream", &self.stream)
.field("in_progress_queue", &self.in_progress_queue)
.field("global_weight", &self.global_weight)
.finish()
}
}
impl<St> FutureQueue<St>
where
St: Stream,
St::Item: WeightedFuture,
{
pub(crate) fn new(stream: St, max_weight: usize) -> Self {
Self {
stream: PeekableFused::new(stream.fuse()),
in_progress_queue: FuturesUnordered::new(),
global_weight: GlobalWeight::new(max_weight),
}
}
pub fn max_weight(&self) -> usize {
self.global_weight.max()
}
pub fn current_weight(&self) -> usize {
self.global_weight.current()
}
pub fn get_ref(&self) -> &St {
self.stream.get_ref().get_ref()
}
pub fn get_mut(&mut self) -> &mut St {
self.stream.get_mut().get_mut()
}
pub fn get_pin_mut(self: Pin<&mut Self>) -> core::pin::Pin<&mut St> {
self.project().stream.get_pin_mut().get_pin_mut()
}
pub fn into_inner(self) -> St {
self.stream.into_inner().into_inner()
}
}
impl<St> Stream for FutureQueue<St>
where
St: Stream,
St::Item: WeightedFuture,
{
type Item = <<St::Item as WeightedFuture>::Future as Future>::Output;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
while let Poll::Ready(Some(weighted_future)) = this.stream.as_mut().poll_peek(cx) {
if !this.global_weight.has_space_for(weighted_future.weight()) {
break;
}
let (weight, future) = match this.stream.as_mut().poll_next(cx) {
Poll::Ready(Some(weighted_future)) => weighted_future.into_components(),
_ => unreachable!("we just peeked at this item"),
};
this.global_weight.add_weight(weight);
this.in_progress_queue
.push(FutureWithWeight::new(weight, future));
}
match this.in_progress_queue.poll_next_unpin(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some((weight, output))) => {
this.global_weight.sub_weight(weight);
return Poll::Ready(Some(output));
}
Poll::Ready(None) => {}
}
if this.stream.is_done() {
Poll::Ready(None)
} else {
Poll::Pending
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let queue_len = self.in_progress_queue.len();
let (lower, upper) = self.stream.size_hint();
let lower = lower.saturating_add(queue_len);
let upper = match upper {
Some(x) => x.checked_add(queue_len),
None => None,
};
(lower, upper)
}
}
pub trait WeightedFuture: private::Sealed {
type Future: Future;
fn weight(&self) -> usize;
fn into_components(self) -> (usize, Self::Future);
}
mod private {
pub trait Sealed {}
}
impl<Fut> private::Sealed for (usize, Fut) where Fut: Future {}
impl<Fut> WeightedFuture for (usize, Fut)
where
Fut: Future,
{
type Future = Fut;
#[inline]
fn weight(&self) -> usize {
self.0
}
#[inline]
fn into_components(self) -> (usize, Self::Future) {
self
}
}
pin_project! {
#[must_use = "futures do nothing unless polled"]
struct FutureWithWeight<Fut> {
#[pin]
future: Fut,
weight: usize,
}
}
impl<Fut> FutureWithWeight<Fut> {
pub fn new(weight: usize, future: Fut) -> Self {
Self { future, weight }
}
}
impl<Fut> Future for FutureWithWeight<Fut>
where
Fut: Future,
{
type Output = (usize, Fut::Output);
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.future.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(output) => Poll::Ready((*this.weight, output)),
}
}
}