diff --git a/WORKSPACE b/WORKSPACE index 4f5c16d..0b5229e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -119,9 +119,9 @@ http_archive( "//third_party/federated_compute:libcppbor.patch", "//third_party/federated_compute:visibility.patch", ], - sha256 = "b7debf9d2c5cff2d5a50ed6fd8d25758499d512697685879362e48442d6e370b", - strip_prefix = "federated-compute-b7e18da41f46ba19edc95148feab840649bb3756", - url = "https://github.com/google/federated-compute/archive/b7e18da41f46ba19edc95148feab840649bb3756.tar.gz", + sha256 = "3a8427667e68533085b1037b8cbe6089b9b8f5bbfeaf3abdd2f896af2be846af", + strip_prefix = "federated-compute-9d71326389a006b416b95f2fe13b5fac5aa31e21", + url = "https://github.com/google/federated-compute/archive/9d71326389a006b416b95f2fe13b5fac5aa31e21.tar.gz", ) git_repository( diff --git a/containers/confidential_transform_test_concat/confidential_transform_server.cc b/containers/confidential_transform_test_concat/confidential_transform_server.cc index 2eb7bf4..0f62b1f 100644 --- a/containers/confidential_transform_test_concat/confidential_transform_server.cc +++ b/containers/confidential_transform_test_concat/confidential_transform_server.cc @@ -117,8 +117,7 @@ absl::Status TestConcatConfidentialTransform::Session( if (absl::Status nonce_status = nonce_checker.CheckBlobNonce( write_request.first_request_metadata()); !nonce_status.ok()) { - stream->Write(ToSessionWriteFinishedResponse(nonce_status, - /*available_memory*/ 0)); + stream->Write(ToSessionWriteFinishedResponse(nonce_status)); break; } @@ -127,14 +126,13 @@ absl::Status TestConcatConfidentialTransform::Session( write_request.data()); if (!unencrypted_data.ok()) { stream->Write( - ToSessionWriteFinishedResponse(unencrypted_data.status(), - /*available_memory*/ 0)); + ToSessionWriteFinishedResponse(unencrypted_data.status())); break; } absl::StrAppend(&state, *unencrypted_data); stream->Write(ToSessionWriteFinishedResponse( - absl::OkStatus(), /*available_memory*/ 0, + absl::OkStatus(), write_request.first_request_metadata().total_size_bytes())); break; } diff --git a/containers/confidential_transform_test_concat/confidential_transform_server.h b/containers/confidential_transform_test_concat/confidential_transform_server.h index 17d24f0..6317367 100644 --- a/containers/confidential_transform_test_concat/confidential_transform_server.h +++ b/containers/confidential_transform_test_concat/confidential_transform_server.h @@ -31,8 +31,7 @@ namespace confidential_federated_compute::confidential_transform_test_concat { // Test ConfidentialTransform service that concatenates inputs. This test -// service doesn't return `write_capacity_bytes`, nor does it manage the number -// of sessions. +// service doesn't manage the number of sessions. class TestConcatConfidentialTransform final : public fcp::confidentialcompute::ConfidentialTransform::Service { public: diff --git a/containers/fed_sql/confidential_transform_server.cc b/containers/fed_sql/confidential_transform_server.cc index 6967fcb..34171fd 100644 --- a/containers/fed_sql/confidential_transform_server.cc +++ b/containers/fed_sql/confidential_transform_server.cc @@ -124,29 +124,24 @@ absl::Status HandleWrite( const WriteRequest& request, CheckpointAggregator& aggregator, BlobDecryptor* blob_decryptor, NonceChecker& nonce_checker, grpc::ServerReaderWriter* stream, - long available_memory, const std::vector* intrinsics) { + const std::vector* intrinsics) { if (absl::Status nonce_status = nonce_checker.CheckBlobNonce(request.first_request_metadata()); !nonce_status.ok()) { - stream->Write( - ToSessionWriteFinishedResponse(nonce_status, available_memory)); + stream->Write(ToSessionWriteFinishedResponse(nonce_status)); return absl::OkStatus(); } FedSqlContainerWriteConfiguration write_config; if (!request.first_request_configuration().UnpackTo(&write_config)) { - stream->Write(ToSessionWriteFinishedResponse( - absl::InvalidArgumentError( - "Failed to parse FedSqlContainerWriteConfiguration."), - available_memory)); + stream->Write(ToSessionWriteFinishedResponse(absl::InvalidArgumentError( + "Failed to parse FedSqlContainerWriteConfiguration."))); return absl::OkStatus(); } absl::StatusOr unencrypted_data = blob_decryptor->DecryptBlob( request.first_request_metadata(), request.data()); if (!unencrypted_data.ok()) { - stream->Write(ToSessionWriteFinishedResponse(unencrypted_data.status(), - - available_memory)); + stream->Write(ToSessionWriteFinishedResponse(unencrypted_data.status())); return absl::OkStatus(); } @@ -164,8 +159,7 @@ absl::Status HandleWrite( absl::Status(parser.status().code(), absl::StrCat("Failed to deserialize checkpoint for " "AGGREGATION_TYPE_ACCUMULATE: ", - parser.status().message())), - available_memory)); + parser.status().message())))); return absl::OkStatus(); } FCP_RETURN_IF_ERROR(aggregator.Accumulate(*parser.value())); @@ -179,24 +173,20 @@ absl::Status HandleWrite( absl::Status(other.status().code(), absl::StrCat("Failed to deserialize checkpoint for " "AGGREGATION_TYPE_MERGE: ", - other.status().message())), - available_memory)); + other.status().message())))); return absl::OkStatus(); } FCP_RETURN_IF_ERROR(aggregator.MergeWith(std::move(*other.value()))); break; } default: - stream->Write(ToSessionWriteFinishedResponse( - absl::InvalidArgumentError( - "AggCoreAggregationType must be specified."), - available_memory)); + stream->Write(ToSessionWriteFinishedResponse(absl::InvalidArgumentError( + "AggCoreAggregationType must be specified."))); return absl::OkStatus(); } stream->Write(ToSessionWriteFinishedResponse( - absl::OkStatus(), available_memory, - request.first_request_metadata().total_size_bytes())); + absl::OkStatus(), request.first_request_metadata().total_size_bytes())); return absl::OkStatus(); } @@ -334,8 +324,7 @@ absl::Status FedSqlConfidentialTransform::FedSqlInitialize( } absl::Status FedSqlConfidentialTransform::FedSqlSession( - grpc::ServerReaderWriter* stream, - long available_memory) { + grpc::ServerReaderWriter* stream) { BlobDecryptor* blob_decryptor; std::unique_ptr aggregator; const std::vector* intrinsics; @@ -370,8 +359,6 @@ absl::Status FedSqlConfidentialTransform::FedSqlSession( NonceChecker nonce_checker; *configure_response.mutable_configure()->mutable_nonce() = nonce_checker.GetSessionNonce(); - configure_response.mutable_configure()->set_write_capacity_bytes( - available_memory); stream->Write(configure_response); // Initialze result_blob_metadata with unencrypted metadata since @@ -391,21 +378,18 @@ absl::Status FedSqlConfidentialTransform::FedSqlSession( EarliestExpirationTimeMetadata( result_blob_metadata, write_request.first_request_metadata()); if (!earliest_expiration_metadata.ok()) { - stream->Write(ToSessionWriteFinishedResponse( - absl::Status( - earliest_expiration_metadata.status().code(), - absl::StrCat( - "Failed to extract expiration timestamp from " - "BlobMetadata with encryption: ", - earliest_expiration_metadata.status().message())), - available_memory)); + stream->Write(ToSessionWriteFinishedResponse(absl::Status( + earliest_expiration_metadata.status().code(), + absl::StrCat("Failed to extract expiration timestamp from " + "BlobMetadata with encryption: ", + earliest_expiration_metadata.status().message())))); break; } result_blob_metadata = *earliest_expiration_metadata; // TODO: spin up a thread to incorporate each blob. FCP_RETURN_IF_ERROR(HandleWrite(write_request, *aggregator, blob_decryptor, nonce_checker, stream, - available_memory, intrinsics)); + intrinsics)); break; } case SessionRequest::kFinalize: @@ -432,16 +416,16 @@ grpc::Status FedSqlConfidentialTransform::Initialize( grpc::Status FedSqlConfidentialTransform::Session( ServerContext* context, grpc::ServerReaderWriter* stream) { - long available_memory = session_tracker_.AddSession(); - if (available_memory > 0) { - grpc::Status status = ToGrpcStatus(FedSqlSession(stream, available_memory)); - absl::Status remove_session = session_tracker_.RemoveSession(); - if (!remove_session.ok()) { - return ToGrpcStatus(remove_session); - } - return status; + if (absl::Status session_status = session_tracker_.AddSession(); + !session_status.ok()) { + return ToGrpcStatus(session_status); + } + grpc::Status status = ToGrpcStatus(FedSqlSession(stream)); + absl::Status remove_session = session_tracker_.RemoveSession(); + if (!remove_session.ok()) { + return ToGrpcStatus(remove_session); } - return ToGrpcStatus(absl::UnavailableError("No session memory available.")); + return status; } } // namespace confidential_federated_compute::fed_sql diff --git a/containers/fed_sql/confidential_transform_server.h b/containers/fed_sql/confidential_transform_server.h index 98dc801..2b13adc 100644 --- a/containers/fed_sql/confidential_transform_server.h +++ b/containers/fed_sql/confidential_transform_server.h @@ -44,9 +44,9 @@ class FedSqlConfidentialTransform final // TODO: add absl::Nonnull to crypto_stub. explicit FedSqlConfidentialTransform( oak::containers::v1::OrchestratorCrypto::StubInterface* crypto_stub, - int max_num_sessions, long max_session_memory_bytes) + int max_num_sessions) : crypto_stub_(*ABSL_DIE_IF_NULL(crypto_stub)), - session_tracker_(max_num_sessions, max_session_memory_bytes) {} + session_tracker_(max_num_sessions) {} grpc::Status Initialize( grpc::ServerContext* context, @@ -67,8 +67,7 @@ class FedSqlConfidentialTransform final absl::Status FedSqlSession( grpc::ServerReaderWriter* - stream, - long stream_memory); + stream); oak::containers::v1::OrchestratorCrypto::StubInterface& crypto_stub_; confidential_federated_compute::SessionTracker session_tracker_; diff --git a/containers/fed_sql/confidential_transform_server_test.cc b/containers/fed_sql/confidential_transform_server_test.cc index 95ad169..2721704 100644 --- a/containers/fed_sql/confidential_transform_server_test.cc +++ b/containers/fed_sql/confidential_transform_server_test.cc @@ -100,7 +100,6 @@ using ::testing::Test; using testing::UnorderedElementsAre; inline constexpr int kMaxNumSessions = 8; -inline constexpr long kMaxSessionMemoryBytes = 1000000; std::string BuildSingleInt32TensorCheckpoint( std::string column_name, std::initializer_list input_values) { @@ -166,8 +165,7 @@ class FedSqlServerTest : public Test { BlobMetadata DefaultBlobMetadata() const; testing::NiceMock mock_crypto_stub_; - FedSqlConfidentialTransform service_{&mock_crypto_stub_, kMaxNumSessions, - kMaxSessionMemoryBytes}; + FedSqlConfidentialTransform service_{&mock_crypto_stub_, kMaxNumSessions}; std::unique_ptr server_; std::unique_ptr stub_; }; @@ -408,8 +406,6 @@ TEST_F(FedSqlServerTest, SessionConfigureGeneratesNonce) { ASSERT_TRUE(session_response.has_configure()); ASSERT_GT(session_response.configure().nonce().size(), 0); - ASSERT_EQ(session_response.configure().write_capacity_bytes(), - kMaxSessionMemoryBytes); } TEST_F(FedSqlServerTest, SessionRejectsMoreThanMaximumNumSessions) { @@ -453,7 +449,8 @@ TEST_F(FedSqlServerTest, SessionRejectsMoreThanMaximumNumSessions) { stream = stub_->Session(&rejected_context); ASSERT_TRUE(stream->Write(rejected_request)); ASSERT_FALSE(stream->Read(&rejected_response)); - ASSERT_EQ(stream->Finish().error_code(), grpc::StatusCode::UNAVAILABLE); + ASSERT_EQ(stream->Finish().error_code(), + grpc::StatusCode::FAILED_PRECONDITION); } TEST_F(FedSqlServerTest, SessionBeforeInitialize) { @@ -635,8 +632,6 @@ TEST_F(FedSqlServerFederatedSumTest, SessionWriteAccumulateCommitsBlob) { ASSERT_TRUE(write_response.has_write()); ASSERT_EQ(write_response.write().committed_size_bytes(), data.size()); ASSERT_EQ(write_response.write().status().code(), grpc::OK); - ASSERT_EQ(write_response.write().write_capacity_bytes(), - kMaxSessionMemoryBytes); } TEST_F(FedSqlServerFederatedSumTest, SessionAccumulatesAndReports) { @@ -789,8 +784,6 @@ TEST_F(FedSqlServerFederatedSumTest, SessionIgnoresUnparseableInputs) { ASSERT_TRUE(write_response_2.has_write()); ASSERT_EQ(write_response_2.write().committed_size_bytes(), 0); ASSERT_EQ(write_response_2.write().status().code(), grpc::INVALID_ARGUMENT); - ASSERT_EQ(write_response_2.write().write_capacity_bytes(), - kMaxSessionMemoryBytes); FedSqlContainerFinalizeConfiguration finalize_config = PARSE_TEXT_PROTO(R"pb( type: FINALIZATION_TYPE_REPORT diff --git a/containers/session.cc b/containers/session.cc index 2547894..b5d9552 100644 --- a/containers/session.cc +++ b/containers/session.cc @@ -23,13 +23,14 @@ using ::fcp::base::ToGrpcStatus; using ::fcp::confidentialcompute::SessionResponse; using ::fcp::confidentialcompute::WriteFinishedResponse; -long SessionTracker::AddSession() { +absl::Status SessionTracker::AddSession() { absl::MutexLock l(&mutex_); if (num_sessions_ < max_num_sessions_) { num_sessions_++; - return max_session_memory_bytes_; + return absl::OkStatus(); } - return 0; + return absl::FailedPreconditionError( + "SessionTracker: already at the maximum number of sessions."); } absl::Status SessionTracker::RemoveSession() { @@ -43,14 +44,12 @@ absl::Status SessionTracker::RemoveSession() { } SessionResponse ToSessionWriteFinishedResponse(absl::Status status, - long available_memory, long committed_size_bytes) { grpc::Status grpc_status = ToGrpcStatus(std::move(status)); SessionResponse session_response; WriteFinishedResponse* response = session_response.mutable_write(); response->mutable_status()->set_code(grpc_status.error_code()); response->mutable_status()->set_message(grpc_status.error_message()); - response->set_write_capacity_bytes(available_memory); response->set_committed_size_bytes(committed_size_bytes); return session_response; } diff --git a/containers/session.h b/containers/session.h index 1f6b443..f3b570a 100644 --- a/containers/session.h +++ b/containers/session.h @@ -27,13 +27,11 @@ namespace confidential_federated_compute { // This class is threadsafe. class SessionTracker { public: - SessionTracker(int max_num_sessions, long max_session_memory_bytes) - : max_num_sessions_(max_num_sessions), - max_session_memory_bytes_(max_session_memory_bytes) {}; + SessionTracker(int max_num_sessions) : max_num_sessions_(max_num_sessions) {}; // Tries to add a session and returns the amount of memory in bytes that the // session is allowed. Returns 0 if there is no available memory. - long AddSession(); + absl::Status AddSession(); // Tries to remove a session and returns an error if unable to do so. absl::Status RemoveSession(); @@ -48,7 +46,7 @@ class SessionTracker { // Create a SessionResponse with a WriteFinishedResponse. fcp::confidentialcompute::SessionResponse ToSessionWriteFinishedResponse( - absl::Status status, long available_memory, long committed_size_bytes = 0); + absl::Status status, long committed_size_bytes = 0); } // namespace confidential_federated_compute #endif // CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_SESSION_H_ diff --git a/containers/session_test.cc b/containers/session_test.cc index 748d477..92a6081 100644 --- a/containers/session_test.cc +++ b/containers/session_test.cc @@ -23,36 +23,37 @@ namespace { using ::fcp::confidentialcompute::SessionResponse; TEST(SessionTest, AddSession) { - SessionTracker session_tracker(1, 100); - ASSERT_GT(session_tracker.AddSession(), 0); + SessionTracker session_tracker(1); + EXPECT_TRUE(session_tracker.AddSession().ok()); } TEST(SessionTest, MaximumSessionsReachedAddSession) { - SessionTracker session_tracker(1, 100); - ASSERT_GT(session_tracker.AddSession(), 0); - ASSERT_EQ(session_tracker.AddSession(), 0); + SessionTracker session_tracker(1); + EXPECT_TRUE(session_tracker.AddSession().ok()); + EXPECT_EQ(session_tracker.AddSession().code(), + absl::StatusCode::kFailedPrecondition); } TEST(SessionTest, MaximumSessionsReachedCanAddSessionAfterRemoveSession) { - SessionTracker session_tracker(1, 100); - ASSERT_GT(session_tracker.AddSession(), 0); - ASSERT_EQ(session_tracker.AddSession(), 0); - ASSERT_TRUE(session_tracker.RemoveSession().ok()); - ASSERT_GT(session_tracker.AddSession(), 0); + SessionTracker session_tracker(1); + EXPECT_TRUE(session_tracker.AddSession().ok()); + EXPECT_EQ(session_tracker.AddSession().code(), + absl::StatusCode::kFailedPrecondition); + EXPECT_TRUE(session_tracker.RemoveSession().ok()); + EXPECT_TRUE(session_tracker.AddSession().ok()); } TEST(SessionTest, RemoveSessionWithoutAddSessionFails) { - SessionTracker session_tracker(1, 100); - ASSERT_EQ(session_tracker.RemoveSession().code(), + SessionTracker session_tracker(1); + EXPECT_EQ(session_tracker.RemoveSession().code(), absl::StatusCode::kFailedPrecondition); } TEST(SessionTest, ErrorToSessionWriteFinishedResponseTest) { - SessionResponse response = ToSessionWriteFinishedResponse( - absl::InvalidArgumentError("invalid arg"), 42); + SessionResponse response = + ToSessionWriteFinishedResponse(absl::InvalidArgumentError("invalid arg")); ASSERT_TRUE(response.has_write()); EXPECT_EQ(response.write().committed_size_bytes(), 0); - EXPECT_EQ(response.write().write_capacity_bytes(), 42); EXPECT_EQ(response.write().status().code(), grpc::StatusCode::INVALID_ARGUMENT); EXPECT_EQ(response.write().status().message(), "invalid arg"); @@ -60,10 +61,9 @@ TEST(SessionTest, ErrorToSessionWriteFinishedResponseTest) { TEST(SessionTest, OkToSessionWriteFinishedResponseTest) { SessionResponse response = - ToSessionWriteFinishedResponse(absl::OkStatus(), 42, 6); + ToSessionWriteFinishedResponse(absl::OkStatus(), 6); ASSERT_TRUE(response.has_write()); EXPECT_EQ(response.write().committed_size_bytes(), 6); - EXPECT_EQ(response.write().write_capacity_bytes(), 42); EXPECT_EQ(response.write().status().code(), grpc::StatusCode::OK); EXPECT_TRUE(response.write().status().message().empty()); }