Skip to content

Commit

Permalink
feat(client): retry http1 connection if closed by server
Browse files Browse the repository at this point in the history
  • Loading branch information
joelwurtz committed Jan 10, 2025
1 parent 99f7c97 commit 1de2886
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 29 deletions.
6 changes: 3 additions & 3 deletions client/src/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ impl ResponseBody {
}
}

pub(crate) fn can_destroy_on_drop(&mut self) -> bool {
pub(crate) fn can_destroy_on_drop(&self) -> bool {
#[cfg(feature = "http1")]
if let Self::H1(ref mut body) = *self {
return body.conn_mut().is_destroy_on_drop();
if let Self::H1(ref body) = *self {
return body.conn().is_destroy_on_drop();
}

false
Expand Down
4 changes: 2 additions & 2 deletions client/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ impl<const PAYLOAD_LIMIT: usize> Response<PAYLOAD_LIMIT> {
/// Public API for test purpose.
///
/// Used for testing server implementation to make sure it follows spec.
pub fn can_close_connection(&mut self) -> bool {
self.res.body_mut().can_destroy_on_drop()
pub fn can_close_connection(&self) -> bool {
self.res.body().can_destroy_on_drop()
}
}

Expand Down
19 changes: 15 additions & 4 deletions client/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ pub(crate) fn base_service() -> HttpService {

let ServiceRequest { req, client, timeout } = req;

let uri = Uri::try_parse(req.uri())?;
let connect_uri = req.uri().clone();
let uri = Uri::try_parse(&connect_uri)?;

// temporary version to record possible version downgrade/upgrade happens when making connections.
// alpn protocol and alt-svc header are possible source of version change.
Expand All @@ -96,9 +97,9 @@ pub(crate) fn base_service() -> HttpService {

let mut connect = Connect::new(uri);

let _date = client.date_service.handle();

loop {
let _date = client.date_service.handle();

match version {
Version::HTTP_2 | Version::HTTP_3 => match client.shared_pool.acquire(&connect.uri).await {
shared::AcquireOutput::Conn(mut _conn) => {
Expand Down Expand Up @@ -235,7 +236,17 @@ pub(crate) fn base_service() -> HttpService {
}
Ok(Err(e)) => {
_conn.destroy_on_drop();
Err(e.into())

match e {
crate::h1::Error::Io(err) => {
if err.kind() == std::io::ErrorKind::UnexpectedEof {
continue;
}

Err(crate::h1::Error::Io(err).into())
}
_ => Err(e.into()),
}
}
Err(_) => {
_conn.destroy_on_drop();
Expand Down
46 changes: 41 additions & 5 deletions test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use std::{
error, fmt, fs,
future::Future,
io,
net::SocketAddr,
net::TcpListener,
net::{SocketAddr, TcpListener, ToSocketAddrs},
pin::Pin,
task::{Context, Poll},
time::Duration,
Expand Down Expand Up @@ -36,9 +35,18 @@ where
T::Response: ReadyService + Service<Req>,
Req: TryFrom<NetStream> + 'static,
{
let lst = TcpListener::bind("127.0.0.1:0")?;
test_server_with_addr(service, "127.0.0.1:0")
}

let addr = lst.local_addr()?;
pub fn test_server_with_addr<T, Req, A>(service: T, addr: A) -> Result<TestServerHandle, Error>
where
T: Service + Send + Sync + 'static,
T::Response: ReadyService + Service<Req>,
Req: TryFrom<NetStream> + 'static,
A: ToSocketAddrs,
{
let lst = TcpListener::bind(addr)?;
let local_addr = lst.local_addr()?;

let handle = Builder::new()
.worker_threads(1)
Expand All @@ -47,7 +55,35 @@ where
.listen::<_, _, _, Req>("test_server", lst, service)
.build();

Ok(TestServerHandle { addr, handle })
Ok(TestServerHandle {
addr: local_addr,
handle,
})
}

/// A specialized http/1 server on top of [test_server]
pub fn test_h1_server_with_addr<T, B, E, A>(service: T, addr: A) -> Result<TestServerHandle, Error>
where
T: Service + Send + Sync + 'static,
T::Response: ReadyService + Service<Request<RequestExt<h1::RequestBody>>, Response = HResponse<B>> + 'static,
<T::Response as Service<Request<RequestExt<h1::RequestBody>>>>::Error: fmt::Debug,
T::Error: error::Error + 'static,
B: Stream<Item = Result<Bytes, E>> + 'static,
E: fmt::Debug + 'static,
A: ToSocketAddrs,
{
#[cfg(not(feature = "io-uring"))]
{
test_server_with_addr::<_, (TcpStream, SocketAddr), A>(service.enclosed(HttpServiceBuilder::h1()), addr)
}

#[cfg(feature = "io-uring")]
{
test_server_with_addr::<_, (xitca_io::net::io_uring::TcpStream, SocketAddr), A>(
service.enclosed(HttpServiceBuilder::h1().io_uring()),
addr,
)
}
}

/// A specialized http/1 server on top of [test_server]
Expand Down
48 changes: 40 additions & 8 deletions test/tests/h1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use xitca_http::{
},
};
use xitca_service::fn_service;
use xitca_test::{test_h1_server, Error};
use xitca_test::{test_h1_server, test_h1_server_with_addr, Error};

#[tokio::test]
async fn h1_get() -> Result<(), Error> {
Expand All @@ -27,7 +27,7 @@ async fn h1_get() -> Result<(), Error> {
let c = Client::new();

for _ in 0..3 {
let mut res = c.get(&server_url).version(Version::HTTP_11).send().await?;
let res = c.get(&server_url).version(Version::HTTP_11).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
Expand All @@ -41,6 +41,38 @@ async fn h1_get() -> Result<(), Error> {
Ok(())
}

#[tokio::test]
async fn h1_get_connection_closed_by_server() -> Result<(), Error> {
let mut handle = test_h1_server(fn_service(handle))?;
let ip_port = handle.ip_port_string();

let server_url = format!("http://{}/", ip_port);

let c = Client::builder().set_pool_capacity(1).finish();

let res = c.get(&server_url).version(Version::HTTP_11).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
assert_eq!("GET Response", body);

handle.try_handle()?.stop(false);
handle.await?;

let mut handle = test_h1_server_with_addr(fn_service(crate::handle), ip_port)?;
let res = c.get(&server_url).version(Version::HTTP_11).send().await?;

assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
assert_eq!("GET Response", body);

handle.try_handle()?.stop(true);
handle.await?;

Ok(())
}

#[tokio::test]
async fn h1_head() -> Result<(), Error> {
let mut handle = test_h1_server(fn_service(handle))?;
Expand All @@ -50,7 +82,7 @@ async fn h1_head() -> Result<(), Error> {
let c = Client::new();

for _ in 0..3 {
let mut res = c.head(&server_url).version(Version::HTTP_11).send().await?;
let res = c.head(&server_url).version(Version::HTTP_11).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
Expand Down Expand Up @@ -79,7 +111,7 @@ async fn h1_post() -> Result<(), Error> {
}
let body_len = body.len();

let mut res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?;
let res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.limit::<{ 12 * 1024 }>().string().await?;
Expand Down Expand Up @@ -107,7 +139,7 @@ async fn h1_drop_body_read() -> Result<(), Error> {
body.extend_from_slice(b"Hello,World!");
}

let mut res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?;
let res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(res.can_close_connection());
}
Expand All @@ -133,7 +165,7 @@ async fn h1_partial_body_read() -> Result<(), Error> {
body.extend_from_slice(b"Hello,World!");
}

let mut res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?;
let res = c.post(&server_url).version(Version::HTTP_11).text(body).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(res.can_close_connection());
}
Expand All @@ -153,7 +185,7 @@ async fn h1_close_connection() -> Result<(), Error> {

let c = Client::new();

let mut res = c.get(&server_url).version(Version::HTTP_11).send().await?;
let res = c.get(&server_url).version(Version::HTTP_11).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(res.can_close_connection());

Expand Down Expand Up @@ -190,7 +222,7 @@ async fn h1_request_too_large() -> Result<(), Error> {
req.headers_mut()
.insert("large-header", HeaderValue::try_from(body).unwrap());

let mut res = req.send().await?;
let res = req.send().await?;
assert_eq!(res.status().as_u16(), 431);
assert!(res.can_close_connection());

Expand Down
8 changes: 4 additions & 4 deletions test/tests/h2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async fn h2_get() -> Result<(), Error> {
let c = Client::new();

for _ in 0..3 {
let mut res = c.get(&server_url).version(Version::HTTP_2).send().await?;
let res = c.get(&server_url).version(Version::HTTP_2).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
Expand All @@ -46,7 +46,7 @@ async fn h2_no_host_header() -> Result<(), Error> {
let mut req = c.get(&server_url).version(Version::HTTP_2);
req.headers_mut().insert(header::HOST, "localhost".parse().unwrap());

let mut res = req.send().await?;
let res = req.send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
Expand All @@ -73,7 +73,7 @@ async fn h2_post() -> Result<(), Error> {
for _ in 0..1024 * 1024 {
body.extend_from_slice(b"Hello,World!");
}
let mut res = c.post(&server_url).version(Version::HTTP_2).text(body).send().await?;
let res = c.post(&server_url).version(Version::HTTP_2).text(body).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let _ = res.body().await;
Expand Down Expand Up @@ -142,7 +142,7 @@ async fn h2_keepalive() -> Result<(), Error> {
.block_on(async move {
let c = Client::new();

let mut res = c.get(&server_url).version(Version::HTTP_2).send().await?;
let res = c.get(&server_url).version(Version::HTTP_2).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
Expand Down
6 changes: 3 additions & 3 deletions test/tests/h3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async fn h3_get() -> Result<(), Error> {
let server_url = format!("https://localhost:{}/", handle.addr().port());

for _ in 0..3 {
let mut res = c.get(&server_url).version(Version::HTTP_3).send().await?;
let res = c.get(&server_url).version(Version::HTTP_3).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
Expand All @@ -43,7 +43,7 @@ async fn h3_no_host_header() -> Result<(), Error> {
let mut req = c.get(&server_url).version(Version::HTTP_3);
req.headers_mut().insert(header::HOST, "localhost".parse().unwrap());

let mut res = req.send().await?;
let res = req.send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
let body = res.string().await?;
Expand All @@ -70,7 +70,7 @@ async fn h3_post() -> Result<(), Error> {
for _ in 0..1024 * 1024 {
body.extend_from_slice(b"Hello,World!");
}
let mut res = c.post(&server_url).version(Version::HTTP_3).text(body).send().await?;
let res = c.post(&server_url).version(Version::HTTP_3).text(body).send().await?;
assert_eq!(res.status().as_u16(), 200);
assert!(!res.can_close_connection());
}
Expand Down

0 comments on commit 1de2886

Please sign in to comment.