From 781a342d6aa30bb9a21471561620ee3611150788 Mon Sep 17 00:00:00 2001 From: Nova Fallen Date: Fri, 26 Jul 2024 16:22:43 +0000 Subject: [PATCH] Fail aggregation if no inputs have been aggregated. Workflows that haven't aggregated any inputs should be marked as failed rather than succeeded as this likely indicates some configuration issue or bug. If no inputs are aggregated but the operation is SERIALIZE rather than REPORT, the operation should succeed as there is still an opportunity for the empty input to be aggregated with other partial aggregates and produce avalid result. BUG: 354248240 Change-Id: I341370265d0783d3ed65b9058014391b0f9f1764 --- WORKSPACE | 6 +- .../agg_core/pipeline_transform_server.cc | 11 +++ .../pipeline_transform_server_test.cc | 24 +++--- .../fed_sql/confidential_transform_server.cc | 10 +++ .../confidential_transform_server_test.cc | 76 +++++++++++++++++++ 5 files changed, 109 insertions(+), 18 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 04a6f4d..c32dc4f 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -64,9 +64,9 @@ python_register_toolchains( http_archive( name = "org_tensorflow_federated", - sha256 = "e3e198f291375d4d05d584b2555c1a64c72e83fe34ada28ca129997000afc269", - strip_prefix = "tensorflow-federated-a34fb3088695221e326d532e4d417957325dd9cd", - url = "https://github.com/google-parfait/tensorflow-federated/archive/a34fb3088695221e326d532e4d417957325dd9cd.tar.gz", + sha256 = "343d12a98ef8d98202e1dca898d84390a14fdc296af60f889cbb4023f38ebcdb", + strip_prefix = "tensorflow-federated-734a8669dc9842f4355d4bce240cd47883bda0c4", + url = "https://github.com/google-parfait/tensorflow-federated/archive/734a8669dc9842f4355d4bce240cd47883bda0c4.tar.gz", ) # Use a newer version of BoringSSL than what TF gives us, so we can use diff --git a/containers/agg_core/pipeline_transform_server.cc b/containers/agg_core/pipeline_transform_server.cc index 334ee64..baa5900 100644 --- a/containers/agg_core/pipeline_transform_server.cc +++ b/containers/agg_core/pipeline_transform_server.cc @@ -251,6 +251,17 @@ absl::Status AggCorePipelineTransform::AggCoreTransform( "The aggregation can't be completed due to failed preconditions."); } + // Fail if there were no valid inputs, as this likely indicates some issue + // with configuration of the overall workload. + FCP_ASSIGN_OR_RETURN(int num_checkpoints_aggregated, + aggregator->GetNumCheckpointsAggregated()); + if (num_checkpoints_aggregated < 1) { + return absl::InvalidArgumentError( + "The aggregation can't be successfully completed because no inputs " + "were aggregated.\n" + "This may be because inputs were ignored due to an earlier error."); + } + FederatedComputeCheckpointBuilderFactory builder_factory; std::unique_ptr checkpoint_builder = builder_factory.Create(); diff --git a/containers/agg_core/pipeline_transform_server_test.cc b/containers/agg_core/pipeline_transform_server_test.cc index c605885..bc2b5a3 100644 --- a/containers/agg_core/pipeline_transform_server_test.cc +++ b/containers/agg_core/pipeline_transform_server_test.cc @@ -326,7 +326,7 @@ TEST_F(AggCoreTransformTest, ASSERT_EQ(cwt->config_properties.fields().at("delta").number_value(), 2.2); } -TEST_F(AggCoreTransformTest, TransformZeroInputsReturnsEmptyCheckpoint) { +TEST_F(AggCoreTransformTest, TransformZeroInputsReturnsInvalidArgumentError) { FederatedComputeCheckpointParserFactory parser_factory; grpc::ClientContext configure_context; ConfigureAndAttestRequest configure_request; @@ -341,21 +341,15 @@ TEST_F(AggCoreTransformTest, TransformZeroInputsReturnsEmptyCheckpoint) { TransformRequest transform_request; grpc::ClientContext transform_context; TransformResponse transform_response; - auto transform_status = stub_->Transform( - &transform_context, transform_request, &transform_response); - - ASSERT_TRUE(transform_status.ok()); - ASSERT_EQ(transform_response.outputs_size(), 1); - ASSERT_TRUE(transform_response.outputs(0).has_unencrypted_data()); + auto status = stub_->Transform(&transform_context, transform_request, + &transform_response); - absl::Cord wire_format_result( - transform_response.outputs(0).unencrypted_data()); - auto parser = parser_factory.Create(wire_format_result); - auto col_values = (*parser)->GetTensor("foo_out"); - // A column with a sum of 0 is returned. - ASSERT_EQ(col_values->num_elements(), 1); - ASSERT_EQ(col_values->dtype(), DataType::DT_INT32); - ASSERT_EQ(col_values->AsSpan().at(0), 0); + ASSERT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); + ASSERT_THAT( + status.error_message(), + HasSubstr( + "The aggregation can't be successfully completed because no inputs " + "were aggregated")); } TEST_F(AggCoreTransformTest, TransformExecutesFederatedSum) { diff --git a/containers/fed_sql/confidential_transform_server.cc b/containers/fed_sql/confidential_transform_server.cc index fdfb9bc..38cd67f 100644 --- a/containers/fed_sql/confidential_transform_server.cc +++ b/containers/fed_sql/confidential_transform_server.cc @@ -171,6 +171,16 @@ absl::StatusOr FedSqlSession::FinalizeSession( return absl::FailedPreconditionError( "The aggregation can't be completed due to failed preconditions."); } + // Fail if there were no valid inputs, as this likely indicates some issue + // with configuration of the overall workload. + FCP_ASSIGN_OR_RETURN(int num_checkpoints_aggregated, + aggregator_->GetNumCheckpointsAggregated()); + if (num_checkpoints_aggregated < 1) { + return absl::InvalidArgumentError( + "The aggregation can't be successfully completed because no inputs " + "were aggregated.\n" + "This may be because inputs were ignored due to an earlier error."); + } FederatedComputeCheckpointBuilderFactory builder_factory; std::unique_ptr checkpoint_builder = diff --git a/containers/fed_sql/confidential_transform_server_test.cc b/containers/fed_sql/confidential_transform_server_test.cc index 2721704..2380e2b 100644 --- a/containers/fed_sql/confidential_transform_server_test.cc +++ b/containers/fed_sql/confidential_transform_server_test.cc @@ -765,6 +765,82 @@ TEST_F(FedSqlServerFederatedSumTest, SessionMergesAndReports) { ASSERT_EQ(col_values->AsSpan().at(0), 3); } +TEST_F(FedSqlServerFederatedSumTest, SerializeZeroInputsProducesEmptyOutput) { + SessionRequest write_request = + CreateDefaultWriteRequest(AGGREGATION_TYPE_ACCUMULATE, + BuildSingleInt32TensorCheckpoint("foo", {1})); + SessionResponse write_response; + + FedSqlContainerFinalizeConfiguration finalize_config = PARSE_TEXT_PROTO(R"pb( + type: FINALIZATION_TYPE_SERIALIZE + )pb"); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + finalize_config); + ASSERT_TRUE(stream_->Write(finalize_request)); + ASSERT_TRUE(stream_->Read(&finalize_response)); + + ASSERT_TRUE(finalize_response.has_read()); + ASSERT_TRUE(finalize_response.read().finish_read()); + ASSERT_GT( + finalize_response.read().first_response_metadata().total_size_bytes(), 0); + ASSERT_TRUE( + finalize_response.read().first_response_metadata().has_unencrypted()); + + absl::StatusOr> deserialized_agg = + CheckpointAggregator::Deserialize(DefaultConfiguration(), + finalize_response.read().data()); + ASSERT_TRUE(deserialized_agg.ok()); + + FederatedComputeCheckpointBuilderFactory builder_factory; + std::unique_ptr checkpoint_builder = + builder_factory.Create(); + + absl::StatusOr num_checkpoints_aggregated = + (*deserialized_agg)->GetNumCheckpointsAggregated(); + ASSERT_TRUE(num_checkpoints_aggregated.ok()) + << num_checkpoints_aggregated.status(); + ASSERT_EQ(*num_checkpoints_aggregated, 0); + + // Merging the empty serialized aggregator with another aggregator should have + // no effect on the output of the other aggregator. + FederatedComputeCheckpointParserFactory parser_factory; + auto input_parser = + parser_factory + .Create(absl::Cord(BuildSingleInt32TensorCheckpoint("foo", {3}))) + .value(); + std::unique_ptr other_aggregator = + CheckpointAggregator::Create(DefaultConfiguration()).value(); + ASSERT_TRUE(other_aggregator->Accumulate(*input_parser).ok()); + ASSERT_TRUE( + other_aggregator->MergeWith(std::move(*deserialized_agg->release())) + .ok()); + + ASSERT_TRUE((*other_aggregator).Report(*checkpoint_builder).ok()); + absl::StatusOr checkpoint = checkpoint_builder->Build(); + auto parser = parser_factory.Create(*checkpoint); + auto col_values = (*parser)->GetTensor("foo_out"); + // A column with a sum of 3 is returned. + ASSERT_EQ(col_values->num_elements(), 1); + ASSERT_EQ(col_values->dtype(), DataType::DT_INT32); + ASSERT_EQ(col_values->AsSpan().at(0), 3); +} + +TEST_F(FedSqlServerFederatedSumTest, ReportZeroInputsReturnsInvalidArgument) { + FedSqlContainerFinalizeConfiguration finalize_config = PARSE_TEXT_PROTO(R"pb( + type: FINALIZATION_TYPE_REPORT + )pb"); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + finalize_config); + ASSERT_TRUE(stream_->Write(finalize_request)); + ASSERT_FALSE(stream_->Read(&finalize_response)); + grpc::Status status = stream_->Finish(); + ASSERT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); +} + TEST_F(FedSqlServerFederatedSumTest, SessionIgnoresUnparseableInputs) { SessionRequest write_request_1 = CreateDefaultWriteRequest(AGGREGATION_TYPE_ACCUMULATE,