Skip to content

Commit

Permalink
Fix r2-netty illegal state exception due to premature channel recycli…
Browse files Browse the repository at this point in the history
…ng (#973)

Currently, the r2-netty client recycles connections once the response body has been fully received. However, it is possible for a server to return a response before fully consuming the request body. In this case, the channel will be recycled before the request body has made it through the pipeline. If a subsequent request comes in before the prior request is complete, an illegal state exception is thrown by the netty http object encoder.
  • Loading branch information
TylerHorth authored Jan 31, 2024
1 parent 6e6d645 commit 9de7a95
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 12 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and what APIs have changed, if applicable.

## [Unreleased]

## [29.50.1] - 2024-01-31
- Fix r2-netty illegal state exception due to premature channel recycling.

## [29.50.0] - 2024-01-31
- Minor version bump due to internal LinkedIn tooling requirement. No functional changes.

Expand Down Expand Up @@ -5624,7 +5627,8 @@ patch operations can re-use these classes for generating patch messages.

## [0.14.1]

[Unreleased]: https://github.com/linkedin/rest.li/compare/v29.50.0...master
[Unreleased]: https://github.com/linkedin/rest.li/compare/v29.50.1...master
[29.50.1]: https://github.com/linkedin/rest.li/compare/v29.50.0...v29.50.1
[29.50.0]: https://github.com/linkedin/rest.li/compare/v29.49.9...v29.50.0
[29.49.9]: https://github.com/linkedin/rest.li/compare/v29.49.8...v29.49.9
[29.49.8]: https://github.com/linkedin/rest.li/compare/v29.49.7...v29.49.8
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version=29.50.0
version=29.50.1
group=com.linkedin.pegasus
org.gradle.configureondemand=true
org.gradle.parallel=true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,16 @@ public enum ChannelPipelineEvent
{
/**
* User event raised in the {@link ChannelPipeline} that indicates the
* response is fully received and the {@link Channel} is ready to be
* returned or disposed.
* request is fully written and the {@link Channel} may be ready to be
* returned or disposed. Channel may be returned once both the request
* and response are complete.
*/
REQUEST_COMPLETE,
/**
* User event raised in the {@link ChannelPipeline} that indicates the
* response is fully received and the {@link Channel} may be ready to be
* returned or disposed. Channel may be returned once both the request
* and response are complete.
*/
RESPONSE_COMPLETE
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.linkedin.r2.filter.R2Constants;
import com.linkedin.r2.message.stream.entitystream.ReadHandle;
import com.linkedin.r2.message.stream.entitystream.Reader;
import com.linkedin.r2.netty.common.ChannelPipelineEvent;
import com.linkedin.r2.netty.common.NettyChannelAttributes;
import com.linkedin.r2.netty.common.StreamingTimeout;
import com.linkedin.util.clock.SystemClock;
Expand Down Expand Up @@ -96,6 +97,7 @@ public void onDataAvailable(ByteString data)
public void onDone()
{
_ctx.writeAndFlush(EOF);
_ctx.fireUserEventTriggered(ChannelPipelineEvent.REQUEST_COMPLETE);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ public class ChannelLifecycleHandler extends ChannelInboundHandlerAdapter
{
private final boolean _recycle;

// State of the connection:
// If the connection is half closed, then either the request has been fully sent, or the response fully received.
private boolean _halfClosed = false;

public ChannelLifecycleHandler(boolean recycle)
{
_recycle = recycle;
Expand Down Expand Up @@ -71,15 +75,19 @@ private boolean isChannelRecyclableException(Throwable cause)
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt)
{
if (ChannelPipelineEvent.RESPONSE_COMPLETE == evt)
if (ChannelPipelineEvent.REQUEST_COMPLETE == evt || ChannelPipelineEvent.RESPONSE_COMPLETE == evt)
{
if (_recycle)
{
tryReturnChannel(ctx);
}
else
{
tryDisposeChannel(ctx);
_halfClosed = !_halfClosed;

if (!_halfClosed) {
if (_recycle)
{
tryReturnChannel(ctx);
}
else
{
tryDisposeChannel(ctx);
}
}
}
ctx.fireUserEventTriggered(evt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

import com.linkedin.common.callback.Callback;
import com.linkedin.data.ByteString;
import com.linkedin.r2.message.rest.RestRequest;
import com.linkedin.r2.message.stream.StreamRequest;
import com.linkedin.r2.message.stream.StreamResponse;
import com.linkedin.r2.message.stream.StreamResponseBuilder;
import com.linkedin.r2.message.stream.entitystream.EntityStreams;
import com.linkedin.r2.message.stream.entitystream.Writer;
import com.linkedin.r2.netty.common.ChannelPipelineEvent;
import com.linkedin.r2.netty.common.NettyChannelAttributes;
import com.linkedin.r2.netty.entitystream.StreamReader;
import com.linkedin.r2.netty.entitystream.StreamWriter;
Expand Down Expand Up @@ -77,6 +79,10 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise)
}
else
{
if (msg instanceof RestRequest)
{
ctx.fireUserEventTriggered(ChannelPipelineEvent.REQUEST_COMPLETE);
}
ctx.write(msg, promise);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
package com.linkedin.r2.transport.http.client;

import com.linkedin.common.callback.FutureCallback;
import com.linkedin.common.util.None;
import com.linkedin.data.ByteString;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.message.stream.StreamRequest;
import com.linkedin.r2.message.stream.StreamRequestBuilder;
import com.linkedin.r2.message.stream.StreamResponse;
import com.linkedin.r2.message.stream.entitystream.ByteStringWriter;
import com.linkedin.r2.message.stream.entitystream.EntityStream;
import com.linkedin.r2.message.stream.entitystream.EntityStreams;
import com.linkedin.r2.message.stream.entitystream.FullEntityReader;
import com.linkedin.r2.message.stream.entitystream.WriteHandle;
import com.linkedin.r2.message.stream.entitystream.Writer;
import com.linkedin.r2.transport.common.bridge.client.TransportClient;
import com.linkedin.r2.transport.common.bridge.common.TransportCallback;
import com.linkedin.r2.transport.common.bridge.common.TransportResponse;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.HttpVersion;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.apache.commons.io.Charsets;
import org.testng.Assert;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import static com.linkedin.r2.transport.http.client.HttpClientFactory.*;


public class TestPipelineV2NettyClient {
private static final int TIMEOUT_MILLIS = 1_000;
private static final int PORT = 8080;
private static final String LOCALHOST = "http://localhost:" + PORT;

private TestServer _server;
private HttpClientFactory _clientFactory;
private TransportClient _client;

@BeforeMethod
private void setup() {
_server = new TestServer();
_clientFactory = new HttpClientFactory.Builder().setUsePipelineV2(true).build();

HashMap<String, String> clientProperties = new HashMap<>();
clientProperties.put(HTTP_REQUEST_TIMEOUT, String.valueOf(TIMEOUT_MILLIS));
clientProperties.put(HTTP_POOL_SIZE, "1");

_client = _clientFactory.getClient(clientProperties);
}

@AfterMethod
private void shutdown() throws InterruptedException, ExecutionException, TimeoutException, IOException {
FutureCallback<None> clientShutdown = new FutureCallback<>();
FutureCallback<None> factoryShutdown = new FutureCallback<>();

_client.shutdown(clientShutdown);
_clientFactory.shutdown(factoryShutdown);

clientShutdown.get(TIMEOUT_MILLIS, TimeUnit.MILLISECONDS);
factoryShutdown.get(TIMEOUT_MILLIS, TimeUnit.MILLISECONDS);

_server.close();
}

/**
* Test response returned before request complete.
* Connection should not be returned to the pool until after the request payload has been fully uploaded.
*/
@Test
public void testResponseReturnedBeforeRequestComplete() throws Exception {
DelayWriter delayWriter = new DelayWriter(new ByteStringWriter(ByteString.copyString("Hello!", Charsets.UTF_8)));

verifyResponse(postRequest(EntityStreams.newEntityStream(delayWriter)));

CompletableFuture<StreamResponse> secondResponseFuture = postRequest(EntityStreams.emptyStream());

delayWriter.run();

verifyResponse(secondResponseFuture);
}

private CompletableFuture<StreamResponse> postRequest(EntityStream body) throws URISyntaxException {
StreamRequest streamRequest = new StreamRequestBuilder(new URI(LOCALHOST)).setMethod("POST").build(body);

CompletableTransportCallback responseFutureCallback = new CompletableTransportCallback();
_client.streamRequest(streamRequest, new RequestContext(), new HashMap<>(), responseFutureCallback);

return responseFutureCallback;
}

private void verifyResponse(CompletableFuture<StreamResponse> responseFuture) throws Exception {
StreamResponse response = responseFuture.get(TIMEOUT_MILLIS, TimeUnit.MILLISECONDS);

Assert.assertEquals(response.getStatus(), 200);

FutureCallback<ByteString> responseBodyFuture = new FutureCallback<>();
response.getEntityStream().setReader(new FullEntityReader(responseBodyFuture));

String responseBody = responseBodyFuture.get(TIMEOUT_MILLIS, TimeUnit.MILLISECONDS).asString(StandardCharsets.UTF_8);
Assert.assertEquals(responseBody, "GOOD");
}

@ChannelHandler.Sharable
private static class TestServer extends ChannelInboundHandlerAdapter implements Closeable {
private final NioEventLoopGroup _group = new NioEventLoopGroup();
private final Channel _channel;

public TestServer() {
ChannelFuture channelFuture = new ServerBootstrap()
.group(_group)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<NioSocketChannel>() {
@Override
protected void initChannel(NioSocketChannel ch) throws Exception {
ch.pipeline().addLast(new HttpServerCodec(), TestServer.this);
}
})
.bind(new InetSocketAddress(PORT));

channelFuture.awaitUninterruptibly(TIMEOUT_MILLIS, TimeUnit.MILLISECONDS);

_channel = channelFuture.channel();
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof HttpRequest) {
ByteBuf body = Unpooled.copiedBuffer("GOOD", Charsets.UTF_8);
HttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, body);
response.headers().add(HttpHeaderNames.CONTENT_LENGTH, body.readableBytes());
ctx.writeAndFlush(response);
}
}

@Override
public void close() throws IOException {
_channel.close().awaitUninterruptibly(TIMEOUT_MILLIS, TimeUnit.MILLISECONDS);
_group.shutdownGracefully().awaitUninterruptibly(TIMEOUT_MILLIS, TimeUnit.MILLISECONDS);
}
}

private static class DelayWriter implements Writer {
private final Writer _delegate;
private final DelayExecutor _executor = new DelayExecutor();

public DelayWriter(Writer delegate) {
_delegate = delegate;
}

public void run() throws InterruptedException {
_executor.run();
}

@Override
public void onInit(WriteHandle wh) {
_executor.execute(() -> _delegate.onInit(new WriteHandle() {
@Override
public void write(ByteString data) {
wh.write(data);
}

@Override
public void done() {
wh.done();
_executor.shutdown();
}

@Override
public void error(Throwable throwable) {
wh.error(throwable);
_executor.shutdown();
}

@Override
public int remaining() {
return wh.remaining();
}
}));
}

@Override
public void onWritePossible() {
_executor.execute(_delegate::onWritePossible);
}

@Override
public void onAbort(Throwable e) {
_executor.execute(() -> _delegate.onAbort(e));
_executor.shutdown();
}
}

private static class DelayExecutor implements Executor {
private static final Runnable TERMINATE = () -> {};
private final BlockingQueue<Runnable> _tasks = new LinkedBlockingQueue<>();
private final Thread _thread = new Thread(() -> {
try {
Runnable task;
while ((task = _tasks.take()) != TERMINATE) {
task.run();
}
} catch (InterruptedException ignored) {
}
});

@Override
public void execute(Runnable command) {
_tasks.add(command);
}

public void run() throws InterruptedException {
_thread.start();
_thread.join();
}

public void shutdown() {
_tasks.add(TERMINATE);
}
}

private static class CompletableTransportCallback extends CompletableFuture<StreamResponse>
implements TransportCallback<StreamResponse> {
@Override
public void onResponse(TransportResponse<StreamResponse> response) {
if (response.hasError()) {
completeExceptionally(response.getError());
} else {
complete(response.getResponse());
}
}
}
}

0 comments on commit 9de7a95

Please sign in to comment.