diff --git a/Cargo.lock b/Cargo.lock index eee5b22..9f4cd61 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1570,6 +1570,7 @@ dependencies = [ "reqwest", "thiserror", "tokio", + "tokio-util", "tower", "url", ] diff --git a/ubw-sward/Cargo.toml b/ubw-sward/Cargo.toml index 391af72..b07eab4 100644 --- a/ubw-sward/Cargo.toml +++ b/ubw-sward/Cargo.toml @@ -12,4 +12,5 @@ rand = {workspace = true} thiserror = {workspace = true} compact_str = {workspace = true} bytes = {workspace = true} -regex = "1.11" \ No newline at end of file +regex = "1.11" +tokio-util = "0.7" \ No newline at end of file diff --git a/ubw-sward/src/http/mod.rs b/ubw-sward/src/http/mod.rs index d9e3e18..b78dd22 100644 --- a/ubw-sward/src/http/mod.rs +++ b/ubw-sward/src/http/mod.rs @@ -1,4 +1,5 @@ -pub mod simple; pub mod random; +pub mod simple; +pub use random::RandomUrlGenerator; pub use simple::{SimpleHttpRequest, SimpleHttpSward}; diff --git a/ubw-sward/src/http/simple.rs b/ubw-sward/src/http/simple.rs index e2ecfab..8da0372 100644 --- a/ubw-sward/src/http/simple.rs +++ b/ubw-sward/src/http/simple.rs @@ -2,7 +2,6 @@ use bytes::Bytes; use reqwest::header::HeaderMap; use reqwest::{Client, Method}; use std::pin::Pin; -use std::sync::Arc; use std::task::{Context, Poll}; use tower::Service; use url::Url; @@ -12,13 +11,13 @@ use url::Url; pub struct SimpleHttpSward { client: Client, method: Method, - headers: Arc, + headers: HeaderMap, sent_count: usize, } impl SimpleHttpSward { /// Create a new simple http sward. - pub fn new(client: Client, method: Method, headers: Arc) -> Self { + pub fn new(client: Client, method: Method, headers: HeaderMap) -> Self { Self { client, method, @@ -51,7 +50,7 @@ impl Service for SimpleHttpSward { Box::pin( self.client .request(self.method.clone(), req.url) - .headers(self.headers.as_ref().clone()) + .headers(self.headers.clone()) .send(), ) } diff --git a/ubw-sward/src/utils/mod.rs b/ubw-sward/src/utils/mod.rs index e69de29..0f75dc6 100644 --- a/ubw-sward/src/utils/mod.rs +++ b/ubw-sward/src/utils/mod.rs @@ -0,0 +1 @@ +pub mod multiplexed; \ No newline at end of file diff --git a/ubw-sward/src/utils/multiplexed.rs b/ubw-sward/src/utils/multiplexed.rs new file mode 100644 index 0000000..a7145cd --- /dev/null +++ b/ubw-sward/src/utils/multiplexed.rs @@ -0,0 +1,75 @@ +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::{OwnedSemaphorePermit, Semaphore}; +use tokio_util::sync::PollSemaphore; +use tower::Service; + +pub struct MultiplexedSward { + inner: S, + poll: PollSemaphore, + stash_permit: Option, +} + +impl MultiplexedSward { + pub fn new(inner: S, capacity: usize) -> Self { + let semaphore = Arc::new(Semaphore::new(capacity)); + Self { + inner, + poll: PollSemaphore::new(semaphore.clone()), + stash_permit: None, + } + } +} + +impl Service for MultiplexedSward +where + S: Service, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = MultiplexedSwardError; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.stash_permit.is_none() { + match self.poll.poll_acquire(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => { + return Poll::Ready(Err(MultiplexedSwardError::SemaphoreClosed)); + } + Poll::Ready(Some(permit)) => { + self.stash_permit = Some(permit); + } + } + } + self.inner + .poll_ready(cx) + .map_err(MultiplexedSwardError::Inner) + } + + fn call(&mut self, req: Req) -> Self::Future { + let Some(permit) = self.stash_permit.take() else { + return Box::pin(async { Err(MultiplexedSwardError::NotReady) }); + }; + let fut = self.inner.call(req); + Box::pin(async move { + let _ = permit; + let res = fut.await; + res.map_err(MultiplexedSwardError::Inner) + }) + } +} + +/// Error that can occur when multiplexing requests. +#[derive(Debug, thiserror::Error)] +pub enum MultiplexedSwardError { + #[error("{0}")] + Inner(E), + + #[error("Semaphore is closed.")] + SemaphoreClosed, + + #[error("call() invoked without ready; call poll_ready first")] + NotReady, +}