diff --git a/WORKSPACE b/WORKSPACE index 5469195..91d1e84 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -122,9 +122,9 @@ http_archive( "//third_party/federated_compute:libcppbor.patch", "//third_party/federated_compute:visibility.patch", ], - sha256 = "1a5e61e54b384e404ad64034636b1a411c969bc725d85d0f567d452353c7d18c", - strip_prefix = "federated-compute-6ff27b581f3ddada0f3dff9732fb7aa43b2da827", - url = "https://github.com/google/federated-compute/archive/6ff27b581f3ddada0f3dff9732fb7aa43b2da827.tar.gz", + sha256 = "813ec78c5b28a71335b795e4313b7c6c4a17497b459ff2cd38e8a516dc403988", + strip_prefix = "federated-compute-703e249f3258e0f5e1359ec0a877f4d5164c6eea", + url = "https://github.com/google/federated-compute/archive/703e249f3258e0f5e1359ec0a877f4d5164c6eea.tar.gz", ) http_archive( diff --git a/containers/tff_server/BUILD b/containers/tff_server/BUILD new file mode 100644 index 0000000..575f60e --- /dev/null +++ b/containers/tff_server/BUILD @@ -0,0 +1,91 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cc_library( + name = "confidential_transform_server", + srcs = ["confidential_transform_server.cc"], + hdrs = ["confidential_transform_server.h"], + deps = [ + "//containers:blob_metadata", + "//containers:confidential_transform_server_base", + "//containers:crypto", + "//containers:session", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_protobuf//:protobuf", + "@federated-compute//fcp/base", + "@federated-compute//fcp/base:status_converters", + "@federated-compute//fcp/confidentialcompute:crypto", + "@federated-compute//fcp/confidentialcompute:tff_execution_helper", + "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_grpc", + "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_proto", + "@federated-compute//fcp/protos/confidentialcompute:tff_config_cc_proto", + "@oak//proto/containers:orchestrator_crypto_cc_grpc", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/aggregation/core:tensor", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/aggregation/protocol:federated_compute_checkpoint_parser", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/aggregation/tensorflow:converters", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/executors:executor", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/executors:federating_executor", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/executors:reference_resolving_executor", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/executors:tensorflow_executor", + "@org_tensorflow_federated//tensorflow_federated/proto/v0:executor_cc_proto", + ], +) + +cc_test( + name = "confidential_transform_server_test", + size = "small", + srcs = ["confidential_transform_server_test.cc"], + data = [ + "//containers/tff_server:testing/client_data_function.txtpb", + "//containers/tff_server:testing/no_argument_function.txtpb", + "//containers/tff_server:testing/server_data_function.txtpb", + ], + deps = [ + ":confidential_transform_server", + "//containers:blob_metadata", + "//containers:confidential_transform_server_base", + "//containers:crypto", + "//containers:session", + "//testing:parse_text_proto", + "@com_github_grpc_grpc//:grpc++", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_protobuf//:protobuf", + "@federated-compute//fcp/base", + "@federated-compute//fcp/base:status_converters", + "@federated-compute//fcp/confidentialcompute:crypto", + "@federated-compute//fcp/confidentialcompute:tff_execution_helper", + "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_grpc", + "@federated-compute//fcp/protos/confidentialcompute:confidential_transform_cc_proto", + "@federated-compute//fcp/protos/confidentialcompute:tff_config_cc_proto", + "@googletest//:gtest_main", + "@oak//proto/containers:orchestrator_crypto_cc_grpc", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/aggregation/core:tensor", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/aggregation/protocol:federated_compute_checkpoint_builder", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/aggregation/protocol:federated_compute_checkpoint_parser", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/aggregation/tensorflow:converters", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/aggregation/testing:test_data", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/executors:executor", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/executors:federating_executor", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/executors:reference_resolving_executor", + "@org_tensorflow_federated//tensorflow_federated/cc/core/impl/executors:tensorflow_executor", + "@org_tensorflow_federated//tensorflow_federated/proto/v0:executor_cc_proto", + ], +) diff --git a/containers/tff_server/confidential_transform_server.cc b/containers/tff_server/confidential_transform_server.cc new file mode 100644 index 0000000..cca86e7 --- /dev/null +++ b/containers/tff_server/confidential_transform_server.cc @@ -0,0 +1,275 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "containers/tff_server/confidential_transform_server.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/die_if_null.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "containers/blob_metadata.h" +#include "containers/crypto.h" +#include "containers/session.h" +#include "fcp/base/status_converters.h" +#include "fcp/confidentialcompute/crypto.h" +#include "fcp/confidentialcompute/tff_execution_helper.h" +#include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" +#include "fcp/protos/confidentialcompute/confidential_transform.pb.h" +#include "fcp/protos/confidentialcompute/tff_config.pb.h" +#include "google/protobuf/repeated_ptr_field.h" +#include "grpcpp/support/status.h" +#include "tensorflow_federated/cc/core/impl/aggregation/base/monitoring.h" +#include "tensorflow_federated/cc/core/impl/aggregation/protocol/federated_compute_checkpoint_parser.h" +#include "tensorflow_federated/cc/core/impl/aggregation/tensorflow/converters.h" +#include "tensorflow_federated/cc/core/impl/executors/executor.h" +#include "tensorflow_federated/cc/core/impl/executors/federating_executor.h" +#include "tensorflow_federated/cc/core/impl/executors/reference_resolving_executor.h" +#include "tensorflow_federated/cc/core/impl/executors/tensor_serialization.h" +#include "tensorflow_federated/cc/core/impl/executors/tensorflow_executor.h" + +namespace confidential_federated_compute::tff_server { + +using ::fcp::confidentialcompute::BlobHeader; +using ::fcp::confidentialcompute::BlobMetadata; +using ::fcp::confidentialcompute::FinalizeRequest; +using ::fcp::confidentialcompute::ReadResponse; +using ::fcp::confidentialcompute::Record; +using ::fcp::confidentialcompute::SessionResponse; +using ::fcp::confidentialcompute::TffSessionConfig; +using ::fcp::confidentialcompute::TffSessionWriteConfig; +using ::fcp::confidentialcompute::WriteRequest; +using ::tensorflow_federated::aggregation::CheckpointParser; +using ::tensorflow_federated::aggregation:: + FederatedComputeCheckpointParserFactory; + +absl::Status TffSession::ConfigureSession( + fcp::confidentialcompute::SessionRequest configure_request) { + if (child_executor_ != nullptr) { + return absl::FailedPreconditionError("Session already configured."); + } + + TffSessionConfig session_config; + if (!configure_request.has_configure() || + !configure_request.configure().configuration().UnpackTo( + &session_config)) { + return absl::InvalidArgumentError("TffSessionConfig invalid."); + } + + auto leaf_executor_fn = []() { + return tensorflow_federated::CreateReferenceResolvingExecutor( + tensorflow_federated::CreateTensorFlowExecutor()); + }; + tensorflow_federated::CardinalityMap cardinality_map; + cardinality_map[tensorflow_federated::kClientsUri] = + session_config.num_clients(); + FCP_ASSIGN_OR_RETURN( + auto federating_executor, + tensorflow_federated::CreateFederatingExecutor( + /*server_child=*/leaf_executor_fn(), + /*client_child=*/leaf_executor_fn(), cardinality_map)); + child_executor_ = tensorflow_federated::CreateReferenceResolvingExecutor( + federating_executor); + function_ = std::move(session_config.function()); + if (session_config.has_initial_arg()) { + argument_ = std::move(session_config.initial_arg()); + } + output_access_policy_node_id_ = session_config.output_access_policy_node_id(); + + return absl::OkStatus(); +} + +absl::StatusOr TffSession::ParseData( + const std::string& uri, std::string unencrypted_data, + int64_t total_size_bytes) { + tensorflow_federated::v0::Value value; + if (!value.ParseFromString(unencrypted_data)) { + return ToSessionWriteFinishedResponse(absl::InvalidArgumentError( + "Failed to deserialize the data to a TFF Value.")); + } + auto [it, inserted] = data_by_uri_.insert({uri, std::move(value)}); + if (!inserted) { + return ToSessionWriteFinishedResponse(absl::FailedPreconditionError( + "Data corresponding to URI already written to session.")); + } + return ToSessionWriteFinishedResponse(absl::OkStatus(), total_size_bytes); +} + +absl::StatusOr TffSession::ParseClientData( + const std::string& uri, std::string unencrypted_data, + int64_t total_size_bytes) { + FederatedComputeCheckpointParserFactory parser_factory; + absl::StatusOr> parser = + parser_factory.Create(absl::Cord(std::move(unencrypted_data))); + if (!parser.ok()) { + return ToSessionWriteFinishedResponse(absl::Status( + parser.status().code(), + absl::StrCat("Failed to deserialize the federated compute checkpoint. ", + parser.status().message()))); + } + auto [it, inserted] = + client_checkpoint_parser_by_uri_.insert({uri, std::move(parser.value())}); + if (!inserted) { + return ToSessionWriteFinishedResponse(absl::FailedPreconditionError( + "Data corresponding to URI already written to session.")); + } + return ToSessionWriteFinishedResponse(absl::OkStatus(), total_size_bytes); +} + +absl::StatusOr TffSession::SessionWrite( + const WriteRequest& write_request, std::string unencrypted_data) { + if (child_executor_ == nullptr) { + return absl::FailedPreconditionError( + "Session must be configured before data can be written."); + } + + TffSessionWriteConfig write_config; + if (!write_request.has_first_request_configuration() || + !write_request.first_request_configuration().UnpackTo(&write_config)) { + return ToSessionWriteFinishedResponse( + absl::InvalidArgumentError("Failed to parse TffSessionWriteConfig.")); + } + + if (write_config.client_upload()) { + return ParseClientData( + write_config.uri(), std::move(unencrypted_data), + write_request.first_request_metadata().total_size_bytes()); + } + + return ParseData(write_config.uri(), std::move(unencrypted_data), + write_request.first_request_metadata().total_size_bytes()); +} + +absl::StatusOr TffSession::FetchData( + const std::string& uri) { + auto data = data_by_uri_.find(uri); + if (data == data_by_uri_.end()) { + return absl::InvalidArgumentError( + "Data in argument was not provided to the transform."); + } + return data->second; +} + +absl::StatusOr TffSession::FetchClientData( + const std::string& uri, const std::string& key) { + auto parser = client_checkpoint_parser_by_uri_.find(uri); + if (parser == client_checkpoint_parser_by_uri_.end()) { + return absl::InvalidArgumentError( + "Data in argument was not provided to the transform."); + } + // Note that each key can only be accessed a single time from the parser. So, + // this relies on the fact that a given uri, key pair will only appear once in + // the input argument. + absl::StatusOr agg_tensor = + parser->second->GetTensor(key); + if (!agg_tensor.ok()) { + return absl::Status( + agg_tensor.status().code(), + absl::StrCat("Invalid tensor name. ", agg_tensor.status().message())); + } + absl::StatusOr tensor = + tensorflow_federated::aggregation::tensorflow::ToTfTensor( + std::move(*agg_tensor)); + if (!tensor.ok()) { + return absl::Status( + tensor.status().code(), + absl::StrCat("Invalid tensor data. ", tensor.status().message())); + } + tensorflow_federated::v0::Value value; + FCP_RETURN_IF_ERROR( + tensorflow_federated::SerializeTensorValue(std::move(*tensor), &value)); + + return value; +} + +absl::StatusOr TffSession::FinalizeSession( + const FinalizeRequest& request, const BlobMetadata& input_metadata) { + if (child_executor_ == nullptr) { + return absl::FailedPreconditionError( + "Session must be configured before it can be finalized."); + } + + FCP_ASSIGN_OR_RETURN(tensorflow_federated::OwnedValueId fn_handle, + child_executor_->CreateValue(function_)); + + std::optional optional_arg_handle; + if (argument_.has_value()) { + FCP_ASSIGN_OR_RETURN( + tensorflow_federated::v0::Value replaced_arg, + fcp::confidential_compute::ReplaceDatas( + *argument_, + [this](std::string uri) { return this->FetchData(uri); }, + [this](std::string uri, std::string key) { + return this->FetchClientData(uri, key); + })) + .replaced_value; + FCP_ASSIGN_OR_RETURN( + std::shared_ptr arg_handle, + fcp::confidential_compute::Embed(replaced_arg, child_executor_)); + optional_arg_handle = std::move(*arg_handle); + } + + FCP_ASSIGN_OR_RETURN( + tensorflow_federated::OwnedValueId call_handle, + child_executor_->CreateCall(fn_handle, optional_arg_handle)); + tensorflow_federated::v0::Value call_result; + FCP_RETURN_IF_ERROR(child_executor_->Materialize(call_handle, &call_result)); + std::string unencrypted_result = call_result.SerializeAsString(); + + // If all inputs are unencrypted, output result can be unencrypted. + if (input_metadata.has_unencrypted()) { + BlobMetadata result_metadata; + result_metadata.set_total_size_bytes(unencrypted_result.size()); + result_metadata.mutable_unencrypted(); + SessionResponse unencrypted_response; + ReadResponse* unencrypted_read_response = + unencrypted_response.mutable_read(); + unencrypted_read_response->set_finish_read(true); + *(unencrypted_read_response->mutable_data()) = + std::move(unencrypted_result); + *(unencrypted_read_response->mutable_first_response_metadata()) = + std::move(result_metadata); + return unencrypted_response; + } + + RecordEncryptor encryptor; + BlobHeader previous_header; + if (!previous_header.ParseFromString( + input_metadata.hpke_plus_aead_data().ciphertext_associated_data())) { + return absl::InvalidArgumentError( + "Failed to parse the BlobHeader when trying to encrypt outputs."); + } + FCP_ASSIGN_OR_RETURN( + Record result_record, + encryptor.EncryptRecord(unencrypted_result, + input_metadata.hpke_plus_aead_data() + .rewrapped_symmetric_key_associated_data() + .reencryption_public_key(), + previous_header.access_policy_sha256(), + output_access_policy_node_id_)); + SessionResponse response; + ReadResponse* read_response = response.mutable_read(); + read_response->set_finish_read(true); + *(read_response->mutable_data()) = + std::move(result_record.hpke_plus_aead_data().ciphertext()); + *(read_response->mutable_first_response_metadata()) = + GetBlobMetadataFromRecord(result_record); + return response; +} +} // namespace confidential_federated_compute::tff_server diff --git a/containers/tff_server/confidential_transform_server.h b/containers/tff_server/confidential_transform_server.h new file mode 100644 index 0000000..a956eb3 --- /dev/null +++ b/containers/tff_server/confidential_transform_server.h @@ -0,0 +1,114 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_TFF_SERVER_CONFIDENTIAL_TRANSFORM_SERVER_H_ +#define CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_TFF_SERVER_CONFIDENTIAL_TRANSFORM_SERVER_H_ + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/log/die_if_null.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "containers/confidential_transform_server_base.h" +#include "containers/crypto.h" +#include "containers/session.h" +#include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" +#include "fcp/protos/confidentialcompute/confidential_transform.pb.h" +#include "fcp/protos/confidentialcompute/tff_config.pb.h" +#include "google/protobuf/repeated_ptr_field.h" +#include "grpcpp/server_context.h" +#include "grpcpp/support/status.h" +#include "proto/containers/orchestrator_crypto.grpc.pb.h" +#include "tensorflow_federated/cc/core/impl/aggregation/protocol/federated_compute_checkpoint_parser.h" +#include "tensorflow_federated/cc/core/impl/executors/executor.h" +#include "tensorflow_federated/proto/v0/executor.pb.h" + +namespace confidential_federated_compute::tff_server { + +// TFF implementation of Session interface. Not threadsafe. +class TffSession final : public confidential_federated_compute::Session { + public: + TffSession() {}; + + // Configure the session with the computation to be run, the initial + // argument, and create the TFF executor with the specified number of clients. + absl::Status ConfigureSession( + fcp::confidentialcompute::SessionRequest configure_request) override; + // Adds a data blob from a given URI into the session and parses the data into + // a TFF value. + absl::StatusOr SessionWrite( + const fcp::confidentialcompute::WriteRequest& write_request, + std::string unencrypted_data) override; + // Resolves all data URIs in the initial argument, embeds the argument into + // the TFF stack, executes the computation, and encrypts and outputs the + // result. + absl::StatusOr FinalizeSession( + const fcp::confidentialcompute::FinalizeRequest& request, + const fcp::confidentialcompute::BlobMetadata& input_metadata) override; + + private: + absl::StatusOr FetchData( + const std::string& uri); + absl::StatusOr FetchClientData( + const std::string& uri, const std::string& key); + absl::StatusOr ParseData( + const std::string& uri, std::string unencrypted_data, + int64_t total_size_bytes); + absl::StatusOr ParseClientData( + const std::string& uri, std::string unencrypted_data, + int64_t total_size_bytes); + + tensorflow_federated::v0::Value function_; + std::optional argument_ = std::nullopt; + uint32_t output_access_policy_node_id_; + // TFF executor to which lambda computations can be delegated after + // Data values have been resolved. + std::shared_ptr child_executor_; + // Map of URI to TFF Values that have been added into the session. + absl::flat_hash_map + data_by_uri_; + // Map of URI to ClientCheckpointParsers that contain client data uploads that + // have been added into the session. + absl::flat_hash_map< + std::string, + std::unique_ptr> + client_checkpoint_parser_by_uri_; +}; + +// ConfidentialTransform service for Tensorflow Federated. +class TffConfidentialTransform final + : public confidential_federated_compute::ConfidentialTransformBase { + public: + TffConfidentialTransform( + oak::containers::v1::OrchestratorCrypto::StubInterface* crypto_stub) + : ConfidentialTransformBase(crypto_stub) {}; + + protected: + // No transform specific initialization for TFF. + virtual absl::StatusOr InitializeTransform( + const fcp::confidentialcompute::InitializeRequest* request) override { + return absl::OkStatus(); + }; + + virtual absl::StatusOr< + std::unique_ptr> + CreateSession() override { + return std::make_unique(); + } +}; + +} // namespace confidential_federated_compute::tff_server + +#endif // CONFIDENTIAL_FEDERATED_COMPUTE_CONTAINERS_TFF_SERVER_CONFIDENTIAL_TRANSFORM_SERVER_H_ diff --git a/containers/tff_server/confidential_transform_server_test.cc b/containers/tff_server/confidential_transform_server_test.cc new file mode 100644 index 0000000..0a0d364 --- /dev/null +++ b/containers/tff_server/confidential_transform_server_test.cc @@ -0,0 +1,1040 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "containers/tff_server/confidential_transform_server.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/die_if_null.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "containers/blob_metadata.h" +#include "containers/crypto.h" +#include "containers/session.h" +#include "fcp/base/status_converters.h" +#include "fcp/confidentialcompute/crypto.h" +#include "fcp/confidentialcompute/tff_execution_helper.h" +#include "fcp/protos/confidentialcompute/confidential_transform.grpc.pb.h" +#include "fcp/protos/confidentialcompute/confidential_transform.pb.h" +#include "fcp/protos/confidentialcompute/file_info.pb.h" +#include "fcp/protos/confidentialcompute/tff_config.pb.h" +#include "gmock/gmock.h" +#include "google/protobuf/repeated_ptr_field.h" +#include "grpcpp/support/status.h" +#include "gtest/gtest.h" +#include "tensorflow_federated/cc/core/impl/aggregation/base/monitoring.h" +#include "tensorflow_federated/cc/core/impl/aggregation/protocol/federated_compute_checkpoint_builder.h" +#include "tensorflow_federated/cc/core/impl/aggregation/protocol/federated_compute_checkpoint_parser.h" +#include "tensorflow_federated/cc/core/impl/aggregation/tensorflow/converters.h" +#include "tensorflow_federated/cc/core/impl/aggregation/testing/test_data.h" +#include "tensorflow_federated/cc/core/impl/executors/executor.h" +#include "tensorflow_federated/cc/core/impl/executors/federating_executor.h" +#include "tensorflow_federated/cc/core/impl/executors/reference_resolving_executor.h" +#include "tensorflow_federated/cc/core/impl/executors/tensor_serialization.h" +#include "tensorflow_federated/cc/core/impl/executors/tensorflow_executor.h" +#include "tensorflow_federated/proto/v0/computation.pb.h" +#include "testing/parse_text_proto.h" + +namespace confidential_federated_compute::tff_server { + +namespace { + +using ::fcp::confidentialcompute::BlobHeader; +using ::fcp::confidentialcompute::BlobMetadata; +using ::fcp::confidentialcompute::FileInfo; +using ::fcp::confidentialcompute::FinalizeRequest; +using ::fcp::confidentialcompute::ReadResponse; +using ::fcp::confidentialcompute::Record; +using ::fcp::confidentialcompute::SessionRequest; +using ::fcp::confidentialcompute::SessionResponse; +using ::fcp::confidentialcompute::TffSessionConfig; +using ::fcp::confidentialcompute::TffSessionWriteConfig; +using ::fcp::confidentialcompute::WriteRequest; +using ::tensorflow_federated::aggregation::CheckpointBuilder; +using ::tensorflow_federated::aggregation::CheckpointParser; +using ::tensorflow_federated::aggregation::CreateTestData; +using ::tensorflow_federated::aggregation::DataType; +using ::tensorflow_federated::aggregation:: + FederatedComputeCheckpointBuilderFactory; +using ::tensorflow_federated::aggregation:: + FederatedComputeCheckpointParserFactory; +using ::tensorflow_federated::aggregation::Tensor; +using ::tensorflow_federated::aggregation::TensorShape; +using ::tensorflow_federated::v0::Value; +using ::testing::HasSubstr; + +constexpr absl::string_view kNoArgumentComputationPath = + "containers/tff_server/testing/no_argument_function.txtpb"; +constexpr absl::string_view kServerDataComputationPath = + "containers/tff_server/testing/server_data_function.txtpb"; +constexpr absl::string_view kClientDataComputationPath = + "containers/tff_server/testing/client_data_function.txtpb"; + +absl::StatusOr LoadFileAsTffValue(absl::string_view path) { + // Before creating the std::ifstream, convert the absl::string_view to + // std::string. + std::string path_str(path); + std::ifstream file_istream(path_str); + if (!file_istream) { + return absl::FailedPreconditionError("Error loading file: " + path_str); + } + std::stringstream file_stream; + file_stream << file_istream.rdbuf(); + tensorflow_federated::v0::Computation computation; + if (!google::protobuf::TextFormat::ParseFromString( + std::move(file_stream.str()), &computation)) { + return absl::InvalidArgumentError( + "Error parsing TFF Computation from file."); + } + Value value; + *value.mutable_computation() = std::move(computation); + return value; +} + +std::string BuildClientCheckpoint(std::initializer_list input_values) { + FederatedComputeCheckpointBuilderFactory builder_factory; + std::unique_ptr ckpt_builder = builder_factory.Create(); + + absl::StatusOr t1 = + Tensor::Create(DataType::DT_INT32, + TensorShape({static_cast(input_values.size())}), + CreateTestData(input_values)); + CHECK_OK(t1); + CHECK_OK(ckpt_builder->Add("key", *t1)); + auto checkpoint = ckpt_builder->Build(); + CHECK_OK(checkpoint.status()); + + std::string checkpoint_string; + absl::CopyCordToString(*checkpoint, &checkpoint_string); + return checkpoint_string; +} + +WriteRequest CreateDefaultWriteRequest(std::string data, bool client_upload) { + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + metadata.set_total_size_bytes(data.size()); + TffSessionWriteConfig config; + config.set_uri("uri"); + config.set_client_upload(client_upload); + WriteRequest write_request; + *write_request.mutable_first_request_metadata() = metadata; + write_request.mutable_first_request_configuration()->PackFrom(config); + write_request.set_commit(true); + write_request.set_data(data); + return write_request; +} + +WriteRequest CreateDefaultClientWriteRequest(std::string data) { + return CreateDefaultWriteRequest(data, true); +} + +WriteRequest CreateWriteRequest(std::string data, FileInfo file_info) { + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + metadata.set_total_size_bytes(data.size()); + TffSessionWriteConfig config; + config.set_uri(file_info.uri()); + config.set_client_upload(file_info.client_upload()); + WriteRequest write_request; + *write_request.mutable_first_request_metadata() = metadata; + write_request.mutable_first_request_configuration()->PackFrom(config); + write_request.set_commit(true); + write_request.set_data(data); + return write_request; +} + +TffSessionConfig DefaultSessionConfiguration() { + TffSessionConfig config = PARSE_TEXT_PROTO(R"pb( + num_clients: 3 + output_access_policy_node_id: 3 + )pb"); + *config.mutable_function() = *LoadFileAsTffValue(kNoArgumentComputationPath); + return config; +} + +TffSessionConfig CreateSessionConfiguration(Value function, Value argument, + int num_clients) { + TffSessionConfig config = PARSE_TEXT_PROTO(R"pb( + output_access_policy_node_id: 3 + )pb"); + *config.mutable_function() = std::move(function); + *config.mutable_initial_arg() = std::move(argument); + config.set_num_clients(num_clients); + return config; +} + +TEST(TffSessionTest, ConfigureSessionSuccess) { + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + DefaultSessionConfiguration()); + ASSERT_TRUE(session.ConfigureSession(request).ok()); +} + +TEST(TffSessionTest, SessionAlreadyConfiguredFailure) { + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + DefaultSessionConfiguration()); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + auto status = session.ConfigureSession(request); + ASSERT_EQ(status.code(), absl::StatusCode::kFailedPrecondition); + ASSERT_THAT(status.message(), HasSubstr("Session already configured.")); +} + +TEST(TffSessionTest, InvalidConfigureRequestFailure) { + TffSession session; + SessionRequest request; + auto status = session.ConfigureSession(request); + ASSERT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + ASSERT_THAT(status.message(), HasSubstr("TffSessionConfig invalid.")); +} + +TEST(TffSessionTest, WriteBeforeConfigureFailure) { + TffSession session; + auto status = + session.SessionWrite(CreateDefaultClientWriteRequest("data"), "data") + .status(); + ASSERT_EQ(status.code(), absl::StatusCode::kFailedPrecondition); + ASSERT_THAT(status.message(), HasSubstr("Session must be configured before")); +} + +TEST(TffSessionTest, InvalidWriteConfigSessionWriteSuccessButDataIgnored) { + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + DefaultSessionConfiguration()); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + WriteRequest write_request = CreateDefaultClientWriteRequest("data"); + write_request.clear_first_request_configuration(); + auto response = session.SessionWrite(write_request, "data").value(); + ASSERT_EQ(response.write().committed_size_bytes(), 0); + ASSERT_EQ(response.write().status().code(), + grpc::StatusCode::INVALID_ARGUMENT); + ASSERT_THAT(response.write().status().message(), + HasSubstr("Failed to parse TffSessionWriteConfig.")); +} + +TEST(TffSessionTest, WriteInvalidDataSessionWriteSuccessButDataIgnored) { + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + DefaultSessionConfiguration()); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + WriteRequest write_request = CreateDefaultWriteRequest("data", false); + auto response = session.SessionWrite(write_request, "data").value(); + ASSERT_EQ(response.write().committed_size_bytes(), 0); + ASSERT_EQ(response.write().status().code(), + grpc::StatusCode::INVALID_ARGUMENT); + ASSERT_THAT(response.write().status().message(), + HasSubstr("Failed to deserialize the data")); +} + +TEST(TffSessionTest, + WriteInvalidClientCheckpointWriteSessionSuccessButDataIgnored) { + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + DefaultSessionConfiguration()); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + WriteRequest write_request = CreateDefaultClientWriteRequest("data"); + auto response = session.SessionWrite(write_request, "data").value(); + ASSERT_EQ(response.write().committed_size_bytes(), 0); + ASSERT_EQ(response.write().status().code(), + grpc::StatusCode::INVALID_ARGUMENT); + ASSERT_THAT( + response.write().status().message(), + HasSubstr("Failed to deserialize the federated compute checkpoint.")); +} + +TEST(TffSessionTest, WriteDataSameUriWriteSessionSuccessButDataIgnored) { + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + DefaultSessionConfiguration()); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + Value data; + data.mutable_computation() + ->mutable_literal() + ->mutable_value() + ->mutable_int32_list() + ->add_value(20); + std::string data_string = data.SerializeAsString(); + WriteRequest write_request = CreateDefaultWriteRequest(data_string, false); + auto response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), 0); + ASSERT_EQ(response.write().status().code(), + grpc::StatusCode::FAILED_PRECONDITION); + ASSERT_THAT( + response.write().status().message(), + HasSubstr("Data corresponding to URI already written to session.")); +} + +TEST(TffSessionTest, WriteClientDataSameUriWriteSessionSuccessButDataIgnored) { + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + DefaultSessionConfiguration()); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + std::string data_string = BuildClientCheckpoint({1}); + WriteRequest write_request = CreateDefaultClientWriteRequest(data_string); + auto response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), 0); + ASSERT_EQ(response.write().status().code(), + grpc::StatusCode::FAILED_PRECONDITION); + ASSERT_THAT( + response.write().status().message(), + HasSubstr("Data corresponding to URI already written to session.")); +} + +TEST(TffSessionTest, WriteDataSuccess) { + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + DefaultSessionConfiguration()); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + Value data; + data.mutable_computation() + ->mutable_literal() + ->mutable_value() + ->mutable_int32_list() + ->add_value(20); + std::string data_string = data.SerializeAsString(); + WriteRequest write_request = CreateDefaultWriteRequest(data_string, false); + auto response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); +} + +TEST(TffSessionTest, WriteClientDataSuccess) { + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + DefaultSessionConfiguration()); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + std::string data_string = BuildClientCheckpoint({1}); + WriteRequest write_request = CreateDefaultClientWriteRequest(data_string); + auto response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); +} + +TEST(TffSessionTest, FinalizeBeforeConfigureFailure) { + TffSession session; + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + FinalizeRequest request; + auto status = session.FinalizeSession(request, metadata).status(); + ASSERT_EQ(status.code(), absl::StatusCode::kFailedPrecondition); + ASSERT_THAT(status.message(), HasSubstr("Session must be configured")); +} + +TEST(TffSessionTest, FinalizeInvalidFunctionFailure) { + TffSession session; + SessionRequest request; + TffSessionConfig config = DefaultSessionConfiguration(); + config.clear_function(); + request.mutable_configure()->mutable_configuration()->PackFrom(config); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + FinalizeRequest finalize_request; + auto status = session.FinalizeSession(finalize_request, metadata).status(); + ASSERT_EQ(status.code(), absl::StatusCode::kUnimplemented); +} + +TEST(TffSessionTest, FinalizeWithoutArgumentSuccess) { + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + DefaultSessionConfiguration()); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + FinalizeRequest finalize_request; + auto result = session.FinalizeSession(finalize_request, metadata).value(); + ReadResponse read_response = result.read(); + Value value; + value.ParseFromString(read_response.data()); + tensorflow::Tensor output_tensor = + tensorflow_federated::DeserializeTensorValue(value.federated().value(0)) + .value(); + EXPECT_EQ(output_tensor.NumElements(), 1); + auto flat = output_tensor.unaligned_flat(); + EXPECT_EQ(flat(0), 10) << flat(0); +} + +TEST(TffSessionTest, FinalizeMultipleDataInputsSuccess) { + FileInfo file_info_1; + file_info_1.set_uri("server1"); + FileInfo file_info_2; + file_info_2.set_uri("server2"); + FileInfo file_info_3; + file_info_3.set_uri("server3"); + file_info_3.set_key("key3"); + + // Create Function + Value function = *LoadFileAsTffValue(kServerDataComputationPath); + + // Create Argument + Value argument; + argument.mutable_federated() + ->mutable_type() + ->mutable_placement() + ->mutable_value() + ->set_uri("server"); + argument.mutable_federated()->mutable_type()->set_all_equal(true); + argument.mutable_federated() + ->add_value() + ->mutable_computation() + ->mutable_data() + ->mutable_content() + ->PackFrom(file_info_3); + + // Configure Session + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + CreateSessionConfiguration(std::move(function), std::move(argument), 3)); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + // Write Data to Session + Value data1; + data1.mutable_computation() + ->mutable_literal() + ->mutable_value() + ->mutable_int32_list() + ->add_value(10); + std::string data_string = data1.SerializeAsString(); + WriteRequest write_request = CreateWriteRequest(data_string, file_info_1); + auto response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + + Value data2; + data2.mutable_computation() + ->mutable_literal() + ->mutable_value() + ->mutable_int32_list() + ->add_value(20); + data_string = data2.SerializeAsString(); + write_request = CreateWriteRequest(data_string, file_info_2); + response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + + Value data3; + data3.mutable_computation() + ->mutable_literal() + ->mutable_value() + ->mutable_int32_list() + ->add_value(30); + Value data4; + data4.mutable_computation() + ->mutable_literal() + ->mutable_value() + ->mutable_int32_list() + ->add_value(-1); + Value data5; + auto data_struct = data5.mutable_struct_(); + auto struct_v1 = data_struct->add_element(); + struct_v1->set_name("key3"); + *struct_v1->mutable_value() = data3; + auto struct_v2 = data_struct->add_element(); + struct_v2->set_name("foo"); + *struct_v2->mutable_value() = data4; + data_string = data5.SerializeAsString(); + write_request = CreateWriteRequest(data_string, file_info_3); + response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + + // Finalize Session + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + FinalizeRequest finalize_request; + auto result = session.FinalizeSession(finalize_request, metadata).value(); + ReadResponse read_response = result.read(); + Value value; + value.ParseFromString(read_response.data()); + tensorflow::Tensor output_tensor = + tensorflow_federated::DeserializeTensorValue( + value.struct_().element(0).value().federated().value(0)) + .value(); + EXPECT_EQ(output_tensor.NumElements(), 1); + auto flat = output_tensor.unaligned_flat(); + // 30 scaled by 10 = 300 + EXPECT_EQ(flat(0), 300) << flat(0); + output_tensor = tensorflow_federated::DeserializeTensorValue( + value.struct_().element(1).value().federated().value(0)) + .value(); + EXPECT_EQ(output_tensor.NumElements(), 1); + auto flat2 = output_tensor.unaligned_flat(); + // 30 scaled by 10 * num clients = 900 + EXPECT_EQ(flat2(0), 900) << flat2(0); +} + +TEST(TffSessionTest, FinalizeMultipleClientDataInputsSuccess) { + FileInfo file_info_1; + file_info_1.set_uri("client1"); + file_info_1.set_key("key"); + file_info_1.set_client_upload(true); + FileInfo file_info_2; + file_info_2.set_uri("client2"); + file_info_2.set_key("key"); + file_info_2.set_client_upload(true); + FileInfo file_info_3; + file_info_3.set_uri("client3"); + file_info_3.set_key("key"); + file_info_3.set_client_upload(true); + + // Create Function + Value function = *LoadFileAsTffValue(kClientDataComputationPath); + + // Create Argument + Value argument; + argument.mutable_federated() + ->mutable_type() + ->mutable_placement() + ->mutable_value() + ->set_uri("clients"); + argument.mutable_federated()->mutable_type()->set_all_equal(false); + argument.mutable_federated() + ->add_value() + ->mutable_computation() + ->mutable_data() + ->mutable_content() + ->PackFrom(file_info_1); + argument.mutable_federated() + ->add_value() + ->mutable_computation() + ->mutable_data() + ->mutable_content() + ->PackFrom(file_info_2); + argument.mutable_federated() + ->add_value() + ->mutable_computation() + ->mutable_data() + ->mutable_content() + ->PackFrom(file_info_3); + + // Configure Session + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + CreateSessionConfiguration(std::move(function), std::move(argument), 3)); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + // Write Data to Session + std::string data_string = BuildClientCheckpoint({10}); + WriteRequest write_request = CreateWriteRequest(data_string, file_info_1); + auto response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + + data_string = BuildClientCheckpoint({20}); + write_request = CreateWriteRequest(data_string, file_info_2); + response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + + data_string = BuildClientCheckpoint({30}); + write_request = CreateWriteRequest(data_string, file_info_3); + response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + + // Finalize Session + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + FinalizeRequest finalize_request; + auto result = session.FinalizeSession(finalize_request, metadata).value(); + ReadResponse read_response = result.read(); + Value value; + value.ParseFromString(read_response.data()); + tensorflow::Tensor output_tensor = + tensorflow_federated::DeserializeTensorValue(value.federated().value(0)) + .value(); + EXPECT_EQ(output_tensor.NumElements(), 1); + auto flat = output_tensor.unaligned_flat(); + EXPECT_EQ(flat(0), 60) << flat(0); +} + +TEST(TffSessionTest, FinalizeIgnoresInvalidDataInputsSuccess) { + FileInfo file_info_1; + file_info_1.set_uri("server1"); + FileInfo file_info_2; + file_info_2.set_uri("server2"); + FileInfo file_info_3; + file_info_3.set_uri("server3"); + file_info_3.set_key("key3"); + + // Create Function + Value function = *LoadFileAsTffValue(kServerDataComputationPath); + + // Create Argument + Value argument; + argument.mutable_federated() + ->mutable_type() + ->mutable_placement() + ->mutable_value() + ->set_uri("server"); + argument.mutable_federated()->mutable_type()->set_all_equal(true); + argument.mutable_federated() + ->add_value() + ->mutable_computation() + ->mutable_data() + ->mutable_content() + ->PackFrom(file_info_3); + + // Configure Session + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + CreateSessionConfiguration(std::move(function), std::move(argument), 3)); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + // Write Data to Session + Value data1; + data1.mutable_computation() + ->mutable_literal() + ->mutable_value() + ->mutable_int32_list() + ->add_value(10); + std::string data_string = data1.SerializeAsString(); + WriteRequest write_request = CreateWriteRequest(data_string, file_info_1); + auto response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + + data_string = "invalid data"; + write_request = CreateWriteRequest(data_string, file_info_2); + response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), 0); + ASSERT_EQ(response.write().status().code(), + grpc::StatusCode::INVALID_ARGUMENT); + + Value data3; + data3.mutable_computation() + ->mutable_literal() + ->mutable_value() + ->mutable_int32_list() + ->add_value(30); + Value data4; + data4.mutable_computation() + ->mutable_literal() + ->mutable_value() + ->mutable_int32_list() + ->add_value(-1); + Value data5; + auto data_struct = data5.mutable_struct_(); + auto struct_v1 = data_struct->add_element(); + struct_v1->set_name("key3"); + *struct_v1->mutable_value() = data3; + auto struct_v2 = data_struct->add_element(); + struct_v2->set_name("foo"); + *struct_v2->mutable_value() = data4; + data_string = data5.SerializeAsString(); + write_request = CreateWriteRequest(data_string, file_info_3); + response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + + // Finalize Session + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + FinalizeRequest finalize_request; + auto result = session.FinalizeSession(finalize_request, metadata).value(); + ReadResponse read_response = result.read(); + Value value; + value.ParseFromString(read_response.data()); + tensorflow::Tensor output_tensor = + tensorflow_federated::DeserializeTensorValue( + value.struct_().element(0).value().federated().value(0)) + .value(); + EXPECT_EQ(output_tensor.NumElements(), 1); + auto flat = output_tensor.unaligned_flat(); + // 30 scaled by 10 = 300 + EXPECT_EQ(flat(0), 300) << flat(0); + output_tensor = tensorflow_federated::DeserializeTensorValue( + value.struct_().element(1).value().federated().value(0)) + .value(); + EXPECT_EQ(output_tensor.NumElements(), 1); + auto flat2 = output_tensor.unaligned_flat(); + // 30 scaled by 10 * num clients = 900 + EXPECT_EQ(flat2(0), 900) << flat2(0); +} + +TEST(TffSessionTest, FinalizeIgnoresInvalidClientDataInputsSuccess) { + FileInfo file_info_1; + file_info_1.set_uri("client1"); + file_info_1.set_key("key"); + file_info_1.set_client_upload(true); + FileInfo file_info_2; + file_info_2.set_uri("client2"); + file_info_2.set_key("key"); + file_info_2.set_client_upload(true); + FileInfo file_info_3; + file_info_3.set_uri("client3"); + file_info_3.set_key("key"); + file_info_3.set_client_upload(true); + + // Create Function + Value function = *LoadFileAsTffValue(kClientDataComputationPath); + + // Create Argument + Value argument; + argument.mutable_federated() + ->mutable_type() + ->mutable_placement() + ->mutable_value() + ->set_uri("clients"); + argument.mutable_federated()->mutable_type()->set_all_equal(false); + argument.mutable_federated() + ->add_value() + ->mutable_computation() + ->mutable_data() + ->mutable_content() + ->PackFrom(file_info_1); + argument.mutable_federated() + ->add_value() + ->mutable_computation() + ->mutable_data() + ->mutable_content() + ->PackFrom(file_info_2); + + // Configure Session + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + CreateSessionConfiguration(std::move(function), std::move(argument), 2)); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + // Write Data to Session + std::string data_string = BuildClientCheckpoint({10}); + WriteRequest write_request = CreateWriteRequest(data_string, file_info_1); + auto response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + + data_string = BuildClientCheckpoint({20}); + write_request = CreateWriteRequest(data_string, file_info_2); + response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + + data_string = "Invalid data."; + write_request = CreateWriteRequest(data_string, file_info_3); + response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), 0); + ASSERT_EQ(response.write().status().code(), + grpc::StatusCode::INVALID_ARGUMENT); + + // Finalize Session + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + FinalizeRequest finalize_request; + auto result = session.FinalizeSession(finalize_request, metadata).value(); + ReadResponse read_response = result.read(); + Value value; + value.ParseFromString(read_response.data()); + tensorflow::Tensor output_tensor = + tensorflow_federated::DeserializeTensorValue(value.federated().value(0)) + .value(); + EXPECT_EQ(output_tensor.NumElements(), 1); + auto flat = output_tensor.unaligned_flat(); + EXPECT_EQ(flat(0), 30) << flat(0); +} + +TEST(TffSessionTest, FinalizeEncryptsOutputSuccess) { + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + DefaultSessionConfiguration()); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + std::string ciphertext_associated_data = + BlobHeader::default_instance().SerializeAsString(); + fcp::confidential_compute::MessageDecryptor decryptor; + absl::StatusOr reencryption_public_key = + decryptor.GetPublicKey([](absl::string_view) { return ""; }, 0); + ASSERT_TRUE(reencryption_public_key.ok()); + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + )pb"); + metadata.mutable_hpke_plus_aead_data()->set_ciphertext_associated_data( + ciphertext_associated_data); + metadata.mutable_hpke_plus_aead_data() + ->mutable_rewrapped_symmetric_key_associated_data() + ->set_reencryption_public_key(reencryption_public_key.value()); + FinalizeRequest finalize_request; + auto result = session.FinalizeSession(finalize_request, metadata).value(); + + ReadResponse read_response = result.read(); + ASSERT_TRUE( + read_response.first_response_metadata().has_hpke_plus_aead_data()); + BlobMetadata::HpkePlusAeadMetadata result_metadata = + read_response.first_response_metadata().hpke_plus_aead_data(); + absl::StatusOr decrypted_result = decryptor.Decrypt( + read_response.data(), result_metadata.ciphertext_associated_data(), + result_metadata.encrypted_symmetric_key(), + result_metadata.ciphertext_associated_data(), + result_metadata.encapsulated_public_key()); + ASSERT_TRUE(decrypted_result.ok()) << decrypted_result.status(); + + Value value; + value.ParseFromString(decrypted_result.value()); + tensorflow::Tensor output_tensor = + tensorflow_federated::DeserializeTensorValue(value.federated().value(0)) + .value(); + EXPECT_EQ(output_tensor.NumElements(), 1); + auto flat = output_tensor.unaligned_flat(); + EXPECT_EQ(flat(0), 10) << flat(0); +} + +TEST(TffSessionTest, FinalizeInvalidClientTensorFailure) { + FileInfo file_info_1; + file_info_1.set_uri("client1"); + file_info_1.set_key("key"); + file_info_1.set_client_upload(true); + + // Create Function + Value function = *LoadFileAsTffValue(kClientDataComputationPath); + + // Create Argument + Value argument; + argument.mutable_federated() + ->mutable_type() + ->mutable_placement() + ->mutable_value() + ->set_uri("clients"); + argument.mutable_federated()->mutable_type()->set_all_equal(false); + argument.mutable_federated() + ->add_value() + ->mutable_computation() + ->mutable_data() + ->mutable_content() + ->PackFrom(file_info_1); + + // Configure Session + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + CreateSessionConfiguration(std::move(function), std::move(argument), 1)); + + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + // Create checkpoint tensor with invalid key. + FederatedComputeCheckpointBuilderFactory builder_factory; + std::unique_ptr ckpt_builder = builder_factory.Create(); + + absl::StatusOr t1 = + Tensor::Create(DataType::DT_INT32, TensorShape({static_cast(1)}), + CreateTestData({1})); + CHECK_OK(t1); + CHECK_OK(ckpt_builder->Add("invalid_key", *t1)); + auto checkpoint = ckpt_builder->Build(); + CHECK_OK(checkpoint.status()); + + std::string checkpoint_string; + absl::CopyCordToString(*checkpoint, &checkpoint_string); + + WriteRequest write_request = + CreateDefaultClientWriteRequest(checkpoint_string); + auto response = + session.SessionWrite(write_request, checkpoint_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), checkpoint_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + + // Finalize Session + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + FinalizeRequest finalize_request; + auto status = session.FinalizeSession(finalize_request, metadata).status(); + ASSERT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + ASSERT_THAT(status.message(), HasSubstr("Data in argument was not provided")); +} + +TEST(TffSessionTest, FinalizeDataNotReplaceableFailure) { + FileInfo file_info_1; + file_info_1.set_uri("client1"); + file_info_1.set_key("key"); + file_info_1.set_client_upload(true); + FileInfo file_info_2; + file_info_2.set_uri("client2"); + file_info_2.set_key("key"); + file_info_2.set_client_upload(true); + FileInfo file_info_3; + file_info_3.set_uri("client3"); + file_info_3.set_key("key"); + file_info_3.set_client_upload(true); + + // Create Function + Value function = *LoadFileAsTffValue(kClientDataComputationPath); + + // Create Argument + Value argument; + argument.mutable_federated() + ->mutable_type() + ->mutable_placement() + ->mutable_value() + ->set_uri("clients"); + argument.mutable_federated()->mutable_type()->set_all_equal(false); + argument.mutable_federated() + ->add_value() + ->mutable_computation() + ->mutable_data() + ->mutable_content() + ->PackFrom(file_info_1); + argument.mutable_federated() + ->add_value() + ->mutable_computation() + ->mutable_data() + ->mutable_content() + ->PackFrom(file_info_2); + argument.mutable_federated() + ->add_value() + ->mutable_computation() + ->mutable_data() + ->mutable_content() + ->PackFrom(file_info_3); + + // Configure Session + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + CreateSessionConfiguration(std::move(function), std::move(argument), 3)); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + // Write Data to Session + std::string data_string = BuildClientCheckpoint({10}); + WriteRequest write_request = CreateWriteRequest(data_string, file_info_1); + auto response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + + data_string = BuildClientCheckpoint({20}); + write_request = CreateWriteRequest(data_string, file_info_2); + response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), data_string.size()); + ASSERT_EQ(response.write().status().code(), grpc::StatusCode::OK); + + data_string = "Invalid data."; + write_request = CreateWriteRequest(data_string, file_info_3); + response = session.SessionWrite(write_request, data_string).value(); + ASSERT_EQ(response.write().committed_size_bytes(), 0); + ASSERT_EQ(response.write().status().code(), + grpc::StatusCode::INVALID_ARGUMENT); + + // Finalize Session + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + unencrypted {} + )pb"); + FinalizeRequest finalize_request; + auto status = session.FinalizeSession(finalize_request, metadata).status(); + ASSERT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + ASSERT_THAT(status.message(), HasSubstr("Data in argument was not provided")); +} + +TEST(TffSessionTest, FinalizeEncryptWithInvalidBlobHeaderFailure) { + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + DefaultSessionConfiguration()); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + fcp::confidential_compute::MessageDecryptor decryptor; + absl::StatusOr reencryption_public_key = + decryptor.GetPublicKey([](absl::string_view) { return ""; }, 0); + ASSERT_TRUE(reencryption_public_key.ok()); + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + )pb"); + metadata.mutable_hpke_plus_aead_data()->set_ciphertext_associated_data( + "invalid ciphertext_associated_data"); + metadata.mutable_hpke_plus_aead_data() + ->mutable_rewrapped_symmetric_key_associated_data() + ->set_reencryption_public_key(reencryption_public_key.value()); + FinalizeRequest finalize_request; + auto status = session.FinalizeSession(finalize_request, metadata).status(); + ASSERT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + ASSERT_THAT(status.message(), HasSubstr("Failed to parse the BlobHeader")); +} + +TEST(TffSessionTest, FinalizeEncryptOutputRecordErrorFailure) { + TffSession session; + SessionRequest request; + request.mutable_configure()->mutable_configuration()->PackFrom( + DefaultSessionConfiguration()); + ASSERT_TRUE(session.ConfigureSession(request).ok()); + + std::string ciphertext_associated_data = + BlobHeader::default_instance().SerializeAsString(); + BlobMetadata metadata = PARSE_TEXT_PROTO(R"pb( + compression_type: COMPRESSION_TYPE_NONE + )pb"); + metadata.mutable_hpke_plus_aead_data()->set_ciphertext_associated_data( + ciphertext_associated_data); + metadata.mutable_hpke_plus_aead_data() + ->mutable_rewrapped_symmetric_key_associated_data() + ->set_reencryption_public_key("invalid key"); + FinalizeRequest finalize_request; + auto status = session.FinalizeSession(finalize_request, metadata).status(); + ASSERT_EQ(status.code(), absl::StatusCode::kInvalidArgument); + ASSERT_THAT(status.message(), HasSubstr("failed to decode CWT")); +} + +} // namespace + +} // namespace confidential_federated_compute::tff_server diff --git a/containers/tff_server/testing/.gitignore b/containers/tff_server/testing/.gitignore new file mode 100644 index 0000000..f9606a3 --- /dev/null +++ b/containers/tff_server/testing/.gitignore @@ -0,0 +1 @@ +/venv diff --git a/containers/tff_server/testing/client_data_function.txtpb b/containers/tff_server/testing/client_data_function.txtpb new file mode 100644 index 0000000..90e0ff9 --- /dev/null +++ b/containers/tff_server/testing/client_data_function.txtpb @@ -0,0 +1,850 @@ +type { + function { + parameter { + federated { + placement { + value { + uri: "clients" + } + } + member { + tensor { + dtype: DT_INT32 + } + } + } + } + result { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } +} +lambda { + parameter_name: "client_data_comp_arg" + result { + type { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + block { + local { + name: "fc_client_data_comp_symbol_0" + value { + type { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + call { + function { + type { + function { + parameter { + struct { + element { + value { + federated { + placement { + value { + uri: "clients" + } + } + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + function { + parameter { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + function { + parameter { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + function { + parameter { + tensor { + dtype: DT_INT32 + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + } + result { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + intrinsic { + uri: "federated_aggregate" + } + } + argument { + type { + struct { + element { + value { + federated { + placement { + value { + uri: "clients" + } + } + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + function { + parameter { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + function { + parameter { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + function { + parameter { + tensor { + dtype: DT_INT32 + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + } + struct { + element { + value { + type { + federated { + placement { + value { + uri: "clients" + } + } + member { + tensor { + dtype: DT_INT32 + } + } + } + } + reference { + name: "client_data_comp_arg" + } + } + } + element { + value { + type { + tensor { + dtype: DT_INT32 + } + } + literal { + value { + dtype: DT_INT32 + shape { + } + int32_list { + value: 0 + } + } + } + } + } + element { + value { + type { + function { + parameter { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + lambda { + parameter_name: "a" + result { + type { + tensor { + dtype: DT_INT32 + } + } + call { + function { + type { + function { + parameter { + struct { + element { + name: "x" + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + name: "y" + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + tensorflow { + graph_def { + [type.googleapis.com/tensorflow.GraphDef] { + node { + name: "session_token_tensor" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "arg_x" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "arg_y" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "add" + op: "AddV2" + input: "arg_x" + input: "arg_y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node { + name: "Identity" + op: "Identity" + input: "add" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + versions { + producer: 1575 + } + } + } + parameter { + struct { + element { + tensor { + tensor_name: "arg_x:0" + } + } + element { + tensor { + tensor_name: "arg_y:0" + } + } + } + } + result { + tensor { + tensor_name: "Identity:0" + } + } + session_token_tensor_name: "session_token_tensor:0" + } + } + argument { + type { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + reference { + name: "a" + } + } + } + } + } + } + } + element { + value { + type { + function { + parameter { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + lambda { + parameter_name: "a" + result { + type { + tensor { + dtype: DT_INT32 + } + } + call { + function { + type { + function { + parameter { + struct { + element { + name: "x" + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + name: "y" + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + tensorflow { + graph_def { + [type.googleapis.com/tensorflow.GraphDef] { + node { + name: "session_token_tensor" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "arg_x" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "arg_y" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "add" + op: "AddV2" + input: "arg_x" + input: "arg_y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node { + name: "Identity" + op: "Identity" + input: "add" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + versions { + producer: 1575 + } + } + } + parameter { + struct { + element { + tensor { + tensor_name: "arg_x:0" + } + } + element { + tensor { + tensor_name: "arg_y:0" + } + } + } + } + result { + tensor { + tensor_name: "Identity:0" + } + } + session_token_tensor_name: "session_token_tensor:0" + } + } + argument { + type { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + reference { + name: "a" + } + } + } + } + } + } + } + element { + value { + type { + function { + parameter { + tensor { + dtype: DT_INT32 + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + tensorflow { + graph_def { + [type.googleapis.com/tensorflow.GraphDef] { + node { + name: "session_token_tensor" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "arg" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "Identity" + op: "Identity" + input: "arg" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + versions { + producer: 1575 + } + } + } + parameter { + tensor { + tensor_name: "arg:0" + } + } + result { + tensor { + tensor_name: "Identity:0" + } + } + session_token_tensor_name: "session_token_tensor:0" + } + } + } + } + } + } + } + } + result { + type { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + reference { + name: "fc_client_data_comp_symbol_0" + } + } + } + } +} diff --git a/containers/tff_server/testing/generate_test_computations.py b/containers/tff_server/testing/generate_test_computations.py new file mode 100644 index 0000000..96e929d --- /dev/null +++ b/containers/tff_server/testing/generate_test_computations.py @@ -0,0 +1,104 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Binary for generating serialized TFF computations for testing. + +Run using: + + python3 -m venv venv && source venv/bin/activate + pip install --upgrade pip + pip install --upgrade tensorflow-federated==0.75.0 + python3 generate_test_computations.py + deactivate +""" + +import collections +import numpy as np +import os + +from absl import app +from absl import flags + +from google.protobuf import text_format +import tensorflow_federated as tff + +OUTPUT_DIR = flags.DEFINE_string('output_dir', '.', 'Output directory') + +NO_ARGUMENT_FUNCTION = 'no_argument_function.txtpb' +SERVER_DATA_FUNCTION = 'server_data_function.txtpb' +CLIENT_DATA_FUNCTION = 'client_data_function.txtpb' + +@tff.federated_computation +def no_argument_comp(): + return tff.federated_value(10, tff.SERVER) + +@tff.tf_computation(np.int32, np.int32) +def add(x, y): + return x + y + +@tff.tf_computation(np.int32) +def identity(x): + return x + +@tff.tf_computation(np.int32) +def scale(x): + return x * 10 + +@tff.federated_computation(tff.FederatedType(np.int32, tff.CLIENTS)) +def client_data_comp(client_data): + return tff.federated_aggregate(client_data, 0, add, add, identity) + +@tff.federated_computation(tff.FederatedType(np.int32, tff.SERVER)) +def server_data_comp(server_state): + scaled_server_state = tff.federated_map(scale, server_state) + broadcasted_server_state = tff.federated_broadcast(scaled_server_state) + summed_broadcast = tff.federated_aggregate( + broadcasted_server_state, 0, add, add, identity + ) + return scaled_server_state, summed_broadcast + +def generate_test_computations() -> None: + """Generates serialized test computations and writes them out to files.""" + no_argument_function_text_proto = text_format.MessageToString( + tff.framework.serialize_computation(no_argument_comp)) + no_argument_function_filepath = os.path.join( + OUTPUT_DIR.value, NO_ARGUMENT_FUNCTION) + + with open(no_argument_function_filepath, 'w') as f: + f.write(no_argument_function_text_proto) + + server_data_function_text_proto = text_format.MessageToString( + tff.framework.serialize_computation(server_data_comp)) + server_data_function_filepath = os.path.join( + OUTPUT_DIR.value, SERVER_DATA_FUNCTION) + + with open(server_data_function_filepath, 'w') as f: + f.write(server_data_function_text_proto) + + client_data_function_text_proto = text_format.MessageToString( + tff.framework.serialize_computation(client_data_comp)) + client_data_function_filepath = os.path.join( + OUTPUT_DIR.value, CLIENT_DATA_FUNCTION) + + with open(client_data_function_filepath, 'w') as f: + f.write(client_data_function_text_proto) + + +def main(argv: collections.abc.Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + generate_test_computations() + + +if __name__ == '__main__': + app.run(main) diff --git a/containers/tff_server/testing/no_argument_function.txtpb b/containers/tff_server/testing/no_argument_function.txtpb new file mode 100644 index 0000000..da819e4 --- /dev/null +++ b/containers/tff_server/testing/no_argument_function.txtpb @@ -0,0 +1,128 @@ +type { + function { + result { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } +} +lambda { + result { + type { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + block { + local { + name: "fc_no_argument_comp_symbol_0" + value { + type { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + call { + function { + type { + function { + parameter { + tensor { + dtype: DT_INT32 + } + } + result { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + intrinsic { + uri: "federated_value_at_server" + } + } + argument { + type { + tensor { + dtype: DT_INT32 + } + } + literal { + value { + dtype: DT_INT32 + shape { + } + int32_list { + value: 10 + } + } + } + } + } + } + } + result { + type { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + reference { + name: "fc_no_argument_comp_symbol_0" + } + } + } + } +} diff --git a/containers/tff_server/testing/server_data_function.txtpb b/containers/tff_server/testing/server_data_function.txtpb new file mode 100644 index 0000000..2b52a76 --- /dev/null +++ b/containers/tff_server/testing/server_data_function.txtpb @@ -0,0 +1,1308 @@ +type { + function { + parameter { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + result { + struct { + element { + value { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + } + } +} +lambda { + parameter_name: "server_data_comp_arg" + result { + type { + struct { + element { + value { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + } + block { + local { + name: "fc_server_data_comp_symbol_0" + value { + type { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + call { + function { + type { + function { + parameter { + struct { + element { + value { + function { + parameter { + tensor { + dtype: DT_INT32 + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + } + result { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + intrinsic { + uri: "federated_apply" + } + } + argument { + type { + struct { + element { + value { + function { + parameter { + tensor { + dtype: DT_INT32 + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + } + struct { + element { + value { + type { + function { + parameter { + tensor { + dtype: DT_INT32 + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + tensorflow { + graph_def { + [type.googleapis.com/tensorflow.GraphDef] { + node { + name: "session_token_tensor" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "arg" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "mul/y" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 10 + } + } + } + } + node { + name: "mul" + op: "Mul" + input: "arg" + input: "mul/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node { + name: "Identity" + op: "Identity" + input: "mul" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + versions { + producer: 1575 + } + } + } + parameter { + tensor { + tensor_name: "arg:0" + } + } + result { + tensor { + tensor_name: "Identity:0" + } + } + session_token_tensor_name: "session_token_tensor:0" + } + } + } + element { + value { + type { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + reference { + name: "server_data_comp_arg" + } + } + } + } + } + } + } + } + local { + name: "fc_server_data_comp_symbol_1" + value { + type { + federated { + placement { + value { + uri: "clients" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + call { + function { + type { + function { + parameter { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + result { + federated { + placement { + value { + uri: "clients" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + intrinsic { + uri: "federated_broadcast" + } + } + argument { + type { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + reference { + name: "fc_server_data_comp_symbol_0" + } + } + } + } + } + local { + name: "fc_server_data_comp_symbol_2" + value { + type { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + call { + function { + type { + function { + parameter { + struct { + element { + value { + federated { + placement { + value { + uri: "clients" + } + } + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + function { + parameter { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + function { + parameter { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + function { + parameter { + tensor { + dtype: DT_INT32 + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + } + result { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + intrinsic { + uri: "federated_aggregate" + } + } + argument { + type { + struct { + element { + value { + federated { + placement { + value { + uri: "clients" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + function { + parameter { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + function { + parameter { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + function { + parameter { + tensor { + dtype: DT_INT32 + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + } + struct { + element { + value { + type { + federated { + placement { + value { + uri: "clients" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + reference { + name: "fc_server_data_comp_symbol_1" + } + } + } + element { + value { + type { + tensor { + dtype: DT_INT32 + } + } + literal { + value { + dtype: DT_INT32 + shape { + } + int32_list { + value: 0 + } + } + } + } + } + element { + value { + type { + function { + parameter { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + lambda { + parameter_name: "a" + result { + type { + tensor { + dtype: DT_INT32 + } + } + call { + function { + type { + function { + parameter { + struct { + element { + name: "x" + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + name: "y" + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + tensorflow { + graph_def { + [type.googleapis.com/tensorflow.GraphDef] { + node { + name: "session_token_tensor" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "arg_x" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "arg_y" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "add" + op: "AddV2" + input: "arg_x" + input: "arg_y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node { + name: "Identity" + op: "Identity" + input: "add" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + versions { + producer: 1575 + } + } + } + parameter { + struct { + element { + tensor { + tensor_name: "arg_x:0" + } + } + element { + tensor { + tensor_name: "arg_y:0" + } + } + } + } + result { + tensor { + tensor_name: "Identity:0" + } + } + session_token_tensor_name: "session_token_tensor:0" + } + } + argument { + type { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + reference { + name: "a" + } + } + } + } + } + } + } + element { + value { + type { + function { + parameter { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + lambda { + parameter_name: "a" + result { + type { + tensor { + dtype: DT_INT32 + } + } + call { + function { + type { + function { + parameter { + struct { + element { + name: "x" + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + name: "y" + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + tensorflow { + graph_def { + [type.googleapis.com/tensorflow.GraphDef] { + node { + name: "session_token_tensor" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "arg_x" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "arg_y" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "add" + op: "AddV2" + input: "arg_x" + input: "arg_y" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + node { + name: "Identity" + op: "Identity" + input: "add" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + versions { + producer: 1575 + } + } + } + parameter { + struct { + element { + tensor { + tensor_name: "arg_x:0" + } + } + element { + tensor { + tensor_name: "arg_y:0" + } + } + } + } + result { + tensor { + tensor_name: "Identity:0" + } + } + session_token_tensor_name: "session_token_tensor:0" + } + } + argument { + type { + struct { + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + element { + value { + tensor { + dtype: DT_INT32 + } + } + } + } + } + reference { + name: "a" + } + } + } + } + } + } + } + element { + value { + type { + function { + parameter { + tensor { + dtype: DT_INT32 + } + } + result { + tensor { + dtype: DT_INT32 + } + } + } + } + tensorflow { + graph_def { + [type.googleapis.com/tensorflow.GraphDef] { + node { + name: "session_token_tensor" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "arg" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "Identity" + op: "Identity" + input: "arg" + attr { + key: "T" + value { + type: DT_INT32 + } + } + } + versions { + producer: 1575 + } + } + } + parameter { + tensor { + tensor_name: "arg:0" + } + } + result { + tensor { + tensor_name: "Identity:0" + } + } + session_token_tensor_name: "session_token_tensor:0" + } + } + } + } + } + } + } + } + result { + type { + struct { + element { + value { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + element { + value { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + } + } + } + struct { + element { + value { + type { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + reference { + name: "fc_server_data_comp_symbol_0" + } + } + } + element { + value { + type { + federated { + placement { + value { + uri: "server" + } + } + all_equal: true + member { + tensor { + dtype: DT_INT32 + } + } + } + } + reference { + name: "fc_server_data_comp_symbol_2" + } + } + } + } + } + } + } +}