diff --git a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/GatewayServerMvcAutoConfiguration.java b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/GatewayServerMvcAutoConfiguration.java index f888bec5ec..dabf7d2dfa 100644 --- a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/GatewayServerMvcAutoConfiguration.java +++ b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/GatewayServerMvcAutoConfiguration.java @@ -16,19 +16,19 @@ package org.springframework.cloud.gateway.server.mvc; +import java.util.Map; + import org.springframework.beans.factory.ObjectProvider; +import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.http.client.HttpClientAutoConfiguration; -import org.springframework.boot.autoconfigure.http.client.HttpClientProperties; +import org.springframework.boot.autoconfigure.http.client.HttpClientProperties.Factory; import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.autoconfigure.web.client.RestTemplateAutoConfiguration; -import org.springframework.boot.http.client.ClientHttpRequestFactoryBuilder; -import org.springframework.boot.http.client.ClientHttpRequestFactorySettings; -import org.springframework.boot.http.client.JdkClientHttpRequestFactoryBuilder; -import org.springframework.boot.ssl.SslBundle; -import org.springframework.boot.ssl.SslBundles; +import org.springframework.boot.env.EnvironmentPostProcessor; +import org.springframework.boot.http.client.ClientHttpRequestFactorySettings.Redirects; import org.springframework.boot.web.client.RestClientCustomizer; import org.springframework.cloud.gateway.server.mvc.common.ArgumentSupplierBeanPostProcessor; import org.springframework.cloud.gateway.server.mvc.config.GatewayMvcAotRuntimeHintsRegistrar; @@ -55,8 +55,11 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; import org.springframework.context.annotation.ImportRuntimeHints; +import org.springframework.core.env.ConfigurableEnvironment; import org.springframework.core.env.Environment; +import org.springframework.core.env.MapPropertySource; import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.util.ClassUtils; import org.springframework.util.StringUtils; import org.springframework.web.client.RestClient; @@ -85,8 +88,12 @@ public RouterFunctionHolderFactory routerFunctionHolderFactory(Environment env) } @Bean - public RestClientCustomizer gatewayRestClientCustomizer(ClientHttpRequestFactory requestFactory) { - return restClientBuilder -> restClientBuilder.requestFactory(requestFactory); + public RestClientCustomizer gatewayRestClientCustomizer( + ObjectProvider requestFactoryProvider) { + return restClientBuilder -> { + // for backwards compatibility if user overrode + requestFactoryProvider.ifAvailable(restClientBuilder::requestFactory); + }; } @Bean @@ -111,36 +118,6 @@ public ForwardedRequestHeadersFilter forwardedRequestHeadersFilter() { return new ForwardedRequestHeadersFilter(); } - @Bean - @ConditionalOnMissingBean - public ClientHttpRequestFactory gatewayClientHttpRequestFactory(HttpClientProperties properties, - SslBundles sslBundles) { - - SslBundle sslBundle = null; - if (StringUtils.hasText(properties.getSsl().getBundle())) { - sslBundle = sslBundles.getBundle(properties.getSsl().getBundle()); - } - ClientHttpRequestFactorySettings settings = ClientHttpRequestFactorySettings.ofSslBundle(sslBundle) - .withConnectTimeout(properties.getConnectTimeout()) - .withReadTimeout(properties.getReadTimeout()) - .withRedirects(ClientHttpRequestFactorySettings.Redirects.DONT_FOLLOW); - - ClientHttpRequestFactoryBuilder builder = ClientHttpRequestFactoryBuilder.detect(); - if (builder instanceof JdkClientHttpRequestFactoryBuilder) { - // TODO: customize restricted headers - String restrictedHeaders = System.getProperty("jdk.httpclient.allowRestrictedHeaders"); - if (!StringUtils.hasText(restrictedHeaders)) { - System.setProperty("jdk.httpclient.allowRestrictedHeaders", "host"); - } - else if (StringUtils.hasText(restrictedHeaders) && !restrictedHeaders.contains("host")) { - System.setProperty("jdk.httpclient.allowRestrictedHeaders", restrictedHeaders + ",host"); - } - } - - // Autodetect - return builder.build(settings); - } - @Bean @ConditionalOnMissingBean public GatewayMvcProperties gatewayMvcProperties() { @@ -222,4 +199,46 @@ public XForwardedRequestHeadersFilterProperties xForwardedRequestHeadersFilterPr return new XForwardedRequestHeadersFilterProperties(); } + static class GatewayHttpClientEnvironmentPostProcessor implements EnvironmentPostProcessor { + + static final boolean APACHE = ClassUtils.isPresent("org.apache.hc.client5.http.impl.classic.HttpClients", null); + static final boolean JETTY = ClassUtils.isPresent("org.eclipse.jetty.client.HttpClient", null); + static final boolean REACTOR_NETTY = ClassUtils.isPresent("reactor.netty.http.client.HttpClient", null); + static final boolean JDK = ClassUtils.isPresent("java.net.http.HttpClient", null); + static final boolean HIGHER_PRIORITY = APACHE || JETTY || REACTOR_NETTY; + + @Override + public void postProcessEnvironment(ConfigurableEnvironment environment, SpringApplication application) { + Redirects redirects = environment.getProperty("spring.http.client.redirects", Redirects.class); + if (redirects == null) { + // the user hasn't set anything, change the default + environment.getPropertySources() + .addFirst(new MapPropertySource("gatewayHttpClientProperties", + Map.of("spring.http.client.redirects", Redirects.DONT_FOLLOW))); + } + Factory factory = environment.getProperty("spring.http.client.factory", Factory.class); + boolean setJdkHttpClientProperties = false; + + if (factory == null && !HIGHER_PRIORITY) { + // autodetect + setJdkHttpClientProperties = JDK; + } + else if (factory == Factory.JDK) { + setJdkHttpClientProperties = JDK; + } + + if (setJdkHttpClientProperties) { + // TODO: customize restricted headers + String restrictedHeaders = System.getProperty("jdk.httpclient.allowRestrictedHeaders"); + if (!StringUtils.hasText(restrictedHeaders)) { + System.setProperty("jdk.httpclient.allowRestrictedHeaders", "host"); + } + else if (StringUtils.hasText(restrictedHeaders) && !restrictedHeaders.contains("host")) { + System.setProperty("jdk.httpclient.allowRestrictedHeaders", restrictedHeaders + ",host"); + } + } + } + + } + } diff --git a/spring-cloud-gateway-server-mvc/src/main/resources/META-INF/spring.factories b/spring-cloud-gateway-server-mvc/src/main/resources/META-INF/spring.factories index 972cdb6d78..789c0150f7 100644 --- a/spring-cloud-gateway-server-mvc/src/main/resources/META-INF/spring.factories +++ b/spring-cloud-gateway-server-mvc/src/main/resources/META-INF/spring.factories @@ -12,3 +12,6 @@ org.springframework.cloud.gateway.server.mvc.handler.HandlerSupplier=\ org.springframework.cloud.gateway.server.mvc.predicate.PredicateSupplier=\ org.springframework.cloud.gateway.server.mvc.predicate.MvcPredicateSupplier,\ org.springframework.cloud.gateway.server.mvc.predicate.GatewayRequestPredicates.PredicateSupplier + +org.springframework.boot.env.EnvironmentPostProcessor=\ + org.springframework.cloud.gateway.server.mvc.GatewayServerMvcAutoConfiguration.GatewayHttpClientEnvironmentPostProcessor \ No newline at end of file diff --git a/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/GatewayServerMvcAutoConfigurationTests.java b/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/GatewayServerMvcAutoConfigurationTests.java index e5ed6cf174..955264860e 100644 --- a/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/GatewayServerMvcAutoConfigurationTests.java +++ b/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/GatewayServerMvcAutoConfigurationTests.java @@ -29,6 +29,9 @@ import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; import org.springframework.boot.autoconfigure.web.client.RestTemplateAutoConfiguration; import org.springframework.boot.builder.SpringApplicationBuilder; +import org.springframework.boot.http.client.ClientHttpRequestFactoryBuilder; +import org.springframework.boot.http.client.ClientHttpRequestFactorySettings; +import org.springframework.boot.http.client.SimpleClientHttpRequestFactoryBuilder; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.cloud.gateway.server.mvc.filter.FormFilter; import org.springframework.cloud.gateway.server.mvc.filter.ForwardedRequestHeadersFilter; @@ -40,8 +43,6 @@ import org.springframework.cloud.gateway.server.mvc.filter.WeightCalculatorFilter; import org.springframework.cloud.gateway.server.mvc.filter.XForwardedRequestHeadersFilter; import org.springframework.context.ConfigurableApplicationContext; -import org.springframework.http.client.JdkClientHttpRequestFactory; -import org.springframework.test.util.ReflectionTestUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -135,7 +136,6 @@ void filterEnabledPropertiesWork() { void gatewayHttpClientPropertiesWork() { ConfigurableApplicationContext context = new SpringApplicationBuilder(TestConfig.class) .properties("spring.main.web-application-type=none", - "spring.cloud.gateway.mvc.http-client.connect-timeout=1s", "spring.cloud.gateway.mvc.http-client.connect-timeout=1s", "spring.cloud.gateway.mvc.http-client.read-timeout=2s", "spring.cloud.gateway.mvc.http-client.ssl-bundle=mybundle", @@ -143,14 +143,16 @@ void gatewayHttpClientPropertiesWork() { "spring.ssl.bundle.pem.mybundle.keystore.certificate=" + cert, "spring.ssl.bundle.pem.mybundle.keystore.key=" + key) .run(); - JdkClientHttpRequestFactory requestFactory = context.getBean(JdkClientHttpRequestFactory.class); + ClientHttpRequestFactorySettings settings = context.getBean(ClientHttpRequestFactorySettings.class); HttpClientProperties properties = context.getBean(HttpClientProperties.class); assertThat(properties.getConnectTimeout()).hasSeconds(1); assertThat(properties.getReadTimeout()).hasSeconds(2); assertThat(properties.getSsl().getBundle()).isEqualTo("mybundle"); assertThat(properties.getFactory()).isNull(); - Object readTimeout = ReflectionTestUtils.getField(requestFactory, "readTimeout"); - assertThat(readTimeout).isEqualTo(Duration.ofSeconds(2)); + assertThat(settings.readTimeout()).isEqualTo(Duration.ofSeconds(2)); + assertThat(settings.connectTimeout()).isEqualTo(Duration.ofSeconds(1)); + assertThat(settings.sslBundle()).isNotNull(); + assertThat(settings.redirects()).isEqualTo(ClientHttpRequestFactorySettings.Redirects.DONT_FOLLOW); } @Test @@ -164,19 +166,30 @@ void bootHttpClientPropertiesWork() { "spring.ssl.bundle.pem.mybundle.keystore.certificate=" + cert, "spring.ssl.bundle.pem.mybundle.keystore.key=" + key) .run(context -> { - assertThat(context).hasSingleBean(JdkClientHttpRequestFactory.class) + assertThat(context).hasSingleBean(ClientHttpRequestFactorySettings.class) .hasSingleBean(HttpClientProperties.class); HttpClientProperties httpClient = context.getBean(HttpClientProperties.class); assertThat(httpClient.getConnectTimeout()).hasSeconds(1); assertThat(httpClient.getReadTimeout()).hasSeconds(2); assertThat(httpClient.getSsl().getBundle()).isEqualTo("mybundle"); assertThat(httpClient.getFactory()).isNull(); - JdkClientHttpRequestFactory requestFactory = context.getBean(JdkClientHttpRequestFactory.class); - Object readTimeout = ReflectionTestUtils.getField(requestFactory, "readTimeout"); - assertThat(readTimeout).isEqualTo(Duration.ofSeconds(2)); + ClientHttpRequestFactorySettings settings = context.getBean(ClientHttpRequestFactorySettings.class); + assertThat(settings.readTimeout()).isEqualTo(Duration.ofSeconds(2)); + assertThat(settings.connectTimeout()).isEqualTo(Duration.ofSeconds(1)); + assertThat(settings.sslBundle()).isNotNull(); + // cant test redirects because EnvironmentPostProcessor is not run }); } + @Test + void settingHttpClientFactoryWorks() { + ConfigurableApplicationContext context = new SpringApplicationBuilder(TestConfig.class) + .properties("spring.main.web-application-type=none", "spring.http.client.factory=simple") + .run(); + ClientHttpRequestFactoryBuilder builder = context.getBean(ClientHttpRequestFactoryBuilder.class); + assertThat(builder).isInstanceOf(SimpleClientHttpRequestFactoryBuilder.class); + } + @SpringBootConfiguration @EnableAutoConfiguration static class TestConfig { diff --git a/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/test/TestAutoConfiguration.java b/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/test/TestAutoConfiguration.java index 7966b3766c..c9457cb164 100644 --- a/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/test/TestAutoConfiguration.java +++ b/spring-cloud-gateway-server-mvc/src/test/java/org/springframework/cloud/gateway/server/mvc/test/TestAutoConfiguration.java @@ -23,27 +23,22 @@ import org.springframework.boot.web.client.RestTemplateCustomizer; import org.springframework.cloud.gateway.server.mvc.GatewayServerMvcAutoConfiguration; import org.springframework.cloud.gateway.server.mvc.test.client.DefaultTestRestClient; -import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Lazy; import org.springframework.core.env.Environment; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; -import org.springframework.http.client.ClientHttpRequestFactory; @AutoConfiguration(after = GatewayServerMvcAutoConfiguration.class) public class TestAutoConfiguration { @Bean - RestTemplateCustomizer testRestClientRestTemplateCustomizer(ApplicationContext context) { - return restTemplate -> { - restTemplate.setRequestFactory(context.getBean(ClientHttpRequestFactory.class)); - restTemplate.setClientHttpRequestInitializers(List.of(request -> { - if (!request.getHeaders().containsKey(HttpHeaders.ACCEPT)) { - request.getHeaders().setAccept(List.of(MediaType.ALL)); - } - })); - }; + RestTemplateCustomizer testRestClientRestTemplateCustomizer() { + return restTemplate -> restTemplate.setClientHttpRequestInitializers(List.of(request -> { + if (!request.getHeaders().containsKey(HttpHeaders.ACCEPT)) { + request.getHeaders().setAccept(List.of(MediaType.ALL)); + } + })); } @Bean