From 6eb09411b518f6cc582b96764856c140fc1736d8 Mon Sep 17 00:00:00 2001 From: Maya Spivak Date: Tue, 23 Jul 2024 20:10:40 +0000 Subject: [PATCH] Refactor confidential_transform_server so that it can be used from multiple containers. This adds a ConfidentialTransformServer base class and a new Session interface which can be implemented to store individual session state. Note that the FedSqlSession implementation is not currently threadsafe because each session is currently handled sequentially. This will likely need to change if blobs are incorporated in parallel. Change-Id: I683eb2b40aa93d9b5534b45a1147c89f559c9e27 --- containers/BUILD | 45 ++ .../confidential_transform_server_base.cc | 220 ++++++ .../confidential_transform_server_base.h | 85 ++ ...confidential_transform_server_base_test.cc | 738 ++++++++++++++++++ containers/fed_sql/BUILD | 1 + .../fed_sql/confidential_transform_server.cc | 209 +---- .../fed_sql/confidential_transform_server.h | 77 +- containers/session.h | 22 + 8 files changed, 1191 insertions(+), 206 deletions(-) create mode 100644 containers/confidential_transform_server_base.cc create mode 100644 containers/confidential_transform_server_base.h create mode 100644 containers/confidential_transform_server_base_test.cc diff --git a/containers/BUILD b/containers/BUILD index 755a253..c8eadc8 100644 --- a/containers/BUILD +++ b/containers/BUILD @@ -154,3 +154,48 @@ cc_test( "@oak//proto/containers:interfaces_cc_proto", ], ) + +cc_library( + name = "confidential_transform_server_base", + srcs = ["confidential_transform_server_base.cc"], + hdrs = ["confidential_transform_server_base.h"], + deps = [ + ":blob_metadata", + ":crypto", + ":session", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_protobuf//:protobuf", + "@federated-compute//fcp/base", + "@federated-compute//fcp/base:status_converters", + "@federated-compute//fcp/confidentialcompute:crypto", + "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_grpc", + "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_proto", + "@federated-compute//fcp/protos/confidentialcompute:fed_sql_container_config_cc_proto", + "@oak//proto/containers:orchestrator_crypto_cc_grpc", + ], +) + +cc_test( + name = "confidential_transform_server_base_test", + size = "small", + srcs = ["confidential_transform_server_base_test.cc"], + deps = [ + ":blob_metadata", + ":confidential_transform_server_base", + ":crypto", + ":crypto_test_utils", + "//testing:parse_text_proto", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@federated-compute//fcp/confidentialcompute:crypto", + "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_grpc", + "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_proto", + "@federated-compute//fcp/protos/confidentialcompute:fed_sql_container_config_cc_proto", + "@googletest//:gtest_main", + ], +) diff --git a/containers/confidential_transform_server_base.cc b/containers/confidential_transform_server_base.cc new file mode 100644 index 0000000..443f62e --- /dev/null +++ b/containers/confidential_transform_server_base.cc @@ -0,0 +1,220 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "containers/confidential_transform_server_base.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/die_if_null.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "containers/blob_metadata.h" +#include "containers/crypto.h" +#include "containers/session.h" +#include "fcp/base/status_converters.h" +#include "fcp/confidentialcompute/crypto.h" +#include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" +#include "fcp/protos/confidentialcompute/confidential_transform.pb.h" +#include "google/protobuf/repeated_ptr_field.h" +#include "grpcpp/support/status.h" + +namespace confidential_federated_compute { + +using ::fcp::base::ToGrpcStatus; +using ::fcp::confidential_compute::NonceChecker; +using ::fcp::confidentialcompute::BlobMetadata; +using ::fcp::confidentialcompute::ConfidentialTransform; +using ::fcp::confidentialcompute::InitializeRequest; +using ::fcp::confidentialcompute::InitializeResponse; +using ::fcp::confidentialcompute::SessionRequest; +using ::fcp::confidentialcompute::SessionResponse; +using ::fcp::confidentialcompute::WriteRequest; +using ::grpc::ServerContext; + +namespace { + +// Decrypts and parses a record and incorporates the record into the session. +// +// Reports status to the client in WriteFinishedResponse. +// +// TODO: handle blobs that span multiple WriteRequests. +absl::Status HandleWrite( + const WriteRequest& request, BlobDecryptor* blob_decryptor, + NonceChecker& nonce_checker, + grpc::ServerReaderWriter* stream, + Session* session) { + if (absl::Status nonce_status = + nonce_checker.CheckBlobNonce(request.first_request_metadata()); + !nonce_status.ok()) { + stream->Write(ToSessionWriteFinishedResponse(nonce_status)); + return absl::OkStatus(); + } + + absl::StatusOr unencrypted_data = blob_decryptor->DecryptBlob( + request.first_request_metadata(), request.data()); + if (!unencrypted_data.ok()) { + stream->Write(ToSessionWriteFinishedResponse(unencrypted_data.status())); + return absl::OkStatus(); + } + + FCP_ASSIGN_OR_RETURN( + SessionResponse response, + session->SessionWrite(request, std::move(unencrypted_data.value()))); + + stream->Write(response); + return absl::OkStatus(); +} + +} // namespace + +absl::Status ConfidentialTransformBase::InitializeInternal( + const fcp::confidentialcompute::InitializeRequest* request, + fcp::confidentialcompute::InitializeResponse* response) { + FCP_ASSIGN_OR_RETURN(google::protobuf::Struct config_properties, + InitializeTransform(request)); + const BlobDecryptor* blob_decryptor; + { + absl::MutexLock l(&mutex_); + if (blob_decryptor_ != std::nullopt) { + return absl::FailedPreconditionError( + "Initialize can only be called once."); + } + blob_decryptor_.emplace(crypto_stub_, config_properties); + + // Since blob_decryptor_ is set once in Initialize and never + // modified, and the underlying object is threadsafe, it is safe to store a + // local pointer to it and access the object without a lock after we check + // under the mutex that a value has been set for the std::optional wrapper. + blob_decryptor = &*blob_decryptor_; + } + + FCP_ASSIGN_OR_RETURN(*response->mutable_public_key(), + blob_decryptor->GetPublicKey()); + return absl::OkStatus(); +} + +absl::Status ConfidentialTransformBase::SessionInternal( + grpc::ServerReaderWriter* stream) { + BlobDecryptor* blob_decryptor; + { + absl::MutexLock l(&mutex_); + if (blob_decryptor_ == std::nullopt) { + return absl::FailedPreconditionError( + "Initialize must be called before Session."); + } + + // Since blob_decryptor_ is set once in Initialize and never + // modified, and the underlying object is threadsafe, it is safe to store a + // local pointer to it and access the object without a lock after we check + // under the mutex that values have been set for the std::optional wrappers. + blob_decryptor = &*blob_decryptor_; + } + + SessionRequest configure_request; + bool success = stream->Read(&configure_request); + if (!success) { + return absl::AbortedError("Session failed to read client message."); + } + + if (!configure_request.has_configure()) { + return absl::FailedPreconditionError( + "Session must be configured with a ConfigureRequest before any other " + "requests."); + } + FCP_ASSIGN_OR_RETURN( + std::unique_ptr session, + CreateSession()); + FCP_RETURN_IF_ERROR(session->ConfigureSession(configure_request)); + SessionResponse configure_response; + NonceChecker nonce_checker; + *configure_response.mutable_configure()->mutable_nonce() = + nonce_checker.GetSessionNonce(); + stream->Write(configure_response); + + // Initialze result_blob_metadata with unencrypted metadata since + // EarliestExpirationTimeMetadata expects inputs to have either unencrypted or + // hpke_plus_aead_data. + BlobMetadata result_blob_metadata; + result_blob_metadata.mutable_unencrypted(); + SessionRequest session_request; + while (stream->Read(&session_request)) { + switch (session_request.kind_case()) { + case SessionRequest::kWrite: { + const WriteRequest& write_request = session_request.write(); + // Use the metadata with the earliest expiration timestamp for + // encrypting any results. + absl::StatusOr earliest_expiration_metadata = + EarliestExpirationTimeMetadata( + result_blob_metadata, write_request.first_request_metadata()); + if (!earliest_expiration_metadata.ok()) { + stream->Write(ToSessionWriteFinishedResponse(absl::Status( + earliest_expiration_metadata.status().code(), + absl::StrCat("Failed to extract expiration timestamp from " + "BlobMetadata with encryption: ", + earliest_expiration_metadata.status().message())))); + break; + } + result_blob_metadata = *earliest_expiration_metadata; + // TODO: spin up a thread to incorporate each blob. + FCP_RETURN_IF_ERROR(HandleWrite(write_request, blob_decryptor, + nonce_checker, stream, session.get())); + break; + } + case SessionRequest::kFinalize: { + FCP_ASSIGN_OR_RETURN( + SessionResponse finalize_response, + session->FinalizeSession(session_request.finalize(), + result_blob_metadata)); + stream->Write(finalize_response); + return absl::OkStatus(); + } + case SessionRequest::kConfigure: + default: + return absl::FailedPreconditionError(absl::StrCat( + "Session expected a write request but received request of type: ", + session_request.kind_case())); + } + } + + return absl::AbortedError( + "Session failed to read client write or finalize message."); +} + +grpc::Status ConfidentialTransformBase::Initialize( + ServerContext* context, const InitializeRequest* request, + InitializeResponse* response) { + return ToGrpcStatus(InitializeInternal(request, response)); +} + +grpc::Status ConfidentialTransformBase::Session( + ServerContext* context, + grpc::ServerReaderWriter* stream) { + if (absl::Status session_status = session_tracker_.AddSession(); + !session_status.ok()) { + return ToGrpcStatus(session_status); + } + grpc::Status status = ToGrpcStatus(SessionInternal(stream)); + absl::Status remove_session = session_tracker_.RemoveSession(); + if (!remove_session.ok()) { + return ToGrpcStatus(remove_session); + } + return status; +} + +} // namespace confidential_federated_compute diff --git a/containers/confidential_transform_server_base.h b/containers/confidential_transform_server_base.h new file mode 100644 index 0000000..86b20ef --- /dev/null +++ b/containers/confidential_transform_server_base.h @@ -0,0 +1,85 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_CONFIDENTIAL_TRANSFORM_SERVER_BASE_H_ +#define CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_CONFIDENTIAL_TRANSFORM_SERVER_BASE_H_ + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/log/die_if_null.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "containers/crypto.h" +#include "containers/session.h" +#include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" +#include "fcp/protos/confidentialcompute/confidential_transform.pb.h" +#include "google/protobuf/repeated_ptr_field.h" +#include "grpcpp/server_context.h" +#include "grpcpp/support/status.h" +#include "proto/containers/orchestrator_crypto.grpc.pb.h" + +namespace confidential_federated_compute { + +// Base class that implements the ConfidentialTransform service protocol. +class ConfidentialTransformBase + : public fcp::confidentialcompute::ConfidentialTransform::Service { + public: + grpc::Status Initialize( + grpc::ServerContext* context, + const fcp::confidentialcompute::InitializeRequest* request, + fcp::confidentialcompute::InitializeResponse* response) override; + + grpc::Status Session( + grpc::ServerContext* context, + grpc::ServerReaderWriter* + stream) override; + + protected: + ConfidentialTransformBase( + oak::containers::v1::OrchestratorCrypto::StubInterface* crypto_stub, + int max_num_sessions) + : crypto_stub_(*ABSL_DIE_IF_NULL(crypto_stub)), + session_tracker_(max_num_sessions) {} + + virtual absl::StatusOr InitializeTransform( + const fcp::confidentialcompute::InitializeRequest* request) = 0; + virtual absl::StatusOr< + std::unique_ptr> + CreateSession() = 0; + + private: + absl::Status InitializeInternal( + const fcp::confidentialcompute::InitializeRequest* request, + fcp::confidentialcompute::InitializeResponse* response); + + absl::Status SessionInternal( + grpc::ServerReaderWriter* + stream); + + oak::containers::v1::OrchestratorCrypto::StubInterface& crypto_stub_; + confidential_federated_compute::SessionTracker session_tracker_; + absl::Mutex mutex_; + // The mutex is used to protect the optional wrapping blob_decryptor_ to + // ensure the BlobDecryptor is initialized, but the BlobDecryptor is itself + // threadsafe. + std::optional blob_decryptor_ + ABSL_GUARDED_BY(mutex_); +}; + +} // namespace confidential_federated_compute + +#endif // CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_CONFIDENTIAL_TRANSFORM_SERVER_BASE_H_ diff --git a/containers/confidential_transform_server_base_test.cc b/containers/confidential_transform_server_base_test.cc new file mode 100644 index 0000000..1e30b1b --- /dev/null +++ b/containers/confidential_transform_server_base_test.cc @@ -0,0 +1,738 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "containers/confidential_transform_server_base.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_format.h" +#include "containers/blob_metadata.h" +#include "containers/crypto.h" +#include "containers/crypto_test_utils.h" +#include "containers/session.h" +#include "fcp/confidentialcompute/cose.h" +#include "fcp/confidentialcompute/crypto.h" +#include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" +#include "fcp/protos/confidentialcompute/confidential_transform.pb.h" +#include "gmock/gmock.h" +#include "google/protobuf/repeated_ptr_field.h" +#include "grpcpp/channel.h" +#include "grpcpp/client_context.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/server.h" +#include "grpcpp/server_builder.h" +#include "grpcpp/server_context.h" +#include "gtest/gtest.h" +#include "proto/containers/orchestrator_crypto_mock.grpc.pb.h" +#include "testing/parse_text_proto.h" + +namespace confidential_federated_compute { + +namespace { + +using ::fcp::confidential_compute::MessageDecryptor; +using ::fcp::confidential_compute::NonceAndCounter; +using ::fcp::confidential_compute::NonceGenerator; +using ::fcp::confidential_compute::OkpCwt; +using ::fcp::confidentialcompute::BlobHeader; +using ::fcp::confidentialcompute::BlobMetadata; +using ::fcp::confidentialcompute::ConfidentialTransform; +using ::fcp::confidentialcompute::FinalizeRequest; +using ::fcp::confidentialcompute::InitializeRequest; +using ::fcp::confidentialcompute::InitializeResponse; +using ::fcp::confidentialcompute::ReadResponse; +using ::fcp::confidentialcompute::Record; +using ::fcp::confidentialcompute::SessionRequest; +using ::fcp::confidentialcompute::SessionResponse; +using ::fcp::confidentialcompute::WriteRequest; +using ::grpc::Server; +using ::grpc::ServerBuilder; +using ::grpc::ServerContext; +using ::grpc::StatusCode; +using ::oak::containers::v1::MockOrchestratorCryptoStub; +using ::testing::_; +using ::testing::HasSubstr; +using ::testing::Return; +using ::testing::Test; + +inline constexpr int kMaxNumSessions = 8; + +class MockSession final : public confidential_federated_compute::Session { + public: + MOCK_METHOD(absl::Status, ConfigureSession, + (SessionRequest configure_request), (override)); + MOCK_METHOD(absl::StatusOr, SessionWrite, + (const WriteRequest& write_request, std::string unencrypted_data), + (override)); + MOCK_METHOD(absl::StatusOr, FinalizeSession, + (const FinalizeRequest& request, + const BlobMetadata& input_metadata), + (override)); +}; + +SessionRequest CreateDefaultWriteRequest(std::string data) { + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + metadata.set_total_size_bytes(data.size()); + SessionRequest request; + WriteRequest* write_request = request.mutable_write(); + *write_request->mutable_first_request_metadata() = metadata; + write_request->set_commit(true); + write_request->set_data(data); + return request; +} + +SessionResponse GetDefaultFinalizeResponse() { + SessionResponse response; + ReadResponse* read_response = response.mutable_read(); + read_response->set_finish_read(true); + std::string result = "test result"; + *(read_response->mutable_data()) = result; + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + metadata.set_total_size_bytes(result.size()); + *(read_response->mutable_first_response_metadata()) = metadata; + return response; +} + +class FakeConfidentialTransform final + : public confidential_federated_compute::ConfidentialTransformBase { + public: + FakeConfidentialTransform( + oak::containers::v1::OrchestratorCrypto::StubInterface* crypto_stub, + int max_num_sessions) + : ConfidentialTransformBase(crypto_stub, max_num_sessions) {}; + + void AddSession( + std::unique_ptr session) { + session_ = std::move(session); + }; + + protected: + virtual absl::StatusOr InitializeTransform( + const fcp::confidentialcompute::InitializeRequest* request) { + google::rpc::Status config_status; + if (!request->configuration().UnpackTo(&config_status)) { + return absl::InvalidArgumentError("Config cannot be unpacked."); + } + if (config_status.code() != grpc::StatusCode::OK) { + return absl::InvalidArgumentError("Invalid config."); + } + return google::protobuf::Struct(); + } + + virtual absl::StatusOr< + std::unique_ptr> + CreateSession() { + if (session_ == nullptr) { + auto session = + std::make_unique(); + EXPECT_CALL(*session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + return std::move(session); + } + return std::move(session_); + } + + private: + std::unique_ptr session_; +}; + +class ConfidentialTransformServerBaseTest : public Test { + public: + ConfidentialTransformServerBaseTest() { + int port; + const std::string server_address = "[::1]:"; + ServerBuilder builder; + builder.AddListeningPort(server_address + "0", + grpc::InsecureServerCredentials(), &port); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + LOG(INFO) << "Server listening on " << server_address + std::to_string(port) + << std::endl; + stub_ = ConfidentialTransform::NewStub( + grpc::CreateChannel(server_address + std::to_string(port), + grpc::InsecureChannelCredentials())); + } + + ~ConfidentialTransformServerBaseTest() override { server_->Shutdown(); } + + protected: + testing::NiceMock mock_crypto_stub_; + FakeConfidentialTransform service_{&mock_crypto_stub_, kMaxNumSessions}; + std::unique_ptr server_; + std::unique_ptr stub_; +}; + +TEST_F(ConfidentialTransformServerBaseTest, InitializeRequestWrongMessageType) { + grpc::ClientContext context; + google::protobuf::Value value; + InitializeRequest request; + InitializeResponse response; + request.mutable_configuration()->PackFrom(value); + + auto status = stub_->Initialize(&context, request, &response); + ASSERT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); + ASSERT_THAT(status.error_message(), HasSubstr("Config cannot be unpacked.")); +} + +TEST_F(ConfidentialTransformServerBaseTest, InitializeMoreThanOnce) { + grpc::ClientContext context; + InitializeRequest request; + InitializeResponse response; + google::rpc::Status config_status; + config_status.set_code(grpc::StatusCode::OK); + request.mutable_configuration()->PackFrom(config_status); + + ASSERT_TRUE(stub_->Initialize(&context, request, &response).ok()); + + grpc::ClientContext second_context; + auto status = stub_->Initialize(&second_context, request, &response); + + ASSERT_EQ(status.error_code(), grpc::StatusCode::FAILED_PRECONDITION); + ASSERT_THAT(status.error_message(), + HasSubstr("Initialize can only be called once")); +} + +TEST_F(ConfidentialTransformServerBaseTest, ValidInitialize) { + grpc::ClientContext context; + InitializeRequest request; + InitializeResponse response; + google::rpc::Status config_status; + config_status.set_code(grpc::StatusCode::OK); + request.mutable_configuration()->PackFrom(config_status); + + ASSERT_TRUE(stub_->Initialize(&context, request, &response).ok()); + + absl::StatusOr cwt = OkpCwt::Decode(response.public_key()); + ASSERT_TRUE(cwt.ok()); +} + +TEST_F(ConfidentialTransformServerBaseTest, SessionConfigureGeneratesNonce) { + grpc::ClientContext configure_context; + InitializeRequest request; + InitializeResponse response; + google::rpc::Status config_status; + config_status.set_code(grpc::StatusCode::OK); + request.mutable_configuration()->PackFrom(config_status); + + ASSERT_TRUE(stub_->Initialize(&configure_context, request, &response).ok()); + + grpc::ClientContext session_context; + SessionRequest session_request; + SessionResponse session_response; + session_request.mutable_configure(); + + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, FinalizeSession(_, _)) + .WillOnce(Return(GetDefaultFinalizeResponse())); + service_.AddSession(std::move(mock_session)); + + std::unique_ptr<::grpc::ClientReaderWriter> + stream = stub_->Session(&session_context); + ASSERT_TRUE(stream->Write(session_request)); + ASSERT_TRUE(stream->Read(&session_response)); + + ASSERT_TRUE(session_response.has_configure()); + ASSERT_GT(session_response.configure().nonce().size(), 0); + + google::rpc::Status config; + config.set_code(grpc::StatusCode::OK); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + config); + ASSERT_TRUE(stream->Write(finalize_request)); + ASSERT_TRUE(stream->Read(&finalize_response)); + ASSERT_TRUE(stream->Finish().ok()); +} + +TEST_F(ConfidentialTransformServerBaseTest, + SessionRejectsMoreThanMaximumNumSessions) { + grpc::ClientContext configure_context; + InitializeRequest request; + InitializeResponse response; + google::rpc::Status config_status; + config_status.set_code(grpc::StatusCode::OK); + request.mutable_configuration()->PackFrom(config_status); + + ASSERT_TRUE(stub_->Initialize(&configure_context, request, &response).ok()); + + std::vector>> + streams; + std::vector> contexts; + for (int i = 0; i < kMaxNumSessions; i++) { + std::unique_ptr session_context = + std::make_unique(); + SessionRequest session_request; + SessionResponse session_response; + session_request.mutable_configure(); + + std::unique_ptr<::grpc::ClientReaderWriter> + stream = stub_->Session(session_context.get()); + ASSERT_TRUE(stream->Write(session_request)); + ASSERT_TRUE(stream->Read(&session_response)); + + // Keep the context and stream so they don't go out of scope and end the + // session. + contexts.emplace_back(std::move(session_context)); + streams.emplace_back(std::move(stream)); + } + + grpc::ClientContext rejected_context; + SessionRequest rejected_request; + SessionResponse rejected_response; + rejected_request.mutable_configure(); + + std::unique_ptr<::grpc::ClientReaderWriter> + stream = stub_->Session(&rejected_context); + ASSERT_TRUE(stream->Write(rejected_request)); + ASSERT_FALSE(stream->Read(&rejected_response)); + ASSERT_EQ(stream->Finish().error_code(), + grpc::StatusCode::FAILED_PRECONDITION); +} + +TEST_F(ConfidentialTransformServerBaseTest, SessionBeforeInitialize) { + grpc::ClientContext session_context; + SessionRequest configure_request; + SessionResponse configure_response; + configure_request.mutable_configure()->mutable_configuration(); + + std::unique_ptr<::grpc::ClientReaderWriter> + stream = stub_->Session(&session_context); + ASSERT_TRUE(stream->Write(configure_request)); + ASSERT_FALSE(stream->Read(&configure_response)); + ASSERT_EQ(stream->Finish().error_code(), + grpc::StatusCode::FAILED_PRECONDITION); +} + +class InitializedConfidentialTransformServerBaseTest + : public ConfidentialTransformServerBaseTest { + public: + InitializedConfidentialTransformServerBaseTest() { + grpc::ClientContext configure_context; + InitializeRequest request; + InitializeResponse response; + google::rpc::Status config_status; + config_status.set_code(grpc::StatusCode::OK); + request.mutable_configuration()->PackFrom(config_status); + + CHECK(stub_->Initialize(&configure_context, request, &response).ok()); + public_key_ = response.public_key(); + } + + protected: + void StartSession() { + SessionRequest session_request; + SessionResponse session_response; + session_request.mutable_configure(); + + stream_ = stub_->Session(&session_context_); + CHECK(stream_->Write(session_request)); + CHECK(stream_->Read(&session_response)); + nonce_generator_ = + std::make_unique(session_response.configure().nonce()); + } + grpc::ClientContext session_context_; + std::unique_ptr<::grpc::ClientReaderWriter> + stream_; + std::unique_ptr nonce_generator_; + std::string public_key_; +}; + +TEST_F(InitializedConfidentialTransformServerBaseTest, + SessionWritesAndFinalizes) { + std::string data = "test data"; + SessionRequest write_request = CreateDefaultWriteRequest(data); + SessionResponse write_response; + + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, SessionWrite(_, _)) + .WillRepeatedly(Return( + ToSessionWriteFinishedResponse(absl::OkStatus(), data.size()))); + EXPECT_CALL(*mock_session, FinalizeSession(_, _)) + .WillOnce(Return(GetDefaultFinalizeResponse())); + service_.AddSession(std::move(mock_session)); + StartSession(); + + // Accumulate the same unencrypted blob twice. + ASSERT_TRUE(stream_->Write(write_request)); + ASSERT_TRUE(stream_->Read(&write_response)); + ASSERT_TRUE(stream_->Write(write_request)); + ASSERT_TRUE(stream_->Read(&write_response)); + + google::rpc::Status config; + config.set_code(grpc::StatusCode::OK); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + config); + ASSERT_TRUE(stream_->Write(finalize_request)); + ASSERT_TRUE(stream_->Read(&finalize_response)); + ASSERT_TRUE(stream_->Finish().ok()); + + ASSERT_TRUE(finalize_response.has_read()); + ASSERT_TRUE(finalize_response.read().finish_read()); + ASSERT_GT( + finalize_response.read().first_response_metadata().total_size_bytes(), 0); + ASSERT_TRUE( + finalize_response.read().first_response_metadata().has_unencrypted()); +} + +TEST_F(InitializedConfidentialTransformServerBaseTest, + SessionIgnoresInvalidInputs) { + std::string data = "test data"; + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, SessionWrite(_, _)) + .WillOnce( + Return(ToSessionWriteFinishedResponse(absl::OkStatus(), data.size()))) + .WillOnce(Return(ToSessionWriteFinishedResponse( + absl::InvalidArgumentError("Invalid argument"), 0))); + EXPECT_CALL(*mock_session, FinalizeSession(_, _)) + .WillOnce(Return(GetDefaultFinalizeResponse())); + service_.AddSession(std::move(mock_session)); + StartSession(); + + SessionRequest write_request_1 = CreateDefaultWriteRequest(data); + SessionResponse write_response_1; + + ASSERT_TRUE(stream_->Write(write_request_1)); + ASSERT_TRUE(stream_->Read(&write_response_1)); + + SessionRequest write_request_2 = CreateDefaultWriteRequest(data); + SessionResponse write_response_2; + + ASSERT_TRUE(stream_->Write(write_request_2)); + ASSERT_TRUE(stream_->Read(&write_response_2)); + + ASSERT_TRUE(write_response_2.has_write()); + ASSERT_EQ(write_response_2.write().committed_size_bytes(), 0); + ASSERT_EQ(write_response_2.write().status().code(), grpc::INVALID_ARGUMENT); + + google::rpc::Status config; + config.set_code(grpc::StatusCode::OK); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + config); + ASSERT_TRUE(stream_->Write(finalize_request)); + ASSERT_TRUE(stream_->Read(&finalize_response)); + + ASSERT_TRUE(finalize_response.has_read()); + ASSERT_TRUE(finalize_response.read().finish_read()); + ASSERT_GT( + finalize_response.read().first_response_metadata().total_size_bytes(), 0); + ASSERT_TRUE( + finalize_response.read().first_response_metadata().has_unencrypted()); +} + +TEST_F(InitializedConfidentialTransformServerBaseTest, + SessionFailsIfWriteFails) { + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, SessionWrite(_, _)) + .WillOnce(Return(absl::InternalError("Internal Error"))); + service_.AddSession(std::move(mock_session)); + StartSession(); + + SessionRequest write_request_1 = CreateDefaultWriteRequest("test data"); + SessionResponse write_response_1; + + ASSERT_TRUE(stream_->Write(write_request_1)); + ASSERT_FALSE(stream_->Read(&write_response_1)); + ASSERT_EQ(stream_->Finish().error_code(), grpc::StatusCode::INTERNAL); +} + +TEST_F(InitializedConfidentialTransformServerBaseTest, + SessionFailsIfFinalizeFails) { + std::string data = "test data"; + SessionRequest write_request = CreateDefaultWriteRequest(data); + SessionResponse write_response; + + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, SessionWrite(_, _)) + .WillRepeatedly(Return( + ToSessionWriteFinishedResponse(absl::OkStatus(), data.size()))); + EXPECT_CALL(*mock_session, FinalizeSession(_, _)) + .WillOnce(Return(absl::InternalError("Internal Error"))); + service_.AddSession(std::move(mock_session)); + StartSession(); + + // Accumulate the same unencrypted blob twice. + ASSERT_TRUE(stream_->Write(write_request)); + ASSERT_TRUE(stream_->Read(&write_response)); + ASSERT_TRUE(stream_->Write(write_request)); + ASSERT_TRUE(stream_->Read(&write_response)); + + google::rpc::Status config; + config.set_code(grpc::StatusCode::INTERNAL); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + config); + ASSERT_TRUE(stream_->Write(finalize_request)); + ASSERT_FALSE(stream_->Read(&finalize_response)); + ASSERT_EQ(stream_->Finish().error_code(), grpc::StatusCode::INTERNAL); +} + +TEST_F(InitializedConfidentialTransformServerBaseTest, + SessionDecryptsMultipleRecords) { + std::string message_0 = "test data 0"; + std::string message_1 = "test data 1"; + std::string message_2 = "test data 2"; + + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, SessionWrite(_, _)) + .WillOnce(Return( + ToSessionWriteFinishedResponse(absl::OkStatus(), message_0.size()))) + .WillOnce(Return( + ToSessionWriteFinishedResponse(absl::OkStatus(), message_1.size()))) + .WillOnce(Return( + ToSessionWriteFinishedResponse(absl::OkStatus(), message_2.size()))); + EXPECT_CALL(*mock_session, FinalizeSession(_, _)) + .WillOnce(Return(GetDefaultFinalizeResponse())); + service_.AddSession(std::move(mock_session)); + StartSession(); + + MessageDecryptor decryptor; + absl::StatusOr reencryption_public_key = + decryptor.GetPublicKey([](absl::string_view) { return ""; }, 0); + ASSERT_TRUE(reencryption_public_key.ok()); + std::string ciphertext_associated_data = + BlobHeader::default_instance().SerializeAsString(); + + absl::StatusOr nonce_0 = + nonce_generator_->GetNextBlobNonce(); + ASSERT_TRUE(nonce_0.ok()); + absl::StatusOr rewrapped_record_0 = + crypto_test_utils::CreateRewrappedRecord( + message_0, ciphertext_associated_data, public_key_, + nonce_0->blob_nonce, *reencryption_public_key); + ASSERT_TRUE(rewrapped_record_0.ok()) << rewrapped_record_0.status(); + + SessionRequest request_0; + WriteRequest* write_request_0 = request_0.mutable_write(); + google::rpc::Status config; + config.set_code(grpc::StatusCode::OK); + *write_request_0->mutable_first_request_metadata() = + GetBlobMetadataFromRecord(*rewrapped_record_0); + write_request_0->mutable_first_request_metadata() + ->mutable_hpke_plus_aead_data() + ->set_counter(nonce_0->counter); + write_request_0->mutable_first_request_configuration()->PackFrom(config); + write_request_0->set_commit(true); + write_request_0->set_data( + rewrapped_record_0->hpke_plus_aead_data().ciphertext()); + + SessionResponse response_0; + + ASSERT_TRUE(stream_->Write(request_0)); + ASSERT_TRUE(stream_->Read(&response_0)); + ASSERT_EQ(response_0.write().status().code(), grpc::OK); + + absl::StatusOr nonce_1 = + nonce_generator_->GetNextBlobNonce(); + ASSERT_TRUE(nonce_1.ok()); + absl::StatusOr rewrapped_record_1 = + crypto_test_utils::CreateRewrappedRecord( + message_1, ciphertext_associated_data, public_key_, + nonce_1->blob_nonce, *reencryption_public_key); + ASSERT_TRUE(rewrapped_record_1.ok()) << rewrapped_record_1.status(); + + SessionRequest request_1; + WriteRequest* write_request_1 = request_1.mutable_write(); + *write_request_1->mutable_first_request_metadata() = + GetBlobMetadataFromRecord(*rewrapped_record_1); + write_request_1->mutable_first_request_metadata() + ->mutable_hpke_plus_aead_data() + ->set_counter(nonce_1->counter); + write_request_1->mutable_first_request_configuration()->PackFrom(config); + write_request_1->set_commit(true); + write_request_1->set_data( + rewrapped_record_1->hpke_plus_aead_data().ciphertext()); + + SessionResponse response_1; + + ASSERT_TRUE(stream_->Write(request_1)); + ASSERT_TRUE(stream_->Read(&response_1)); + ASSERT_EQ(response_1.write().status().code(), grpc::OK); + + absl::StatusOr nonce_2 = + nonce_generator_->GetNextBlobNonce(); + ASSERT_TRUE(nonce_2.ok()); + absl::StatusOr rewrapped_record_2 = + crypto_test_utils::CreateRewrappedRecord( + message_2, ciphertext_associated_data, public_key_, + nonce_2->blob_nonce, *reencryption_public_key); + ASSERT_TRUE(rewrapped_record_2.ok()) << rewrapped_record_2.status(); + + SessionRequest request_2; + WriteRequest* write_request_2 = request_2.mutable_write(); + *write_request_2->mutable_first_request_metadata() = + GetBlobMetadataFromRecord(*rewrapped_record_2); + write_request_2->mutable_first_request_metadata() + ->mutable_hpke_plus_aead_data() + ->set_counter(nonce_2->counter); + write_request_2->mutable_first_request_configuration()->PackFrom(config); + write_request_2->set_commit(true); + write_request_2->set_data( + rewrapped_record_2->hpke_plus_aead_data().ciphertext()); + + SessionResponse response_2; + + ASSERT_TRUE(stream_->Write(request_2)); + ASSERT_TRUE(stream_->Read(&response_2)); + ASSERT_EQ(response_2.write().status().code(), grpc::OK); + + google::rpc::Status finalize_config; + finalize_config.set_code(grpc::StatusCode::OK); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + finalize_config); + ASSERT_TRUE(stream_->Write(finalize_request)); + ASSERT_TRUE(stream_->Read(&finalize_response)); + + ASSERT_TRUE(finalize_response.has_read()); + ASSERT_TRUE(finalize_response.read().finish_read()); + ASSERT_GT( + finalize_response.read().first_response_metadata().total_size_bytes(), 0); + ASSERT_TRUE( + finalize_response.read().first_response_metadata().has_unencrypted()); +} + +TEST_F(InitializedConfidentialTransformServerBaseTest, + TransformIgnoresUndecryptableInputs) { + std::string message_1 = "test data 1"; + + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, SessionWrite(_, _)) + .WillOnce(Return( + ToSessionWriteFinishedResponse(absl::OkStatus(), message_1.size()))); + EXPECT_CALL(*mock_session, FinalizeSession(_, _)) + .WillOnce(Return(GetDefaultFinalizeResponse())); + service_.AddSession(std::move(mock_session)); + StartSession(); + + MessageDecryptor decryptor; + absl::StatusOr reencryption_public_key = + decryptor.GetPublicKey([](absl::string_view) { return ""; }, 0); + ASSERT_TRUE(reencryption_public_key.ok()); + std::string ciphertext_associated_data = "ciphertext associated data"; + + // Create one record that will fail to decrypt and one record that can be + // successfully decrypted. + std::string message_0 = "test data 0"; + absl::StatusOr nonce_0 = + nonce_generator_->GetNextBlobNonce(); + ASSERT_TRUE(nonce_0.ok()); + absl::StatusOr rewrapped_record_0 = + crypto_test_utils::CreateRewrappedRecord( + message_0, ciphertext_associated_data, public_key_, + nonce_0->blob_nonce, *reencryption_public_key); + ASSERT_TRUE(rewrapped_record_0.ok()) << rewrapped_record_0.status(); + + SessionRequest request_0; + WriteRequest* write_request_0 = request_0.mutable_write(); + google::rpc::Status config; + config.set_code(grpc::StatusCode::OK); + *write_request_0->mutable_first_request_metadata() = + GetBlobMetadataFromRecord(*rewrapped_record_0); + write_request_0->mutable_first_request_metadata() + ->mutable_hpke_plus_aead_data() + ->set_counter(nonce_0->counter); + write_request_0->mutable_first_request_configuration()->PackFrom(config); + write_request_0->set_commit(true); + write_request_0->set_data("undecryptable"); + + SessionResponse response_0; + + ASSERT_TRUE(stream_->Write(request_0)); + ASSERT_TRUE(stream_->Read(&response_0)); + ASSERT_EQ(response_0.write().status().code(), grpc::INVALID_ARGUMENT); + + absl::StatusOr nonce_1 = + nonce_generator_->GetNextBlobNonce(); + ASSERT_TRUE(nonce_1.ok()); + absl::StatusOr rewrapped_record_1 = + crypto_test_utils::CreateRewrappedRecord( + message_1, ciphertext_associated_data, public_key_, + nonce_1->blob_nonce, *reencryption_public_key); + ASSERT_TRUE(rewrapped_record_1.ok()) << rewrapped_record_1.status(); + + SessionRequest request_1; + WriteRequest* write_request_1 = request_1.mutable_write(); + *write_request_1->mutable_first_request_metadata() = + GetBlobMetadataFromRecord(*rewrapped_record_1); + write_request_1->mutable_first_request_metadata() + ->mutable_hpke_plus_aead_data() + ->set_counter(nonce_1->counter); + write_request_1->mutable_first_request_configuration()->PackFrom(config); + write_request_1->set_commit(true); + write_request_1->set_data( + rewrapped_record_1->hpke_plus_aead_data().ciphertext()); + + SessionResponse response_1; + + ASSERT_TRUE(stream_->Write(request_1)); + ASSERT_TRUE(stream_->Read(&response_1)); + ASSERT_EQ(response_1.write().status().code(), grpc::OK); + + google::rpc::Status finalize_config; + finalize_config.set_code(grpc::StatusCode::OK); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + finalize_config); + ASSERT_TRUE(stream_->Write(finalize_request)); + ASSERT_TRUE(stream_->Read(&finalize_response)); + + ASSERT_TRUE(finalize_response.has_read()); + ASSERT_TRUE(finalize_response.read().finish_read()); + ASSERT_GT( + finalize_response.read().first_response_metadata().total_size_bytes(), 0); + ASSERT_TRUE( + finalize_response.read().first_response_metadata().has_unencrypted()); +} + +} // namespace + +} // namespace confidential_federated_compute diff --git a/containers/fed_sql/BUILD b/containers/fed_sql/BUILD index d93a64e..5ded894 100644 --- a/containers/fed_sql/BUILD +++ b/containers/fed_sql/BUILD @@ -18,6 +18,7 @@ cc_library( hdrs = ["confidential_transform_server.h"], deps = [ "//containers:blob_metadata", + "//containers:confidential_transform_server_base", "//containers:crypto", "//containers:session", "@com_github_grpc_grpc//:grpc++", diff --git a/containers/fed_sql/confidential_transform_server.cc b/containers/fed_sql/confidential_transform_server.cc index 34171fd..fdfb9bc 100644 --- a/containers/fed_sql/confidential_transform_server.cc +++ b/containers/fed_sql/confidential_transform_server.cc @@ -38,7 +38,6 @@ #include "tensorflow_federated/cc/core/impl/aggregation/core/intrinsic.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/config_converter.h" -#include "tensorflow_federated/cc/core/impl/aggregation/protocol/configuration.pb.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/federated_compute_checkpoint_builder.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/federated_compute_checkpoint_parser.h" @@ -46,15 +45,10 @@ namespace confidential_federated_compute::fed_sql { namespace { -using ::fcp::base::ToGrpcStatus; -using ::fcp::confidential_compute::NonceChecker; using ::fcp::confidentialcompute::AGGREGATION_TYPE_ACCUMULATE; using ::fcp::confidentialcompute::AGGREGATION_TYPE_MERGE; using ::fcp::confidentialcompute::BlobHeader; using ::fcp::confidentialcompute::BlobMetadata; -using ::fcp::confidentialcompute::ConfidentialTransform; -using ::fcp::confidentialcompute::ConfigureRequest; -using ::fcp::confidentialcompute::ConfigureResponse; using ::fcp::confidentialcompute::FedSqlContainerFinalizeConfiguration; using ::fcp::confidentialcompute::FedSqlContainerInitializeConfiguration; using ::fcp::confidentialcompute::FedSqlContainerWriteConfiguration; @@ -62,24 +56,19 @@ using ::fcp::confidentialcompute::FINALIZATION_TYPE_REPORT; using ::fcp::confidentialcompute::FINALIZATION_TYPE_SERIALIZE; using ::fcp::confidentialcompute::FinalizeRequest; using ::fcp::confidentialcompute::InitializeRequest; -using ::fcp::confidentialcompute::InitializeResponse; using ::fcp::confidentialcompute::ReadResponse; using ::fcp::confidentialcompute::Record; -using ::fcp::confidentialcompute::SessionRequest; using ::fcp::confidentialcompute::SessionResponse; using ::fcp::confidentialcompute::WriteRequest; -using ::grpc::ServerContext; using ::tensorflow_federated::aggregation::CheckpointAggregator; using ::tensorflow_federated::aggregation::CheckpointBuilder; using ::tensorflow_federated::aggregation::CheckpointParser; -using ::tensorflow_federated::aggregation::Configuration; using ::tensorflow_federated::aggregation::DT_DOUBLE; using ::tensorflow_federated::aggregation:: FederatedComputeCheckpointBuilderFactory; using ::tensorflow_federated::aggregation:: FederatedComputeCheckpointParserFactory; using ::tensorflow_federated::aggregation::Intrinsic; -using ::tensorflow_federated::aggregation::Tensor; constexpr char kFedSqlDpGroupByUri[] = "fedsql_dp_group_by"; @@ -111,38 +100,14 @@ absl::Status ValidateTopLevelIntrinsics( return absl::OkStatus(); } -// Decrypts and parses a record and accumulates it into the state of the -// CheckpointAggregator `aggregator`. -// -// Returns an error if the aggcore state may be invalid and the session needs to -// be shut down. Otherwise, reports status to the client in -// WriteFinishedResponse -// -// TODO: handle blobs that span multiple WriteRequests. -// TODO: add tracking for available memory. -absl::Status HandleWrite( - const WriteRequest& request, CheckpointAggregator& aggregator, - BlobDecryptor* blob_decryptor, NonceChecker& nonce_checker, - grpc::ServerReaderWriter* stream, - const std::vector* intrinsics) { - if (absl::Status nonce_status = - nonce_checker.CheckBlobNonce(request.first_request_metadata()); - !nonce_status.ok()) { - stream->Write(ToSessionWriteFinishedResponse(nonce_status)); - return absl::OkStatus(); - } +} // namespace +absl::StatusOr FedSqlSession::SessionWrite( + const WriteRequest& write_request, std::string unencrypted_data) { FedSqlContainerWriteConfiguration write_config; - if (!request.first_request_configuration().UnpackTo(&write_config)) { - stream->Write(ToSessionWriteFinishedResponse(absl::InvalidArgumentError( - "Failed to parse FedSqlContainerWriteConfiguration."))); - return absl::OkStatus(); - } - absl::StatusOr unencrypted_data = blob_decryptor->DecryptBlob( - request.first_request_metadata(), request.data()); - if (!unencrypted_data.ok()) { - stream->Write(ToSessionWriteFinishedResponse(unencrypted_data.status())); - return absl::OkStatus(); + if (!write_request.first_request_configuration().UnpackTo(&write_config)) { + return ToSessionWriteFinishedResponse(absl::InvalidArgumentError( + "Failed to parse FedSqlContainerWriteConfiguration.")); } // In case of an error with Accumulate or MergeWith, the session is @@ -153,41 +118,37 @@ absl::Status HandleWrite( case AGGREGATION_TYPE_ACCUMULATE: { FederatedComputeCheckpointParserFactory parser_factory; absl::StatusOr> parser = - parser_factory.Create(absl::Cord(std::move(*unencrypted_data))); + parser_factory.Create(absl::Cord(std::move(unencrypted_data))); if (!parser.ok()) { - stream->Write(ToSessionWriteFinishedResponse( + return ToSessionWriteFinishedResponse( absl::Status(parser.status().code(), absl::StrCat("Failed to deserialize checkpoint for " "AGGREGATION_TYPE_ACCUMULATE: ", - parser.status().message())))); - return absl::OkStatus(); + parser.status().message()))); } - FCP_RETURN_IF_ERROR(aggregator.Accumulate(*parser.value())); + FCP_RETURN_IF_ERROR(aggregator_->Accumulate(*parser.value())); break; } case AGGREGATION_TYPE_MERGE: { absl::StatusOr> other = - CheckpointAggregator::Deserialize(intrinsics, *unencrypted_data); + CheckpointAggregator::Deserialize(&intrinsics_, unencrypted_data); if (!other.ok()) { - stream->Write(ToSessionWriteFinishedResponse( + return ToSessionWriteFinishedResponse( absl::Status(other.status().code(), absl::StrCat("Failed to deserialize checkpoint for " "AGGREGATION_TYPE_MERGE: ", - other.status().message())))); - return absl::OkStatus(); + other.status().message()))); } - FCP_RETURN_IF_ERROR(aggregator.MergeWith(std::move(*other.value()))); + FCP_RETURN_IF_ERROR(aggregator_->MergeWith(std::move(*other.value()))); break; } default: - stream->Write(ToSessionWriteFinishedResponse(absl::InvalidArgumentError( - "AggCoreAggregationType must be specified."))); - return absl::OkStatus(); + return ToSessionWriteFinishedResponse(absl::InvalidArgumentError( + "AggCoreAggregationType must be specified.")); } - - stream->Write(ToSessionWriteFinishedResponse( - absl::OkStatus(), request.first_request_metadata().total_size_bytes())); - return absl::OkStatus(); + return ToSessionWriteFinishedResponse( + absl::OkStatus(), + write_request.first_request_metadata().total_size_bytes()); } // Runs the requested finalization operation and write the uncompressed result @@ -195,11 +156,8 @@ absl::Status HandleWrite( // // Any errors in HandleFinalize kill the stream, since the stream can no longer // be modified after the Finalize call. -absl::Status HandleFinalize( - const FinalizeRequest& request, - std::unique_ptr aggregator, - grpc::ServerReaderWriter* stream, - const BlobMetadata& input_metadata) { +absl::StatusOr FedSqlSession::FinalizeSession( + const FinalizeRequest& request, const BlobMetadata& input_metadata) { FedSqlContainerFinalizeConfiguration finalize_config; if (!request.configuration().UnpackTo(&finalize_config)) { return absl::InvalidArgumentError( @@ -209,7 +167,7 @@ absl::Status HandleFinalize( BlobMetadata result_metadata; switch (finalize_config.type()) { case fcp::confidentialcompute::FINALIZATION_TYPE_REPORT: { - if (!aggregator->CanReport()) { + if (!aggregator_->CanReport()) { return absl::FailedPreconditionError( "The aggregation can't be completed due to failed preconditions."); } @@ -217,7 +175,7 @@ absl::Status HandleFinalize( FederatedComputeCheckpointBuilderFactory builder_factory; std::unique_ptr checkpoint_builder = builder_factory.Create(); - FCP_RETURN_IF_ERROR(aggregator->Report(*checkpoint_builder)); + FCP_RETURN_IF_ERROR(aggregator_->Report(*checkpoint_builder)); FCP_ASSIGN_OR_RETURN(absl::Cord checkpoint_cord, checkpoint_builder->Build()); absl::CopyCordToString(checkpoint_cord, &result); @@ -228,7 +186,7 @@ absl::Status HandleFinalize( } case FINALIZATION_TYPE_SERIALIZE: { FCP_ASSIGN_OR_RETURN(std::string serialized_aggregator, - std::move(*aggregator).Serialize()); + std::move(*aggregator_).Serialize()); if (input_metadata.has_unencrypted()) { result = std::move(serialized_aggregator); result_metadata.set_total_size_bytes(result.size()); @@ -264,15 +222,12 @@ absl::Status HandleFinalize( read_response->set_finish_read(true); *(read_response->mutable_data()) = result; *(read_response->mutable_first_response_metadata()) = result_metadata; - stream->Write(response); - return absl::OkStatus(); + return response; } -} // namespace - -absl::Status FedSqlConfidentialTransform::FedSqlInitialize( - const fcp::confidentialcompute::InitializeRequest* request, - fcp::confidentialcompute::InitializeResponse* response) { +absl::StatusOr +FedSqlConfidentialTransform::InitializeTransform( + const fcp::confidentialcompute::InitializeRequest* request) { FedSqlContainerInitializeConfiguration config; if (!request->configuration().UnpackTo(&config)) { return absl::InvalidArgumentError( @@ -280,7 +235,6 @@ absl::Status FedSqlConfidentialTransform::FedSqlInitialize( } FCP_RETURN_IF_ERROR( CheckpointAggregator::ValidateConfig(config.agg_configuration())); - const BlobDecryptor* blob_decryptor; { absl::MutexLock l(&mutex_); if (intrinsics_ != std::nullopt) { @@ -309,123 +263,30 @@ absl::Status FedSqlConfidentialTransform::FedSqlInitialize( } intrinsics_.emplace(std::move(intrinsics)); - blob_decryptor_.emplace(crypto_stub_, config_properties); - - // Since blob_decryptor_ is set once in Initialize and never - // modified, and the underlying object is threadsafe, it is safe to store a - // local pointer to it and access the object without a lock after we check - // under the mutex that a value has been set for the std::optional wrapper. - blob_decryptor = &*blob_decryptor_; + return config_properties; } - - FCP_ASSIGN_OR_RETURN(*response->mutable_public_key(), - blob_decryptor->GetPublicKey()); - return absl::OkStatus(); } -absl::Status FedSqlConfidentialTransform::FedSqlSession( - grpc::ServerReaderWriter* stream) { - BlobDecryptor* blob_decryptor; +absl::StatusOr> +FedSqlConfidentialTransform::CreateSession() { std::unique_ptr aggregator; const std::vector* intrinsics; { absl::MutexLock l(&mutex_); - if (intrinsics_ == std::nullopt || blob_decryptor_ == std::nullopt) { + if (intrinsics_ == std::nullopt) { return absl::FailedPreconditionError( "Initialize must be called before Session."); } - // Since blob_decryptor_ is set once in Initialize and never + // Since intrinsics_ is set once in Initialize and never // modified, and the underlying object is threadsafe, it is safe to store a // local pointer to it and access the object without a lock after we check // under the mutex that values have been set for the std::optional wrappers. - blob_decryptor = &*blob_decryptor_; intrinsics = &*intrinsics_; } FCP_ASSIGN_OR_RETURN(aggregator, CheckpointAggregator::Create(intrinsics)); - SessionRequest configure_request; - bool success = stream->Read(&configure_request); - if (!success) { - return absl::AbortedError("Session failed to read client message."); - } - - if (!configure_request.has_configure()) { - return absl::FailedPreconditionError( - "Session must be configured with a ConfigureRequest before any other " - "requests."); - } - SessionResponse configure_response; - NonceChecker nonce_checker; - *configure_response.mutable_configure()->mutable_nonce() = - nonce_checker.GetSessionNonce(); - stream->Write(configure_response); - - // Initialze result_blob_metadata with unencrypted metadata since - // EarliestExpirationTimeMetadata expects inputs to have either unencrypted or - // hpke_plus_aead_data. - BlobMetadata result_blob_metadata; - result_blob_metadata.mutable_unencrypted(); - SessionRequest session_request; - while (stream->Read(&session_request)) { - switch (session_request.kind_case()) { - case SessionRequest::kWrite: { - const WriteRequest& write_request = session_request.write(); - // If any of the input blobs are encrypted, then encrypt the result of - // FINALIZATION_TYPE_SERIALIZE. Use the metadata with the earliest - // expiration timestamp. - absl::StatusOr earliest_expiration_metadata = - EarliestExpirationTimeMetadata( - result_blob_metadata, write_request.first_request_metadata()); - if (!earliest_expiration_metadata.ok()) { - stream->Write(ToSessionWriteFinishedResponse(absl::Status( - earliest_expiration_metadata.status().code(), - absl::StrCat("Failed to extract expiration timestamp from " - "BlobMetadata with encryption: ", - earliest_expiration_metadata.status().message())))); - break; - } - result_blob_metadata = *earliest_expiration_metadata; - // TODO: spin up a thread to incorporate each blob. - FCP_RETURN_IF_ERROR(HandleWrite(write_request, *aggregator, - blob_decryptor, nonce_checker, stream, - intrinsics)); - break; - } - case SessionRequest::kFinalize: - return HandleFinalize(session_request.finalize(), std::move(aggregator), - stream, result_blob_metadata); - case SessionRequest::kConfigure: - default: - return absl::FailedPreconditionError(absl::StrCat( - "Session expected a write request but received request of type: ", - session_request.kind_case())); - } - } - - return absl::AbortedError( - "Session failed to read client write or finalize message."); + return std::make_unique( + FedSqlSession(std::move(aggregator), *intrinsics)); } - -grpc::Status FedSqlConfidentialTransform::Initialize( - ServerContext* context, const InitializeRequest* request, - InitializeResponse* response) { - return ToGrpcStatus(FedSqlInitialize(request, response)); -} - -grpc::Status FedSqlConfidentialTransform::Session( - ServerContext* context, - grpc::ServerReaderWriter* stream) { - if (absl::Status session_status = session_tracker_.AddSession(); - !session_status.ok()) { - return ToGrpcStatus(session_status); - } - grpc::Status status = ToGrpcStatus(FedSqlSession(stream)); - absl::Status remove_session = session_tracker_.RemoveSession(); - if (!remove_session.ok()) { - return ToGrpcStatus(remove_session); - } - return status; -} - } // namespace confidential_federated_compute::fed_sql diff --git a/containers/fed_sql/confidential_transform_server.h b/containers/fed_sql/confidential_transform_server.h index 2b13adc..d947e85 100644 --- a/containers/fed_sql/confidential_transform_server.h +++ b/containers/fed_sql/confidential_transform_server.h @@ -21,6 +21,7 @@ #include "absl/log/die_if_null.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" +#include "containers/confidential_transform_server_base.h" #include "containers/crypto.h" #include "containers/session.h" #include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" @@ -38,47 +39,59 @@ namespace confidential_federated_compute::fed_sql { // step of FedSQL. // TODO: execute the per-client SQL query step. class FedSqlConfidentialTransform final - : public fcp::confidentialcompute::ConfidentialTransform::Service { + : public confidential_federated_compute::ConfidentialTransformBase { public: - // The OrchestratorCrypto stub must not be NULL and must outlive this object. - // TODO: add absl::Nonnull to crypto_stub. - explicit FedSqlConfidentialTransform( + FedSqlConfidentialTransform( oak::containers::v1::OrchestratorCrypto::StubInterface* crypto_stub, int max_num_sessions) - : crypto_stub_(*ABSL_DIE_IF_NULL(crypto_stub)), - session_tracker_(max_num_sessions) {} + : ConfidentialTransformBase(crypto_stub, max_num_sessions) {}; - grpc::Status Initialize( - grpc::ServerContext* context, - const fcp::confidentialcompute::InitializeRequest* request, - fcp::confidentialcompute::InitializeResponse* response) override; - - grpc::Status Session( - grpc::ServerContext* context, - grpc::ServerReaderWriter* - stream) override; + protected: + virtual absl::StatusOr InitializeTransform( + const fcp::confidentialcompute::InitializeRequest* request) override; + virtual absl::StatusOr< + std::unique_ptr> + CreateSession() override; private: - absl::Status FedSqlInitialize( - const fcp::confidentialcompute::InitializeRequest* request, - fcp::confidentialcompute::InitializeResponse* response); - - absl::Status FedSqlSession( - grpc::ServerReaderWriter* - stream); - - oak::containers::v1::OrchestratorCrypto::StubInterface& crypto_stub_; - confidential_federated_compute::SessionTracker session_tracker_; absl::Mutex mutex_; - // The mutex is used to protect the optional wrapping blob_decryptor_ and - // intrinsics_ to ensure the BlobDecryptor and vector are initialized, but - // the BlobDecryptor and const vector are themselves threadsafe. std::optional> intrinsics_ ABSL_GUARDED_BY(mutex_); - std::optional blob_decryptor_ - ABSL_GUARDED_BY(mutex_); +}; + +// FedSql implementation of Session interface. Not threadsafe. +class FedSqlSession final : public confidential_federated_compute::Session { + public: + FedSqlSession( + std::unique_ptr + aggregator, + const std::vector& + intrinsics) + : aggregator_(std::move(aggregator)), intrinsics_(intrinsics) {}; + // Currently no FedSql per-session configuration. + absl::Status ConfigureSession( + fcp::confidentialcompute::SessionRequest configure_request) override { + return absl::OkStatus(); + } + // Accumulates a record into the state of the CheckpointAggregator + // `aggregator`. + // + // Returns an error if the aggcore state may be invalid and the session + // needs to be shut down. + absl::StatusOr SessionWrite( + const fcp::confidentialcompute::WriteRequest& write_request, + std::string unencrypted_data) override; + // Run any session finalization logic and complete the session. + // After finalization, the session state is no longer mutable. + absl::StatusOr FinalizeSession( + const fcp::confidentialcompute::FinalizeRequest& request, + const fcp::confidentialcompute::BlobMetadata& input_metadata) override; + + private: + // The aggregator used during the session to accumulate writes. + std::unique_ptr + aggregator_; + const std::vector& intrinsics_; }; } // namespace confidential_federated_compute::fed_sql diff --git a/containers/session.h b/containers/session.h index f3b570a..9b5a883 100644 --- a/containers/session.h +++ b/containers/session.h @@ -18,6 +18,7 @@ #define CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_SESSION_H_ #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "fcp/protos/confidentialcompute/confidential_transform.pb.h" namespace confidential_federated_compute { @@ -48,5 +49,26 @@ class SessionTracker { fcp::confidentialcompute::SessionResponse ToSessionWriteFinishedResponse( absl::Status status, long committed_size_bytes = 0); +// Interface for interacting with a session in a container. Implementations +// may not be threadsafe. +class Session { + public: + // Initialize the session with the given configuration. + virtual absl::Status ConfigureSession( + fcp::confidentialcompute::SessionRequest configure_request) = 0; + // Incorporate a write request into the session. + virtual absl::StatusOr + SessionWrite(const fcp::confidentialcompute::WriteRequest& write_request, + std::string unencrypted_data) = 0; + // Run any session finalization logic and complete the session. + // After finalization, the session state is no longer mutable. + virtual absl::StatusOr + FinalizeSession( + const fcp::confidentialcompute::FinalizeRequest& request, + const fcp::confidentialcompute::BlobMetadata& input_metadata) = 0; + + virtual ~Session() = default; +}; + } // namespace confidential_federated_compute #endif // CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_SESSION_H_