Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dataflow Streaming] Support to receive multiple work items in a single StreamingGetWorkResponseChunk #33512

Merged
merged 7 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ public interface DataflowStreamingPipelineOptions extends PipelineOptions {

void setUseSeparateWindmillHeartbeatStreams(Boolean value);

@Description("If true, GetWorkStreams will request multiple work items in a response chunk.")
boolean getWindmillRequestBatchedGetWorkResponse();

void setWindmillRequestBatchedGetWorkResponse(boolean value);

@Description("The number of streams to use for GetData requests.")
@Default.Integer(1)
int getWindmillGetDataStreamCount();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,8 @@ private static GrpcWindmillStreamFactory.Builder createGrpcwindmillStreamFactory
.setSendKeyedGetDataRequests(
!options.isEnableStreamingEngine()
|| DataflowRunner.hasExperiment(
options, "streaming_engine_disable_new_heartbeat_requests"));
options, "streaming_engine_disable_new_heartbeat_requests"))
.setRequestBatchedGetWorkResponse(options.getWindmillRequestBatchedGetWorkResponse());
}

private static JobHeader createJobHeader(DataflowWorkerHarnessOptions options, long clientId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import com.google.auto.value.AutoValue;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import javax.annotation.Nullable;
Expand All @@ -43,6 +44,7 @@
*/
@NotThreadSafe
final class GetWorkResponseChunkAssembler {

private static final Logger LOG = LoggerFactory.getLogger(GetWorkResponseChunkAssembler.class);

private final GetWorkTimingInfosTracker workTimingInfosTracker;
Expand All @@ -61,17 +63,26 @@ final class GetWorkResponseChunkAssembler {
* Appends the response chunk bytes to the {@link #data }byte buffer. Return the assembled
* WorkItem if all response chunks for a WorkItem have been received.
*/
Optional<AssembledWorkItem> append(Windmill.StreamingGetWorkResponseChunk chunk) {
List<AssembledWorkItem> append(Windmill.StreamingGetWorkResponseChunk chunk) {
if (chunk.hasComputationMetadata()) {
metadata = ComputationMetadata.fromProto(chunk.getComputationMetadata());
}

data = data.concat(chunk.getSerializedWorkItem());
bufferedSize += chunk.getSerializedWorkItem().size();
workTimingInfosTracker.addTimingInfo(chunk.getPerWorkItemTimingInfosList());

// If the entire WorkItem has been received, assemble the WorkItem.
return chunk.getRemainingBytesForWorkItem() == 0 ? flushToWorkItem() : Optional.empty();
List<AssembledWorkItem> response = new ArrayList<>();
for (int i = 0; i < chunk.getSerializedWorkItemList().size(); i++) {
ByteString serializedWorkItem = chunk.getSerializedWorkItemList().get(i);
data = data.concat(serializedWorkItem);
bufferedSize += serializedWorkItem.size();
long remainingSize = 0;
if (i == chunk.getSerializedWorkItemList().size() - 1) {
remainingSize = chunk.getRemainingBytesForWorkItem();
}
if (remainingSize == 0) {
flushToWorkItem().ifPresent(response::add);
}
}
return response;
}

/**
Expand Down Expand Up @@ -100,6 +111,7 @@ private Optional<AssembledWorkItem> flushToWorkItem() {

@AutoValue
abstract static class ComputationMetadata {

private static ComputationMetadata fromProto(
Windmill.ComputationWorkItemMetadata metadataProto) {
return new AutoValue_GetWorkResponseChunkAssembler_ComputationMetadata(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ final class GrpcDirectGetWorkStream
extends AbstractWindmillStream<StreamingGetWorkRequest, StreamingGetWorkResponseChunk>
implements GetWorkStream {
private static final Logger LOG = LoggerFactory.getLogger(GrpcDirectGetWorkStream.class);

private static final StreamingGetWorkRequest HEALTH_CHECK_REQUEST =
StreamingGetWorkRequest.newBuilder()
.setRequestExtension(
Expand Down Expand Up @@ -88,6 +89,8 @@ final class GrpcDirectGetWorkStream
*/
private final ConcurrentMap<Long, GetWorkResponseChunkAssembler> workItemAssemblers;

private final boolean requestBatchedGetWorkResponse;

private GrpcDirectGetWorkStream(
String backendWorkerToken,
Function<
Expand All @@ -99,6 +102,7 @@ private GrpcDirectGetWorkStream(
StreamObserverFactory streamObserverFactory,
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
boolean requestBatchedGetWorkResponse,
ThrottleTimer getWorkThrottleTimer,
HeartbeatSender heartbeatSender,
GetDataClient getDataClient,
Expand Down Expand Up @@ -127,6 +131,7 @@ private GrpcDirectGetWorkStream(
.setItems(requestHeader.getMaxItems())
.setBytes(requestHeader.getMaxBytes())
.build());
this.requestBatchedGetWorkResponse = requestBatchedGetWorkResponse;
}

static GrpcDirectGetWorkStream create(
Expand All @@ -140,6 +145,7 @@ static GrpcDirectGetWorkStream create(
StreamObserverFactory streamObserverFactory,
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
boolean requestBatchedGetWorkResponse,
ThrottleTimer getWorkThrottleTimer,
HeartbeatSender heartbeatSender,
GetDataClient getDataClient,
Expand All @@ -153,6 +159,7 @@ static GrpcDirectGetWorkStream create(
streamObserverFactory,
streamRegistry,
logEveryNStreamFailures,
requestBatchedGetWorkResponse,
getWorkThrottleTimer,
heartbeatSender,
getDataClient,
Expand Down Expand Up @@ -209,6 +216,7 @@ protected synchronized void onNewStream() throws WindmillStreamShutdownException
.setMaxItems(initialGetWorkBudget.items())
.setMaxBytes(initialGetWorkBudget.bytes())
.build())
.setSupportsMultipleWorkItemsInChunk(requestBatchedGetWorkResponse)
.build();
lastRequest.set(request);
budgetTracker.recordBudgetRequested(initialGetWorkBudget);
Expand Down Expand Up @@ -243,7 +251,7 @@ protected void onResponse(StreamingGetWorkResponseChunk chunk) {
workItemAssemblers
.computeIfAbsent(chunk.getStreamId(), unused -> new GetWorkResponseChunkAssembler())
.append(chunk)
.ifPresent(this::consumeAssembledWorkItem);
.forEach(this::consumeAssembledWorkItem);
}

private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,19 @@ final class GrpcGetWorkStream
implements GetWorkStream {

private static final Logger LOG = LoggerFactory.getLogger(GrpcGetWorkStream.class);
private static final StreamingGetWorkRequest HEALTH_CHECK =
StreamingGetWorkRequest.newBuilder()
.setRequestExtension(
StreamingGetWorkRequestExtension.newBuilder().setMaxItems(0).setMaxBytes(0).build())
.build();

private final GetWorkRequest request;
private final WorkItemReceiver receiver;
private final ThrottleTimer getWorkThrottleTimer;
private final Map<Long, GetWorkResponseChunkAssembler> workItemAssemblers;
private final AtomicLong inflightMessages;
private final AtomicLong inflightBytes;
private final boolean requestBatchedGetWorkResponse;

private GrpcGetWorkStream(
String backendWorkerToken,
Expand All @@ -64,6 +70,7 @@ private GrpcGetWorkStream(
StreamObserverFactory streamObserverFactory,
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
boolean requestBatchedGetWorkResponse,
ThrottleTimer getWorkThrottleTimer,
WorkItemReceiver receiver) {
super(
Expand All @@ -81,6 +88,7 @@ private GrpcGetWorkStream(
this.workItemAssemblers = new ConcurrentHashMap<>();
this.inflightMessages = new AtomicLong();
this.inflightBytes = new AtomicLong();
this.requestBatchedGetWorkResponse = requestBatchedGetWorkResponse;
}

public static GrpcGetWorkStream create(
Expand All @@ -94,6 +102,7 @@ public static GrpcGetWorkStream create(
StreamObserverFactory streamObserverFactory,
Set<AbstractWindmillStream<?, ?>> streamRegistry,
int logEveryNStreamFailures,
boolean requestBatchedGetWorkResponse,
ThrottleTimer getWorkThrottleTimer,
WorkItemReceiver receiver) {
return new GrpcGetWorkStream(
Expand All @@ -104,6 +113,7 @@ public static GrpcGetWorkStream create(
streamObserverFactory,
streamRegistry,
logEveryNStreamFailures,
requestBatchedGetWorkResponse,
getWorkThrottleTimer,
receiver);
}
Expand Down Expand Up @@ -132,7 +142,11 @@ protected synchronized void onNewStream() throws WindmillStreamShutdownException
workItemAssemblers.clear();
inflightMessages.set(request.getMaxItems());
inflightBytes.set(request.getMaxBytes());
trySend(StreamingGetWorkRequest.newBuilder().setRequest(request).build());
trySend(
StreamingGetWorkRequest.newBuilder()
.setSupportsMultipleWorkItemsInChunk(requestBatchedGetWorkResponse)
.setRequest(request)
.build());
}

@Override
Expand All @@ -153,11 +167,7 @@ public void appendSpecificHtml(PrintWriter writer) {

@Override
public void sendHealthCheck() throws WindmillStreamShutdownException {
trySend(
StreamingGetWorkRequest.newBuilder()
.setRequestExtension(
StreamingGetWorkRequestExtension.newBuilder().setMaxItems(0).setMaxBytes(0).build())
.build());
trySend(HEALTH_CHECK);
}

@Override
Expand All @@ -166,7 +176,7 @@ protected void onResponse(StreamingGetWorkResponseChunk chunk) {
workItemAssemblers
.computeIfAbsent(chunk.getStreamId(), unused -> new GetWorkResponseChunkAssembler())
.append(chunk)
.ifPresent(this::consumeAssembledWorkItem);
.forEach(this::consumeAssembledWorkItem);
}

private void consumeAssembledWorkItem(AssembledWorkItem assembledWorkItem) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ static GrpcWindmillServer newTestInstance(
.setSendKeyedGetDataRequests(sendKeyedGetDataRequests)
.setHealthCheckIntervalMillis(
testOptions.getWindmillServiceStreamingRpcHealthCheckPeriodMs())
.setRequestBatchedGetWorkResponse(
testOptions.getWindmillRequestBatchedGetWorkResponse())
.build();

return new GrpcWindmillServer(testOptions, windmillStreamFactory, dispatcherClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public class GrpcWindmillStreamFactory implements StatusDataProvider {
// If true, then active work refreshes will be sent as KeyedGetDataRequests. Otherwise, use the
// newer ComputationHeartbeatRequests.
private final boolean sendKeyedGetDataRequests;
private final boolean requestBatchedGetWorkResponse;
private final Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses;

private GrpcWindmillStreamFactory(
Expand All @@ -99,6 +100,7 @@ private GrpcWindmillStreamFactory(
int streamingRpcBatchLimit,
int windmillMessagesBetweenIsReadyChecks,
boolean sendKeyedGetDataRequests,
boolean requestBatchedGetWorkResponse,
Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses,
Supplier<Duration> maxBackOffSupplier) {
this.jobHeader = jobHeader;
Expand All @@ -115,6 +117,7 @@ private GrpcWindmillStreamFactory(
.backoff());
this.streamRegistry = ConcurrentHashMap.newKeySet();
this.sendKeyedGetDataRequests = sendKeyedGetDataRequests;
this.requestBatchedGetWorkResponse = requestBatchedGetWorkResponse;
this.processHeartbeatResponses = processHeartbeatResponses;
this.streamIdGenerator = new AtomicLong();
}
Expand All @@ -126,6 +129,7 @@ static GrpcWindmillStreamFactory create(
int streamingRpcBatchLimit,
int windmillMessagesBetweenIsReadyChecks,
boolean sendKeyedGetDataRequests,
boolean requestBatchedGetWorkResponse,
Consumer<List<ComputationHeartbeatResponse>> processHeartbeatResponses,
Supplier<Duration> maxBackOffSupplier,
int healthCheckIntervalMillis) {
Expand All @@ -136,6 +140,7 @@ static GrpcWindmillStreamFactory create(
streamingRpcBatchLimit,
windmillMessagesBetweenIsReadyChecks,
sendKeyedGetDataRequests,
requestBatchedGetWorkResponse,
processHeartbeatResponses,
maxBackOffSupplier);

Expand Down Expand Up @@ -174,6 +179,7 @@ public static GrpcWindmillStreamFactory.Builder of(JobHeader jobHeader) {
.setStreamingRpcBatchLimit(DEFAULT_STREAMING_RPC_BATCH_LIMIT)
.setHealthCheckIntervalMillis(NO_HEALTH_CHECKS)
.setSendKeyedGetDataRequests(true)
.setRequestBatchedGetWorkResponse(false)
.setProcessHeartbeatResponses(ignored -> {});
}

Expand Down Expand Up @@ -209,6 +215,7 @@ public GetWorkStream createGetWorkStream(
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
requestBatchedGetWorkResponse,
getWorkThrottleTimer,
processWorkItem);
}
Expand All @@ -229,6 +236,7 @@ public GetWorkStream createDirectGetWorkStream(
newStreamObserverFactory(),
streamRegistry,
logEveryNStreamFailures,
requestBatchedGetWorkResponse,
getWorkThrottleTimer,
heartbeatSender,
getDataClient,
Expand Down Expand Up @@ -356,6 +364,8 @@ Builder setProcessHeartbeatResponses(

Builder setHealthCheckIntervalMillis(int healthCheckIntervalMillis);

Builder setRequestBatchedGetWorkResponse(boolean enabled);

GrpcWindmillStreamFactory build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ public void testStreamsStartCorrectly() throws InterruptedException {
@Test
public void testOnNewWorkerMetadata_correctlyRemovesStaleWindmillServers()
throws InterruptedException {
TestGetWorkBudgetDistributor getWorkBudgetDistributor = spy(new TestGetWorkBudgetDistributor());
GetWorkBudgetDistributor getWorkBudgetDistributor = mock(GetWorkBudgetDistributor.class);
arunpandianp marked this conversation as resolved.
Show resolved Hide resolved
fanOutStreamingEngineWorkProvider =
newFanOutStreamingEngineWorkerHarness(
GetWorkBudget.builder().setItems(1).setBytes(1).build(),
Expand Down
Loading
Loading