Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grpc status headers #2973

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ public void hello(HelloRequest request, StreamObserver<HelloResponse> responseOb
HelloResponse response = HelloResponse.newBuilder().setGreeting(greeting).build();

responseObserver.onNext(response);

if ("failWithRuntimeExceptionAfterData!".equals(request.getFirstName())) {
StatusRuntimeException exception = Status.RESOURCE_EXHAUSTED.withDescription("Too long firstNames?")
.asRuntimeException();
responseObserver.onError(exception);
return;
}

responseObserver.onCompleted();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.springframework.boot.test.web.server.LocalServerPort;

import static io.grpc.Status.FAILED_PRECONDITION;
import static io.grpc.Status.RESOURCE_EXHAUSTED;
import static io.grpc.netty.NegotiationType.TLS;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment;

Expand Down Expand Up @@ -75,15 +76,35 @@ private ManagedChannel createSecuredChannel(int port) throws SSLException {
@Test
public void gRPCUnaryCallShouldHandleRuntimeException() throws SSLException {
ManagedChannel channel = createSecuredChannel(gatewayPort);
boolean thrown = false;

try {
HelloServiceGrpc.newBlockingStub(channel)
.hello(HelloRequest.newBuilder().setFirstName("failWithRuntimeException!").build());
}
catch (StatusRuntimeException e) {
Assertions.assertThat(FAILED_PRECONDITION.getCode()).isEqualTo(e.getStatus().getCode());
Assertions.assertThat("Invalid firstName").isEqualTo(e.getStatus().getDescription());
thrown = true;
Assertions.assertThat(e.getStatus().getCode()).isEqualTo(FAILED_PRECONDITION.getCode());
Assertions.assertThat(e.getStatus().getDescription()).isEqualTo("Invalid firstName");
}
Assertions.assertThat(thrown).withFailMessage("Expected exception not thrown!").isTrue();
}

@Test
public void gRPCUnaryCallShouldHandleRuntimeException2() throws SSLException {
ManagedChannel channel = createSecuredChannel(gatewayPort);
boolean thrown = false;
try {
HelloServiceGrpc.newBlockingStub(channel)
.hello(HelloRequest.newBuilder().setFirstName("failWithRuntimeExceptionAfterData!").build())
.getGreeting();
}
catch (StatusRuntimeException e) {
thrown = true;
Assertions.assertThat(e.getStatus().getCode()).isEqualTo(RESOURCE_EXHAUSTED.getCode());
Assertions.assertThat(e.getStatus().getDescription()).isEqualTo("Too long firstNames?");
}
Assertions.assertThat(thrown).withFailMessage("Expected exception not thrown!").isTrue();
}

private TrustManager[] createTrustAllTrustManager() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.springframework.cloud.gateway.filter.headers;

import reactor.netty.http.client.HttpClientResponse;
import reactor.netty.http.server.HttpServerResponse;

import org.springframework.core.Ordered;
Expand All @@ -26,6 +27,8 @@
import org.springframework.util.StringUtils;
import org.springframework.web.server.ServerWebExchange;

import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CLIENT_RESPONSE_ATTR;

/**
* @author Alberto C. Ríos
*/
Expand All @@ -37,45 +40,62 @@ public class GRPCResponseHeadersFilter implements HttpHeadersFilter, Ordered {

@Override
public HttpHeaders filter(HttpHeaders headers, ServerWebExchange exchange) {
ServerHttpResponse response = exchange.getResponse();
HttpHeaders responseHeaders = response.getHeaders();
if (isGRPC(exchange)) {
String trailerHeaderValue = GRPC_STATUS_HEADER + "," + GRPC_MESSAGE_HEADER;
String originalTrailerHeaderValue = responseHeaders.getFirst(HttpHeaders.TRAILER);
if (originalTrailerHeaderValue != null) {
trailerHeaderValue += "," + originalTrailerHeaderValue;
}
responseHeaders.set(HttpHeaders.TRAILER, trailerHeaderValue);
ServerHttpResponse response = exchange.getResponse();
HttpHeaders responseHeaders = response.getHeaders();

while (response instanceof ServerHttpResponseDecorator) {
response = ((ServerHttpResponseDecorator) response).getDelegate();
if (headers.containsKey(GRPC_STATUS_HEADER)) {
if (!"0".equals(headers.getFirst(GRPC_STATUS_HEADER))) {
response.setComplete(); // avoid empty DATA frame
}
}
if (response instanceof AbstractServerHttpResponse) {
String grpcStatus = getGrpcStatus(headers);
String grpcMessage = getGrpcMessage(headers);
((HttpServerResponse) ((AbstractServerHttpResponse) response).getNativeResponse()).trailerHeaders(h -> {
h.set(GRPC_STATUS_HEADER, grpcStatus);
h.set(GRPC_MESSAGE_HEADER, grpcMessage);

HttpClientResponse nettyInResponse = exchange.getAttribute(CLIENT_RESPONSE_ATTR);
if (nettyInResponse != null) {
nettyInResponse.trailerHeaders().subscribe(entries -> {
if (entries.contains(GRPC_STATUS_HEADER)) {
addTrailingHeader(entries, response, responseHeaders);
}
});
}

}

return headers;
}

private boolean isGRPC(ServerWebExchange exchange) {
String contentTypeValue = exchange.getRequest().getHeaders().getFirst(HttpHeaders.CONTENT_TYPE);
return StringUtils.startsWithIgnoreCase(contentTypeValue, "application/grpc");
private void addTrailingHeader(io.netty.handler.codec.http.HttpHeaders sourceHeaders, ServerHttpResponse response,
HttpHeaders responseHeaders) {
String trailerHeaderValue = GRPC_STATUS_HEADER + "," + GRPC_MESSAGE_HEADER;
String originalTrailerHeaderValue = responseHeaders.getFirst(HttpHeaders.TRAILER);
if (originalTrailerHeaderValue != null) {
trailerHeaderValue += "," + originalTrailerHeaderValue;
}
responseHeaders.set(HttpHeaders.TRAILER, trailerHeaderValue);

HttpServerResponse nettyOutResponse = getNettyResponse(response);
if (nettyOutResponse != null) {
String grpcStatus = sourceHeaders.get(GRPC_STATUS_HEADER, "0");
String grpcMessage = sourceHeaders.get(GRPC_MESSAGE_HEADER, "");
nettyOutResponse.trailerHeaders(h -> {
h.set(GRPC_STATUS_HEADER, grpcStatus);
h.set(GRPC_MESSAGE_HEADER, grpcMessage);
});
}
}

private String getGrpcStatus(HttpHeaders headers) {
final String grpcStatusValue = headers.getFirst(GRPC_STATUS_HEADER);
return StringUtils.hasText(grpcStatusValue) ? grpcStatusValue : "0";
private HttpServerResponse getNettyResponse(ServerHttpResponse response) {
while (response instanceof ServerHttpResponseDecorator) {
response = ((ServerHttpResponseDecorator) response).getDelegate();
}
if (response instanceof AbstractServerHttpResponse) {
return ((AbstractServerHttpResponse) response).getNativeResponse();
}
return null;
}

private String getGrpcMessage(HttpHeaders headers) {
final String grpcStatusValue = headers.getFirst(GRPC_MESSAGE_HEADER);
return StringUtils.hasText(grpcStatusValue) ? grpcStatusValue : "";
private boolean isGRPC(ServerWebExchange exchange) {
String contentTypeValue = exchange.getRequest().getHeaders().getFirst(HttpHeaders.CONTENT_TYPE);
return StringUtils.startsWithIgnoreCase(contentTypeValue, "application/grpc");
}

@Override
Expand Down