diff --git a/d2/src/main/java/com/linkedin/d2/balancer/clients/BackupRequestsClient.java b/d2/src/main/java/com/linkedin/d2/balancer/clients/BackupRequestsClient.java index 655e9a6fcb..241bbd4d08 100644 --- a/d2/src/main/java/com/linkedin/d2/balancer/clients/BackupRequestsClient.java +++ b/d2/src/main/java/com/linkedin/d2/balancer/clients/BackupRequestsClient.java @@ -32,6 +32,7 @@ import com.linkedin.d2.balancer.properties.ServiceProperties; import com.linkedin.d2.balancer.strategies.LoadBalancerStrategy.ExcludedHostHints; import com.linkedin.d2.balancer.util.LoadBalancerUtil; +import com.linkedin.data.ByteString; import com.linkedin.r2.filter.R2Constants; import com.linkedin.r2.message.Request; import com.linkedin.r2.message.RequestContext; @@ -39,6 +40,9 @@ import com.linkedin.r2.message.rest.RestResponse; import com.linkedin.r2.message.stream.StreamRequest; import com.linkedin.r2.message.stream.StreamResponse; +import com.linkedin.r2.message.stream.entitystream.ByteStringWriter; +import com.linkedin.r2.message.stream.entitystream.EntityStreams; +import com.linkedin.r2.message.stream.entitystream.FullEntityObserver; import com.linkedin.r2.util.NamedThreadFactory; import java.net.URI; import java.util.List; @@ -380,6 +384,29 @@ public void streamRequest(StreamRequest request, Callback callba @Override public void streamRequest(StreamRequest request, RequestContext requestContext, Callback callback) { + // Buffering stream request raises concerns on memory usage and performance. + // Currently only support backup requests with IS_FULL_REQUEST. + if (!isFullRequest(requestContext)) { + _d2Client.streamRequest(request, requestContext, callback); + return; + } + if (!isBuffered(requestContext)) { + final FullEntityObserver observer = new FullEntityObserver(new Callback() + { + @Override + public void onError(Throwable e) + { + LOG.warn("Failed to record request's entity for retrying backup request."); + } + + @Override + public void onSuccess(ByteString result) + { + requestContext.putLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY, result); + } + }); + request.getEntityStream().addObserver(observer); + } if (_isD2Async) { requestAsync(request, requestContext, _d2Client::streamRequest, callback); @@ -524,6 +551,7 @@ public DecoratedCallback(R request, RequestContext requestContext, DecoratorClie executorService.schedule(this::maybeSendBackupRequest, delayNano, TimeUnit.NANOSECONDS); } + @SuppressWarnings("unchecked") private void maybeSendBackupRequest() { Set exclusionSet = ExcludedHostHints.getRequestContextExcludedHosts(_requestContext); @@ -532,9 +560,25 @@ private void maybeSendBackupRequest() if (exclusionSet != null) { exclusionSet.forEach(uri -> ExcludedHostHints.addRequestContextExcludedHost(_backupRequestContext, uri)); + if (_request instanceof StreamRequest && !isBuffered(_requestContext)) { + return; + } if (!_done.get() && _strategy.isBackupRequestAllowed()) { - _client.doRequest(_request, _backupRequestContext, new Callback() + R request = _request; + if (_request instanceof StreamRequest) { + StreamRequest req = (StreamRequest)_request; + req = req.builder() + .build(EntityStreams.newEntityStream(new ByteStringWriter( + (ByteString) _requestContext.getLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY) + ))); + request = (R)req; + if (!isBuffered(_backupRequestContext)) { + _backupRequestContext.putLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY, + _requestContext.getLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY)); + } + } + _client.doRequest(request, _backupRequestContext, new Callback() { @Override public void onSuccess(T result) @@ -721,4 +765,15 @@ public boolean equals(Object obj) } + private static boolean isFullRequest(RequestContext requestContext) + { + Object isFullRequest = requestContext.getLocalAttr(R2Constants.IS_FULL_REQUEST); + return isFullRequest != null && (Boolean)isFullRequest; + } + + private static boolean isBuffered(RequestContext requestContext) + { + Object bufferedBody = requestContext.getLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY); + return bufferedBody != null; + } } diff --git a/d2/src/test/java/com/linkedin/d2/balancer/clients/TestBackupRequestsClient.java b/d2/src/test/java/com/linkedin/d2/balancer/clients/TestBackupRequestsClient.java index 1fdcf0bb7b..5721206475 100644 --- a/d2/src/test/java/com/linkedin/d2/balancer/clients/TestBackupRequestsClient.java +++ b/d2/src/test/java/com/linkedin/d2/balancer/clients/TestBackupRequestsClient.java @@ -34,9 +34,9 @@ import com.linkedin.d2.backuprequests.TestTrackingBackupRequestsStrategy; import com.linkedin.d2.backuprequests.TrackingBackupRequestsStrategy; import com.linkedin.d2.balancer.KeyMapper; +import com.linkedin.d2.balancer.LoadBalancer; import com.linkedin.d2.balancer.ServiceUnavailableException; import com.linkedin.d2.balancer.StaticLoadBalancerState; -import com.linkedin.d2.balancer.LoadBalancer; import com.linkedin.d2.balancer.properties.PartitionData; import com.linkedin.d2.balancer.properties.ServiceProperties; import com.linkedin.d2.balancer.simple.SimpleLoadBalancer; @@ -51,11 +51,17 @@ import com.linkedin.r2.message.rest.RestRequestBuilder; import com.linkedin.r2.message.rest.RestResponse; import com.linkedin.r2.message.rest.RestResponseBuilder; +import com.linkedin.r2.message.stream.StreamRequest; +import com.linkedin.r2.message.stream.StreamRequestBuilder; +import com.linkedin.r2.message.stream.StreamResponse; +import com.linkedin.r2.message.stream.StreamResponseBuilder; +import com.linkedin.r2.message.stream.entitystream.ByteStringWriter; +import com.linkedin.r2.message.stream.entitystream.DrainReader; +import com.linkedin.r2.message.stream.entitystream.EntityStreams; import com.linkedin.r2.transport.common.bridge.client.TransportClient; import com.linkedin.r2.transport.common.bridge.common.TransportCallback; import com.linkedin.r2.transport.common.bridge.common.TransportResponseImpl; import com.linkedin.util.clock.SystemClock; - import java.io.IOException; import java.net.URI; import java.util.ArrayList; @@ -69,6 +75,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ScheduledExecutorService; @@ -84,11 +91,7 @@ import org.testng.annotations.DataProvider; import org.testng.annotations.Test; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertNotNull; -import static org.testng.Assert.assertNotSame; -import static org.testng.Assert.assertSame; -import static org.testng.Assert.assertTrue; +import static org.testng.Assert.*; public class TestBackupRequestsClient @@ -98,6 +101,7 @@ public class TestBackupRequestsClient private static final String CLUSTER_NAME = "testCluster"; private static final String PATH = ""; private static final String STRATEGY_NAME = "degrader"; + private static final String BUFFERED_HEADER = "buffered"; private static final ByteString CONTENT = ByteString.copy(new byte[8092]); private ScheduledExecutorService _executor; @@ -126,6 +130,105 @@ public void testRequest(boolean isD2Async) throws Exception assertEquals(response.get().getStatus(), 200); } + @Test(invocationCount = 3, dataProvider = "isD2Async") + public void testStreamRequestWithNoIsFullRequest(boolean isD2Async) throws Exception { + int responseDelayNano = 100000000; //1s till response comes back + int backupDelayNano = 50000000; // make backup request after 0.5 second + Deque hostsReceivingRequest = new ConcurrentLinkedDeque<>(); + BackupRequestsClient client = + createAlwaysBackupClientWithHosts(Arrays.asList("http://test1.com:123", "http://test2.com:123"), + hostsReceivingRequest, responseDelayNano, backupDelayNano, isD2Async); + + URI uri = URI.create("d2://testService"); + + // if there is no IS_FULL_REQUEST set, backup requests will not happen + StreamRequest streamRequest = + new StreamRequestBuilder(uri).build(EntityStreams.newEntityStream(new ByteStringWriter(CONTENT))); + RequestContext context = new RequestContext(); + context.putLocalAttr(R2Constants.OPERATION, "get"); + RequestContext context1 = context.clone(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference failure = new AtomicReference<>(); + + client.streamRequest(streamRequest, context1, new Callback() { + @Override + public void onError(Throwable e) { + failure.set(new AssertionError("Callback onError")); + latch.countDown(); + } + + @Override + public void onSuccess(StreamResponse result) { + try { + assertEquals(result.getStatus(), 200); + assertEquals(result.getHeader("buffered"), "false"); + assertEquals(hostsReceivingRequest.size(), 1); + assertEquals(new HashSet<>(hostsReceivingRequest).size(), 1); + hostsReceivingRequest.clear(); + } catch (AssertionError e) { + failure.set(e); + } + latch.countDown(); + } + }); + + latch.await(2, TimeUnit.SECONDS); + if (failure.get() != null) { + throw failure.get(); + } + } + + @Test(invocationCount = 3, dataProvider = "isD2Async") + public void testStreamRequestWithIsFullRequest(boolean isD2Async) throws Exception { + int responseDelayNano = 500000000; //5s till response comes back + int backupDelayNano = 100000000; // make backup request after 1 second + Deque hostsReceivingRequest = new ConcurrentLinkedDeque<>(); + BackupRequestsClient client = + createAlwaysBackupClientWithHosts(Arrays.asList("http://test1.com:123", "http://test2.com:123"), + hostsReceivingRequest, responseDelayNano, backupDelayNano, isD2Async); + + URI uri = URI.create("d2://testService"); + + // if there is IS_FULL_REQUEST set, backup requests will happen + StreamRequest streamRequest = + new StreamRequestBuilder(uri).build(EntityStreams.newEntityStream(new ByteStringWriter(CONTENT))); + RequestContext context = new RequestContext(); + context.putLocalAttr(R2Constants.OPERATION, "get"); + context.putLocalAttr(R2Constants.IS_FULL_REQUEST, true); + RequestContext context1 = context.clone(); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference failure = new AtomicReference<>(); + + client.streamRequest(streamRequest, context1, new Callback() { + @Override + public void onError(Throwable e) { + failure.set(new AssertionError("Callback onError")); + latch.countDown(); + } + + @Override + public void onSuccess(StreamResponse result) { + try { + assertEquals(result.getStatus(), 200); + assertEquals(result.getHeader("buffered"), "true"); + assertEquals(hostsReceivingRequest.size(), 2); + assertEquals(new HashSet<>(hostsReceivingRequest).size(), 2); + hostsReceivingRequest.clear(); + } catch (AssertionError e) { + failure.set(e); + } + latch.countDown(); + } + }); + + latch.await(6, TimeUnit.SECONDS); + if (failure.get() != null) { + throw failure.get(); + } + } + /** * Backup Request should still work when a hint is given together with the flag indicating that the hint is only a preference, not requirement. */ @@ -629,6 +732,31 @@ public void restRequest(RestRequest request, () -> callback.onResponse(TransportResponseImpl.success(new RestResponseBuilder().build())), responseDelayNano, TimeUnit.NANOSECONDS); } + + @Override + public void streamRequest(StreamRequest request, + RequestContext requestContext, + Map wireAttrs, + TransportCallback callback) { + // whenever a trackerClient is used to make request, record down it's hostname + hostsReceivingRequestList.add(uri); + if (null != requestContext.getLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY)) { + callback.onResponse(TransportResponseImpl.success(new StreamResponseBuilder().setHeader( + BUFFERED_HEADER, String.valueOf(requestContext.getLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY) != null) + ).build(EntityStreams.emptyStream()))); + return; + } + request.getEntityStream().setReader(new DrainReader(){ + public void onDone() { + // delay response to allow backup request to happen + _executor.schedule( + () -> callback.onResponse(TransportResponseImpl.success(new StreamResponseBuilder().setHeader( + BUFFERED_HEADER, String.valueOf(requestContext.getLocalAttr(R2Constants.BACKUP_REQUEST_BUFFERED_BODY) != null) + ).build(EntityStreams.emptyStream()))), responseDelayNano, + TimeUnit.NANOSECONDS); + } + }); + } }; } }; diff --git a/r2-core/src/main/java/com/linkedin/r2/filter/R2Constants.java b/r2-core/src/main/java/com/linkedin/r2/filter/R2Constants.java index c685753233..830e70688b 100644 --- a/r2-core/src/main/java/com/linkedin/r2/filter/R2Constants.java +++ b/r2-core/src/main/java/com/linkedin/r2/filter/R2Constants.java @@ -56,6 +56,7 @@ public class R2Constants public static final int DEFAULT_DATA_CHUNK_SIZE = 8192; public static final boolean DEFAULT_REST_OVER_STREAM = false; public static final String RETRY_MESSAGE_ATTRIBUTE_KEY = "RETRY"; + public static final String BACKUP_REQUEST_BUFFERED_BODY = "BACKUP_REQUEST_BUFFERED_BODY"; @Deprecated public static final String EXPECTED_SERVER_CERT_PRINCIPAL_NAME = "EXPECTED_SERVER_CERT_PRINCIPAL_NAME"; public static final String REQUESTED_SSL_SESSION_VALIDATOR = "REQUESTED_SSL_SESSION_VALIDATOR";