diff --git a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/ReactiveLoadBalancerClientFilterTests.java b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/ReactiveLoadBalancerClientFilterTests.java index f48ab97ee7..62f8375825 100644 --- a/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/ReactiveLoadBalancerClientFilterTests.java +++ b/spring-cloud-gateway-server/src/test/java/org/springframework/cloud/gateway/filter/ReactiveLoadBalancerClientFilterTests.java @@ -312,6 +312,7 @@ void shouldPassRequestToLoadBalancer() { URI lbUri = URI.create("lb://service1?a=b"); ServerWebExchange serverWebExchange = mock(ServerWebExchange.class); when(serverWebExchange.getAttribute(GATEWAY_REQUEST_URL_ATTR)).thenReturn(lbUri); + when(serverWebExchange.getAttributes()).thenReturn(new HashMap<>(Map.of("myattr", "myattrval"))); when(serverWebExchange.getRequiredAttribute(GATEWAY_ORIGINAL_REQUEST_URL_ATTR)) .thenReturn(new LinkedHashSet<>()); when(serverWebExchange.getRequest()).thenReturn(request); @@ -323,9 +324,12 @@ void shouldPassRequestToLoadBalancer() { filter.filter(serverWebExchange, chain); - verify(loadBalancer).choose(argThat((Request passedRequest) -> ((RequestDataContext) passedRequest.getContext()) - .getClientRequest().getUrl().equals(request.getURI()) - && ((RequestDataContext) passedRequest.getContext()).getHint().equals(hint))); + verify(loadBalancer).choose(argThat((Request passedRequest) -> { + RequestDataContext context = (RequestDataContext) passedRequest.getContext(); + return context.getClientRequest().getUrl().equals(request.getURI()) + && "myattrval".equals(context.getClientRequest().getAttributes().get("myattr")) + && context.getHint().equals(hint); + })); }