Skip to content

Commit

Permalink
[Dataflow Streaming] Support to receive multiple work items in a sing…
Browse files Browse the repository at this point in the history
…le StreamingGetWorkResponseChunk (#33512)
  • Loading branch information
arunpandianp authored Jan 10, 2025
1 parent 35732ce commit d7c5691
Show file tree
Hide file tree
Showing 11 changed files with 332 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ public interface DataflowStreamingPipelineOptions extends PipelineOptions {

void setUseSeparateWindmillHeartbeatStreams(Boolean value);

@Description("If true, GetWorkStreams will request multiple work items in a response chunk.")
boolean getWindmillRequestBatchedGetWorkResponse();

void setWindmillRequestBatchedGetWorkResponse(boolean value);

@Description("The number of streams to use for GetData requests.")
@Default.Integer(1)
int getWindmillGetDataStreamCount();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,8 @@ private static GrpcWindmillStreamFactory.Builder createGrpcwindmillStreamFactory
.setSendKeyedGetDataRequests(
!options.isEnableStreamingEngine()
|| DataflowRunner.hasExperiment(
options, "streaming_engine_disable_new_heartbeat_requests"));
options, "streaming_engine_disable_new_heartbeat_requests"))
.setRequestBatchedGetWorkResponse(options.getWindmillRequestBatchedGetWorkResponse());
}

private static JobHeader createJobHeader(DataflowWorkerHarnessOptions options, long clientId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import com.google.auto.value.AutoValue;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import javax.annotation.Nullable;
Expand All @@ -43,6 +44,7 @@
*/
@NotThreadSafe
final class GetWorkResponseChunkAssembler {

private static final Logger LOG = LoggerFactory.getLogger(GetWorkResponseChunkAssembler.class);

private final GetWorkTimingInfosTracker workTimingInfosTracker;
Expand All @@ -61,17 +63,26 @@ final class GetWorkResponseChunkAssembler {
* Appends the response chunk bytes to the {@link #data }byte buffer. Return the assembled
* WorkItem if all response chunks for a WorkItem have been received.
*/
Optional<AssembledWorkItem> append(Windmill.StreamingGetWorkResponseChunk chunk) {
List<AssembledWorkItem> append(Windmill.StreamingGetWorkResponseChunk chunk) {
if (chunk.hasComputationMetadata()) {
metadata = ComputationMetadata.fromProto(chunk.getComputationMetadata());
}

data = data.concat(chunk.getSerializedWorkItem());
bufferedSize += chunk.getSerializedWorkItem().size();
workTimingInfosTracker.addTimingInfo(chunk.getPerWorkItemTimingInfosList());

// If the entire WorkItem has been received, assemble the WorkItem.
return chunk.getRemainingBytesForWorkItem() == 0 ? flushToWorkItem() : Optional.empty();
List<AssembledWorkItem> response = new ArrayList<>();
for (int i = 0; i < chunk.getSerializedWorkItemList().size(); i++) {
ByteString serializedWorkItem = chunk.getSerializedWorkItemList().get(i);
data = data.concat(serializedWorkItem);
bufferedSize += serializedWorkItem.size();
long remainingSize = 0;
if (i == chunk.getSerializedWorkItemList().size() - 1) {
remainingSize = chunk.getRemainingBytesForWorkItem();
}
if (remainingSize == 0) {
flushToWorkItem().ifPresent(response::add);
}
}
return response;
}

/**
Expand Down Expand Up @@ -100,6 +111,7 @@ private Optional<AssembledWorkItem> flushToWorkItem() {

@AutoValue
abstract static class ComputationMetadata {

private static ComputationMetadata fromProto(
Windmill.ComputationWorkItemMetadata metadataProto) {
return new AutoValue_GetWorkResponseChunkAssembler_ComputationMetadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ final class GrpcDirectGetWorkStream
extends AbstractWindmillStream<StreamingGetWorkRequest, StreamingGetWorkResponseChunk>
implements GetWorkStream {
private static final Logger LOG = LoggerFactory.getLogger(GrpcDirectGetWorkStream.class);

private static final StreamingGetWorkRequest HEALTH_CHECK_REQUEST =
StreamingGetWorkRequest.newBuilder()
.setRequestExtension(
Expand Down Expand Up @@ -88,6 +89,8 @@ final class GrpcDirectGetWorkStream
*/
private final ConcurrentMap<Long, GetWorkResponseChunkAssembler> workItemAssemblers;

private final boolean requestBatchedGetWorkResponse;

private GrpcDirectGetWorkStream(
String backendWorkerToken,
Function<
Expand All @@ -99,6 +102,7 @@ private GrpcDirectGetWorkStream(
StreamObserverFactory streamObserverFactory,
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
boolean requestBatchedGetWorkResponse,
ThrottleTimer getWorkThrottleTimer,
HeartbeatSender heartbeatSender,
GetDataClient getDataClient,
Expand Down Expand Up @@ -127,6 +131,7 @@ private GrpcDirectGetWorkStream(
.setItems(requestHeader.getMaxItems())
.setBytes(requestHeader.getMaxBytes())
.build());
this.requestBatchedGetWorkResponse = requestBatchedGetWorkResponse;
}

static GrpcDirectGetWorkStream create(
Expand All @@ -140,6 +145,7 @@ static GrpcDirectGetWorkStream create(
StreamObserverFactory streamObserverFactory,
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
boolean requestBatchedGetWorkResponse,
ThrottleTimer getWorkThrottleTimer,
HeartbeatSender heartbeatSender,
GetDataClient getDataClient,
Expand All @@ -153,6 +159,7 @@ static GrpcDirectGetWorkStream create(
streamObserverFactory,
streamRegistry,
logEveryNStreamFailures,
requestBatchedGetWorkResponse,
getWorkThrottleTimer,
heartbeatSender,
getDataClient,
Expand Down Expand Up @@ -209,6 +216,7 @@ protected synchronized void onNewStream() throws WindmillStreamShutdownException
.setMaxItems(initialGetWorkBudget.items())
.setMaxBytes(initialGetWorkBudget.bytes())
.build())
.setSupportsMultipleWorkItemsInChunk(requestBatchedGetWorkResponse)
.build();
lastRequest.set(request);
budgetTracker.recordBudgetRequested(initialGetWorkBudget);
Expand Down Expand Up @@ -243,7 +251,7 @@ protected void onResponse(StreamingGetWorkResponseChunk chunk) {
workItemAssemblers
.computeIfAbsent(chunk.getStreamId(), unused -> new GetWorkResponseChunkAssembler())
.append(chunk)
.ifPresent(this::consumeAssembledWorkItem);
.forEach(this::consumeAssembledWorkItem);
}

private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,19 @@ final class GrpcGetWorkStream
implements GetWorkStream {

private static final Logger LOG = LoggerFactory.getLogger(GrpcGetWorkStream.class);
private static final StreamingGetWorkRequest HEALTH_CHECK =
StreamingGetWorkRequest.newBuilder()
.setRequestExtension(
StreamingGetWorkRequestExtension.newBuilder().setMaxItems(0).setMaxBytes(0).build())
.build();

private final GetWorkRequest request;
private final WorkItemReceiver receiver;
private final ThrottleTimer getWorkThrottleTimer;
private final Map<Long, GetWorkResponseChunkAssembler> workItemAssemblers;
private final AtomicLong inflightMessages;
private final AtomicLong inflightBytes;
private final boolean requestBatchedGetWorkResponse;

private GrpcGetWorkStream(
String backendWorkerToken,
Expand All @@ -64,6 +70,7 @@ private GrpcGetWorkStream(
StreamObserverFactory streamObserverFactory,
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
boolean requestBatchedGetWorkResponse,
ThrottleTimer getWorkThrottleTimer,
WorkItemReceiver receiver) {
super(
Expand All @@ -81,6 +88,7 @@ private GrpcGetWorkStream(
this.workItemAssemblers = new ConcurrentHashMap<>();
this.inflightMessages = new AtomicLong();
this.inflightBytes = new AtomicLong();
this.requestBatchedGetWorkResponse = requestBatchedGetWorkResponse;
}

public static GrpcGetWorkStream create(
Expand All @@ -94,6 +102,7 @@ public static GrpcGetWorkStream create(
StreamObserverFactory streamObserverFactory,
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
boolean requestBatchedGetWorkResponse,
ThrottleTimer getWorkThrottleTimer,
WorkItemReceiver receiver) {
return new GrpcGetWorkStream(
Expand All @@ -104,6 +113,7 @@ public static GrpcGetWorkStream create(
streamObserverFactory,
streamRegistry,
logEveryNStreamFailures,
requestBatchedGetWorkResponse,
getWorkThrottleTimer,
receiver);
}
Expand Down Expand Up @@ -132,7 +142,11 @@ protected synchronized void onNewStream() throws WindmillStreamShutdownException
workItemAssemblers.clear();
inflightMessages.set(request.getMaxItems());
inflightBytes.set(request.getMaxBytes());
trySend(StreamingGetWorkRequest.newBuilder().setRequest(request).build());
trySend(
StreamingGetWorkRequest.newBuilder()
.setSupportsMultipleWorkItemsInChunk(requestBatchedGetWorkResponse)
.setRequest(request)
.build());
}

@Override
Expand All @@ -153,11 +167,7 @@ public void appendSpecificHtml(PrintWriter writer) {

@Override
public void sendHealthCheck() throws WindmillStreamShutdownException {
trySend(
StreamingGetWorkRequest.newBuilder()
.setRequestExtension(
StreamingGetWorkRequestExtension.newBuilder().setMaxItems(0).setMaxBytes(0).build())
.build());
trySend(HEALTH_CHECK);
}

@Override
Expand All @@ -166,7 +176,7 @@ protected void onResponse(StreamingGetWorkResponseChunk chunk) {
workItemAssemblers
.computeIfAbsent(chunk.getStreamId(), unused -> new GetWorkResponseChunkAssembler())
.append(chunk)
.ifPresent(this::consumeAssembledWorkItem);
.forEach(this::consumeAssembledWorkItem);
}

private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ static GrpcWindmillServer newTestInstance(
.setSendKeyedGetDataRequests(sendKeyedGetDataRequests)
.setHealthCheckIntervalMillis(
testOptions.getWindmillServiceStreamingRpcHealthCheckPeriodMs())
.setRequestBatchedGetWorkResponse(
testOptions.getWindmillRequestBatchedGetWorkResponse())
.build();

return new GrpcWindmillServer(testOptions, windmillStreamFactory, dispatcherClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public class GrpcWindmillStreamFactory implements StatusDataProvider {
// If true, then active work refreshes will be sent as KeyedGetDataRequests. Otherwise, use the
// newer ComputationHeartbeatRequests.
private final boolean sendKeyedGetDataRequests;
private final boolean requestBatchedGetWorkResponse;
private final Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses;

private GrpcWindmillStreamFactory(
Expand All @@ -99,6 +100,7 @@ private GrpcWindmillStreamFactory(
int streamingRpcBatchLimit,
int windmillMessagesBetweenIsReadyChecks,
boolean sendKeyedGetDataRequests,
boolean requestBatchedGetWorkResponse,
Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses,
Supplier<Duration> maxBackOffSupplier) {
this.jobHeader = jobHeader;
Expand All @@ -115,6 +117,7 @@ private GrpcWindmillStreamFactory(
.backoff());
this.streamRegistry = ConcurrentHashMap.newKeySet();
this.sendKeyedGetDataRequests = sendKeyedGetDataRequests;
this.requestBatchedGetWorkResponse = requestBatchedGetWorkResponse;
this.processHeartbeatResponses = processHeartbeatResponses;
this.streamIdGenerator = new AtomicLong();
}
Expand All @@ -126,6 +129,7 @@ static GrpcWindmillStreamFactory create(
int streamingRpcBatchLimit,
int windmillMessagesBetweenIsReadyChecks,
boolean sendKeyedGetDataRequests,
boolean requestBatchedGetWorkResponse,
Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses,
Supplier<Duration> maxBackOffSupplier,
int healthCheckIntervalMillis) {
Expand All @@ -136,6 +140,7 @@ static GrpcWindmillStreamFactory create(
streamingRpcBatchLimit,
windmillMessagesBetweenIsReadyChecks,
sendKeyedGetDataRequests,
requestBatchedGetWorkResponse,
processHeartbeatResponses,
maxBackOffSupplier);

Expand Down Expand Up @@ -174,6 +179,7 @@ public static GrpcWindmillStreamFactory.Builder of(JobHeader jobHeader) {
.setStreamingRpcBatchLimit(DEFAULT_STREAMING_RPC_BATCH_LIMIT)
.setHealthCheckIntervalMillis(NO_HEALTH_CHECKS)
.setSendKeyedGetDataRequests(true)
.setRequestBatchedGetWorkResponse(false)
.setProcessHeartbeatResponses(ignored -> {});
}

Expand Down Expand Up @@ -209,6 +215,7 @@ public GetWorkStream createGetWorkStream(
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
requestBatchedGetWorkResponse,
getWorkThrottleTimer,
processWorkItem);
}
Expand All @@ -229,6 +236,7 @@ public GetWorkStream createDirectGetWorkStream(
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
requestBatchedGetWorkResponse,
getWorkThrottleTimer,
heartbeatSender,
getDataClient,
Expand Down Expand Up @@ -356,6 +364,8 @@ Builder setProcessHeartbeatResponses(

Builder setHealthCheckIntervalMillis(int healthCheckIntervalMillis);

Builder setRequestBatchedGetWorkResponse(boolean enabled);

GrpcWindmillStreamFactory build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ public void testStreamsStartCorrectly() throws InterruptedException {
@Test
public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers()
throws InterruptedException {
TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor());
GetWorkBudgetDistributor getWorkBudgetDistributor = mock(GetWorkBudgetDistributor.class);
fanOutStreamingEngineWorkProvider =
newFanOutStreamingEngineWorkerHarness(
GetWorkBudget.builder().setItems(1).setBytes(1).build(),
Expand Down
Loading

0 comments on commit d7c5691

Please sign in to comment.