From 7f1048a0dfde767ab9214391dae07e1654f5dad6 Mon Sep 17 00:00:00 2001 From: sgibb Date: Mon, 11 Mar 2024 18:30:49 -0400 Subject: [PATCH] Updates config suppor for HandlerFunctions to accept RouteProperties Fixes gh-3188 --- ...yMvcPropertiesBeanDefinitionRegistrar.java | 62 ++++++++++--------- .../mvc/config/NormalizedOperationMethod.java | 25 +++++--- .../filter/LoadBalancerHandlerSupplier.java | 5 ++ .../server/mvc/handler/HandlerFunctions.java | 5 ++ ...propertiesbeandefinitionregistrartests.yml | 2 +- 5 files changed, 59 insertions(+), 40 deletions(-) diff --git a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/config/GatewayMvcPropertiesBeanDefinitionRegistrar.java b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/config/GatewayMvcPropertiesBeanDefinitionRegistrar.java index 8960fb20f3..80a520650b 100644 --- a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/config/GatewayMvcPropertiesBeanDefinitionRegistrar.java +++ b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/config/GatewayMvcPropertiesBeanDefinitionRegistrar.java @@ -185,16 +185,18 @@ private RouterFunction getRouterFunction(RouteProperties routeProperties, String // TODO: cache? // translate handlerFunction String scheme = routeProperties.getUri().getScheme(); - Map handlerArgs = new HashMap<>(); - // TODO: avoid hardcoded scheme/uri args - // maybe find empty args or single RouteProperties param? - if (scheme.equals("lb")) { - handlerArgs.put("uri", routeProperties.getUri().toString()); - } + Map handlerArgs = new HashMap<>(); Optional handlerOperationMethod = findOperation(handlerOperations, scheme.toLowerCase(), handlerArgs); if (handlerOperationMethod.isEmpty()) { - throw new IllegalStateException("Unable to find HandlerFunction for scheme: " + scheme); + // single RouteProperties param + handlerArgs.clear(); + String routePropsKey = StringUtils.uncapitalize(RouteProperties.class.getSimpleName()); + handlerArgs.put(routePropsKey, routeProperties); + handlerOperationMethod = findOperation(handlerOperations, scheme.toLowerCase(), handlerArgs); + if (handlerOperationMethod.isEmpty()) { + throw new IllegalStateException("Unable to find HandlerFunction for scheme: " + scheme); + } } NormalizedOperationMethod normalizedOpMethod = handlerOperationMethod.get(); Object response = invokeOperation(normalizedOpMethod, normalizedOpMethod.getNormalizedArgs()); @@ -215,21 +217,21 @@ else if (response instanceof HandlerDiscoverer.Result result) { MultiValueMap predicateOperations = predicateDiscoverer.getOperations(); final AtomicReference predicate = new AtomicReference<>(); - routeProperties.getPredicates() - .forEach(predicateProperties -> translate(predicateOperations, predicateProperties.getName(), - predicateProperties.getArgs(), RequestPredicate.class, requestPredicate -> { - log.trace(LogMessage.format("Adding predicate to route %s - %s", routeId, - predicateProperties)); - if (predicate.get() == null) { - predicate.set(requestPredicate); - } - else { - RequestPredicate combined = predicate.get().and(requestPredicate); - predicate.set(combined); - } - log.trace(LogMessage.format("Combined predicate for route %s - %s", routeId, - predicate.get())); - })); + routeProperties.getPredicates().forEach(predicateProperties -> { + Map args = new LinkedHashMap<>(predicateProperties.getArgs()); + translate(predicateOperations, predicateProperties.getName(), args, RequestPredicate.class, + requestPredicate -> { + log.trace(LogMessage.format("Adding predicate to route %s - %s", routeId, predicateProperties)); + if (predicate.get() == null) { + predicate.set(requestPredicate); + } + else { + RequestPredicate combined = predicate.get().and(requestPredicate); + predicate.set(combined); + } + log.trace(LogMessage.format("Combined predicate for route %s - %s", routeId, predicate.get())); + }); + }); // combine predicate and handlerFunction builder.route(predicate.get(), handlerFunction); @@ -237,8 +239,10 @@ else if (response instanceof HandlerDiscoverer.Result result) { // translate filters MultiValueMap filterOperations = filterDiscoverer.getOperations(); - routeProperties.getFilters().forEach(filterProperties -> translate(filterOperations, filterProperties.getName(), - filterProperties.getArgs(), HandlerFilterFunction.class, builder::filter)); + routeProperties.getFilters().forEach(filterProperties -> { + Map args = new LinkedHashMap<>(filterProperties.getArgs()); + translate(filterOperations, filterProperties.getName(), args, HandlerFilterFunction.class, builder::filter); + }); builder.withAttribute(MvcUtils.GATEWAY_ROUTE_ID_ATTR, routeId); @@ -246,7 +250,7 @@ else if (response instanceof HandlerDiscoverer.Result result) { } private void translate(MultiValueMap operations, String operationName, - Map operationArgs, Class returnType, Consumer operationHandler) { + Map operationArgs, Class returnType, Consumer operationHandler) { String normalizedName = StringUtils.uncapitalize(operationName); Optional operationMethod = findOperation(operations, normalizedName, operationArgs); if (operationMethod.isPresent()) { @@ -263,14 +267,14 @@ private void translate(MultiValueMap operations, St } private Optional findOperation(MultiValueMap operations, - String operationName, Map operationArgs) { + String operationName, Map operationArgs) { return operations.getOrDefault(operationName, Collections.emptyList()).stream() .map(operationMethod -> new NormalizedOperationMethod(operationMethod, operationArgs)) .filter(opeMethod -> matchOperation(opeMethod, operationArgs)).findFirst(); } - private static boolean matchOperation(NormalizedOperationMethod operationMethod, Map args) { - Map normalizedArgs = operationMethod.getNormalizedArgs(); + private static boolean matchOperation(NormalizedOperationMethod operationMethod, Map args) { + Map normalizedArgs = operationMethod.getNormalizedArgs(); OperationParameters parameters = operationMethod.getParameters(); if (operationMethod.isConfigurable()) { // this is a special case @@ -288,7 +292,7 @@ private static boolean matchOperation(NormalizedOperationMethod operationMethod, return true; } - private T invokeOperation(OperationMethod operationMethod, Map operationArgs) { + private T invokeOperation(OperationMethod operationMethod, Map operationArgs) { Map args = new HashMap<>(); if (operationMethod.isConfigurable()) { OperationParameter operationParameter = operationMethod.getParameters().get(0); diff --git a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/config/NormalizedOperationMethod.java b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/config/NormalizedOperationMethod.java index 7ca9ed542a..9e0b62169a 100644 --- a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/config/NormalizedOperationMethod.java +++ b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/config/NormalizedOperationMethod.java @@ -35,13 +35,13 @@ public class NormalizedOperationMethod implements OperationMethod { private final OperationMethod delegate; - private final Map normalizedArgs; + private final Map normalizedArgs; /** * Create a new {@link DefaultOperationMethod} instance. * @param method the source method */ - public NormalizedOperationMethod(OperationMethod delegate, Map args) { + public NormalizedOperationMethod(OperationMethod delegate, Map args) { this.delegate = delegate; normalizedArgs = normalizeArgs(args); } @@ -61,31 +61,36 @@ public OperationParameters getParameters() { return delegate.getParameters(); } - public Map getNormalizedArgs() { + public Map getNormalizedArgs() { return normalizedArgs; } - private Map normalizeArgs(Map operationArgs) { + @Override + public String toString() { + return delegate.toString(); + } + + private Map normalizeArgs(Map operationArgs) { if (hasGeneratedKey(operationArgs)) { Shortcut shortcut = getMethod().getAnnotation(Shortcut.class); if (shortcut != null) { String[] fieldOrder = getFieldOrder(shortcut); return switch (shortcut.type()) { case DEFAULT -> { - Map map = new HashMap<>(); + Map map = new HashMap<>(); int entryIdx = 0; - for (Map.Entry entry : operationArgs.entrySet()) { + for (Map.Entry entry : operationArgs.entrySet()) { String key = normalizeKey(entry.getKey(), entryIdx, operationArgs, fieldOrder); // TODO: support spel? // getValue(parser, beanFactory, entry.getValue()); - String value = entry.getValue(); + Object value = entry.getValue(); map.put(key, value); entryIdx++; } yield map; } case LIST -> { - Map map = new HashMap<>(); + Map map = new HashMap<>(); // field order should be of size 1 Assert.isTrue(fieldOrder != null && fieldOrder.length == 1, "Shortcut Configuration Type GATHER_LIST must have shortcutFieldOrder of size 1"); @@ -110,11 +115,11 @@ private String[] getFieldOrder(Shortcut shortcut) { return fieldOrder; } - private static boolean hasGeneratedKey(Map operationArgs) { + private static boolean hasGeneratedKey(Map operationArgs) { return operationArgs.keySet().stream().anyMatch(key -> key.startsWith(NameUtils.GENERATED_NAME_PREFIX)); } - static String normalizeKey(String key, int entryIdx, Map args, String[] fieldOrder) { + static String normalizeKey(String key, int entryIdx, Map args, String[] fieldOrder) { // RoutePredicateFactory has name hints and this has a fake key name // replace with the matching key hint if (key.startsWith(NameUtils.GENERATED_NAME_PREFIX) && fieldOrder.length > 0 && entryIdx < args.size() diff --git a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/LoadBalancerHandlerSupplier.java b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/LoadBalancerHandlerSupplier.java index 6a3d336545..3cedfa350a 100644 --- a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/LoadBalancerHandlerSupplier.java +++ b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/filter/LoadBalancerHandlerSupplier.java @@ -22,6 +22,7 @@ import java.util.Collection; import java.util.Collections; +import org.springframework.cloud.gateway.server.mvc.config.RouteProperties; import org.springframework.cloud.gateway.server.mvc.handler.HandlerDiscoverer; import org.springframework.cloud.gateway.server.mvc.handler.HandlerFunctions; import org.springframework.cloud.gateway.server.mvc.handler.HandlerSupplier; @@ -33,6 +34,10 @@ public Collection get() { return Arrays.asList(getClass().getMethods()); } + public static HandlerDiscoverer.Result lb(RouteProperties routeProperties) { + return lb(routeProperties.getUri()); + } + public static HandlerDiscoverer.Result lb(URI uri) { // TODO: how to do something other than http return new HandlerDiscoverer.Result(HandlerFunctions.http(), diff --git a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/HandlerFunctions.java b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/HandlerFunctions.java index 705dfdba7e..81760ff153 100644 --- a/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/HandlerFunctions.java +++ b/spring-cloud-gateway-server-mvc/src/main/java/org/springframework/cloud/gateway/server/mvc/handler/HandlerFunctions.java @@ -26,6 +26,7 @@ import jakarta.servlet.ServletException; import org.springframework.cloud.gateway.server.mvc.common.MvcUtils; +import org.springframework.cloud.gateway.server.mvc.config.RouteProperties; import org.springframework.web.servlet.function.HandlerFunction; import org.springframework.web.servlet.function.ServerRequest; import org.springframework.web.servlet.function.ServerResponse; @@ -36,6 +37,10 @@ private HandlerFunctions() { } + public static HandlerFunction forward(RouteProperties routeProperties) { + return forward(routeProperties.getUri().getPath()); + } + public static HandlerFunction forward(String path) { // ok() is wrong, but can be overridden by the forwarded request. return request -> GatewayServerResponse.ok().build((httpServletRequest, httpServletResponse) -> { diff --git a/spring-cloud-gateway-server-mvc/src/test/resources/application-propertiesbeandefinitionregistrartests.yml b/spring-cloud-gateway-server-mvc/src/test/resources/application-propertiesbeandefinitionregistrartests.yml index ef59ccf2e6..d27c4c7668 100644 --- a/spring-cloud-gateway-server-mvc/src/test/resources/application-propertiesbeandefinitionregistrartests.yml +++ b/spring-cloud-gateway-server-mvc/src/test/resources/application-propertiesbeandefinitionregistrartests.yml @@ -45,7 +45,7 @@ spring.cloud.gateway.mvc: name: X-Test values: listRoute3 - id: listRoute4 - uri: https://example1.com + uri: forward:/mycontroller predicates: - Path=/anything/example1 filters: