Skip to content

Commit

Permalink
Encrypt the outputs of report requests if the report output node ID is
Browse files Browse the repository at this point in the history
set.

Set the output node ID of encrypted outputs based on the IDs specified
in the InitializeRequest configuration.

BUG: 345838534
Change-Id: Ifdac93b5f12decf0d0c7b979a665f0a1110fd67b
  • Loading branch information
zpgong committed Aug 9, 2024
1 parent abe65e5 commit c0ef67d
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 21 deletions.
66 changes: 47 additions & 19 deletions containers/fed_sql/confidential_transform_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,24 @@ absl::StatusOr<absl::Cord> Serialize(std::vector<TensorColumn> columns) {
return ckpt_builder->Build();
}

absl::StatusOr<Record> EncryptSessionResult(
const BlobMetadata& input_metadata, absl::string_view unencrypted_result,
uint32_t output_access_policy_node_id) {
RecordEncryptor encryptor;
BlobHeader previous_header;
if (!previous_header.ParseFromString(
input_metadata.hpke_plus_aead_data().ciphertext_associated_data())) {
return absl::InvalidArgumentError(
"Failed to parse the BlobHeader when trying to encrypt outputs.");
}
return encryptor.EncryptRecord(unencrypted_result,
input_metadata.hpke_plus_aead_data()
.rewrapped_symmetric_key_associated_data()
.reencryption_public_key(),
previous_header.access_policy_sha256(),
output_access_policy_node_id);
}

} // namespace

absl::StatusOr<std::unique_ptr<CheckpointParser>>
Expand Down Expand Up @@ -271,10 +289,25 @@ absl::StatusOr<SessionResponse> FedSqlSession::FinalizeSession(
FCP_RETURN_IF_ERROR(aggregator_->Report(*checkpoint_builder));
FCP_ASSIGN_OR_RETURN(absl::Cord checkpoint_cord,
checkpoint_builder->Build());
absl::CopyCordToString(checkpoint_cord, &result);
result_metadata.set_compression_type(BlobMetadata::COMPRESSION_TYPE_NONE);
result_metadata.set_total_size_bytes(result.size());
result_metadata.mutable_unencrypted();
std::string unencrypted_result;
absl::CopyCordToString(checkpoint_cord, &unencrypted_result);

if (input_metadata.has_unencrypted() ||
report_output_access_policy_node_id_ == std::nullopt) {
result_metadata.set_compression_type(
BlobMetadata::COMPRESSION_TYPE_NONE);
result_metadata.set_total_size_bytes(unencrypted_result.size());
result_metadata.mutable_unencrypted();
result = std::move(unencrypted_result);
break;
}

FCP_ASSIGN_OR_RETURN(
Record result_record,
EncryptSessionResult(input_metadata, unencrypted_result,
*report_output_access_policy_node_id_));
result_metadata = GetBlobMetadataFromRecord(result_record);
result = result_record.hpke_plus_aead_data().ciphertext();
break;
}
case FINALIZATION_TYPE_SERIALIZE: {
Expand All @@ -286,21 +319,15 @@ absl::StatusOr<SessionResponse> FedSqlSession::FinalizeSession(
result_metadata.mutable_unencrypted();
break;
}
RecordEncryptor encryptor;
BlobHeader previous_header;
if (!previous_header.ParseFromString(input_metadata.hpke_plus_aead_data()
.ciphertext_associated_data())) {
if (serialize_output_access_policy_node_id_ == std::nullopt) {
return absl::InvalidArgumentError(
"Failed to parse the BlobHeader when trying to encrypt outputs.");
"No output access policy node ID set for serialized outputs. This "
"must be set to output serialized state.");
}
FCP_ASSIGN_OR_RETURN(Record result_record,
encryptor.EncryptRecord(
serialized_aggregator,
input_metadata.hpke_plus_aead_data()
.rewrapped_symmetric_key_associated_data()
.reencryption_public_key(),
previous_header.access_policy_sha256(),
finalize_config.output_access_policy_node_id()));
FCP_ASSIGN_OR_RETURN(
Record result_record,
EncryptSessionResult(input_metadata, serialized_aggregator,
*serialize_output_access_policy_node_id_));
result_metadata = GetBlobMetadataFromRecord(result_record);
result = result_record.hpke_plus_aead_data().ciphertext();
break;
Expand Down Expand Up @@ -389,10 +416,11 @@ FedSqlConfidentialTransform::CreateSession() {
// under the mutex that values have been set for the std::optional wrappers.
intrinsics = &*intrinsics_;
}

FCP_ASSIGN_OR_RETURN(aggregator, CheckpointAggregator::Create(intrinsics));
return std::make_unique<FedSqlSession>(
FedSqlSession(std::move(aggregator), *intrinsics));
FedSqlSession(std::move(aggregator), *intrinsics,
serialize_output_access_policy_node_id_,
report_output_access_policy_node_id_));
}

absl::Status FedSqlSession::ConfigureSession(
Expand Down
13 changes: 11 additions & 2 deletions containers/fed_sql/confidential_transform_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,15 @@ class FedSqlSession final : public confidential_federated_compute::Session {
std::unique_ptr<tensorflow_federated::aggregation::CheckpointAggregator>
aggregator,
const std::vector<tensorflow_federated::aggregation::Intrinsic>&
intrinsics)
: aggregator_(std::move(aggregator)), intrinsics_(intrinsics) {};
intrinsics,
const std::optional<uint32_t> serialize_output_access_policy_node_id,
const std::optional<uint32_t> report_output_access_policy_node_id)
: aggregator_(std::move(aggregator)),
intrinsics_(intrinsics),
serialize_output_access_policy_node_id_(
serialize_output_access_policy_node_id),
report_output_access_policy_node_id_(
report_output_access_policy_node_id) {};

// Configure the optional per-client SQL query.
absl::Status ConfigureSession(
Expand Down Expand Up @@ -112,6 +119,8 @@ class FedSqlSession final : public confidential_federated_compute::Session {
aggregator_;
const std::vector<tensorflow_federated::aggregation::Intrinsic>& intrinsics_;
std::optional<const SqlConfiguration> sql_configuration_;
const std::optional<uint32_t> serialize_output_access_policy_node_id_;
const std::optional<uint32_t> report_output_access_policy_node_id_;
};

} // namespace confidential_federated_compute::fed_sql
Expand Down
185 changes: 185 additions & 0 deletions containers/fed_sql/confidential_transform_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ using ::testing::Test;
using testing::UnorderedElementsAre;

inline constexpr int kMaxNumSessions = 8;
inline constexpr int kSerializeOutputNodeId = 1;
inline constexpr int kReportOutputNodeId = 2;

TableSchema CreateTableSchema(std::string name, std::string create_table_sql,
std::vector<ColumnSchema> columns) {
Expand Down Expand Up @@ -877,13 +879,191 @@ TEST_F(FedSqlServerTest, SessionExecutesQueryAndGroupByAggregation) {
EXPECT_THAT(col_values->AsSpan<int64_t>(), UnorderedElementsAre(14, 10, 0));
}

TEST_F(FedSqlServerTest, SerializeEncryptedInputsWithoutOutputNodeIdFails) {
grpc::ClientContext init_context;
InitializeRequest request;
InitializeResponse response;
FedSqlContainerInitializeConfiguration init_config;
*init_config.mutable_agg_configuration() = DefaultConfiguration();
request.mutable_configuration()->PackFrom(init_config);
request.set_max_num_sessions(kMaxNumSessions);

ASSERT_TRUE(stub_->Initialize(&init_context, request, &response).ok());

grpc::ClientContext session_context;
SessionRequest configure_request;
SessionResponse configure_response;
configure_request.mutable_configure();

std::unique_ptr<::grpc::ClientReaderWriter<SessionRequest, SessionResponse>>
stream = stub_->Session(&session_context);
ASSERT_TRUE(stream->Write(configure_request));
ASSERT_TRUE(stream->Read(&configure_response));
auto nonce_generator =
std::make_unique<NonceGenerator>(configure_response.configure().nonce());

std::string input_col_name = "foo";
std::string output_col_name = "foo_out";

MessageDecryptor decryptor;
absl::StatusOr<std::string> reencryption_public_key =
decryptor.GetPublicKey([](absl::string_view) { return ""; }, 0);
ASSERT_TRUE(reencryption_public_key.ok());
std::string ciphertext_associated_data =
BlobHeader::default_instance().SerializeAsString();

std::string message_0 = BuildSingleInt32TensorCheckpoint(input_col_name, {1});
absl::StatusOr<NonceAndCounter> nonce_0 = nonce_generator->GetNextBlobNonce();
ASSERT_TRUE(nonce_0.ok());
absl::StatusOr<Record> rewrapped_record_0 =
crypto_test_utils::CreateRewrappedRecord(
message_0, ciphertext_associated_data, response.public_key(),
nonce_0->blob_nonce, *reencryption_public_key);
ASSERT_TRUE(rewrapped_record_0.ok()) << rewrapped_record_0.status();

SessionRequest request_0;
WriteRequest* write_request_0 = request_0.mutable_write();
FedSqlContainerWriteConfiguration config = PARSE_TEXT_PROTO(R"pb(
type: AGGREGATION_TYPE_ACCUMULATE
)pb");
*write_request_0->mutable_first_request_metadata() =
GetBlobMetadataFromRecord(*rewrapped_record_0);
write_request_0->mutable_first_request_metadata()
->mutable_hpke_plus_aead_data()
->set_counter(nonce_0->counter);
write_request_0->mutable_first_request_configuration()->PackFrom(config);
write_request_0->set_commit(true);
write_request_0->set_data(
rewrapped_record_0->hpke_plus_aead_data().ciphertext());

SessionResponse response_0;

ASSERT_TRUE(stream->Write(request_0));
ASSERT_TRUE(stream->Read(&response_0));

FedSqlContainerFinalizeConfiguration finalize_config = PARSE_TEXT_PROTO(R"pb(
type: FINALIZATION_TYPE_SERIALIZE
)pb");
SessionRequest finalize_request;
SessionResponse finalize_response;
finalize_request.mutable_finalize()->mutable_configuration()->PackFrom(
finalize_config);
ASSERT_TRUE(stream->Write(finalize_request));
ASSERT_FALSE(stream->Read(&finalize_response));
grpc::Status finish_status = stream->Finish();
EXPECT_EQ(finish_status.error_code(), grpc::StatusCode::INVALID_ARGUMENT);
EXPECT_THAT(
finish_status.error_message(),
HasSubstr("No output access policy node ID set for serialized outputs"));
}

TEST_F(FedSqlServerTest,
ReportEncryptedInputsWithOutputNodeIdOutputsEncryptedResult) {
grpc::ClientContext init_context;
InitializeRequest request;
InitializeResponse response;
FedSqlContainerInitializeConfiguration init_config;
*init_config.mutable_agg_configuration() = DefaultConfiguration();
init_config.set_report_output_access_policy_node_id(kReportOutputNodeId);
request.mutable_configuration()->PackFrom(init_config);
request.set_max_num_sessions(kMaxNumSessions);

ASSERT_TRUE(stub_->Initialize(&init_context, request, &response).ok());

grpc::ClientContext session_context;
SessionRequest configure_request;
SessionResponse configure_response;
configure_request.mutable_configure();

std::unique_ptr<::grpc::ClientReaderWriter<SessionRequest, SessionResponse>>
stream = stub_->Session(&session_context);
ASSERT_TRUE(stream->Write(configure_request));
ASSERT_TRUE(stream->Read(&configure_response));
auto nonce_generator =
std::make_unique<NonceGenerator>(configure_response.configure().nonce());

std::string input_col_name = "foo";
std::string output_col_name = "foo_out";

MessageDecryptor decryptor;
absl::StatusOr<std::string> reencryption_public_key =
decryptor.GetPublicKey([](absl::string_view) { return ""; }, 0);
ASSERT_TRUE(reencryption_public_key.ok());
std::string ciphertext_associated_data =
BlobHeader::default_instance().SerializeAsString();

std::string message_0 = BuildSingleInt32TensorCheckpoint(input_col_name, {1});
absl::StatusOr<NonceAndCounter> nonce_0 = nonce_generator->GetNextBlobNonce();
ASSERT_TRUE(nonce_0.ok());
absl::StatusOr<Record> rewrapped_record_0 =
crypto_test_utils::CreateRewrappedRecord(
message_0, ciphertext_associated_data, response.public_key(),
nonce_0->blob_nonce, *reencryption_public_key);
ASSERT_TRUE(rewrapped_record_0.ok()) << rewrapped_record_0.status();

SessionRequest request_0;
WriteRequest* write_request_0 = request_0.mutable_write();
FedSqlContainerWriteConfiguration config = PARSE_TEXT_PROTO(R"pb(
type: AGGREGATION_TYPE_ACCUMULATE
)pb");
*write_request_0->mutable_first_request_metadata() =
GetBlobMetadataFromRecord(*rewrapped_record_0);
write_request_0->mutable_first_request_metadata()
->mutable_hpke_plus_aead_data()
->set_counter(nonce_0->counter);
write_request_0->mutable_first_request_configuration()->PackFrom(config);
write_request_0->set_commit(true);
write_request_0->set_data(
rewrapped_record_0->hpke_plus_aead_data().ciphertext());

SessionResponse response_0;

ASSERT_TRUE(stream->Write(request_0));
ASSERT_TRUE(stream->Read(&response_0));

FedSqlContainerFinalizeConfiguration finalize_config = PARSE_TEXT_PROTO(R"pb(
type: FINALIZATION_TYPE_REPORT
)pb");
SessionRequest finalize_request;
SessionResponse finalize_response;
finalize_request.mutable_finalize()->mutable_configuration()->PackFrom(
finalize_config);
ASSERT_TRUE(stream->Write(finalize_request));
EXPECT_TRUE(stream->Read(&finalize_response));

ASSERT_TRUE(finalize_response.has_read());
ASSERT_TRUE(finalize_response.read().finish_read());
ASSERT_GT(
finalize_response.read().first_response_metadata().total_size_bytes(), 0);
ASSERT_TRUE(finalize_response.read()
.first_response_metadata()
.has_hpke_plus_aead_data());

BlobMetadata::HpkePlusAeadMetadata result_metadata =
finalize_response.read().first_response_metadata().hpke_plus_aead_data();

BlobHeader result_header;
EXPECT_TRUE(result_header.ParseFromString(
result_metadata.ciphertext_associated_data()));
EXPECT_EQ(result_header.access_policy_node_id(), kReportOutputNodeId);
absl::StatusOr<std::string> decrypted_result =
decryptor.Decrypt(finalize_response.read().data(),
result_metadata.ciphertext_associated_data(),
result_metadata.encrypted_symmetric_key(),
result_metadata.ciphertext_associated_data(),
result_metadata.encapsulated_public_key());
ASSERT_TRUE(decrypted_result.ok()) << decrypted_result.status();
}

class FedSqlServerFederatedSumTest : public FedSqlServerTest {
public:
FedSqlServerFederatedSumTest() {
grpc::ClientContext configure_context;
InitializeRequest request;
InitializeResponse response;
FedSqlContainerInitializeConfiguration init_config;
init_config.set_serialize_output_access_policy_node_id(
kSerializeOutputNodeId);
*init_config.mutable_agg_configuration() = DefaultConfiguration();
request.mutable_configuration()->PackFrom(init_config);
request.set_max_num_sessions(kMaxNumSessions);
Expand Down Expand Up @@ -1442,6 +1622,11 @@ TEST_F(FedSqlServerFederatedSumTest,

BlobMetadata::HpkePlusAeadMetadata result_metadata =
finalize_response.read().first_response_metadata().hpke_plus_aead_data();

BlobHeader result_header;
EXPECT_TRUE(result_header.ParseFromString(
result_metadata.ciphertext_associated_data()));
EXPECT_EQ(result_header.access_policy_node_id(), kSerializeOutputNodeId);
// The decryptor with the earliest set expiration time should be able to
// decrypt the encrypted results. The later decryptor should not.
absl::StatusOr<std::string> decrypted_result =
Expand Down

0 comments on commit c0ef67d

Please sign in to comment.