Skip to content

Commit

Permalink
feat(client): use a middleware to retry closed connection instead
Browse files Browse the repository at this point in the history
  • Loading branch information
joelwurtz committed Jan 12, 2025
1 parent 1de2886 commit 5471925
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 17 deletions.
2 changes: 2 additions & 0 deletions client/src/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! middleware offer extended functionality to http client.
mod redirect;
mod retry_closed_connection;

#[cfg(feature = "compress")]
mod decompress;
Expand All @@ -9,3 +10,4 @@ mod decompress;
pub use decompress::Decompress;

pub use redirect::FollowRedirect;
pub use retry_closed_connection::RetryClosedConnection;
65 changes: 65 additions & 0 deletions client/src/middleware/retry_closed_connection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use std::io;

use crate::{
error::Error,
response::Response,
service::{Service, ServiceRequest},
};

/// middleware for retrying closed connection
pub struct RetryClosedConnection<S> {
service: S,
}

impl<S> RetryClosedConnection<S> {
pub fn new(service: S) -> Self {
Self { service }
}
}

impl<'r, 'c, S> Service<ServiceRequest<'r, 'c>> for RetryClosedConnection<S>
where
S: for<'r2, 'c2> Service<ServiceRequest<'r2, 'c2>, Response = Response, Error = Error> + Send + Sync,
{
type Response = Response;
type Error = Error;

async fn call(&self, req: ServiceRequest<'r, 'c>) -> Result<Self::Response, Self::Error> {
let ServiceRequest { req, client, timeout } = req;
let headers = req.headers().clone();
let method = req.method().clone();
let uri = req.uri().clone();

loop {
let res = self.service.call(ServiceRequest { req, client, timeout }).await;

match res {
Err(Error::Io(err)) => {
if err.kind() != io::ErrorKind::UnexpectedEof {
return Err(Error::Io(err));
}
}
Err(Error::H1(crate::h1::Error::Io(err))) => {
if err.kind() != io::ErrorKind::UnexpectedEof {
return Err(Error::H1(crate::h1::Error::Io(err)));
}
}
Err(Error::H2(crate::h2::Error::Io(err))) => {
if err.kind() != io::ErrorKind::UnexpectedEof {
return Err(Error::H2(crate::h2::Error::Io(err)));
}
}
Err(Error::H3(crate::h3::Error::Io(err))) => {
if err.kind() != io::ErrorKind::UnexpectedEof {
return Err(Error::H3(crate::h3::Error::Io(err)));
}
}
res => return res,
}

*req.uri_mut() = uri.clone();
*req.method_mut() = method.clone();
*req.headers_mut() = headers.clone();
}
}
}
19 changes: 4 additions & 15 deletions client/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ pub(crate) fn base_service() -> HttpService {

let ServiceRequest { req, client, timeout } = req;

let connect_uri = req.uri().clone();
let uri = Uri::try_parse(&connect_uri)?;
let uri = Uri::try_parse(req.uri())?;

// temporary version to record possible version downgrade/upgrade happens when making connections.
// alpn protocol and alt-svc header are possible source of version change.
Expand All @@ -97,9 +96,9 @@ pub(crate) fn base_service() -> HttpService {

let mut connect = Connect::new(uri);

loop {
let _date = client.date_service.handle();
let _date = client.date_service.handle();

loop {
match version {
Version::HTTP_2 | Version::HTTP_3 => match client.shared_pool.acquire(&connect.uri).await {
shared::AcquireOutput::Conn(mut _conn) => {
Expand Down Expand Up @@ -236,17 +235,7 @@ pub(crate) fn base_service() -> HttpService {
}
Ok(Err(e)) => {
_conn.destroy_on_drop();

match e {
crate::h1::Error::Io(err) => {
if err.kind() == std::io::ErrorKind::UnexpectedEof {
continue;
}

Err(crate::h1::Error::Io(err).into())
}
_ => Err(e.into()),
}
Err(e.into())
}
Err(_) => {
_conn.destroy_on_drop();
Expand Down
7 changes: 5 additions & 2 deletions test/tests/h1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{
time::Duration,
};

use xitca_client::Client;
use xitca_client::{middleware::RetryClosedConnection, Client};
use xitca_http::{
body::{BoxBody, ResponseBody},
bytes::{Bytes, BytesMut},
Expand Down Expand Up @@ -48,7 +48,10 @@ async fn h1_get_connection_closed_by_server() -> Result<(), Error> {

let server_url = format!("http://{}/", ip_port);

let c = Client::builder().set_pool_capacity(1).finish();
let c = Client::builder()
.middleware(RetryClosedConnection::new)
.set_pool_capacity(1)
.finish();

let res = c.get(&server_url).version(Version::HTTP_11).send().await?;
assert_eq!(res.status().as_u16(), 200);
Expand Down

0 comments on commit 5471925

Please sign in to comment.