Skip to content

Commit

Permalink
Fail aggregation if no inputs have been aggregated.
Browse files Browse the repository at this point in the history
Workflows that haven't aggregated any inputs should be marked as failed
rather than succeeded as this likely indicates some configuration issue
or bug.

If no inputs are aggregated but the operation is SERIALIZE rather than
REPORT, the operation should succeed as there is still an opportunity
for the empty input to be aggregated with other partial aggregates and
produce avalid result.

BUG: 354248240
Change-Id: I341370265d0783d3ed65b9058014391b0f9f1764
  • Loading branch information
nfallen committed Jul 26, 2024
1 parent f1a4170 commit 781a342
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 18 deletions.
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ python_register_toolchains(

http_archive(
name = "org_tensorflow_federated",
sha256 = "e3e198f291375d4d05d584b2555c1a64c72e83fe34ada28ca129997000afc269",
strip_prefix = "tensorflow-federated-a34fb3088695221e326d532e4d417957325dd9cd",
url = "https://github.com/google-parfait/tensorflow-federated/archive/a34fb3088695221e326d532e4d417957325dd9cd.tar.gz",
sha256 = "343d12a98ef8d98202e1dca898d84390a14fdc296af60f889cbb4023f38ebcdb",
strip_prefix = "tensorflow-federated-734a8669dc9842f4355d4bce240cd47883bda0c4",
url = "https://github.com/google-parfait/tensorflow-federated/archive/734a8669dc9842f4355d4bce240cd47883bda0c4.tar.gz",
)

# Use a newer version of BoringSSL than what TF gives us, so we can use
Expand Down
11 changes: 11 additions & 0 deletions containers/agg_core/pipeline_transform_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,17 @@ absl::Status AggCorePipelineTransform::AggCoreTransform(
"The aggregation can't be completed due to failed preconditions.");
}

// Fail if there were no valid inputs, as this likely indicates some issue
// with configuration of the overall workload.
FCP_ASSIGN_OR_RETURN(int num_checkpoints_aggregated,
aggregator->GetNumCheckpointsAggregated());
if (num_checkpoints_aggregated < 1) {
return absl::InvalidArgumentError(
"The aggregation can't be successfully completed because no inputs "
"were aggregated.\n"
"This may be because inputs were ignored due to an earlier error.");
}

FederatedComputeCheckpointBuilderFactory builder_factory;
std::unique_ptr<CheckpointBuilder> checkpoint_builder =
builder_factory.Create();
Expand Down
24 changes: 9 additions & 15 deletions containers/agg_core/pipeline_transform_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ TEST_F(AggCoreTransformTest,
ASSERT_EQ(cwt->config_properties.fields().at("delta").number_value(), 2.2);
}

TEST_F(AggCoreTransformTest, TransformZeroInputsReturnsEmptyCheckpoint) {
TEST_F(AggCoreTransformTest, TransformZeroInputsReturnsInvalidArgumentError) {
FederatedComputeCheckpointParserFactory parser_factory;
grpc::ClientContext configure_context;
ConfigureAndAttestRequest configure_request;
Expand All @@ -341,21 +341,15 @@ TEST_F(AggCoreTransformTest, TransformZeroInputsReturnsEmptyCheckpoint) {
TransformRequest transform_request;
grpc::ClientContext transform_context;
TransformResponse transform_response;
auto transform_status = stub_->Transform(
&transform_context, transform_request, &transform_response);

ASSERT_TRUE(transform_status.ok());
ASSERT_EQ(transform_response.outputs_size(), 1);
ASSERT_TRUE(transform_response.outputs(0).has_unencrypted_data());
auto status = stub_->Transform(&transform_context, transform_request,
&transform_response);

absl::Cord wire_format_result(
transform_response.outputs(0).unencrypted_data());
auto parser = parser_factory.Create(wire_format_result);
auto col_values = (*parser)->GetTensor("foo_out");
// A column with a sum of 0 is returned.
ASSERT_EQ(col_values->num_elements(), 1);
ASSERT_EQ(col_values->dtype(), DataType::DT_INT32);
ASSERT_EQ(col_values->AsSpan<int32_t>().at(0), 0);
ASSERT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT);
ASSERT_THAT(
status.error_message(),
HasSubstr(
"The aggregation can't be successfully completed because no inputs "
"were aggregated"));
}

TEST_F(AggCoreTransformTest, TransformExecutesFederatedSum) {
Expand Down
10 changes: 10 additions & 0 deletions containers/fed_sql/confidential_transform_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ absl::StatusOr<SessionResponse> FedSqlSession::FinalizeSession(
return absl::FailedPreconditionError(
"The aggregation can't be completed due to failed preconditions.");
}
// Fail if there were no valid inputs, as this likely indicates some issue
// with configuration of the overall workload.
FCP_ASSIGN_OR_RETURN(int num_checkpoints_aggregated,
aggregator_->GetNumCheckpointsAggregated());
if (num_checkpoints_aggregated < 1) {
return absl::InvalidArgumentError(
"The aggregation can't be successfully completed because no inputs "
"were aggregated.\n"
"This may be because inputs were ignored due to an earlier error.");
}

FederatedComputeCheckpointBuilderFactory builder_factory;
std::unique_ptr<CheckpointBuilder> checkpoint_builder =
Expand Down
76 changes: 76 additions & 0 deletions containers/fed_sql/confidential_transform_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,82 @@ TEST_F(FedSqlServerFederatedSumTest, SessionMergesAndReports) {
ASSERT_EQ(col_values->AsSpan<int32_t>().at(0), 3);
}

TEST_F(FedSqlServerFederatedSumTest, SerializeZeroInputsProducesEmptyOutput) {
SessionRequest write_request =
CreateDefaultWriteRequest(AGGREGATION_TYPE_ACCUMULATE,
BuildSingleInt32TensorCheckpoint("foo", {1}));
SessionResponse write_response;

FedSqlContainerFinalizeConfiguration finalize_config = PARSE_TEXT_PROTO(R"pb(
type: FINALIZATION_TYPE_SERIALIZE
)pb");
SessionRequest finalize_request;
SessionResponse finalize_response;
finalize_request.mutable_finalize()->mutable_configuration()->PackFrom(
finalize_config);
ASSERT_TRUE(stream_->Write(finalize_request));
ASSERT_TRUE(stream_->Read(&finalize_response));

ASSERT_TRUE(finalize_response.has_read());
ASSERT_TRUE(finalize_response.read().finish_read());
ASSERT_GT(
finalize_response.read().first_response_metadata().total_size_bytes(), 0);
ASSERT_TRUE(
finalize_response.read().first_response_metadata().has_unencrypted());

absl::StatusOr<std::unique_ptr<CheckpointAggregator>> deserialized_agg =
CheckpointAggregator::Deserialize(DefaultConfiguration(),
finalize_response.read().data());
ASSERT_TRUE(deserialized_agg.ok());

FederatedComputeCheckpointBuilderFactory builder_factory;
std::unique_ptr<CheckpointBuilder> checkpoint_builder =
builder_factory.Create();

absl::StatusOr<int> num_checkpoints_aggregated =
(*deserialized_agg)->GetNumCheckpointsAggregated();
ASSERT_TRUE(num_checkpoints_aggregated.ok())
<< num_checkpoints_aggregated.status();
ASSERT_EQ(*num_checkpoints_aggregated, 0);

// Merging the empty serialized aggregator with another aggregator should have
// no effect on the output of the other aggregator.
FederatedComputeCheckpointParserFactory parser_factory;
auto input_parser =
parser_factory
.Create(absl::Cord(BuildSingleInt32TensorCheckpoint("foo", {3})))
.value();
std::unique_ptr<CheckpointAggregator> other_aggregator =
CheckpointAggregator::Create(DefaultConfiguration()).value();
ASSERT_TRUE(other_aggregator->Accumulate(*input_parser).ok());
ASSERT_TRUE(
other_aggregator->MergeWith(std::move(*deserialized_agg->release()))
.ok());

ASSERT_TRUE((*other_aggregator).Report(*checkpoint_builder).ok());
absl::StatusOr<absl::Cord> checkpoint = checkpoint_builder->Build();
auto parser = parser_factory.Create(*checkpoint);
auto col_values = (*parser)->GetTensor("foo_out");
// A column with a sum of 3 is returned.
ASSERT_EQ(col_values->num_elements(), 1);
ASSERT_EQ(col_values->dtype(), DataType::DT_INT32);
ASSERT_EQ(col_values->AsSpan<int32_t>().at(0), 3);
}

TEST_F(FedSqlServerFederatedSumTest, ReportZeroInputsReturnsInvalidArgument) {
FedSqlContainerFinalizeConfiguration finalize_config = PARSE_TEXT_PROTO(R"pb(
type: FINALIZATION_TYPE_REPORT
)pb");
SessionRequest finalize_request;
SessionResponse finalize_response;
finalize_request.mutable_finalize()->mutable_configuration()->PackFrom(
finalize_config);
ASSERT_TRUE(stream_->Write(finalize_request));
ASSERT_FALSE(stream_->Read(&finalize_response));
grpc::Status status = stream_->Finish();
ASSERT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT);
}

TEST_F(FedSqlServerFederatedSumTest, SessionIgnoresUnparseableInputs) {
SessionRequest write_request_1 =
CreateDefaultWriteRequest(AGGREGATION_TYPE_ACCUMULATE,
Expand Down

0 comments on commit 781a342

Please sign in to comment.