use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::sync::{OwnedSemaphorePermit, Semaphore}; use tokio_util::sync::PollSemaphore; use tower::Service; pub struct MultiplexedSward { inner: S, poll: PollSemaphore, stash_permit: Option, } impl MultiplexedSward { pub fn new(inner: S, capacity: usize) -> Self { let semaphore = Arc::new(Semaphore::new(capacity)); Self { inner, poll: PollSemaphore::new(semaphore.clone()), stash_permit: None, } } } impl Service for MultiplexedSward where S: Service, S::Future: Send + 'static, { type Response = S::Response; type Error = MultiplexedSwardError; type Future = Pin> + Send>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { if self.stash_permit.is_none() { match self.poll.poll_acquire(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(None) => { return Poll::Ready(Err(MultiplexedSwardError::SemaphoreClosed)); } Poll::Ready(Some(permit)) => { self.stash_permit = Some(permit); } } } self.inner .poll_ready(cx) .map_err(MultiplexedSwardError::Inner) } fn call(&mut self, req: Req) -> Self::Future { let Some(permit) = self.stash_permit.take() else { return Box::pin(async { Err(MultiplexedSwardError::NotReady) }); }; let fut = self.inner.call(req); Box::pin(async move { let _ = permit; let res = fut.await; res.map_err(MultiplexedSwardError::Inner) }) } } /// Error that can occur when multiplexing requests. #[derive(Debug, thiserror::Error)] pub enum MultiplexedSwardError { #[error("{0}")] Inner(E), #[error("Semaphore is closed.")] SemaphoreClosed, #[error("call() invoked without ready; call poll_ready first")] NotReady, }