diff --git a/ubw-sward/src/socket_stream/mod.rs b/ubw-sward/src/socket_stream/mod.rs index 962e4af..3b67972 100644 --- a/ubw-sward/src/socket_stream/mod.rs +++ b/ubw-sward/src/socket_stream/mod.rs @@ -3,6 +3,7 @@ use std::pin::Pin; pub mod counter; pub mod generator; pub mod integrated; +pub mod tcp; pub mod udp; #[derive(Clone, Copy)] diff --git a/ubw-sward/src/socket_stream/tcp.rs b/ubw-sward/src/socket_stream/tcp.rs new file mode 100644 index 0000000..814fb73 --- /dev/null +++ b/ubw-sward/src/socket_stream/tcp.rs @@ -0,0 +1,95 @@ +use std::{ + sync::Arc, + task::{Context, Poll}, +}; +use tokio::{io::AsyncWriteExt, net::TcpStream, sync::Mutex}; +use tower::Service; + +use crate::socket_stream::{ + BoxedStreamWorkload, SizedStreamWorkload, StreamSendFuture, StreamSward, +}; + +pub struct TcpSward { + stream: Arc>, + sent_count: usize, +} + +impl TcpSward { + pub fn new(stream: Arc>) -> Self { + Self { + stream, + sent_count: 0, + } + } +} + +impl StreamSward for TcpSward { + async fn connect(addr: std::net::SocketAddr) -> Result { + let stream = TcpStream::connect(addr).await?; + let stream = Arc::new(Mutex::new(stream)); + Ok(Self { + stream, + sent_count: 0, + }) + } + fn add_request_count(&mut self) { + self.sent_count += 1; + } + fn send_sized( + &self, + workload: SizedStreamWorkload, + ) -> impl Future> + Send + 'static { + let stream = self.stream.clone(); + async move { + let mut stream = stream.lock().await; + stream.write_all(&workload.bytes).await?; + Ok(N) + } + } + fn send_boxed( + &self, + workload: BoxedStreamWorkload, + ) -> impl Future> + Send + 'static { + let stream = self.stream.clone(); + async move { + let len = workload.0.len(); + let mut stream = stream.lock().await; + stream.write_all(&workload.0).await?; + Ok(len) + } + } +} + +impl Service> for TcpSward { + type Response = usize; + type Error = std::io::Error; + type Future = StreamSendFuture; + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: SizedStreamWorkload) -> Self::Future { + self.add_request_count(); + Box::pin(self.send_sized(req)) + } +} + +impl Service for TcpSward { + type Response = usize; + type Error = std::io::Error; + type Future = StreamSendFuture; + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: BoxedStreamWorkload) -> Self::Future { + self.add_request_count(); + Box::pin(self.send_boxed(req)) + } +} + +impl tower::load::Load for TcpSward { + type Metric = usize; + fn load(&self) -> Self::Metric { + self.sent_count + } +} +