From d68648cac617ba19395a793b6ba5b7394aa7e674 Mon Sep 17 00:00:00 2001 From: Gareth Clay Date: Tue, 16 Apr 2024 20:42:26 +0100 Subject: [PATCH] Remove Origin header when forwarding (#3357) This prevents forwarded requests, such as those from circuit breaker fallbacks, from failing in CORS checks, which require a fully populated scheme and host. Fixes gh-3350 --- .../support/ServerWebExchangeUtils.java | 5 ++++ .../filter/ForwardRoutingFilterTests.java | 6 ++-- ...CloudCircuitBreakerFilterFactoryTests.java | 7 +++++ .../support/ServerWebExchangeUtilsTests.java | 29 +++++++++++++++++++ 4 files changed, 44 insertions(+), 3 deletions(-) diff --git a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/support/ServerWebExchangeUtils.java b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/support/ServerWebExchangeUtils.java index a2792394ba..64118914d2 100644 --- a/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/support/ServerWebExchangeUtils.java +++ b/spring-cloud-gateway-server/src/main/java/org/springframework/cloud/gateway/support/ServerWebExchangeUtils.java @@ -431,6 +431,11 @@ public static Mono handle(DispatcherHandler handler, ServerWebExchange exc // remove attributes that may disrupt the forwarded request exchange.getAttributes().remove(GATEWAY_PREDICATE_PATH_CONTAINER_ATTR); + // CORS check is applied to the original request, but should not be applied to + // internally forwarded requests. + // See https://github.com/spring-cloud/spring-cloud-gateway/issues/3350. + exchange = exchange.mutate().request(request -> request.headers(headers -> headers.setOrigin(null))).build(); + return handler.handle(exchange); } diff --git a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/ForwardRoutingFilterTests.java b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/ForwardRoutingFilterTests.java index ab97965c1d..0274bbab4b 100644 --- a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/ForwardRoutingFilterTests.java +++ b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/ForwardRoutingFilterTests.java @@ -37,6 +37,7 @@ import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.assertArg; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -93,9 +94,8 @@ public void shouldFilterWhenGatewayRequestUrlSchemeIsForward() { forwardRoutingFilter.filter(exchange, chain); verifyNoMoreInteractions(chain); - verify(dispatcherHandler).handle(exchange); - - assertThat(exchange.getAttributes().get(GATEWAY_ALREADY_ROUTED_ATTR)).isNull(); + verify(dispatcherHandler).handle( + assertArg(exchange -> assertThat(exchange.getAttributes().get(GATEWAY_ALREADY_ROUTED_ATTR)).isNull())); } @Test diff --git a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/SpringCloudCircuitBreakerFilterFactoryTests.java b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/SpringCloudCircuitBreakerFilterFactoryTests.java index a5de908932..76aebf1e8e 100644 --- a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/SpringCloudCircuitBreakerFilterFactoryTests.java +++ b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/factory/SpringCloudCircuitBreakerFilterFactoryTests.java @@ -106,6 +106,13 @@ public void filterFallbackForward() { .isOk().expectBody().json("{\"from\":\"circuitbreakerfallbackcontroller3\"}"); } + @Test + public void filterFallbackForwardWithCORS() { + testClient.get().uri("/delay/3?a=b").header("Host", "www.circuitbreakerforward.org") + .header("Origin", "https://cors.withcircuitbreaker.org").exchange().expectStatus().isOk().expectBody() + .json("{\"from\":\"circuitbreakerfallbackcontroller3\"}"); + } + @Test public void filterStatusCodeFallback() { testClient.get().uri("/status/500").header("Host", "www.circuitbreakerstatuscode.org").exchange().expectStatus() diff --git a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/support/ServerWebExchangeUtilsTests.java b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/support/ServerWebExchangeUtilsTests.java index cf415e61c3..e50b840c19 100644 --- a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/support/ServerWebExchangeUtilsTests.java +++ b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/support/ServerWebExchangeUtilsTests.java @@ -23,18 +23,27 @@ import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.publisher.Mono; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DefaultDataBuffer; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.mock.http.server.reactive.MockServerHttpRequest; import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.web.reactive.DispatcherHandler; import org.springframework.web.reactive.function.server.HandlerStrategies; import org.springframework.web.reactive.function.server.ServerRequest; +import org.springframework.web.server.ServerWebExchange; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.assertArg; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CACHED_REQUEST_BODY_ATTR; +import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.GATEWAY_PREDICATE_PATH_CONTAINER_ATTR; import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.expand; +import static org.springframework.http.server.PathContainer.parsePath; public class ServerWebExchangeUtilsTests { @@ -94,6 +103,26 @@ public void duplicatedCachingDataBufferHandling() { Assertions.assertThat(dataBufferBeforeCaching).isEqualTo(dataBufferAfterCached); } + @Test + public void forwardedRequestsHaveDisruptiveAttributesAndHeadersRemoved() { + DispatcherHandler handler = Mockito.mock(DispatcherHandler.class); + Mockito.when(handler.handle(any(ServerWebExchange.class))).thenReturn(Mono.empty()); + + ServerWebExchange originalExchange = mockExchange(Map.of()).mutate() + .request(request -> request.headers(headers -> headers.setOrigin("https://example.com"))).build(); + originalExchange.getAttributes().put(GATEWAY_PREDICATE_PATH_CONTAINER_ATTR, parsePath("/example/path")); + + ServerWebExchangeUtils.handle(handler, originalExchange).block(); + + Mockito.verify(handler).handle(assertArg(exchange -> { + Assertions.assertThat(exchange.getAttributes()).as("exchange attributes") + .doesNotContainKey(GATEWAY_PREDICATE_PATH_CONTAINER_ATTR); + + Assertions.assertThat(exchange.getRequest().getHeaders()).as("request headers") + .doesNotContainKey(HttpHeaders.ORIGIN); + })); + } + private MockServerWebExchange mockExchange(Map vars) { return mockExchange(HttpMethod.GET, vars); }