From 1de28864258f16bf91ad498ca91b12db1e82ec67 Mon Sep 17 00:00:00 2001 From: Joel Wurtz Date: Fri, 10 Jan 2025 14:32:48 +0100 Subject: [PATCH] feat(client): retry http1 connection if closed by server --- client/src/body.rs | 6 +++--- client/src/response.rs | 4 ++-- client/src/service.rs | 19 +++++++++++++---- test/src/lib.rs | 46 +++++++++++++++++++++++++++++++++++----- test/tests/h1.rs | 48 +++++++++++++++++++++++++++++++++++------- test/tests/h2.rs | 8 +++---- test/tests/h3.rs | 6 +++--- 7 files changed, 108 insertions(+), 29 deletions(-) diff --git a/client/src/body.rs b/client/src/body.rs index c2fef9c7..72d41bdf 100644 --- a/client/src/body.rs +++ b/client/src/body.rs @@ -56,10 +56,10 @@ impl ResponseBody { } } - pub(crate) fn can_destroy_on_drop(&mut self) -> bool { + pub(crate) fn can_destroy_on_drop(&self) -> bool { #[cfg(feature = "http1")] - if let Self::H1(ref mut body) = *self { - return body.conn_mut().is_destroy_on_drop(); + if let Self::H1(ref body) = *self { + return body.conn().is_destroy_on_drop(); } false diff --git a/client/src/response.rs b/client/src/response.rs index 4a4183dc..cc564b1b 100644 --- a/client/src/response.rs +++ b/client/src/response.rs @@ -177,8 +177,8 @@ impl Response { /// Public API for test purpose. /// /// Used for testing server implementation to make sure it follows spec. - pub fn can_close_connection(&mut self) -> bool { - self.res.body_mut().can_destroy_on_drop() + pub fn can_close_connection(&self) -> bool { + self.res.body().can_destroy_on_drop() } } diff --git a/client/src/service.rs b/client/src/service.rs index 0751f7c3..1003d523 100644 --- a/client/src/service.rs +++ b/client/src/service.rs @@ -87,7 +87,8 @@ pub(crate) fn base_service() -> HttpService { let ServiceRequest { req, client, timeout } = req; - let uri = Uri::try_parse(req.uri())?; + let connect_uri = req.uri().clone(); + let uri = Uri::try_parse(&connect_uri)?; // temporary version to record possible version downgrade/upgrade happens when making connections. // alpn protocol and alt-svc header are possible source of version change. @@ -96,9 +97,9 @@ pub(crate) fn base_service() -> HttpService { let mut connect = Connect::new(uri); - let _date = client.date_service.handle(); - loop { + let _date = client.date_service.handle(); + match version { Version::HTTP_2 | Version::HTTP_3 => match client.shared_pool.acquire(&connect.uri).await { shared::AcquireOutput::Conn(mut _conn) => { @@ -235,7 +236,17 @@ pub(crate) fn base_service() -> HttpService { } Ok(Err(e)) => { _conn.destroy_on_drop(); - Err(e.into()) + + match e { + crate::h1::Error::Io(err) => { + if err.kind() == std::io::ErrorKind::UnexpectedEof { + continue; + } + + Err(crate::h1::Error::Io(err).into()) + } + _ => Err(e.into()), + } } Err(_) => { _conn.destroy_on_drop(); diff --git a/test/src/lib.rs b/test/src/lib.rs index c2f896cb..89bc835c 100644 --- a/test/src/lib.rs +++ b/test/src/lib.rs @@ -2,8 +2,7 @@ use std::{ error, fmt, fs, future::Future, io, - net::SocketAddr, - net::TcpListener, + net::{SocketAddr, TcpListener, ToSocketAddrs}, pin::Pin, task::{Context, Poll}, time::Duration, @@ -36,9 +35,18 @@ where T::Response: ReadyService + Service, Req: TryFrom + 'static, { - let lst = TcpListener::bind("127.0.0.1:0")?; + test_server_with_addr(service, "127.0.0.1:0") +} - let addr = lst.local_addr()?; +pub fn test_server_with_addr(service: T, addr: A) -> Result +where + T: Service + Send + Sync + 'static, + T::Response: ReadyService + Service, + Req: TryFrom + 'static, + A: ToSocketAddrs, +{ + let lst = TcpListener::bind(addr)?; + let local_addr = lst.local_addr()?; let handle = Builder::new() .worker_threads(1) @@ -47,7 +55,35 @@ where .listen::<_, _, _, Req>("test_server", lst, service) .build(); - Ok(TestServerHandle { addr, handle }) + Ok(TestServerHandle { + addr: local_addr, + handle, + }) +} + +/// A specialized http/1 server on top of [test_server] +pub fn test_h1_server_with_addr(service: T, addr: A) -> Result +where + T: Service + Send + Sync + 'static, + T::Response: ReadyService + Service>, Response = HResponse> + 'static, + >>>::Error: fmt::Debug, + T::Error: error::Error + 'static, + B: Stream> + 'static, + E: fmt::Debug + 'static, + A: ToSocketAddrs, +{ + #[cfg(not(feature = "io-uring"))] + { + test_server_with_addr::<_, (TcpStream, SocketAddr), A>(service.enclosed(HttpServiceBuilder::h1()), addr) + } + + #[cfg(feature = "io-uring")] + { + test_server_with_addr::<_, (xitca_io::net::io_uring::TcpStream, SocketAddr), A>( + service.enclosed(HttpServiceBuilder::h1().io_uring()), + addr, + ) + } } /// A specialized http/1 server on top of [test_server] diff --git a/test/tests/h1.rs b/test/tests/h1.rs index ab892c6a..cfbdd6c5 100644 --- a/test/tests/h1.rs +++ b/test/tests/h1.rs @@ -16,7 +16,7 @@ use xitca_http::{ }, }; use xitca_service::fn_service; -use xitca_test::{test_h1_server, Error}; +use xitca_test::{test_h1_server, test_h1_server_with_addr, Error}; #[tokio::test] async fn h1_get() -> Result<(), Error> { @@ -27,7 +27,7 @@ async fn h1_get() -> Result<(), Error> { let c = Client::new(); for _ in 0..3 { - let mut res = c.get(&server_url).version(Version::HTTP_11).send().await?; + let res = c.get(&server_url).version(Version::HTTP_11).send().await?; assert_eq!(res.status().as_u16(), 200); assert!(!res.can_close_connection()); let body = res.string().await?; @@ -41,6 +41,38 @@ async fn h1_get() -> Result<(), Error> { Ok(()) } +#[tokio::test] +async fn h1_get_connection_closed_by_server() -> Result<(), Error> { + let mut handle = test_h1_server(fn_service(handle))?; + let ip_port = handle.ip_port_string(); + + let server_url = format!("http://{}/", ip_port); + + let c = Client::builder().set_pool_capacity(1).finish(); + + let res = c.get(&server_url).version(Version::HTTP_11).send().await?; + assert_eq!(res.status().as_u16(), 200); + assert!(!res.can_close_connection()); + let body = res.string().await?; + assert_eq!("GET Response", body); + + handle.try_handle()?.stop(false); + handle.await?; + + let mut handle = test_h1_server_with_addr(fn_service(crate::handle), ip_port)?; + let res = c.get(&server_url).version(Version::HTTP_11).send().await?; + + assert_eq!(res.status().as_u16(), 200); + assert!(!res.can_close_connection()); + let body = res.string().await?; + assert_eq!("GET Response", body); + + handle.try_handle()?.stop(true); + handle.await?; + + Ok(()) +} + #[tokio::test] async fn h1_head() -> Result<(), Error> { let mut handle = test_h1_server(fn_service(handle))?; @@ -50,7 +82,7 @@ async fn h1_head() -> Result<(), Error> { let c = Client::new(); for _ in 0..3 { - let mut res = c.head(&server_url).version(Version::HTTP_11).send().await?; + let res = c.head(&server_url).version(Version::HTTP_11).send().await?; assert_eq!(res.status().as_u16(), 200); assert!(!res.can_close_connection()); let body = res.string().await?; @@ -79,7 +111,7 @@ async fn h1_post() -> Result<(), Error> { } let body_len = body.len(); - let mut res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?; + let res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?; assert_eq!(res.status().as_u16(), 200); assert!(!res.can_close_connection()); let body = res.limit::<{ 12 * 1024 }>().string().await?; @@ -107,7 +139,7 @@ async fn h1_drop_body_read() -> Result<(), Error> { body.extend_from_slice(b"Hello,World!"); } - let mut res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?; + let res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?; assert_eq!(res.status().as_u16(), 200); assert!(res.can_close_connection()); } @@ -133,7 +165,7 @@ async fn h1_partial_body_read() -> Result<(), Error> { body.extend_from_slice(b"Hello,World!"); } - let mut res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?; + let res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?; assert_eq!(res.status().as_u16(), 200); assert!(res.can_close_connection()); } @@ -153,7 +185,7 @@ async fn h1_close_connection() -> Result<(), Error> { let c = Client::new(); - let mut res = c.get(&server_url).version(Version::HTTP_11).send().await?; + let res = c.get(&server_url).version(Version::HTTP_11).send().await?; assert_eq!(res.status().as_u16(), 200); assert!(res.can_close_connection()); @@ -190,7 +222,7 @@ async fn h1_request_too_large() -> Result<(), Error> { req.headers_mut() .insert("large-header", HeaderValue::try_from(body).unwrap()); - let mut res = req.send().await?; + let res = req.send().await?; assert_eq!(res.status().as_u16(), 431); assert!(res.can_close_connection()); diff --git a/test/tests/h2.rs b/test/tests/h2.rs index bcf33262..a822dcfe 100644 --- a/test/tests/h2.rs +++ b/test/tests/h2.rs @@ -20,7 +20,7 @@ async fn h2_get() -> Result<(), Error> { let c = Client::new(); for _ in 0..3 { - let mut res = c.get(&server_url).version(Version::HTTP_2).send().await?; + let res = c.get(&server_url).version(Version::HTTP_2).send().await?; assert_eq!(res.status().as_u16(), 200); assert!(!res.can_close_connection()); let body = res.string().await?; @@ -46,7 +46,7 @@ async fn h2_no_host_header() -> Result<(), Error> { let mut req = c.get(&server_url).version(Version::HTTP_2); req.headers_mut().insert(header::HOST, "localhost".parse().unwrap()); - let mut res = req.send().await?; + let res = req.send().await?; assert_eq!(res.status().as_u16(), 200); assert!(!res.can_close_connection()); let body = res.string().await?; @@ -73,7 +73,7 @@ async fn h2_post() -> Result<(), Error> { for _ in 0..1024 * 1024 { body.extend_from_slice(b"Hello,World!"); } - let mut res = c.post(&server_url).version(Version::HTTP_2).text(body).send().await?; + let res = c.post(&server_url).version(Version::HTTP_2).text(body).send().await?; assert_eq!(res.status().as_u16(), 200); assert!(!res.can_close_connection()); let _ = res.body().await; @@ -142,7 +142,7 @@ async fn h2_keepalive() -> Result<(), Error> { .block_on(async move { let c = Client::new(); - let mut res = c.get(&server_url).version(Version::HTTP_2).send().await?; + let res = c.get(&server_url).version(Version::HTTP_2).send().await?; assert_eq!(res.status().as_u16(), 200); assert!(!res.can_close_connection()); let body = res.string().await?; diff --git a/test/tests/h3.rs b/test/tests/h3.rs index 0154bc42..286fb149 100644 --- a/test/tests/h3.rs +++ b/test/tests/h3.rs @@ -17,7 +17,7 @@ async fn h3_get() -> Result<(), Error> { let server_url = format!("https://localhost:{}/", handle.addr().port()); for _ in 0..3 { - let mut res = c.get(&server_url).version(Version::HTTP_3).send().await?; + let res = c.get(&server_url).version(Version::HTTP_3).send().await?; assert_eq!(res.status().as_u16(), 200); assert!(!res.can_close_connection()); let body = res.string().await?; @@ -43,7 +43,7 @@ async fn h3_no_host_header() -> Result<(), Error> { let mut req = c.get(&server_url).version(Version::HTTP_3); req.headers_mut().insert(header::HOST, "localhost".parse().unwrap()); - let mut res = req.send().await?; + let res = req.send().await?; assert_eq!(res.status().as_u16(), 200); assert!(!res.can_close_connection()); let body = res.string().await?; @@ -70,7 +70,7 @@ async fn h3_post() -> Result<(), Error> { for _ in 0..1024 * 1024 { body.extend_from_slice(b"Hello,World!"); } - let mut res = c.post(&server_url).version(Version::HTTP_3).text(body).send().await?; + let res = c.post(&server_url).version(Version::HTTP_3).text(body).send().await?; assert_eq!(res.status().as_u16(), 200); assert!(!res.can_close_connection()); }