implement multiplex

This commit is contained in:
2025-09-01 20:28:44 +09:00
parent 05ffd4cf15
commit b0812bfc87
6 changed files with 84 additions and 6 deletions

1
Cargo.lock generated
View File

@@ -1570,6 +1570,7 @@ dependencies = [
"reqwest",
"thiserror",
"tokio",
"tokio-util",
"tower",
"url",
]

View File

@@ -12,4 +12,5 @@ rand = {workspace = true}
thiserror = {workspace = true}
compact_str = {workspace = true}
bytes = {workspace = true}
regex = "1.11"
regex = "1.11"
tokio-util = "0.7"

View File

@@ -1,4 +1,5 @@
pub mod simple;
pub mod random;
pub mod simple;
pub use random::RandomUrlGenerator;
pub use simple::{SimpleHttpRequest, SimpleHttpSward};

View File

@@ -2,7 +2,6 @@ use bytes::Bytes;
use reqwest::header::HeaderMap;
use reqwest::{Client, Method};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::Service;
use url::Url;
@@ -12,13 +11,13 @@ use url::Url;
pub struct SimpleHttpSward {
client: Client,
method: Method,
headers: Arc<HeaderMap>,
headers: HeaderMap,
sent_count: usize,
}
impl SimpleHttpSward {
/// Create a new simple http sward.
pub fn new(client: Client, method: Method, headers: Arc<HeaderMap>) -> Self {
pub fn new(client: Client, method: Method, headers: HeaderMap) -> Self {
Self {
client,
method,
@@ -51,7 +50,7 @@ impl Service<SimpleHttpRequest> for SimpleHttpSward {
Box::pin(
self.client
.request(self.method.clone(), req.url)
.headers(self.headers.as_ref().clone())
.headers(self.headers.clone())
.send(),
)
}

View File

@@ -0,0 +1 @@
pub mod multiplexed;

View File

@@ -0,0 +1,75 @@
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::PollSemaphore;
use tower::Service;
pub struct MultiplexedSward<S> {
inner: S,
poll: PollSemaphore,
stash_permit: Option<OwnedSemaphorePermit>,
}
impl<S> MultiplexedSward<S> {
pub fn new(inner: S, capacity: usize) -> Self {
let semaphore = Arc::new(Semaphore::new(capacity));
Self {
inner,
poll: PollSemaphore::new(semaphore.clone()),
stash_permit: None,
}
}
}
impl<S, Req> Service<Req> for MultiplexedSward<S>
where
S: Service<Req>,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = MultiplexedSwardError<S::Error>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.stash_permit.is_none() {
match self.poll.poll_acquire(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => {
return Poll::Ready(Err(MultiplexedSwardError::SemaphoreClosed));
}
Poll::Ready(Some(permit)) => {
self.stash_permit = Some(permit);
}
}
}
self.inner
.poll_ready(cx)
.map_err(MultiplexedSwardError::Inner)
}
fn call(&mut self, req: Req) -> Self::Future {
let Some(permit) = self.stash_permit.take() else {
return Box::pin(async { Err(MultiplexedSwardError::NotReady) });
};
let fut = self.inner.call(req);
Box::pin(async move {
let _ = permit;
let res = fut.await;
res.map_err(MultiplexedSwardError::Inner)
})
}
}
/// Error that can occur when multiplexing requests.
#[derive(Debug, thiserror::Error)]
pub enum MultiplexedSwardError<E> {
#[error("{0}")]
Inner(E),
#[error("Semaphore is closed.")]
SemaphoreClosed,
#[error("call() invoked without ready; call poll_ready first")]
NotReady,
}