diff --git a/containers/BUILD b/containers/BUILD index 755a253..c8eadc8 100644 --- a/containers/BUILD +++ b/containers/BUILD @@ -154,3 +154,48 @@ cc_test( "@oak//proto/containers:interfaces_cc_proto", ], ) + +cc_library( + name = "confidential_transform_server_base", + srcs = ["confidential_transform_server_base.cc"], + hdrs = ["confidential_transform_server_base.h"], + deps = [ + ":blob_metadata", + ":crypto", + ":session", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_protobuf//:protobuf", + "@federated-compute//fcp/base", + "@federated-compute//fcp/base:status_converters", + "@federated-compute//fcp/confidentialcompute:crypto", + "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_grpc", + "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_proto", + "@federated-compute//fcp/protos/confidentialcompute:fed_sql_container_config_cc_proto", + "@oak//proto/containers:orchestrator_crypto_cc_grpc", + ], +) + +cc_test( + name = "confidential_transform_server_base_test", + size = "small", + srcs = ["confidential_transform_server_base_test.cc"], + deps = [ + ":blob_metadata", + ":confidential_transform_server_base", + ":crypto", + ":crypto_test_utils", + "//testing:parse_text_proto", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@federated-compute//fcp/confidentialcompute:crypto", + "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_grpc", + "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_proto", + "@federated-compute//fcp/protos/confidentialcompute:fed_sql_container_config_cc_proto", + "@googletest//:gtest_main", + ], +) diff --git a/containers/confidential_transform_server_base.cc b/containers/confidential_transform_server_base.cc new file mode 100644 index 0000000..443f62e --- /dev/null +++ b/containers/confidential_transform_server_base.cc @@ -0,0 +1,220 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "containers/confidential_transform_server_base.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/die_if_null.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "containers/blob_metadata.h" +#include "containers/crypto.h" +#include "containers/session.h" +#include "fcp/base/status_converters.h" +#include "fcp/confidentialcompute/crypto.h" +#include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" +#include "fcp/protos/confidentialcompute/confidential_transform.pb.h" +#include "google/protobuf/repeated_ptr_field.h" +#include "grpcpp/support/status.h" + +namespace confidential_federated_compute { + +using ::fcp::base::ToGrpcStatus; +using ::fcp::confidential_compute::NonceChecker; +using ::fcp::confidentialcompute::BlobMetadata; +using ::fcp::confidentialcompute::ConfidentialTransform; +using ::fcp::confidentialcompute::InitializeRequest; +using ::fcp::confidentialcompute::InitializeResponse; +using ::fcp::confidentialcompute::SessionRequest; +using ::fcp::confidentialcompute::SessionResponse; +using ::fcp::confidentialcompute::WriteRequest; +using ::grpc::ServerContext; + +namespace { + +// Decrypts and parses a record and incorporates the record into the session. +// +// Reports status to the client in WriteFinishedResponse. +// +// TODO: handle blobs that span multiple WriteRequests. +absl::Status HandleWrite( + const WriteRequest& request, BlobDecryptor* blob_decryptor, + NonceChecker& nonce_checker, + grpc::ServerReaderWriter* stream, + Session* session) { + if (absl::Status nonce_status = + nonce_checker.CheckBlobNonce(request.first_request_metadata()); + !nonce_status.ok()) { + stream->Write(ToSessionWriteFinishedResponse(nonce_status)); + return absl::OkStatus(); + } + + absl::StatusOr unencrypted_data = blob_decryptor->DecryptBlob( + request.first_request_metadata(), request.data()); + if (!unencrypted_data.ok()) { + stream->Write(ToSessionWriteFinishedResponse(unencrypted_data.status())); + return absl::OkStatus(); + } + + FCP_ASSIGN_OR_RETURN( + SessionResponse response, + session->SessionWrite(request, std::move(unencrypted_data.value()))); + + stream->Write(response); + return absl::OkStatus(); +} + +} // namespace + +absl::Status ConfidentialTransformBase::InitializeInternal( + const fcp::confidentialcompute::InitializeRequest* request, + fcp::confidentialcompute::InitializeResponse* response) { + FCP_ASSIGN_OR_RETURN(google::protobuf::Struct config_properties, + InitializeTransform(request)); + const BlobDecryptor* blob_decryptor; + { + absl::MutexLock l(&mutex_); + if (blob_decryptor_ != std::nullopt) { + return absl::FailedPreconditionError( + "Initialize can only be called once."); + } + blob_decryptor_.emplace(crypto_stub_, config_properties); + + // Since blob_decryptor_ is set once in Initialize and never + // modified, and the underlying object is threadsafe, it is safe to store a + // local pointer to it and access the object without a lock after we check + // under the mutex that a value has been set for the std::optional wrapper. + blob_decryptor = &*blob_decryptor_; + } + + FCP_ASSIGN_OR_RETURN(*response->mutable_public_key(), + blob_decryptor->GetPublicKey()); + return absl::OkStatus(); +} + +absl::Status ConfidentialTransformBase::SessionInternal( + grpc::ServerReaderWriter* stream) { + BlobDecryptor* blob_decryptor; + { + absl::MutexLock l(&mutex_); + if (blob_decryptor_ == std::nullopt) { + return absl::FailedPreconditionError( + "Initialize must be called before Session."); + } + + // Since blob_decryptor_ is set once in Initialize and never + // modified, and the underlying object is threadsafe, it is safe to store a + // local pointer to it and access the object without a lock after we check + // under the mutex that values have been set for the std::optional wrappers. + blob_decryptor = &*blob_decryptor_; + } + + SessionRequest configure_request; + bool success = stream->Read(&configure_request); + if (!success) { + return absl::AbortedError("Session failed to read client message."); + } + + if (!configure_request.has_configure()) { + return absl::FailedPreconditionError( + "Session must be configured with a ConfigureRequest before any other " + "requests."); + } + FCP_ASSIGN_OR_RETURN( + std::unique_ptr session, + CreateSession()); + FCP_RETURN_IF_ERROR(session->ConfigureSession(configure_request)); + SessionResponse configure_response; + NonceChecker nonce_checker; + *configure_response.mutable_configure()->mutable_nonce() = + nonce_checker.GetSessionNonce(); + stream->Write(configure_response); + + // Initialze result_blob_metadata with unencrypted metadata since + // EarliestExpirationTimeMetadata expects inputs to have either unencrypted or + // hpke_plus_aead_data. + BlobMetadata result_blob_metadata; + result_blob_metadata.mutable_unencrypted(); + SessionRequest session_request; + while (stream->Read(&session_request)) { + switch (session_request.kind_case()) { + case SessionRequest::kWrite: { + const WriteRequest& write_request = session_request.write(); + // Use the metadata with the earliest expiration timestamp for + // encrypting any results. + absl::StatusOr earliest_expiration_metadata = + EarliestExpirationTimeMetadata( + result_blob_metadata, write_request.first_request_metadata()); + if (!earliest_expiration_metadata.ok()) { + stream->Write(ToSessionWriteFinishedResponse(absl::Status( + earliest_expiration_metadata.status().code(), + absl::StrCat("Failed to extract expiration timestamp from " + "BlobMetadata with encryption: ", + earliest_expiration_metadata.status().message())))); + break; + } + result_blob_metadata = *earliest_expiration_metadata; + // TODO: spin up a thread to incorporate each blob. + FCP_RETURN_IF_ERROR(HandleWrite(write_request, blob_decryptor, + nonce_checker, stream, session.get())); + break; + } + case SessionRequest::kFinalize: { + FCP_ASSIGN_OR_RETURN( + SessionResponse finalize_response, + session->FinalizeSession(session_request.finalize(), + result_blob_metadata)); + stream->Write(finalize_response); + return absl::OkStatus(); + } + case SessionRequest::kConfigure: + default: + return absl::FailedPreconditionError(absl::StrCat( + "Session expected a write request but received request of type: ", + session_request.kind_case())); + } + } + + return absl::AbortedError( + "Session failed to read client write or finalize message."); +} + +grpc::Status ConfidentialTransformBase::Initialize( + ServerContext* context, const InitializeRequest* request, + InitializeResponse* response) { + return ToGrpcStatus(InitializeInternal(request, response)); +} + +grpc::Status ConfidentialTransformBase::Session( + ServerContext* context, + grpc::ServerReaderWriter* stream) { + if (absl::Status session_status = session_tracker_.AddSession(); + !session_status.ok()) { + return ToGrpcStatus(session_status); + } + grpc::Status status = ToGrpcStatus(SessionInternal(stream)); + absl::Status remove_session = session_tracker_.RemoveSession(); + if (!remove_session.ok()) { + return ToGrpcStatus(remove_session); + } + return status; +} + +} // namespace confidential_federated_compute diff --git a/containers/confidential_transform_server_base.h b/containers/confidential_transform_server_base.h new file mode 100644 index 0000000..86b20ef --- /dev/null +++ b/containers/confidential_transform_server_base.h @@ -0,0 +1,85 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_CONFIDENTIAL_TRANSFORM_SERVER_BASE_H_ +#define CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_CONFIDENTIAL_TRANSFORM_SERVER_BASE_H_ + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/log/die_if_null.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "containers/crypto.h" +#include "containers/session.h" +#include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" +#include "fcp/protos/confidentialcompute/confidential_transform.pb.h" +#include "google/protobuf/repeated_ptr_field.h" +#include "grpcpp/server_context.h" +#include "grpcpp/support/status.h" +#include "proto/containers/orchestrator_crypto.grpc.pb.h" + +namespace confidential_federated_compute { + +// Base class that implements the ConfidentialTransform service protocol. +class ConfidentialTransformBase + : public fcp::confidentialcompute::ConfidentialTransform::Service { + public: + grpc::Status Initialize( + grpc::ServerContext* context, + const fcp::confidentialcompute::InitializeRequest* request, + fcp::confidentialcompute::InitializeResponse* response) override; + + grpc::Status Session( + grpc::ServerContext* context, + grpc::ServerReaderWriter* + stream) override; + + protected: + ConfidentialTransformBase( + oak::containers::v1::OrchestratorCrypto::StubInterface* crypto_stub, + int max_num_sessions) + : crypto_stub_(*ABSL_DIE_IF_NULL(crypto_stub)), + session_tracker_(max_num_sessions) {} + + virtual absl::StatusOr InitializeTransform( + const fcp::confidentialcompute::InitializeRequest* request) = 0; + virtual absl::StatusOr< + std::unique_ptr> + CreateSession() = 0; + + private: + absl::Status InitializeInternal( + const fcp::confidentialcompute::InitializeRequest* request, + fcp::confidentialcompute::InitializeResponse* response); + + absl::Status SessionInternal( + grpc::ServerReaderWriter* + stream); + + oak::containers::v1::OrchestratorCrypto::StubInterface& crypto_stub_; + confidential_federated_compute::SessionTracker session_tracker_; + absl::Mutex mutex_; + // The mutex is used to protect the optional wrapping blob_decryptor_ to + // ensure the BlobDecryptor is initialized, but the BlobDecryptor is itself + // threadsafe. + std::optional blob_decryptor_ + ABSL_GUARDED_BY(mutex_); +}; + +} // namespace confidential_federated_compute + +#endif // CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_CONFIDENTIAL_TRANSFORM_SERVER_BASE_H_ diff --git a/containers/confidential_transform_server_base_test.cc b/containers/confidential_transform_server_base_test.cc new file mode 100644 index 0000000..1e30b1b --- /dev/null +++ b/containers/confidential_transform_server_base_test.cc @@ -0,0 +1,738 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "containers/confidential_transform_server_base.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/str_format.h" +#include "containers/blob_metadata.h" +#include "containers/crypto.h" +#include "containers/crypto_test_utils.h" +#include "containers/session.h" +#include "fcp/confidentialcompute/cose.h" +#include "fcp/confidentialcompute/crypto.h" +#include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" +#include "fcp/protos/confidentialcompute/confidential_transform.pb.h" +#include "gmock/gmock.h" +#include "google/protobuf/repeated_ptr_field.h" +#include "grpcpp/channel.h" +#include "grpcpp/client_context.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/server.h" +#include "grpcpp/server_builder.h" +#include "grpcpp/server_context.h" +#include "gtest/gtest.h" +#include "proto/containers/orchestrator_crypto_mock.grpc.pb.h" +#include "testing/parse_text_proto.h" + +namespace confidential_federated_compute { + +namespace { + +using ::fcp::confidential_compute::MessageDecryptor; +using ::fcp::confidential_compute::NonceAndCounter; +using ::fcp::confidential_compute::NonceGenerator; +using ::fcp::confidential_compute::OkpCwt; +using ::fcp::confidentialcompute::BlobHeader; +using ::fcp::confidentialcompute::BlobMetadata; +using ::fcp::confidentialcompute::ConfidentialTransform; +using ::fcp::confidentialcompute::FinalizeRequest; +using ::fcp::confidentialcompute::InitializeRequest; +using ::fcp::confidentialcompute::InitializeResponse; +using ::fcp::confidentialcompute::ReadResponse; +using ::fcp::confidentialcompute::Record; +using ::fcp::confidentialcompute::SessionRequest; +using ::fcp::confidentialcompute::SessionResponse; +using ::fcp::confidentialcompute::WriteRequest; +using ::grpc::Server; +using ::grpc::ServerBuilder; +using ::grpc::ServerContext; +using ::grpc::StatusCode; +using ::oak::containers::v1::MockOrchestratorCryptoStub; +using ::testing::_; +using ::testing::HasSubstr; +using ::testing::Return; +using ::testing::Test; + +inline constexpr int kMaxNumSessions = 8; + +class MockSession final : public confidential_federated_compute::Session { + public: + MOCK_METHOD(absl::Status, ConfigureSession, + (SessionRequest configure_request), (override)); + MOCK_METHOD(absl::StatusOr, SessionWrite, + (const WriteRequest& write_request, std::string unencrypted_data), + (override)); + MOCK_METHOD(absl::StatusOr, FinalizeSession, + (const FinalizeRequest& request, + const BlobMetadata& input_metadata), + (override)); +}; + +SessionRequest CreateDefaultWriteRequest(std::string data) { + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + metadata.set_total_size_bytes(data.size()); + SessionRequest request; + WriteRequest* write_request = request.mutable_write(); + *write_request->mutable_first_request_metadata() = metadata; + write_request->set_commit(true); + write_request->set_data(data); + return request; +} + +SessionResponse GetDefaultFinalizeResponse() { + SessionResponse response; + ReadResponse* read_response = response.mutable_read(); + read_response->set_finish_read(true); + std::string result = "test result"; + *(read_response->mutable_data()) = result; + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + metadata.set_total_size_bytes(result.size()); + *(read_response->mutable_first_response_metadata()) = metadata; + return response; +} + +class FakeConfidentialTransform final + : public confidential_federated_compute::ConfidentialTransformBase { + public: + FakeConfidentialTransform( + oak::containers::v1::OrchestratorCrypto::StubInterface* crypto_stub, + int max_num_sessions) + : ConfidentialTransformBase(crypto_stub, max_num_sessions) {}; + + void AddSession( + std::unique_ptr session) { + session_ = std::move(session); + }; + + protected: + virtual absl::StatusOr InitializeTransform( + const fcp::confidentialcompute::InitializeRequest* request) { + google::rpc::Status config_status; + if (!request->configuration().UnpackTo(&config_status)) { + return absl::InvalidArgumentError("Config cannot be unpacked."); + } + if (config_status.code() != grpc::StatusCode::OK) { + return absl::InvalidArgumentError("Invalid config."); + } + return google::protobuf::Struct(); + } + + virtual absl::StatusOr< + std::unique_ptr> + CreateSession() { + if (session_ == nullptr) { + auto session = + std::make_unique(); + EXPECT_CALL(*session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + return std::move(session); + } + return std::move(session_); + } + + private: + std::unique_ptr session_; +}; + +class ConfidentialTransformServerBaseTest : public Test { + public: + ConfidentialTransformServerBaseTest() { + int port; + const std::string server_address = "[::1]:"; + ServerBuilder builder; + builder.AddListeningPort(server_address + "0", + grpc::InsecureServerCredentials(), &port); + builder.RegisterService(&service_); + server_ = builder.BuildAndStart(); + LOG(INFO) << "Server listening on " << server_address + std::to_string(port) + << std::endl; + stub_ = ConfidentialTransform::NewStub( + grpc::CreateChannel(server_address + std::to_string(port), + grpc::InsecureChannelCredentials())); + } + + ~ConfidentialTransformServerBaseTest() override { server_->Shutdown(); } + + protected: + testing::NiceMock mock_crypto_stub_; + FakeConfidentialTransform service_{&mock_crypto_stub_, kMaxNumSessions}; + std::unique_ptr server_; + std::unique_ptr stub_; +}; + +TEST_F(ConfidentialTransformServerBaseTest, InitializeRequestWrongMessageType) { + grpc::ClientContext context; + google::protobuf::Value value; + InitializeRequest request; + InitializeResponse response; + request.mutable_configuration()->PackFrom(value); + + auto status = stub_->Initialize(&context, request, &response); + ASSERT_EQ(status.error_code(), grpc::StatusCode::INVALID_ARGUMENT); + ASSERT_THAT(status.error_message(), HasSubstr("Config cannot be unpacked.")); +} + +TEST_F(ConfidentialTransformServerBaseTest, InitializeMoreThanOnce) { + grpc::ClientContext context; + InitializeRequest request; + InitializeResponse response; + google::rpc::Status config_status; + config_status.set_code(grpc::StatusCode::OK); + request.mutable_configuration()->PackFrom(config_status); + + ASSERT_TRUE(stub_->Initialize(&context, request, &response).ok()); + + grpc::ClientContext second_context; + auto status = stub_->Initialize(&second_context, request, &response); + + ASSERT_EQ(status.error_code(), grpc::StatusCode::FAILED_PRECONDITION); + ASSERT_THAT(status.error_message(), + HasSubstr("Initialize can only be called once")); +} + +TEST_F(ConfidentialTransformServerBaseTest, ValidInitialize) { + grpc::ClientContext context; + InitializeRequest request; + InitializeResponse response; + google::rpc::Status config_status; + config_status.set_code(grpc::StatusCode::OK); + request.mutable_configuration()->PackFrom(config_status); + + ASSERT_TRUE(stub_->Initialize(&context, request, &response).ok()); + + absl::StatusOr cwt = OkpCwt::Decode(response.public_key()); + ASSERT_TRUE(cwt.ok()); +} + +TEST_F(ConfidentialTransformServerBaseTest, SessionConfigureGeneratesNonce) { + grpc::ClientContext configure_context; + InitializeRequest request; + InitializeResponse response; + google::rpc::Status config_status; + config_status.set_code(grpc::StatusCode::OK); + request.mutable_configuration()->PackFrom(config_status); + + ASSERT_TRUE(stub_->Initialize(&configure_context, request, &response).ok()); + + grpc::ClientContext session_context; + SessionRequest session_request; + SessionResponse session_response; + session_request.mutable_configure(); + + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, FinalizeSession(_, _)) + .WillOnce(Return(GetDefaultFinalizeResponse())); + service_.AddSession(std::move(mock_session)); + + std::unique_ptr<::grpc::ClientReaderWriter> + stream = stub_->Session(&session_context); + ASSERT_TRUE(stream->Write(session_request)); + ASSERT_TRUE(stream->Read(&session_response)); + + ASSERT_TRUE(session_response.has_configure()); + ASSERT_GT(session_response.configure().nonce().size(), 0); + + google::rpc::Status config; + config.set_code(grpc::StatusCode::OK); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + config); + ASSERT_TRUE(stream->Write(finalize_request)); + ASSERT_TRUE(stream->Read(&finalize_response)); + ASSERT_TRUE(stream->Finish().ok()); +} + +TEST_F(ConfidentialTransformServerBaseTest, + SessionRejectsMoreThanMaximumNumSessions) { + grpc::ClientContext configure_context; + InitializeRequest request; + InitializeResponse response; + google::rpc::Status config_status; + config_status.set_code(grpc::StatusCode::OK); + request.mutable_configuration()->PackFrom(config_status); + + ASSERT_TRUE(stub_->Initialize(&configure_context, request, &response).ok()); + + std::vector>> + streams; + std::vector> contexts; + for (int i = 0; i < kMaxNumSessions; i++) { + std::unique_ptr session_context = + std::make_unique(); + SessionRequest session_request; + SessionResponse session_response; + session_request.mutable_configure(); + + std::unique_ptr<::grpc::ClientReaderWriter> + stream = stub_->Session(session_context.get()); + ASSERT_TRUE(stream->Write(session_request)); + ASSERT_TRUE(stream->Read(&session_response)); + + // Keep the context and stream so they don't go out of scope and end the + // session. + contexts.emplace_back(std::move(session_context)); + streams.emplace_back(std::move(stream)); + } + + grpc::ClientContext rejected_context; + SessionRequest rejected_request; + SessionResponse rejected_response; + rejected_request.mutable_configure(); + + std::unique_ptr<::grpc::ClientReaderWriter> + stream = stub_->Session(&rejected_context); + ASSERT_TRUE(stream->Write(rejected_request)); + ASSERT_FALSE(stream->Read(&rejected_response)); + ASSERT_EQ(stream->Finish().error_code(), + grpc::StatusCode::FAILED_PRECONDITION); +} + +TEST_F(ConfidentialTransformServerBaseTest, SessionBeforeInitialize) { + grpc::ClientContext session_context; + SessionRequest configure_request; + SessionResponse configure_response; + configure_request.mutable_configure()->mutable_configuration(); + + std::unique_ptr<::grpc::ClientReaderWriter> + stream = stub_->Session(&session_context); + ASSERT_TRUE(stream->Write(configure_request)); + ASSERT_FALSE(stream->Read(&configure_response)); + ASSERT_EQ(stream->Finish().error_code(), + grpc::StatusCode::FAILED_PRECONDITION); +} + +class InitializedConfidentialTransformServerBaseTest + : public ConfidentialTransformServerBaseTest { + public: + InitializedConfidentialTransformServerBaseTest() { + grpc::ClientContext configure_context; + InitializeRequest request; + InitializeResponse response; + google::rpc::Status config_status; + config_status.set_code(grpc::StatusCode::OK); + request.mutable_configuration()->PackFrom(config_status); + + CHECK(stub_->Initialize(&configure_context, request, &response).ok()); + public_key_ = response.public_key(); + } + + protected: + void StartSession() { + SessionRequest session_request; + SessionResponse session_response; + session_request.mutable_configure(); + + stream_ = stub_->Session(&session_context_); + CHECK(stream_->Write(session_request)); + CHECK(stream_->Read(&session_response)); + nonce_generator_ = + std::make_unique(session_response.configure().nonce()); + } + grpc::ClientContext session_context_; + std::unique_ptr<::grpc::ClientReaderWriter> + stream_; + std::unique_ptr nonce_generator_; + std::string public_key_; +}; + +TEST_F(InitializedConfidentialTransformServerBaseTest, + SessionWritesAndFinalizes) { + std::string data = "test data"; + SessionRequest write_request = CreateDefaultWriteRequest(data); + SessionResponse write_response; + + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, SessionWrite(_, _)) + .WillRepeatedly(Return( + ToSessionWriteFinishedResponse(absl::OkStatus(), data.size()))); + EXPECT_CALL(*mock_session, FinalizeSession(_, _)) + .WillOnce(Return(GetDefaultFinalizeResponse())); + service_.AddSession(std::move(mock_session)); + StartSession(); + + // Accumulate the same unencrypted blob twice. + ASSERT_TRUE(stream_->Write(write_request)); + ASSERT_TRUE(stream_->Read(&write_response)); + ASSERT_TRUE(stream_->Write(write_request)); + ASSERT_TRUE(stream_->Read(&write_response)); + + google::rpc::Status config; + config.set_code(grpc::StatusCode::OK); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + config); + ASSERT_TRUE(stream_->Write(finalize_request)); + ASSERT_TRUE(stream_->Read(&finalize_response)); + ASSERT_TRUE(stream_->Finish().ok()); + + ASSERT_TRUE(finalize_response.has_read()); + ASSERT_TRUE(finalize_response.read().finish_read()); + ASSERT_GT( + finalize_response.read().first_response_metadata().total_size_bytes(), 0); + ASSERT_TRUE( + finalize_response.read().first_response_metadata().has_unencrypted()); +} + +TEST_F(InitializedConfidentialTransformServerBaseTest, + SessionIgnoresInvalidInputs) { + std::string data = "test data"; + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, SessionWrite(_, _)) + .WillOnce( + Return(ToSessionWriteFinishedResponse(absl::OkStatus(), data.size()))) + .WillOnce(Return(ToSessionWriteFinishedResponse( + absl::InvalidArgumentError("Invalid argument"), 0))); + EXPECT_CALL(*mock_session, FinalizeSession(_, _)) + .WillOnce(Return(GetDefaultFinalizeResponse())); + service_.AddSession(std::move(mock_session)); + StartSession(); + + SessionRequest write_request_1 = CreateDefaultWriteRequest(data); + SessionResponse write_response_1; + + ASSERT_TRUE(stream_->Write(write_request_1)); + ASSERT_TRUE(stream_->Read(&write_response_1)); + + SessionRequest write_request_2 = CreateDefaultWriteRequest(data); + SessionResponse write_response_2; + + ASSERT_TRUE(stream_->Write(write_request_2)); + ASSERT_TRUE(stream_->Read(&write_response_2)); + + ASSERT_TRUE(write_response_2.has_write()); + ASSERT_EQ(write_response_2.write().committed_size_bytes(), 0); + ASSERT_EQ(write_response_2.write().status().code(), grpc::INVALID_ARGUMENT); + + google::rpc::Status config; + config.set_code(grpc::StatusCode::OK); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + config); + ASSERT_TRUE(stream_->Write(finalize_request)); + ASSERT_TRUE(stream_->Read(&finalize_response)); + + ASSERT_TRUE(finalize_response.has_read()); + ASSERT_TRUE(finalize_response.read().finish_read()); + ASSERT_GT( + finalize_response.read().first_response_metadata().total_size_bytes(), 0); + ASSERT_TRUE( + finalize_response.read().first_response_metadata().has_unencrypted()); +} + +TEST_F(InitializedConfidentialTransformServerBaseTest, + SessionFailsIfWriteFails) { + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, SessionWrite(_, _)) + .WillOnce(Return(absl::InternalError("Internal Error"))); + service_.AddSession(std::move(mock_session)); + StartSession(); + + SessionRequest write_request_1 = CreateDefaultWriteRequest("test data"); + SessionResponse write_response_1; + + ASSERT_TRUE(stream_->Write(write_request_1)); + ASSERT_FALSE(stream_->Read(&write_response_1)); + ASSERT_EQ(stream_->Finish().error_code(), grpc::StatusCode::INTERNAL); +} + +TEST_F(InitializedConfidentialTransformServerBaseTest, + SessionFailsIfFinalizeFails) { + std::string data = "test data"; + SessionRequest write_request = CreateDefaultWriteRequest(data); + SessionResponse write_response; + + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, SessionWrite(_, _)) + .WillRepeatedly(Return( + ToSessionWriteFinishedResponse(absl::OkStatus(), data.size()))); + EXPECT_CALL(*mock_session, FinalizeSession(_, _)) + .WillOnce(Return(absl::InternalError("Internal Error"))); + service_.AddSession(std::move(mock_session)); + StartSession(); + + // Accumulate the same unencrypted blob twice. + ASSERT_TRUE(stream_->Write(write_request)); + ASSERT_TRUE(stream_->Read(&write_response)); + ASSERT_TRUE(stream_->Write(write_request)); + ASSERT_TRUE(stream_->Read(&write_response)); + + google::rpc::Status config; + config.set_code(grpc::StatusCode::INTERNAL); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + config); + ASSERT_TRUE(stream_->Write(finalize_request)); + ASSERT_FALSE(stream_->Read(&finalize_response)); + ASSERT_EQ(stream_->Finish().error_code(), grpc::StatusCode::INTERNAL); +} + +TEST_F(InitializedConfidentialTransformServerBaseTest, + SessionDecryptsMultipleRecords) { + std::string message_0 = "test data 0"; + std::string message_1 = "test data 1"; + std::string message_2 = "test data 2"; + + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, SessionWrite(_, _)) + .WillOnce(Return( + ToSessionWriteFinishedResponse(absl::OkStatus(), message_0.size()))) + .WillOnce(Return( + ToSessionWriteFinishedResponse(absl::OkStatus(), message_1.size()))) + .WillOnce(Return( + ToSessionWriteFinishedResponse(absl::OkStatus(), message_2.size()))); + EXPECT_CALL(*mock_session, FinalizeSession(_, _)) + .WillOnce(Return(GetDefaultFinalizeResponse())); + service_.AddSession(std::move(mock_session)); + StartSession(); + + MessageDecryptor decryptor; + absl::StatusOr reencryption_public_key = + decryptor.GetPublicKey([](absl::string_view) { return ""; }, 0); + ASSERT_TRUE(reencryption_public_key.ok()); + std::string ciphertext_associated_data = + BlobHeader::default_instance().SerializeAsString(); + + absl::StatusOr nonce_0 = + nonce_generator_->GetNextBlobNonce(); + ASSERT_TRUE(nonce_0.ok()); + absl::StatusOr rewrapped_record_0 = + crypto_test_utils::CreateRewrappedRecord( + message_0, ciphertext_associated_data, public_key_, + nonce_0->blob_nonce, *reencryption_public_key); + ASSERT_TRUE(rewrapped_record_0.ok()) << rewrapped_record_0.status(); + + SessionRequest request_0; + WriteRequest* write_request_0 = request_0.mutable_write(); + google::rpc::Status config; + config.set_code(grpc::StatusCode::OK); + *write_request_0->mutable_first_request_metadata() = + GetBlobMetadataFromRecord(*rewrapped_record_0); + write_request_0->mutable_first_request_metadata() + ->mutable_hpke_plus_aead_data() + ->set_counter(nonce_0->counter); + write_request_0->mutable_first_request_configuration()->PackFrom(config); + write_request_0->set_commit(true); + write_request_0->set_data( + rewrapped_record_0->hpke_plus_aead_data().ciphertext()); + + SessionResponse response_0; + + ASSERT_TRUE(stream_->Write(request_0)); + ASSERT_TRUE(stream_->Read(&response_0)); + ASSERT_EQ(response_0.write().status().code(), grpc::OK); + + absl::StatusOr nonce_1 = + nonce_generator_->GetNextBlobNonce(); + ASSERT_TRUE(nonce_1.ok()); + absl::StatusOr rewrapped_record_1 = + crypto_test_utils::CreateRewrappedRecord( + message_1, ciphertext_associated_data, public_key_, + nonce_1->blob_nonce, *reencryption_public_key); + ASSERT_TRUE(rewrapped_record_1.ok()) << rewrapped_record_1.status(); + + SessionRequest request_1; + WriteRequest* write_request_1 = request_1.mutable_write(); + *write_request_1->mutable_first_request_metadata() = + GetBlobMetadataFromRecord(*rewrapped_record_1); + write_request_1->mutable_first_request_metadata() + ->mutable_hpke_plus_aead_data() + ->set_counter(nonce_1->counter); + write_request_1->mutable_first_request_configuration()->PackFrom(config); + write_request_1->set_commit(true); + write_request_1->set_data( + rewrapped_record_1->hpke_plus_aead_data().ciphertext()); + + SessionResponse response_1; + + ASSERT_TRUE(stream_->Write(request_1)); + ASSERT_TRUE(stream_->Read(&response_1)); + ASSERT_EQ(response_1.write().status().code(), grpc::OK); + + absl::StatusOr nonce_2 = + nonce_generator_->GetNextBlobNonce(); + ASSERT_TRUE(nonce_2.ok()); + absl::StatusOr rewrapped_record_2 = + crypto_test_utils::CreateRewrappedRecord( + message_2, ciphertext_associated_data, public_key_, + nonce_2->blob_nonce, *reencryption_public_key); + ASSERT_TRUE(rewrapped_record_2.ok()) << rewrapped_record_2.status(); + + SessionRequest request_2; + WriteRequest* write_request_2 = request_2.mutable_write(); + *write_request_2->mutable_first_request_metadata() = + GetBlobMetadataFromRecord(*rewrapped_record_2); + write_request_2->mutable_first_request_metadata() + ->mutable_hpke_plus_aead_data() + ->set_counter(nonce_2->counter); + write_request_2->mutable_first_request_configuration()->PackFrom(config); + write_request_2->set_commit(true); + write_request_2->set_data( + rewrapped_record_2->hpke_plus_aead_data().ciphertext()); + + SessionResponse response_2; + + ASSERT_TRUE(stream_->Write(request_2)); + ASSERT_TRUE(stream_->Read(&response_2)); + ASSERT_EQ(response_2.write().status().code(), grpc::OK); + + google::rpc::Status finalize_config; + finalize_config.set_code(grpc::StatusCode::OK); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + finalize_config); + ASSERT_TRUE(stream_->Write(finalize_request)); + ASSERT_TRUE(stream_->Read(&finalize_response)); + + ASSERT_TRUE(finalize_response.has_read()); + ASSERT_TRUE(finalize_response.read().finish_read()); + ASSERT_GT( + finalize_response.read().first_response_metadata().total_size_bytes(), 0); + ASSERT_TRUE( + finalize_response.read().first_response_metadata().has_unencrypted()); +} + +TEST_F(InitializedConfidentialTransformServerBaseTest, + TransformIgnoresUndecryptableInputs) { + std::string message_1 = "test data 1"; + + auto mock_session = + std::make_unique(); + EXPECT_CALL(*mock_session, ConfigureSession(_)) + .WillOnce(Return(absl::OkStatus())); + EXPECT_CALL(*mock_session, SessionWrite(_, _)) + .WillOnce(Return( + ToSessionWriteFinishedResponse(absl::OkStatus(), message_1.size()))); + EXPECT_CALL(*mock_session, FinalizeSession(_, _)) + .WillOnce(Return(GetDefaultFinalizeResponse())); + service_.AddSession(std::move(mock_session)); + StartSession(); + + MessageDecryptor decryptor; + absl::StatusOr reencryption_public_key = + decryptor.GetPublicKey([](absl::string_view) { return ""; }, 0); + ASSERT_TRUE(reencryption_public_key.ok()); + std::string ciphertext_associated_data = "ciphertext associated data"; + + // Create one record that will fail to decrypt and one record that can be + // successfully decrypted. + std::string message_0 = "test data 0"; + absl::StatusOr nonce_0 = + nonce_generator_->GetNextBlobNonce(); + ASSERT_TRUE(nonce_0.ok()); + absl::StatusOr rewrapped_record_0 = + crypto_test_utils::CreateRewrappedRecord( + message_0, ciphertext_associated_data, public_key_, + nonce_0->blob_nonce, *reencryption_public_key); + ASSERT_TRUE(rewrapped_record_0.ok()) << rewrapped_record_0.status(); + + SessionRequest request_0; + WriteRequest* write_request_0 = request_0.mutable_write(); + google::rpc::Status config; + config.set_code(grpc::StatusCode::OK); + *write_request_0->mutable_first_request_metadata() = + GetBlobMetadataFromRecord(*rewrapped_record_0); + write_request_0->mutable_first_request_metadata() + ->mutable_hpke_plus_aead_data() + ->set_counter(nonce_0->counter); + write_request_0->mutable_first_request_configuration()->PackFrom(config); + write_request_0->set_commit(true); + write_request_0->set_data("undecryptable"); + + SessionResponse response_0; + + ASSERT_TRUE(stream_->Write(request_0)); + ASSERT_TRUE(stream_->Read(&response_0)); + ASSERT_EQ(response_0.write().status().code(), grpc::INVALID_ARGUMENT); + + absl::StatusOr nonce_1 = + nonce_generator_->GetNextBlobNonce(); + ASSERT_TRUE(nonce_1.ok()); + absl::StatusOr rewrapped_record_1 = + crypto_test_utils::CreateRewrappedRecord( + message_1, ciphertext_associated_data, public_key_, + nonce_1->blob_nonce, *reencryption_public_key); + ASSERT_TRUE(rewrapped_record_1.ok()) << rewrapped_record_1.status(); + + SessionRequest request_1; + WriteRequest* write_request_1 = request_1.mutable_write(); + *write_request_1->mutable_first_request_metadata() = + GetBlobMetadataFromRecord(*rewrapped_record_1); + write_request_1->mutable_first_request_metadata() + ->mutable_hpke_plus_aead_data() + ->set_counter(nonce_1->counter); + write_request_1->mutable_first_request_configuration()->PackFrom(config); + write_request_1->set_commit(true); + write_request_1->set_data( + rewrapped_record_1->hpke_plus_aead_data().ciphertext()); + + SessionResponse response_1; + + ASSERT_TRUE(stream_->Write(request_1)); + ASSERT_TRUE(stream_->Read(&response_1)); + ASSERT_EQ(response_1.write().status().code(), grpc::OK); + + google::rpc::Status finalize_config; + finalize_config.set_code(grpc::StatusCode::OK); + SessionRequest finalize_request; + SessionResponse finalize_response; + finalize_request.mutable_finalize()->mutable_configuration()->PackFrom( + finalize_config); + ASSERT_TRUE(stream_->Write(finalize_request)); + ASSERT_TRUE(stream_->Read(&finalize_response)); + + ASSERT_TRUE(finalize_response.has_read()); + ASSERT_TRUE(finalize_response.read().finish_read()); + ASSERT_GT( + finalize_response.read().first_response_metadata().total_size_bytes(), 0); + ASSERT_TRUE( + finalize_response.read().first_response_metadata().has_unencrypted()); +} + +} // namespace + +} // namespace confidential_federated_compute diff --git a/containers/fed_sql/BUILD b/containers/fed_sql/BUILD index d93a64e..5ded894 100644 --- a/containers/fed_sql/BUILD +++ b/containers/fed_sql/BUILD @@ -18,6 +18,7 @@ cc_library( hdrs = ["confidential_transform_server.h"], deps = [ "//containers:blob_metadata", + "//containers:confidential_transform_server_base", "//containers:crypto", "//containers:session", "@com_github_grpc_grpc//:grpc++", diff --git a/containers/fed_sql/confidential_transform_server.cc b/containers/fed_sql/confidential_transform_server.cc index 34171fd..fdfb9bc 100644 --- a/containers/fed_sql/confidential_transform_server.cc +++ b/containers/fed_sql/confidential_transform_server.cc @@ -38,7 +38,6 @@ #include "tensorflow_federated/cc/core/impl/aggregation/core/intrinsic.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/checkpoint_aggregator.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/config_converter.h" -#include "tensorflow_federated/cc/core/impl/aggregation/protocol/configuration.pb.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/federated_compute_checkpoint_builder.h" #include "tensorflow_federated/cc/core/impl/aggregation/protocol/federated_compute_checkpoint_parser.h" @@ -46,15 +45,10 @@ namespace confidential_federated_compute::fed_sql { namespace { -using ::fcp::base::ToGrpcStatus; -using ::fcp::confidential_compute::NonceChecker; using ::fcp::confidentialcompute::AGGREGATION_TYPE_ACCUMULATE; using ::fcp::confidentialcompute::AGGREGATION_TYPE_MERGE; using ::fcp::confidentialcompute::BlobHeader; using ::fcp::confidentialcompute::BlobMetadata; -using ::fcp::confidentialcompute::ConfidentialTransform; -using ::fcp::confidentialcompute::ConfigureRequest; -using ::fcp::confidentialcompute::ConfigureResponse; using ::fcp::confidentialcompute::FedSqlContainerFinalizeConfiguration; using ::fcp::confidentialcompute::FedSqlContainerInitializeConfiguration; using ::fcp::confidentialcompute::FedSqlContainerWriteConfiguration; @@ -62,24 +56,19 @@ using ::fcp::confidentialcompute::FINALIZATION_TYPE_REPORT; using ::fcp::confidentialcompute::FINALIZATION_TYPE_SERIALIZE; using ::fcp::confidentialcompute::FinalizeRequest; using ::fcp::confidentialcompute::InitializeRequest; -using ::fcp::confidentialcompute::InitializeResponse; using ::fcp::confidentialcompute::ReadResponse; using ::fcp::confidentialcompute::Record; -using ::fcp::confidentialcompute::SessionRequest; using ::fcp::confidentialcompute::SessionResponse; using ::fcp::confidentialcompute::WriteRequest; -using ::grpc::ServerContext; using ::tensorflow_federated::aggregation::CheckpointAggregator; using ::tensorflow_federated::aggregation::CheckpointBuilder; using ::tensorflow_federated::aggregation::CheckpointParser; -using ::tensorflow_federated::aggregation::Configuration; using ::tensorflow_federated::aggregation::DT_DOUBLE; using ::tensorflow_federated::aggregation:: FederatedComputeCheckpointBuilderFactory; using ::tensorflow_federated::aggregation:: FederatedComputeCheckpointParserFactory; using ::tensorflow_federated::aggregation::Intrinsic; -using ::tensorflow_federated::aggregation::Tensor; constexpr char kFedSqlDpGroupByUri[] = "fedsql_dp_group_by"; @@ -111,38 +100,14 @@ absl::Status ValidateTopLevelIntrinsics( return absl::OkStatus(); } -// Decrypts and parses a record and accumulates it into the state of the -// CheckpointAggregator `aggregator`. -// -// Returns an error if the aggcore state may be invalid and the session needs to -// be shut down. Otherwise, reports status to the client in -// WriteFinishedResponse -// -// TODO: handle blobs that span multiple WriteRequests. -// TODO: add tracking for available memory. -absl::Status HandleWrite( - const WriteRequest& request, CheckpointAggregator& aggregator, - BlobDecryptor* blob_decryptor, NonceChecker& nonce_checker, - grpc::ServerReaderWriter* stream, - const std::vector* intrinsics) { - if (absl::Status nonce_status = - nonce_checker.CheckBlobNonce(request.first_request_metadata()); - !nonce_status.ok()) { - stream->Write(ToSessionWriteFinishedResponse(nonce_status)); - return absl::OkStatus(); - } +} // namespace +absl::StatusOr FedSqlSession::SessionWrite( + const WriteRequest& write_request, std::string unencrypted_data) { FedSqlContainerWriteConfiguration write_config; - if (!request.first_request_configuration().UnpackTo(&write_config)) { - stream->Write(ToSessionWriteFinishedResponse(absl::InvalidArgumentError( - "Failed to parse FedSqlContainerWriteConfiguration."))); - return absl::OkStatus(); - } - absl::StatusOr unencrypted_data = blob_decryptor->DecryptBlob( - request.first_request_metadata(), request.data()); - if (!unencrypted_data.ok()) { - stream->Write(ToSessionWriteFinishedResponse(unencrypted_data.status())); - return absl::OkStatus(); + if (!write_request.first_request_configuration().UnpackTo(&write_config)) { + return ToSessionWriteFinishedResponse(absl::InvalidArgumentError( + "Failed to parse FedSqlContainerWriteConfiguration.")); } // In case of an error with Accumulate or MergeWith, the session is @@ -153,41 +118,37 @@ absl::Status HandleWrite( case AGGREGATION_TYPE_ACCUMULATE: { FederatedComputeCheckpointParserFactory parser_factory; absl::StatusOr> parser = - parser_factory.Create(absl::Cord(std::move(*unencrypted_data))); + parser_factory.Create(absl::Cord(std::move(unencrypted_data))); if (!parser.ok()) { - stream->Write(ToSessionWriteFinishedResponse( + return ToSessionWriteFinishedResponse( absl::Status(parser.status().code(), absl::StrCat("Failed to deserialize checkpoint for " "AGGREGATION_TYPE_ACCUMULATE: ", - parser.status().message())))); - return absl::OkStatus(); + parser.status().message()))); } - FCP_RETURN_IF_ERROR(aggregator.Accumulate(*parser.value())); + FCP_RETURN_IF_ERROR(aggregator_->Accumulate(*parser.value())); break; } case AGGREGATION_TYPE_MERGE: { absl::StatusOr> other = - CheckpointAggregator::Deserialize(intrinsics, *unencrypted_data); + CheckpointAggregator::Deserialize(&intrinsics_, unencrypted_data); if (!other.ok()) { - stream->Write(ToSessionWriteFinishedResponse( + return ToSessionWriteFinishedResponse( absl::Status(other.status().code(), absl::StrCat("Failed to deserialize checkpoint for " "AGGREGATION_TYPE_MERGE: ", - other.status().message())))); - return absl::OkStatus(); + other.status().message()))); } - FCP_RETURN_IF_ERROR(aggregator.MergeWith(std::move(*other.value()))); + FCP_RETURN_IF_ERROR(aggregator_->MergeWith(std::move(*other.value()))); break; } default: - stream->Write(ToSessionWriteFinishedResponse(absl::InvalidArgumentError( - "AggCoreAggregationType must be specified."))); - return absl::OkStatus(); + return ToSessionWriteFinishedResponse(absl::InvalidArgumentError( + "AggCoreAggregationType must be specified.")); } - - stream->Write(ToSessionWriteFinishedResponse( - absl::OkStatus(), request.first_request_metadata().total_size_bytes())); - return absl::OkStatus(); + return ToSessionWriteFinishedResponse( + absl::OkStatus(), + write_request.first_request_metadata().total_size_bytes()); } // Runs the requested finalization operation and write the uncompressed result @@ -195,11 +156,8 @@ absl::Status HandleWrite( // // Any errors in HandleFinalize kill the stream, since the stream can no longer // be modified after the Finalize call. -absl::Status HandleFinalize( - const FinalizeRequest& request, - std::unique_ptr aggregator, - grpc::ServerReaderWriter* stream, - const BlobMetadata& input_metadata) { +absl::StatusOr FedSqlSession::FinalizeSession( + const FinalizeRequest& request, const BlobMetadata& input_metadata) { FedSqlContainerFinalizeConfiguration finalize_config; if (!request.configuration().UnpackTo(&finalize_config)) { return absl::InvalidArgumentError( @@ -209,7 +167,7 @@ absl::Status HandleFinalize( BlobMetadata result_metadata; switch (finalize_config.type()) { case fcp::confidentialcompute::FINALIZATION_TYPE_REPORT: { - if (!aggregator->CanReport()) { + if (!aggregator_->CanReport()) { return absl::FailedPreconditionError( "The aggregation can't be completed due to failed preconditions."); } @@ -217,7 +175,7 @@ absl::Status HandleFinalize( FederatedComputeCheckpointBuilderFactory builder_factory; std::unique_ptr checkpoint_builder = builder_factory.Create(); - FCP_RETURN_IF_ERROR(aggregator->Report(*checkpoint_builder)); + FCP_RETURN_IF_ERROR(aggregator_->Report(*checkpoint_builder)); FCP_ASSIGN_OR_RETURN(absl::Cord checkpoint_cord, checkpoint_builder->Build()); absl::CopyCordToString(checkpoint_cord, &result); @@ -228,7 +186,7 @@ absl::Status HandleFinalize( } case FINALIZATION_TYPE_SERIALIZE: { FCP_ASSIGN_OR_RETURN(std::string serialized_aggregator, - std::move(*aggregator).Serialize()); + std::move(*aggregator_).Serialize()); if (input_metadata.has_unencrypted()) { result = std::move(serialized_aggregator); result_metadata.set_total_size_bytes(result.size()); @@ -264,15 +222,12 @@ absl::Status HandleFinalize( read_response->set_finish_read(true); *(read_response->mutable_data()) = result; *(read_response->mutable_first_response_metadata()) = result_metadata; - stream->Write(response); - return absl::OkStatus(); + return response; } -} // namespace - -absl::Status FedSqlConfidentialTransform::FedSqlInitialize( - const fcp::confidentialcompute::InitializeRequest* request, - fcp::confidentialcompute::InitializeResponse* response) { +absl::StatusOr +FedSqlConfidentialTransform::InitializeTransform( + const fcp::confidentialcompute::InitializeRequest* request) { FedSqlContainerInitializeConfiguration config; if (!request->configuration().UnpackTo(&config)) { return absl::InvalidArgumentError( @@ -280,7 +235,6 @@ absl::Status FedSqlConfidentialTransform::FedSqlInitialize( } FCP_RETURN_IF_ERROR( CheckpointAggregator::ValidateConfig(config.agg_configuration())); - const BlobDecryptor* blob_decryptor; { absl::MutexLock l(&mutex_); if (intrinsics_ != std::nullopt) { @@ -309,123 +263,30 @@ absl::Status FedSqlConfidentialTransform::FedSqlInitialize( } intrinsics_.emplace(std::move(intrinsics)); - blob_decryptor_.emplace(crypto_stub_, config_properties); - - // Since blob_decryptor_ is set once in Initialize and never - // modified, and the underlying object is threadsafe, it is safe to store a - // local pointer to it and access the object without a lock after we check - // under the mutex that a value has been set for the std::optional wrapper. - blob_decryptor = &*blob_decryptor_; + return config_properties; } - - FCP_ASSIGN_OR_RETURN(*response->mutable_public_key(), - blob_decryptor->GetPublicKey()); - return absl::OkStatus(); } -absl::Status FedSqlConfidentialTransform::FedSqlSession( - grpc::ServerReaderWriter* stream) { - BlobDecryptor* blob_decryptor; +absl::StatusOr> +FedSqlConfidentialTransform::CreateSession() { std::unique_ptr aggregator; const std::vector* intrinsics; { absl::MutexLock l(&mutex_); - if (intrinsics_ == std::nullopt || blob_decryptor_ == std::nullopt) { + if (intrinsics_ == std::nullopt) { return absl::FailedPreconditionError( "Initialize must be called before Session."); } - // Since blob_decryptor_ is set once in Initialize and never + // Since intrinsics_ is set once in Initialize and never // modified, and the underlying object is threadsafe, it is safe to store a // local pointer to it and access the object without a lock after we check // under the mutex that values have been set for the std::optional wrappers. - blob_decryptor = &*blob_decryptor_; intrinsics = &*intrinsics_; } FCP_ASSIGN_OR_RETURN(aggregator, CheckpointAggregator::Create(intrinsics)); - SessionRequest configure_request; - bool success = stream->Read(&configure_request); - if (!success) { - return absl::AbortedError("Session failed to read client message."); - } - - if (!configure_request.has_configure()) { - return absl::FailedPreconditionError( - "Session must be configured with a ConfigureRequest before any other " - "requests."); - } - SessionResponse configure_response; - NonceChecker nonce_checker; - *configure_response.mutable_configure()->mutable_nonce() = - nonce_checker.GetSessionNonce(); - stream->Write(configure_response); - - // Initialze result_blob_metadata with unencrypted metadata since - // EarliestExpirationTimeMetadata expects inputs to have either unencrypted or - // hpke_plus_aead_data. - BlobMetadata result_blob_metadata; - result_blob_metadata.mutable_unencrypted(); - SessionRequest session_request; - while (stream->Read(&session_request)) { - switch (session_request.kind_case()) { - case SessionRequest::kWrite: { - const WriteRequest& write_request = session_request.write(); - // If any of the input blobs are encrypted, then encrypt the result of - // FINALIZATION_TYPE_SERIALIZE. Use the metadata with the earliest - // expiration timestamp. - absl::StatusOr earliest_expiration_metadata = - EarliestExpirationTimeMetadata( - result_blob_metadata, write_request.first_request_metadata()); - if (!earliest_expiration_metadata.ok()) { - stream->Write(ToSessionWriteFinishedResponse(absl::Status( - earliest_expiration_metadata.status().code(), - absl::StrCat("Failed to extract expiration timestamp from " - "BlobMetadata with encryption: ", - earliest_expiration_metadata.status().message())))); - break; - } - result_blob_metadata = *earliest_expiration_metadata; - // TODO: spin up a thread to incorporate each blob. - FCP_RETURN_IF_ERROR(HandleWrite(write_request, *aggregator, - blob_decryptor, nonce_checker, stream, - intrinsics)); - break; - } - case SessionRequest::kFinalize: - return HandleFinalize(session_request.finalize(), std::move(aggregator), - stream, result_blob_metadata); - case SessionRequest::kConfigure: - default: - return absl::FailedPreconditionError(absl::StrCat( - "Session expected a write request but received request of type: ", - session_request.kind_case())); - } - } - - return absl::AbortedError( - "Session failed to read client write or finalize message."); + return std::make_unique( + FedSqlSession(std::move(aggregator), *intrinsics)); } - -grpc::Status FedSqlConfidentialTransform::Initialize( - ServerContext* context, const InitializeRequest* request, - InitializeResponse* response) { - return ToGrpcStatus(FedSqlInitialize(request, response)); -} - -grpc::Status FedSqlConfidentialTransform::Session( - ServerContext* context, - grpc::ServerReaderWriter* stream) { - if (absl::Status session_status = session_tracker_.AddSession(); - !session_status.ok()) { - return ToGrpcStatus(session_status); - } - grpc::Status status = ToGrpcStatus(FedSqlSession(stream)); - absl::Status remove_session = session_tracker_.RemoveSession(); - if (!remove_session.ok()) { - return ToGrpcStatus(remove_session); - } - return status; -} - } // namespace confidential_federated_compute::fed_sql diff --git a/containers/fed_sql/confidential_transform_server.h b/containers/fed_sql/confidential_transform_server.h index 2b13adc..d947e85 100644 --- a/containers/fed_sql/confidential_transform_server.h +++ b/containers/fed_sql/confidential_transform_server.h @@ -21,6 +21,7 @@ #include "absl/log/die_if_null.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" +#include "containers/confidential_transform_server_base.h" #include "containers/crypto.h" #include "containers/session.h" #include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" @@ -38,47 +39,59 @@ namespace confidential_federated_compute::fed_sql { // step of FedSQL. // TODO: execute the per-client SQL query step. class FedSqlConfidentialTransform final - : public fcp::confidentialcompute::ConfidentialTransform::Service { + : public confidential_federated_compute::ConfidentialTransformBase { public: - // The OrchestratorCrypto stub must not be NULL and must outlive this object. - // TODO: add absl::Nonnull to crypto_stub. - explicit FedSqlConfidentialTransform( + FedSqlConfidentialTransform( oak::containers::v1::OrchestratorCrypto::StubInterface* crypto_stub, int max_num_sessions) - : crypto_stub_(*ABSL_DIE_IF_NULL(crypto_stub)), - session_tracker_(max_num_sessions) {} + : ConfidentialTransformBase(crypto_stub, max_num_sessions) {}; - grpc::Status Initialize( - grpc::ServerContext* context, - const fcp::confidentialcompute::InitializeRequest* request, - fcp::confidentialcompute::InitializeResponse* response) override; - - grpc::Status Session( - grpc::ServerContext* context, - grpc::ServerReaderWriter* - stream) override; + protected: + virtual absl::StatusOr InitializeTransform( + const fcp::confidentialcompute::InitializeRequest* request) override; + virtual absl::StatusOr< + std::unique_ptr> + CreateSession() override; private: - absl::Status FedSqlInitialize( - const fcp::confidentialcompute::InitializeRequest* request, - fcp::confidentialcompute::InitializeResponse* response); - - absl::Status FedSqlSession( - grpc::ServerReaderWriter* - stream); - - oak::containers::v1::OrchestratorCrypto::StubInterface& crypto_stub_; - confidential_federated_compute::SessionTracker session_tracker_; absl::Mutex mutex_; - // The mutex is used to protect the optional wrapping blob_decryptor_ and - // intrinsics_ to ensure the BlobDecryptor and vector are initialized, but - // the BlobDecryptor and const vector are themselves threadsafe. std::optional> intrinsics_ ABSL_GUARDED_BY(mutex_); - std::optional blob_decryptor_ - ABSL_GUARDED_BY(mutex_); +}; + +// FedSql implementation of Session interface. Not threadsafe. +class FedSqlSession final : public confidential_federated_compute::Session { + public: + FedSqlSession( + std::unique_ptr + aggregator, + const std::vector& + intrinsics) + : aggregator_(std::move(aggregator)), intrinsics_(intrinsics) {}; + // Currently no FedSql per-session configuration. + absl::Status ConfigureSession( + fcp::confidentialcompute::SessionRequest configure_request) override { + return absl::OkStatus(); + } + // Accumulates a record into the state of the CheckpointAggregator + // `aggregator`. + // + // Returns an error if the aggcore state may be invalid and the session + // needs to be shut down. + absl::StatusOr SessionWrite( + const fcp::confidentialcompute::WriteRequest& write_request, + std::string unencrypted_data) override; + // Run any session finalization logic and complete the session. + // After finalization, the session state is no longer mutable. + absl::StatusOr FinalizeSession( + const fcp::confidentialcompute::FinalizeRequest& request, + const fcp::confidentialcompute::BlobMetadata& input_metadata) override; + + private: + // The aggregator used during the session to accumulate writes. + std::unique_ptr + aggregator_; + const std::vector& intrinsics_; }; } // namespace confidential_federated_compute::fed_sql diff --git a/containers/session.h b/containers/session.h index f3b570a..9b5a883 100644 --- a/containers/session.h +++ b/containers/session.h @@ -18,6 +18,7 @@ #define CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_SESSION_H_ #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "fcp/protos/confidentialcompute/confidential_transform.pb.h" namespace confidential_federated_compute { @@ -48,5 +49,26 @@ class SessionTracker { fcp::confidentialcompute::SessionResponse ToSessionWriteFinishedResponse( absl::Status status, long committed_size_bytes = 0); +// Interface for interacting with a session in a container. Implementations +// may not be threadsafe. +class Session { + public: + // Initialize the session with the given configuration. + virtual absl::Status ConfigureSession( + fcp::confidentialcompute::SessionRequest configure_request) = 0; + // Incorporate a write request into the session. + virtual absl::StatusOr + SessionWrite(const fcp::confidentialcompute::WriteRequest& write_request, + std::string unencrypted_data) = 0; + // Run any session finalization logic and complete the session. + // After finalization, the session state is no longer mutable. + virtual absl::StatusOr + FinalizeSession( + const fcp::confidentialcompute::FinalizeRequest& request, + const fcp::confidentialcompute::BlobMetadata& input_metadata) = 0; + + virtual ~Session() = default; +}; + } // namespace confidential_federated_compute #endif // CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_SESSION_H_