Skip to content

Commit

Permalink
Remove code related to write_capacity_bytes.
Browse files Browse the repository at this point in the history
This field is not needed because we can instead rely on gRPC's flow
control. The server will have have a fixed size thread pool per session
and block the request-reading while loop if the thread pool is full.
Blocking the while loop would trigger the gRPC level flow control again,
which will block client-side Writes. Once we switch to the async
callback API, client's won't send subsequent Writes until OnWriteDone is
called.

BUG: 345838534

Change-Id: If87308657bd4c715d191f8831e306bb91f8d1fcf
  • Loading branch information
zpgong committed Jul 17, 2024
1 parent 2eab494 commit b6dd662
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 93 deletions.
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ http_archive(
"//third_party/federated_compute:libcppbor.patch",
"//third_party/federated_compute:visibility.patch",
],
sha256 = "b7debf9d2c5cff2d5a50ed6fd8d25758499d512697685879362e48442d6e370b",
strip_prefix = "federated-compute-b7e18da41f46ba19edc95148feab840649bb3756",
url = "https://github.com/google/federated-compute/archive/b7e18da41f46ba19edc95148feab840649bb3756.tar.gz",
sha256 = "3a8427667e68533085b1037b8cbe6089b9b8f5bbfeaf3abdd2f896af2be846af",
strip_prefix = "federated-compute-9d71326389a006b416b95f2fe13b5fac5aa31e21",
url = "https://github.com/google/federated-compute/archive/9d71326389a006b416b95f2fe13b5fac5aa31e21.tar.gz",
)

git_repository(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ absl::Status TestConcatConfidentialTransform::Session(
if (absl::Status nonce_status = nonce_checker.CheckBlobNonce(
write_request.first_request_metadata());
!nonce_status.ok()) {
stream->Write(ToSessionWriteFinishedResponse(nonce_status,
/*available_memory*/ 0));
stream->Write(ToSessionWriteFinishedResponse(nonce_status));
break;
}

Expand All @@ -127,14 +126,13 @@ absl::Status TestConcatConfidentialTransform::Session(
write_request.data());
if (!unencrypted_data.ok()) {
stream->Write(
ToSessionWriteFinishedResponse(unencrypted_data.status(),
/*available_memory*/ 0));
ToSessionWriteFinishedResponse(unencrypted_data.status()));
break;
}

absl::StrAppend(&state, *unencrypted_data);
stream->Write(ToSessionWriteFinishedResponse(
absl::OkStatus(), /*available_memory*/ 0,
absl::OkStatus(),
write_request.first_request_metadata().total_size_bytes()));
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
namespace confidential_federated_compute::confidential_transform_test_concat {

// Test ConfidentialTransform service that concatenates inputs. This test
// service doesn't return `write_capacity_bytes`, nor does it manage the number
// of sessions.
// service doesn't manage the number of sessions.
class TestConcatConfidentialTransform final
: public fcp::confidentialcompute::ConfidentialTransform::Service {
public:
Expand Down
68 changes: 26 additions & 42 deletions containers/fed_sql/confidential_transform_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,29 +124,24 @@ absl::Status HandleWrite(
const WriteRequest& request, CheckpointAggregator& aggregator,
BlobDecryptor* blob_decryptor, NonceChecker& nonce_checker,
grpc::ServerReaderWriter<SessionResponse, SessionRequest>* stream,
long available_memory, const std::vector<Intrinsic>* intrinsics) {
const std::vector<Intrinsic>* intrinsics) {
if (absl::Status nonce_status =
nonce_checker.CheckBlobNonce(request.first_request_metadata());
!nonce_status.ok()) {
stream->Write(
ToSessionWriteFinishedResponse(nonce_status, available_memory));
stream->Write(ToSessionWriteFinishedResponse(nonce_status));
return absl::OkStatus();
}

FedSqlContainerWriteConfiguration write_config;
if (!request.first_request_configuration().UnpackTo(&write_config)) {
stream->Write(ToSessionWriteFinishedResponse(
absl::InvalidArgumentError(
"Failed to parse FedSqlContainerWriteConfiguration."),
available_memory));
stream->Write(ToSessionWriteFinishedResponse(absl::InvalidArgumentError(
"Failed to parse FedSqlContainerWriteConfiguration.")));
return absl::OkStatus();
}
absl::StatusOr<std::string> unencrypted_data = blob_decryptor->DecryptBlob(
request.first_request_metadata(), request.data());
if (!unencrypted_data.ok()) {
stream->Write(ToSessionWriteFinishedResponse(unencrypted_data.status(),

available_memory));
stream->Write(ToSessionWriteFinishedResponse(unencrypted_data.status()));
return absl::OkStatus();
}

Expand All @@ -164,8 +159,7 @@ absl::Status HandleWrite(
absl::Status(parser.status().code(),
absl::StrCat("Failed to deserialize checkpoint for "
"AGGREGATION_TYPE_ACCUMULATE: ",
parser.status().message())),
available_memory));
parser.status().message()))));
return absl::OkStatus();
}
FCP_RETURN_IF_ERROR(aggregator.Accumulate(*parser.value()));
Expand All @@ -179,24 +173,20 @@ absl::Status HandleWrite(
absl::Status(other.status().code(),
absl::StrCat("Failed to deserialize checkpoint for "
"AGGREGATION_TYPE_MERGE: ",
other.status().message())),
available_memory));
other.status().message()))));
return absl::OkStatus();
}
FCP_RETURN_IF_ERROR(aggregator.MergeWith(std::move(*other.value())));
break;
}
default:
stream->Write(ToSessionWriteFinishedResponse(
absl::InvalidArgumentError(
"AggCoreAggregationType must be specified."),
available_memory));
stream->Write(ToSessionWriteFinishedResponse(absl::InvalidArgumentError(
"AggCoreAggregationType must be specified.")));
return absl::OkStatus();
}

stream->Write(ToSessionWriteFinishedResponse(
absl::OkStatus(), available_memory,
request.first_request_metadata().total_size_bytes()));
absl::OkStatus(), request.first_request_metadata().total_size_bytes()));
return absl::OkStatus();
}

Expand Down Expand Up @@ -334,8 +324,7 @@ absl::Status FedSqlConfidentialTransform::FedSqlInitialize(
}

absl::Status FedSqlConfidentialTransform::FedSqlSession(
grpc::ServerReaderWriter<SessionResponse, SessionRequest>* stream,
long available_memory) {
grpc::ServerReaderWriter<SessionResponse, SessionRequest>* stream) {
BlobDecryptor* blob_decryptor;
std::unique_ptr<CheckpointAggregator> aggregator;
const std::vector<Intrinsic>* intrinsics;
Expand Down Expand Up @@ -370,8 +359,6 @@ absl::Status FedSqlConfidentialTransform::FedSqlSession(
NonceChecker nonce_checker;
*configure_response.mutable_configure()->mutable_nonce() =
nonce_checker.GetSessionNonce();
configure_response.mutable_configure()->set_write_capacity_bytes(
available_memory);
stream->Write(configure_response);

// Initialze result_blob_metadata with unencrypted metadata since
Expand All @@ -391,21 +378,18 @@ absl::Status FedSqlConfidentialTransform::FedSqlSession(
EarliestExpirationTimeMetadata(
result_blob_metadata, write_request.first_request_metadata());
if (!earliest_expiration_metadata.ok()) {
stream->Write(ToSessionWriteFinishedResponse(
absl::Status(
earliest_expiration_metadata.status().code(),
absl::StrCat(
"Failed to extract expiration timestamp from "
"BlobMetadata with encryption: ",
earliest_expiration_metadata.status().message())),
available_memory));
stream->Write(ToSessionWriteFinishedResponse(absl::Status(
earliest_expiration_metadata.status().code(),
absl::StrCat("Failed to extract expiration timestamp from "
"BlobMetadata with encryption: ",
earliest_expiration_metadata.status().message()))));
break;
}
result_blob_metadata = *earliest_expiration_metadata;
// TODO: spin up a thread to incorporate each blob.
FCP_RETURN_IF_ERROR(HandleWrite(write_request, *aggregator,
blob_decryptor, nonce_checker, stream,
available_memory, intrinsics));
intrinsics));
break;
}
case SessionRequest::kFinalize:
Expand All @@ -432,16 +416,16 @@ grpc::Status FedSqlConfidentialTransform::Initialize(
grpc::Status FedSqlConfidentialTransform::Session(
ServerContext* context,
grpc::ServerReaderWriter<SessionResponse, SessionRequest>* stream) {
long available_memory = session_tracker_.AddSession();
if (available_memory > 0) {
grpc::Status status = ToGrpcStatus(FedSqlSession(stream, available_memory));
absl::Status remove_session = session_tracker_.RemoveSession();
if (!remove_session.ok()) {
return ToGrpcStatus(remove_session);
}
return status;
if (absl::Status session_status = session_tracker_.AddSession();
!session_status.ok()) {
return ToGrpcStatus(session_status);
}
grpc::Status status = ToGrpcStatus(FedSqlSession(stream));
absl::Status remove_session = session_tracker_.RemoveSession();
if (!remove_session.ok()) {
return ToGrpcStatus(remove_session);
}
return ToGrpcStatus(absl::UnavailableError("No session memory available."));
return status;
}

} // namespace confidential_federated_compute::fed_sql
7 changes: 3 additions & 4 deletions containers/fed_sql/confidential_transform_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ class FedSqlConfidentialTransform final
// TODO: add absl::Nonnull to crypto_stub.
explicit FedSqlConfidentialTransform(
oak::containers::v1::OrchestratorCrypto::StubInterface* crypto_stub,
int max_num_sessions, long max_session_memory_bytes)
int max_num_sessions)
: crypto_stub_(*ABSL_DIE_IF_NULL(crypto_stub)),
session_tracker_(max_num_sessions, max_session_memory_bytes) {}
session_tracker_(max_num_sessions) {}

grpc::Status Initialize(
grpc::ServerContext* context,
Expand All @@ -67,8 +67,7 @@ class FedSqlConfidentialTransform final
absl::Status FedSqlSession(
grpc::ServerReaderWriter<fcp::confidentialcompute::SessionResponse,
fcp::confidentialcompute::SessionRequest>*
stream,
long stream_memory);
stream);

oak::containers::v1::OrchestratorCrypto::StubInterface& crypto_stub_;
confidential_federated_compute::SessionTracker session_tracker_;
Expand Down
13 changes: 3 additions & 10 deletions containers/fed_sql/confidential_transform_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ using ::testing::Test;
using testing::UnorderedElementsAre;

inline constexpr int kMaxNumSessions = 8;
inline constexpr long kMaxSessionMemoryBytes = 1000000;

std::string BuildSingleInt32TensorCheckpoint(
std::string column_name, std::initializer_list<int32_t> input_values) {
Expand Down Expand Up @@ -166,8 +165,7 @@ class FedSqlServerTest : public Test {
BlobMetadata DefaultBlobMetadata() const;

testing::NiceMock<MockOrchestratorCryptoStub> mock_crypto_stub_;
FedSqlConfidentialTransform service_{&mock_crypto_stub_, kMaxNumSessions,
kMaxSessionMemoryBytes};
FedSqlConfidentialTransform service_{&mock_crypto_stub_, kMaxNumSessions};
std::unique_ptr<Server> server_;
std::unique_ptr<ConfidentialTransform::Stub> stub_;
};
Expand Down Expand Up @@ -408,8 +406,6 @@ TEST_F(FedSqlServerTest, SessionConfigureGeneratesNonce) {

ASSERT_TRUE(session_response.has_configure());
ASSERT_GT(session_response.configure().nonce().size(), 0);
ASSERT_EQ(session_response.configure().write_capacity_bytes(),
kMaxSessionMemoryBytes);
}

TEST_F(FedSqlServerTest, SessionRejectsMoreThanMaximumNumSessions) {
Expand Down Expand Up @@ -453,7 +449,8 @@ TEST_F(FedSqlServerTest, SessionRejectsMoreThanMaximumNumSessions) {
stream = stub_->Session(&rejected_context);
ASSERT_TRUE(stream->Write(rejected_request));
ASSERT_FALSE(stream->Read(&rejected_response));
ASSERT_EQ(stream->Finish().error_code(), grpc::StatusCode::UNAVAILABLE);
ASSERT_EQ(stream->Finish().error_code(),
grpc::StatusCode::FAILED_PRECONDITION);
}

TEST_F(FedSqlServerTest, SessionBeforeInitialize) {
Expand Down Expand Up @@ -635,8 +632,6 @@ TEST_F(FedSqlServerFederatedSumTest, SessionWriteAccumulateCommitsBlob) {
ASSERT_TRUE(write_response.has_write());
ASSERT_EQ(write_response.write().committed_size_bytes(), data.size());
ASSERT_EQ(write_response.write().status().code(), grpc::OK);
ASSERT_EQ(write_response.write().write_capacity_bytes(),
kMaxSessionMemoryBytes);
}

TEST_F(FedSqlServerFederatedSumTest, SessionAccumulatesAndReports) {
Expand Down Expand Up @@ -789,8 +784,6 @@ TEST_F(FedSqlServerFederatedSumTest, SessionIgnoresUnparseableInputs) {
ASSERT_TRUE(write_response_2.has_write());
ASSERT_EQ(write_response_2.write().committed_size_bytes(), 0);
ASSERT_EQ(write_response_2.write().status().code(), grpc::INVALID_ARGUMENT);
ASSERT_EQ(write_response_2.write().write_capacity_bytes(),
kMaxSessionMemoryBytes);

FedSqlContainerFinalizeConfiguration finalize_config = PARSE_TEXT_PROTO(R"pb(
type: FINALIZATION_TYPE_REPORT
Expand Down
9 changes: 4 additions & 5 deletions containers/session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ using ::fcp::base::ToGrpcStatus;
using ::fcp::confidentialcompute::SessionResponse;
using ::fcp::confidentialcompute::WriteFinishedResponse;

long SessionTracker::AddSession() {
absl::Status SessionTracker::AddSession() {
absl::MutexLock l(&mutex_);
if (num_sessions_ < max_num_sessions_) {
num_sessions_++;
return max_session_memory_bytes_;
return absl::OkStatus();
}
return 0;
return absl::FailedPreconditionError(
"SessionTracker: already at the maximum number of sessions.");
}

absl::Status SessionTracker::RemoveSession() {
Expand All @@ -43,14 +44,12 @@ absl::Status SessionTracker::RemoveSession() {
}

SessionResponse ToSessionWriteFinishedResponse(absl::Status status,
long available_memory,
long committed_size_bytes) {
grpc::Status grpc_status = ToGrpcStatus(std::move(status));
SessionResponse session_response;
WriteFinishedResponse* response = session_response.mutable_write();
response->mutable_status()->set_code(grpc_status.error_code());
response->mutable_status()->set_message(grpc_status.error_message());
response->set_write_capacity_bytes(available_memory);
response->set_committed_size_bytes(committed_size_bytes);
return session_response;
}
Expand Down
8 changes: 3 additions & 5 deletions containers/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,11 @@ namespace confidential_federated_compute {
// This class is threadsafe.
class SessionTracker {
public:
SessionTracker(int max_num_sessions, long max_session_memory_bytes)
: max_num_sessions_(max_num_sessions),
max_session_memory_bytes_(max_session_memory_bytes) {};
SessionTracker(int max_num_sessions) : max_num_sessions_(max_num_sessions) {};

// Tries to add a session and returns the amount of memory in bytes that the
// session is allowed. Returns 0 if there is no available memory.
long AddSession();
absl::Status AddSession();

// Tries to remove a session and returns an error if unable to do so.
absl::Status RemoveSession();
Expand All @@ -48,7 +46,7 @@ class SessionTracker {

// Create a SessionResponse with a WriteFinishedResponse.
fcp::confidentialcompute::SessionResponse ToSessionWriteFinishedResponse(
absl::Status status, long available_memory, long committed_size_bytes = 0);
absl::Status status, long committed_size_bytes = 0);

} // namespace confidential_federated_compute
#endif // CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_SESSION_H_
34 changes: 17 additions & 17 deletions containers/session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,47 +23,47 @@ namespace {
using ::fcp::confidentialcompute::SessionResponse;

TEST(SessionTest, AddSession) {
SessionTracker session_tracker(1, 100);
ASSERT_GT(session_tracker.AddSession(), 0);
SessionTracker session_tracker(1);
EXPECT_TRUE(session_tracker.AddSession().ok());
}

TEST(SessionTest, MaximumSessionsReachedAddSession) {
SessionTracker session_tracker(1, 100);
ASSERT_GT(session_tracker.AddSession(), 0);
ASSERT_EQ(session_tracker.AddSession(), 0);
SessionTracker session_tracker(1);
EXPECT_TRUE(session_tracker.AddSession().ok());
EXPECT_EQ(session_tracker.AddSession().code(),
absl::StatusCode::kFailedPrecondition);
}

TEST(SessionTest, MaximumSessionsReachedCanAddSessionAfterRemoveSession) {
SessionTracker session_tracker(1, 100);
ASSERT_GT(session_tracker.AddSession(), 0);
ASSERT_EQ(session_tracker.AddSession(), 0);
ASSERT_TRUE(session_tracker.RemoveSession().ok());
ASSERT_GT(session_tracker.AddSession(), 0);
SessionTracker session_tracker(1);
EXPECT_TRUE(session_tracker.AddSession().ok());
EXPECT_EQ(session_tracker.AddSession().code(),
absl::StatusCode::kFailedPrecondition);
EXPECT_TRUE(session_tracker.RemoveSession().ok());
EXPECT_TRUE(session_tracker.AddSession().ok());
}

TEST(SessionTest, RemoveSessionWithoutAddSessionFails) {
SessionTracker session_tracker(1, 100);
ASSERT_EQ(session_tracker.RemoveSession().code(),
SessionTracker session_tracker(1);
EXPECT_EQ(session_tracker.RemoveSession().code(),
absl::StatusCode::kFailedPrecondition);
}

TEST(SessionTest, ErrorToSessionWriteFinishedResponseTest) {
SessionResponse response = ToSessionWriteFinishedResponse(
absl::InvalidArgumentError("invalid arg"), 42);
SessionResponse response =
ToSessionWriteFinishedResponse(absl::InvalidArgumentError("invalid arg"));
ASSERT_TRUE(response.has_write());
EXPECT_EQ(response.write().committed_size_bytes(), 0);
EXPECT_EQ(response.write().write_capacity_bytes(), 42);
EXPECT_EQ(response.write().status().code(),
grpc::StatusCode::INVALID_ARGUMENT);
EXPECT_EQ(response.write().status().message(), "invalid arg");
}

TEST(SessionTest, OkToSessionWriteFinishedResponseTest) {
SessionResponse response =
ToSessionWriteFinishedResponse(absl::OkStatus(), 42, 6);
ToSessionWriteFinishedResponse(absl::OkStatus(), 6);
ASSERT_TRUE(response.has_write());
EXPECT_EQ(response.write().committed_size_bytes(), 6);
EXPECT_EQ(response.write().write_capacity_bytes(), 42);
EXPECT_EQ(response.write().status().code(), grpc::StatusCode::OK);
EXPECT_TRUE(response.write().status().message().empty());
}
Expand Down

0 comments on commit b6dd662

Please sign in to comment.