diff --git a/Cargo.lock b/Cargo.lock index 5d98dfc5..ef87a89e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -142,7 +142,7 @@ dependencies = [ [[package]] name = "audioserve" -version = "0.27.1" +version = "0.27.2" dependencies = [ "anyhow", "async-tar", diff --git a/Cargo.toml b/Cargo.toml index fdf99d89..a4aa7c43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "audioserve" -version = "0.27.1" +version = "0.27.2" authors = ["Ivan "] edition = "2021" rust-version = "1.70" diff --git a/src/services/mod.rs b/src/services/mod.rs index d03d7071..dd9c55df 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -1,6 +1,6 @@ use self::auth::{AuthResult, Authenticator}; use self::request::{QueryParams, RequestWrapper}; -use self::response::ResponseFuture; +use self::response::{ResponseFuture, ResponseResult}; use self::search::Search; use self::transcode::QualityLevel; use crate::config::get_config; @@ -27,7 +27,6 @@ use std::{ convert::Infallible, net::SocketAddr, path::{Path, PathBuf}, - pin::Pin, sync::{atomic::AtomicUsize, Arc}, task::Poll, }; @@ -87,11 +86,13 @@ impl ServiceFactory { is_ssl: bool, ) -> impl Future, Infallible>> { future::ok(MainService { + state: ServiceComponents { + search: self.search.clone(), + transcoding: self.transcoding.clone(), + collections: self.collections.clone(), + }, authenticator: self.authenticator.clone(), rate_limitter: self.rate_limitter.clone(), - search: self.search.clone(), - transcoding: self.transcoding.clone(), - collections: self.collections.clone(), remote_addr, is_ssl, }) @@ -99,12 +100,19 @@ impl ServiceFactory { } #[derive(Clone)] -pub struct MainService { - pub authenticator: Option>>, - pub rate_limitter: Option>, +pub struct ServiceComponents { pub search: Search, pub transcoding: TranscodingDetails, pub collections: Arc, +} + +type OptionalAuthenticatorType = Option>>; + +#[derive(Clone)] +pub struct MainService { + pub state: ServiceComponents, + pub authenticator: OptionalAuthenticatorType, + pub rate_limitter: Option>, pub remote_addr: SocketAddr, pub is_ssl: bool, } @@ -177,31 +185,24 @@ fn is_static_file(path: &str) -> bool { impl Service> for MainService { type Response = Response; type Error = error::Error; - type Future = Pin> + Send>>; + type Future = ResponseFuture; fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, req: Request) -> Self::Future { - Box::pin(self.process_request(req).or_else(|e| { - error!("Request processing error: {}", e); - future::ok(response::internal_error()) - })) - } -} + let state = self.state.clone(); -impl MainService { - fn process_request(&mut self, req: Request) -> ResponseFuture { //Limit rate of requests if configured - if let Some(limiter) = self.rate_limitter.as_ref() { + if let Some(ref limiter) = self.rate_limitter { if limiter.start_one().is_err() { debug!("Rejecting request due to rate limit"); return response::fut(response::too_many_requests); } } - // handle OPTIONS method for CORS preflightAtomicUsize + // handle OPTIONS method for CORS preflight if req.method() == Method::OPTIONS && RequestWrapper::is_cors_enabled_for_request(&req) { debug!( "Got OPTIONS request in CORS mode : {} {:?}", @@ -223,6 +224,24 @@ impl MainService { return response::fut(response::bad_request); } }; + + Box::pin( + MainService::::process_request(state, self.authenticator.clone(), req).or_else( + |e| { + error!("Request processing error: {}", e); + future::ok(response::internal_error()) + }, + ), + ) + } +} + +impl MainService { + async fn process_request( + subservices: ServiceComponents, + authenticator: OptionalAuthenticatorType, + req: RequestWrapper, + ) -> ResponseResult { //static files if req.method() == Method::GET { if req.path() == "/" || req.path() == "/index.html" { @@ -230,54 +249,49 @@ impl MainService { &get_config().client_dir, "index.html", get_config().static_resource_cache_age, - ); + ) + .await; } else if is_static_file(req.path()) { return files::send_static_file( &get_config().client_dir, &req.path()[1..], get_config().static_resource_cache_age, - ); + ) + .await; } } // from here everything must be authenticated - let searcher = self.search.clone(); - let transcoding = self.transcoding.clone(); let cors = req.is_cors_enabled(); let origin = req.headers().typed_get::(); - let resp = match self.authenticator { + let resp = match authenticator { Some(ref auth) => { - let collections = self.collections.clone(); Box::pin(auth.authenticate(req).and_then(move |result| match result { - AuthResult::Authenticated { request, .. } => MainService::::process_checked( - request, - searcher, - transcoding, - collections, - ), + AuthResult::Authenticated { request, .. } => { + MainService::::process_authenticated(request, subservices) + } AuthResult::LoggedIn(resp) | AuthResult::Rejected(resp) => { Box::pin(future::ok(resp)) } })) } - None => MainService::::process_checked( - req, - searcher, - transcoding, - self.collections.clone(), - ), + None => MainService::::process_authenticated(req, subservices), }; - Box::pin(resp.map_ok(move |r| add_cors_headers(r, origin, cors))) + resp.map_ok(move |r| add_cors_headers(r, origin, cors)) + .await } - fn process_checked( - #[allow(unused_mut)] mut req: RequestWrapper, - searcher: Search, - transcoding: TranscodingDetails, - collections: Arc, + fn process_authenticated( + mut req: RequestWrapper, + subservices: ServiceComponents, ) -> ResponseFuture { let params = req.params(); let path = req.path(); + let ServiceComponents { + search, + transcoding, + collections, + } = subservices; match *req.method() { Method::GET => { if path.starts_with("/collections") { @@ -402,7 +416,7 @@ impl MainService { let group = params.get_string("group"); api::search( colllection_index, - searcher, + search, search_string, ord, group, @@ -414,7 +428,7 @@ impl MainService { } } else if path.starts_with("/recent") { let group = params.get_string("group"); - api::recent(colllection_index, searcher, group, req.can_compress()) + api::recent(colllection_index, search, group, req.can_compress()) } else if path.starts_with("/cover/") { files::send_cover( base_dir, diff --git a/src/services/response.rs b/src/services/response.rs index cb7d38f3..ac0fdb41 100644 --- a/src/services/response.rs +++ b/src/services/response.rs @@ -20,7 +20,8 @@ const NOT_IMPLEMENTED_MSG: &str = "Not Implemented"; const INTERNAL_SERVER_ERROR: &str = "Internal server error"; const UNPROCESSABLE_ENTITY: &str = "Ignored"; -pub type ResponseFuture = Pin, Error>> + Send>>; +pub type ResponseResult = Result, Error>; +pub type ResponseFuture = Pin + Send>>; fn short_response(status: StatusCode, msg: &'static str) -> Response { Response::builder()