diff --git a/poem/Cargo.toml b/poem/Cargo.toml index 08e940691d..4277d84313 100644 --- a/poem/Cargo.toml +++ b/poem/Cargo.toml @@ -91,6 +91,7 @@ headers = "0.3.7" thiserror.workspace = true rfc7239 = "0.1.0" mime.workspace = true +wildmatch = "2" # Non-feature optional dependencies multer = { version = "2.1.0", features = ["tokio"], optional = true } diff --git a/poem/src/middleware/cors.rs b/poem/src/middleware/cors.rs index 0b0e92ae58..d682683137 100644 --- a/poem/src/middleware/cors.rs +++ b/poem/src/middleware/cors.rs @@ -3,6 +3,7 @@ use std::{collections::HashSet, str::FromStr, sync::Arc}; use headers::{ AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlExposeHeaders, HeaderMapExt, }; +use wildmatch::WildMatch; use crate::{ endpoint::Endpoint, @@ -39,6 +40,7 @@ use crate::{ pub struct Cors { allow_credentials: bool, allow_origins: HashSet, + allow_origins_wildcard: Vec, allow_origins_fn: Option bool + Send + Sync>>, allow_headers: HashSet, allow_methods: HashSet, @@ -135,6 +137,14 @@ impl Cors { self } + /// Add an allowed origin that supports '*' wildcard. + /// Example: `rust cors.allow_origin_regex("https://*.domain.url")` + fn allow_origin_regex(mut self, origin: impl AsRef) -> Self { + self.allow_origins_wildcard + .push(WildMatch::new(origin.as_ref())); + self + } + /// Add many allow origins. #[must_use] pub fn allow_origins(self, origins: I) -> Self @@ -203,6 +213,7 @@ impl Middleware for Cors { inner: ep, allow_credentials: self.allow_credentials, allow_origins: self.allow_origins.clone(), + allow_origins_wildcard: self.allow_origins_wildcard.clone(), allow_origins_fn: self.allow_origins_fn.clone(), allow_headers: self.allow_headers.clone(), allow_methods: self.allow_methods.clone(), @@ -221,6 +232,7 @@ pub struct CorsEndpoint { inner: E, allow_credentials: bool, allow_origins: HashSet, + allow_origins_wildcard: Vec, allow_origins_fn: Option bool + Send + Sync>>, allow_headers: HashSet, allow_methods: HashSet, @@ -237,6 +249,14 @@ impl CorsEndpoint { return (true, false); } + if self + .allow_origins_wildcard + .iter() + .any(|m| m.matches(origin.to_str().unwrap())) + { + return (true, true); + } + if let Some(allow_origins_fn) = &self.allow_origins_fn { if let Ok(origin) = origin.to_str() { if allow_origins_fn(origin) { @@ -246,7 +266,9 @@ impl CorsEndpoint { } ( - self.allow_origins.is_empty() && self.allow_origins_fn.is_none(), + self.allow_origins.is_empty() + && self.allow_origins_fn.is_none() + && self.allow_origins_wildcard.is_empty(), true, ) } @@ -541,6 +563,41 @@ mod tests { resp.assert_status(StatusCode::FORBIDDEN); } + #[tokio::test] + async fn allow_origins_fn_4() { + let ep = + make_sync(|_| "hello").with(Cors::new().allow_origin_regex("https://*example.com")); + let cli = TestClient::new(ep); + + let resp = cli + .get("/") + .header(header::ORIGIN, "https://example.mx") + .send() + .await; + resp.assert_status(StatusCode::FORBIDDEN); + + let resp = cli + .get("/") + .header(header::ORIGIN, "https://test.example.com") + .send() + .await; + resp.assert_status_is_ok(); + resp.assert_header( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + "https://test.example.com", + ); + resp.assert_header_is_not_exist(header::VARY); + + let resp = cli + .get("/") + .header(header::ORIGIN, "https://example.com") + .send() + .await; + resp.assert_status_is_ok(); + resp.assert_header(header::ACCESS_CONTROL_ALLOW_ORIGIN, "https://example.com"); + resp.assert_header(header::VARY, "Origin"); + } + #[tokio::test] async fn default_cors_middleware() { let ep = make_sync(|_| "hello").with(Cors::new());