diff --git a/containers/confidential_transform_test_concat/BUILD b/containers/confidential_transform_test_concat/BUILD index cb7b30c..f10efad 100644 --- a/containers/confidential_transform_test_concat/BUILD +++ b/containers/confidential_transform_test_concat/BUILD @@ -18,10 +18,10 @@ load("@rules_pkg//pkg:tar.bzl", "pkg_tar") cc_library( name = "confidential_transform_server", - srcs = ["confidential_transform_server.cc"], hdrs = ["confidential_transform_server.h"], deps = [ "//containers:blob_metadata", + "//containers:confidential_transform_server_base", "//containers:crypto", "//containers:session", "@com_github_grpc_grpc//:grpc++", @@ -52,6 +52,7 @@ cc_test( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@federated-compute//fcp/confidentialcompute:crypto", "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_grpc", "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_proto", "@googletest//:gtest_main", diff --git a/containers/confidential_transform_test_concat/confidential_transform_server.cc b/containers/confidential_transform_test_concat/confidential_transform_server.cc deleted file mode 100644 index 0f62b1f..0000000 --- a/containers/confidential_transform_test_concat/confidential_transform_server.cc +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "containers/confidential_transform_test_concat/confidential_transform_server.h" - -#include -#include -#include -#include -#include - -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "containers/blob_metadata.h" -#include "containers/crypto.h" -#include "containers/session.h" -#include "fcp/base/status_converters.h" -#include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" -#include "fcp/protos/confidentialcompute/confidential_transform.pb.h" -#include "grpcpp/support/status.h" -#include "tensorflow_federated/cc/core/impl/aggregation/base/monitoring.h" - -namespace confidential_federated_compute::confidential_transform_test_concat { - -namespace { - -using ::fcp::base::ToGrpcStatus; -using ::fcp::confidential_compute::NonceChecker; -using ::fcp::confidentialcompute::BlobMetadata; -using ::fcp::confidentialcompute::ConfidentialTransform; -using ::fcp::confidentialcompute::ConfigureRequest; -using ::fcp::confidentialcompute::ConfigureResponse; -using ::fcp::confidentialcompute::FinalizeRequest; -using ::fcp::confidentialcompute::InitializeRequest; -using ::fcp::confidentialcompute::InitializeResponse; -using ::fcp::confidentialcompute::ReadResponse; -using ::fcp::confidentialcompute::SessionRequest; -using ::fcp::confidentialcompute::SessionResponse; -using ::fcp::confidentialcompute::WriteFinishedResponse; -using ::fcp::confidentialcompute::WriteRequest; -using ::grpc::ServerContext; - -} // namespace - -absl::Status TestConcatConfidentialTransform::Initialize( - const fcp::confidentialcompute::InitializeRequest* request, - fcp::confidentialcompute::InitializeResponse* response) { - const BlobDecryptor* blob_decryptor; - { - absl::MutexLock l(&mutex_); - blob_decryptor_.emplace(crypto_stub_); - - // Since blob_decryptor_ is set once in Initialize and never - // modified, and the underlying object is threadsafe, it is safe to store a - // local pointer to it and access the object without a lock after we check - // under the mutex that a value has been set for the std::optional wrapper. - blob_decryptor = &*blob_decryptor_; - } - - FCP_ASSIGN_OR_RETURN(*response->mutable_public_key(), - blob_decryptor->GetPublicKey()); - return absl::OkStatus(); -} - -absl::Status TestConcatConfidentialTransform::Session( - grpc::ServerReaderWriter* stream) { - BlobDecryptor* blob_decryptor; - { - absl::MutexLock l(&mutex_); - if (blob_decryptor_ == std::nullopt) { - return absl::FailedPreconditionError( - "Initialize must be called before Session."); - } - - // Since blob_decryptor_ is set once in Initialize and never - // modified, and the underlying object is threadsafe, it is safe to store a - // local pointer to it and access the object without a lock after we check - // under the mutex that values have been set for the std::optional wrappers. - blob_decryptor = &*blob_decryptor_; - } - - SessionRequest configure_request; - bool success = stream->Read(&configure_request); - if (!success) { - return absl::AbortedError("Session failed to read client message."); - } - - if (!configure_request.has_configure()) { - return absl::FailedPreconditionError( - "Session must be configured with a ConfigureRequest before any other " - "requests."); - } - SessionResponse configure_response; - NonceChecker nonce_checker; - *configure_response.mutable_configure()->mutable_nonce() = - nonce_checker.GetSessionNonce(); - configure_response.mutable_configure(); - stream->Write(configure_response); - - SessionRequest session_request; - std::string state = ""; - while (stream->Read(&session_request)) { - switch (session_request.kind_case()) { - case SessionRequest::kWrite: { - const WriteRequest& write_request = session_request.write(); - if (absl::Status nonce_status = nonce_checker.CheckBlobNonce( - write_request.first_request_metadata()); - !nonce_status.ok()) { - stream->Write(ToSessionWriteFinishedResponse(nonce_status)); - break; - } - - absl::StatusOr unencrypted_data = - blob_decryptor->DecryptBlob(write_request.first_request_metadata(), - write_request.data()); - if (!unencrypted_data.ok()) { - stream->Write( - ToSessionWriteFinishedResponse(unencrypted_data.status())); - break; - } - - absl::StrAppend(&state, *unencrypted_data); - stream->Write(ToSessionWriteFinishedResponse( - absl::OkStatus(), - write_request.first_request_metadata().total_size_bytes())); - break; - } - case SessionRequest::kFinalize: { - SessionResponse response; - ReadResponse* read_response = response.mutable_read(); - read_response->set_finish_read(true); - *(read_response->mutable_data()) = state; - - BlobMetadata result_metadata; - result_metadata.mutable_unencrypted(); - result_metadata.set_total_size_bytes(state.length()); - result_metadata.set_compression_type( - BlobMetadata::COMPRESSION_TYPE_NONE); - *(read_response->mutable_first_response_metadata()) = result_metadata; - - stream->Write(response); - return absl::OkStatus(); - } - case SessionRequest::kConfigure: - default: - return absl::FailedPreconditionError( - absl::StrCat("Session expected a write or finalize request but " - "received request of type: ", - session_request.kind_case())); - } - } - - return absl::AbortedError( - "Session failed to read client write or finalize message."); -} - -grpc::Status TestConcatConfidentialTransform::Initialize( - ServerContext* context, const InitializeRequest* request, - InitializeResponse* response) { - return ToGrpcStatus(Initialize(request, response)); -} - -grpc::Status TestConcatConfidentialTransform::Session( - ServerContext* context, - grpc::ServerReaderWriter* stream) { - grpc::Status status = ToGrpcStatus(Session(stream)); - return status; -} - -} // namespace - // confidential_federated_compute::confidential_transform_test_concat diff --git a/containers/confidential_transform_test_concat/confidential_transform_server.h b/containers/confidential_transform_test_concat/confidential_transform_server.h index 6317367..5622fc9 100644 --- a/containers/confidential_transform_test_concat/confidential_transform_server.h +++ b/containers/confidential_transform_test_concat/confidential_transform_server.h @@ -21,7 +21,9 @@ #include "absl/log/die_if_null.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" +#include "containers/confidential_transform_server_base.h" #include "containers/crypto.h" +#include "containers/session.h" #include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" #include "fcp/protos/confidentialcompute/confidential_transform.pb.h" #include "grpcpp/server_context.h" @@ -30,44 +32,70 @@ namespace confidential_federated_compute::confidential_transform_test_concat { -// Test ConfidentialTransform service that concatenates inputs. This test -// service doesn't manage the number of sessions. -class TestConcatConfidentialTransform final - : public fcp::confidentialcompute::ConfidentialTransform::Service { +// TestConcat implementation of Session interface. Not threadsafe. +class TestConcatSession final : public confidential_federated_compute::Session { public: - // The OrchestratorCrypto stub must not be NULL and must outlive this object. - explicit TestConcatConfidentialTransform( - oak::containers::v1::OrchestratorCrypto::StubInterface* crypto_stub) - : crypto_stub_(*ABSL_DIE_IF_NULL(crypto_stub)) {} - - grpc::Status Initialize( - grpc::ServerContext* context, - const fcp::confidentialcompute::InitializeRequest* request, - fcp::confidentialcompute::InitializeResponse* response) override; + TestConcatSession() {}; + // Currently no per-session configuration. + absl::Status ConfigureSession( + fcp::confidentialcompute::SessionRequest configure_request) override { + return absl::OkStatus(); + } + // Concatenates the unencrypted data to the result string. + absl::StatusOr SessionWrite( + const fcp::confidentialcompute::WriteRequest& write_request, + std::string unencrypted_data) override { + absl::StrAppend(&state_, unencrypted_data); + return confidential_federated_compute::ToSessionWriteFinishedResponse( + absl::OkStatus(), + write_request.first_request_metadata().total_size_bytes()); + } + // Run any session finalization logic and complete the session. + // After finalization, the session state is no longer mutable. + absl::StatusOr FinalizeSession( + const fcp::confidentialcompute::FinalizeRequest& request, + const fcp::confidentialcompute::BlobMetadata& input_metadata) override { + fcp::confidentialcompute::SessionResponse response; + fcp::confidentialcompute::ReadResponse* read_response = + response.mutable_read(); + read_response->set_finish_read(true); + *(read_response->mutable_data()) = state_; - grpc::Status Session( - grpc::ServerContext* context, - grpc::ServerReaderWriter* - stream) override; + fcp::confidentialcompute::BlobMetadata result_metadata; + result_metadata.mutable_unencrypted(); + result_metadata.set_total_size_bytes(state_.length()); + result_metadata.set_compression_type( + fcp::confidentialcompute::BlobMetadata::COMPRESSION_TYPE_NONE); + *(read_response->mutable_first_response_metadata()) = result_metadata; + return response; + } private: - absl::Status Initialize( - const fcp::confidentialcompute::InitializeRequest* request, - fcp::confidentialcompute::InitializeResponse* response); + std::string state_ = ""; +}; - absl::Status Session( - grpc::ServerReaderWriter* - stream); +// Test ConfidentialTransform service that concatenates inputs. +class TestConcatConfidentialTransform final + : public confidential_federated_compute::ConfidentialTransformBase { + public: + TestConcatConfidentialTransform( + oak::containers::v1::OrchestratorCrypto::StubInterface* crypto_stub, + int max_num_sessions = 1) + : ConfidentialTransformBase(crypto_stub, max_num_sessions) {}; - oak::containers::v1::OrchestratorCrypto::StubInterface& crypto_stub_; - absl::Mutex mutex_; - // The mutex is used to protect the optional wrapping blob_decryptor_ to - // ensure the BlobDecryptor is initialized, but the BlobDecryptor is itself - // threadsafe. - std::optional blob_decryptor_ - ABSL_GUARDED_BY(mutex_); + protected: + virtual absl::StatusOr InitializeTransform( + const fcp::confidentialcompute::InitializeRequest* request) override { + google::protobuf::Struct config_properties; + return config_properties; + } + virtual absl::StatusOr< + std::unique_ptr> + CreateSession() override { + return std::make_unique< + confidential_federated_compute::confidential_transform_test_concat:: + TestConcatSession>(); + }; }; } // namespace diff --git a/containers/confidential_transform_test_concat/confidential_transform_server_test.cc b/containers/confidential_transform_test_concat/confidential_transform_server_test.cc index 04c69bc..74f86e5 100644 --- a/containers/confidential_transform_test_concat/confidential_transform_server_test.cc +++ b/containers/confidential_transform_test_concat/confidential_transform_server_test.cc @@ -21,6 +21,7 @@ #include "containers/blob_metadata.h" #include "containers/crypto.h" #include "containers/crypto_test_utils.h" +#include "fcp/confidentialcompute/crypto.h" #include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" #include "fcp/protos/confidentialcompute/confidential_transform.pb.h" #include "gmock/gmock.h" @@ -39,6 +40,7 @@ namespace confidential_federated_compute::confidential_transform_test_concat { namespace { +using ::fcp::confidential_compute::MessageDecryptor; using ::fcp::confidential_compute::NonceAndCounter; using ::fcp::confidential_compute::NonceGenerator; using ::fcp::confidentialcompute::BlobHeader; @@ -211,7 +213,10 @@ TEST_F(TestConcatServerSessionTest, SessionWritesAndFinalizesUnencryptedBlobs) { } TEST_F(TestConcatServerSessionTest, SessionDecryptsMultipleBlobsAndFinalizes) { - std::string reencryption_public_key = "reencryption key"; + MessageDecryptor decryptor; + absl::StatusOr reencryption_public_key = + decryptor.GetPublicKey([](absl::string_view) { return ""; }, 0); + ASSERT_TRUE(reencryption_public_key.ok()); std::string ciphertext_associated_data = BlobHeader::default_instance().SerializeAsString(); @@ -222,7 +227,7 @@ TEST_F(TestConcatServerSessionTest, SessionDecryptsMultipleBlobsAndFinalizes) { absl::StatusOr rewrapped_record_0 = crypto_test_utils::CreateRewrappedRecord( message_0, ciphertext_associated_data, public_key_, - nonce_0->blob_nonce, reencryption_public_key); + nonce_0->blob_nonce, *reencryption_public_key); ASSERT_TRUE(rewrapped_record_0.ok()) << rewrapped_record_0.status(); SessionRequest request_0; @@ -249,7 +254,7 @@ TEST_F(TestConcatServerSessionTest, SessionDecryptsMultipleBlobsAndFinalizes) { absl::StatusOr rewrapped_record_1 = crypto_test_utils::CreateRewrappedRecord( message_1, ciphertext_associated_data, public_key_, - nonce_1->blob_nonce, reencryption_public_key); + nonce_1->blob_nonce, *reencryption_public_key); ASSERT_TRUE(rewrapped_record_1.ok()) << rewrapped_record_1.status(); SessionRequest request_1; @@ -285,7 +290,10 @@ TEST_F(TestConcatServerSessionTest, SessionDecryptsMultipleBlobsAndFinalizes) { } TEST_F(TestConcatServerSessionTest, SessionIgnoresUndecryptableInputs) { - std::string reencryption_public_key = "reencryption key"; + MessageDecryptor decryptor; + absl::StatusOr reencryption_public_key = + decryptor.GetPublicKey([](absl::string_view) { return ""; }, 0); + ASSERT_TRUE(reencryption_public_key.ok()); std::string ciphertext_associated_data = BlobHeader::default_instance().SerializeAsString(); @@ -296,7 +304,7 @@ TEST_F(TestConcatServerSessionTest, SessionIgnoresUndecryptableInputs) { absl::StatusOr rewrapped_record_0 = crypto_test_utils::CreateRewrappedRecord( message_0, ciphertext_associated_data, public_key_, - nonce_0->blob_nonce, reencryption_public_key); + nonce_0->blob_nonce, *reencryption_public_key); ASSERT_TRUE(rewrapped_record_0.ok()) << rewrapped_record_0.status(); SessionRequest request_0; @@ -321,7 +329,7 @@ TEST_F(TestConcatServerSessionTest, SessionIgnoresUndecryptableInputs) { absl::StatusOr rewrapped_record_1 = crypto_test_utils::CreateRewrappedRecord( "unused message", ciphertext_associated_data, public_key_, - nonce_1->blob_nonce, reencryption_public_key); + nonce_1->blob_nonce, *reencryption_public_key); ASSERT_TRUE(rewrapped_record_1.ok()) << rewrapped_record_1.status(); SessionRequest invalid_request;