Skip to content

Commit

Permalink
Update to tensorflow federated version 0.86.0
Browse files Browse the repository at this point in the history
Change-Id: Ib7124afff3647efdd39f1be7f600b9f01ce0d0ed
  • Loading branch information
mayaspivak committed Aug 29, 2024
1 parent 21d05d4 commit 2a64341
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 18 deletions.
6 changes: 3 additions & 3 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ python_register_toolchains(

http_archive(
name = "org_tensorflow_federated",
sha256 = "343d12a98ef8d98202e1dca898d84390a14fdc296af60f889cbb4023f38ebcdb",
strip_prefix = "tensorflow-federated-734a8669dc9842f4355d4bce240cd47883bda0c4",
url = "https://github.com/google-parfait/tensorflow-federated/archive/734a8669dc9842f4355d4bce240cd47883bda0c4.tar.gz",
sha256 = "5a514838ea601056da299ad8946cc40db2024131b1260a3454fb161317b1edf8",
strip_prefix = "tensorflow-federated-a7af3c978771c2a9ebd1bc3588597e65bc78249d",
url = "https://github.com/google-parfait/tensorflow-federated/archive/a7af3c978771c2a9ebd1bc3588597e65bc78249d.tar.gz",
)

# Use a newer version of BoringSSL than what TF gives us, so we can use
Expand Down
7 changes: 3 additions & 4 deletions containers/tff_worker/pipeline_transform_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "tensorflow_federated/cc/core/impl/executors/executor.h"
#include "tensorflow_federated/cc/core/impl/executors/federating_executor.h"
#include "tensorflow_federated/cc/core/impl/executors/reference_resolving_executor.h"
#include "tensorflow_federated/cc/core/impl/executors/tensor_serialization.h"
#include "tensorflow_federated/cc/core/impl/executors/tensorflow_executor.h"
#include "tensorflow_federated/proto/v0/executor.pb.h"

Expand Down Expand Up @@ -149,7 +150,7 @@ absl::StatusOr<tff::v0::Value> RestoreClientCheckpointToDict(
FCP_ASSIGN_OR_RETURN(auto parser,
parser_factory.Create(client_stacked_tensor_result));

// The output is a federated CLIENTS-placed Value with a struct of Tensors
// The output is a federated CLIENTS-placed Value with a struct of Arrays
// given the names from the `fed_sql_tf_checkpoint_spec` and the values in
// the FCP checkpoint.
tff::v0::Value_Federated* federated = restored_value.mutable_federated();
Expand All @@ -164,11 +165,9 @@ absl::StatusOr<tff::v0::Value> RestoreClientCheckpointToDict(
parser->GetTensor(column.name()));
absl::StatusOr<tf::Tensor> tensor =
ToTfTensor(std::move(tensor_column_values));
tf::TensorProto tensor_proto;
tensor->AsProtoTensorContent(&tensor_proto);
tff::v0::Value_Struct_Element* element = value_struct->add_element();
element->set_name(column.name());
element->mutable_value()->mutable_tensor()->PackFrom(tensor_proto);
tff::SerializeTensorValue(tensor.value(), element->mutable_value());
}
return restored_value;
}
Expand Down
20 changes: 9 additions & 11 deletions containers/tff_worker/pipeline_transform_server_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "tensorflow_federated/cc/core/impl/aggregation/protocol/federated_compute_checkpoint_builder.h"
#include "tensorflow_federated/cc/core/impl/aggregation/testing/test_data.h"
#include "tensorflow_federated/cc/core/impl/executors/cardinalities.h"
#include "tensorflow_federated/cc/core/impl/executors/tensor_serialization.h"
#include "tensorflow_federated/proto/v0/computation.pb.h"
#include "tensorflow_federated/proto/v0/executor.pb.h"

Expand Down Expand Up @@ -119,10 +120,8 @@ tff_proto::Value BuildFederatedIntClientValue(float int_value) {
tensorflow::Tensor tensor(tensorflow::DT_FLOAT, shape);
auto flat = tensor.flat<float>();
flat(0) = int_value;
tensorflow::TensorProto tensor_proto;
tensor.AsProtoTensorContent(&tensor_proto);
tensorflow_federated::v0::Value* federated_value = federated->add_value();
federated_value->mutable_tensor()->PackFrom(tensor_proto);
tensorflow_federated::SerializeTensorValue(tensor, federated_value);
return value;
}

Expand Down Expand Up @@ -344,17 +343,16 @@ TEST_F(TffPipelineTransformTest, TransformExecutesClientWork) {
EXPECT_TRUE(value.federated().value(0).has_struct_());
EXPECT_EQ(value.federated().value(0).struct_().element_size(), 1);
EXPECT_TRUE(
value.federated().value(0).struct_().element(0).value().has_tensor());
tensorflow::TensorProto output_tensor_proto;
value.federated().value(0).struct_().element(0).value().tensor().UnpackTo(
&output_tensor_proto);
tensorflow::Tensor output_tensor;
CHECK(output_tensor.FromProto(output_tensor_proto));
EXPECT_EQ(output_tensor.NumElements(), 3);
value.federated().value(0).struct_().element(0).value().has_array());
absl::StatusOr<tensorflow::Tensor> output_tensor =
tensorflow_federated::DeserializeTensorValue(
value.federated().value(0).struct_().element(0).value());
EXPECT_TRUE(output_tensor.ok());
EXPECT_EQ(output_tensor.value().NumElements(), 3);

// Test client work computation adds the broadcasted value to each of the
// values in the input tensor.
auto flat = output_tensor.unaligned_flat<float>();
auto flat = output_tensor.value().unaligned_flat<float>();
EXPECT_EQ(flat(0), 11);
EXPECT_EQ(flat(1), 12);
EXPECT_EQ(flat(2), 13);
Expand Down

0 comments on commit 2a64341

Please sign in to comment.