Skip to content

Commit

Permalink
Update the API of tff.learning.algorithms to use the new `LoopImple…
Browse files Browse the repository at this point in the history
…mentation` enum.

Callsites should update and preserve previous behavior as follows:
- `use_experimental_simulation_loop=False` → `loop_implementation=tff.learning.LoopImplementation.DATASET_REDUCE`
- `use_experimental_simulation_loop=True` → `loop_implementation=tff.learning.LoopImplementation.DATASET_ITERATOR`

PiperOrigin-RevId: 657564409
  • Loading branch information
ZacharyGarrett authored and copybara-github committed Jul 30, 2024
1 parent 5c84062 commit e1cc0c8
Show file tree
Hide file tree
Showing 19 changed files with 306 additions and 342 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ and this project adheres to
* Fixed a bug when using `tff.backends.xla` execution stack that raised errors
when single element structures were returned from `tff.jax.computation`
wrapped methods.
* Rename the boolean `use_experimental_simulation_loop` parameter to
`loop_implementation` that accepts an `tff.learning.LoopImplementation` enum
for all `tff.learning.algorithms` methods.
* Modified the model output release frequency to every 10 rounds and the final
round in `tff.learning.programs.train_model`.
* Loosened the `kEpsilonThreshold` constant and updated the tests of
Expand Down
1 change: 1 addition & 0 deletions tensorflow_federated/python/learning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ py_library(
deps = [
":client_weight_lib",
":debug_measurements",
":loop_builder",
":model_update_aggregator",
"//tensorflow_federated/python/learning/algorithms",
"//tensorflow_federated/python/learning/metrics",
Expand Down
1 change: 1 addition & 0 deletions tensorflow_federated/python/learning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from tensorflow_federated.python.learning.client_weight_lib import ClientWeighting
from tensorflow_federated.python.learning.debug_measurements import add_debug_measurements
from tensorflow_federated.python.learning.debug_measurements import add_debug_measurements_with_mixed_dtype
from tensorflow_federated.python.learning.loop_builder import LoopImplementation
from tensorflow_federated.python.learning.model_update_aggregator import compression_aggregator
from tensorflow_federated.python.learning.model_update_aggregator import ddp_secure_aggregator
from tensorflow_federated.python.learning.model_update_aggregator import dp_aggregator
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_federated/python/learning/algorithms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ py_library(
"//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/learning:client_weight_lib",
"//tensorflow_federated/python/learning:loop_builder",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/metrics:types",
"//tensorflow_federated/python/learning/models:functional",
Expand Down Expand Up @@ -93,6 +94,7 @@ py_library(
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/learning:client_weight_lib",
"//tensorflow_federated/python/learning:loop_builder",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/metrics:types",
"//tensorflow_federated/python/learning/models:functional",
Expand Down Expand Up @@ -136,6 +138,7 @@ py_library(
"//tensorflow_federated/python/core/environments/tensorflow_frontend:tensorflow_computation",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/learning:client_weight_lib",
"//tensorflow_federated/python/learning:loop_builder",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/metrics:types",
"//tensorflow_federated/python/learning/models:functional",
Expand Down
23 changes: 10 additions & 13 deletions tensorflow_federated/python/learning/algorithms/fed_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.learning import client_weight_lib
from tensorflow_federated.python.learning import loop_builder
from tensorflow_federated.python.learning.metrics import aggregator as metric_aggregator
from tensorflow_federated.python.learning.metrics import types
from tensorflow_federated.python.learning.models import functional
Expand Down Expand Up @@ -76,7 +77,7 @@ def build_weighted_fed_avg(
model_distributor: Optional[distributors.DistributionProcess] = None,
model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
metrics_aggregator: Optional[types.MetricsAggregatorType] = None,
use_experimental_simulation_loop: bool = False,
loop_implementation: loop_builder.LoopImplementation = loop_builder.LoopImplementation.DATASET_REDUCE,
) -> learning_process.LearningProcess:
"""Builds a learning process that performs federated averaging.
Expand Down Expand Up @@ -156,10 +157,8 @@ def build_weighted_fed_avg(
`tff.learning.models.VariableModel.report_local_unfinalized_metrics()`),
and returns a `tff.Computation` for aggregating the unfinalized metrics.
If `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
use_experimental_simulation_loop: Controls the reduce loop function for
input dataset. An experimental reduce loop is used for simulation. It is
currently necessary to set this flag to True for performant GPU
simulations.
loop_implementation: Changes the implementation of the training loop
generated. See `tff.learning.LoopImplementation` for more details.
Returns:
A `tff.learning.templates.LearningProcess`.
Expand Down Expand Up @@ -256,7 +255,7 @@ def initial_model_weights_fn():
optimizer=client_optimizer_fn,
client_weighting=client_weighting,
metrics_aggregator=metrics_aggregator,
use_experimental_simulation_loop=use_experimental_simulation_loop,
loop_implementation=loop_implementation,
)
)
else:
Expand All @@ -265,7 +264,7 @@ def initial_model_weights_fn():
optimizer=client_optimizer_fn,
client_weighting=client_weighting,
metrics_aggregator=metrics_aggregator,
use_experimental_simulation_loop=use_experimental_simulation_loop,
loop_implementation=loop_implementation,
)
finalizer = apply_optimizer_finalizer.build_apply_optimizer_finalizer(
server_optimizer_fn, model_weights_type
Expand Down Expand Up @@ -294,7 +293,7 @@ def build_unweighted_fed_avg(
model_distributor: Optional[distributors.DistributionProcess] = None,
model_aggregator: Optional[factory.UnweightedAggregationFactory] = None,
metrics_aggregator: types.MetricsAggregatorType = metric_aggregator.sum_then_finalize,
use_experimental_simulation_loop: bool = False,
loop_implementation: loop_builder.LoopImplementation = loop_builder.LoopImplementation.DATASET_REDUCE,
) -> learning_process.LearningProcess:
"""Builds a learning process that performs federated averaging.
Expand Down Expand Up @@ -368,10 +367,8 @@ def build_unweighted_fed_avg(
`tff.learning.models.VariableModel.report_local_unfinalized_metrics()`),
and returns a `tff.Computation` for aggregating the unfinalized metrics.
If `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
use_experimental_simulation_loop: Controls the reduce loop function for
input dataset. An experimental reduce loop is used for simulation. It is
currently necessary to set this flag to True for performant GPU
simulations.
loop_implementation: Changes the implementation of the training loop
generated. See `tff.learning.LoopImplementation` for more details.
Returns:
A `tff.learning.templates.LearningProcess`.
Expand All @@ -392,5 +389,5 @@ def build_unweighted_fed_avg(
model_distributor=model_distributor,
model_aggregator=factory_utils.as_weighted_aggregator(model_aggregator),
metrics_aggregator=metrics_aggregator,
use_experimental_simulation_loop=use_experimental_simulation_loop,
loop_implementation=loop_implementation,
)
17 changes: 7 additions & 10 deletions tensorflow_federated/python/learning/algorithms/fed_avg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,21 @@ def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory):
self.assertEqual(mock_model_fn.call_count, 3)

@parameterized.named_parameters(
('non-simulation_tff_optimizer', False),
('simulation_tff_optimizer', True),
('dataset_reduce', loop_builder.LoopImplementation.DATASET_REDUCE),
('dataset_iterator', loop_builder.LoopImplementation.DATASET_ITERATOR),
)
@mock.patch.object(
loop_builder,
'_dataset_reduce_fn',
wraps=loop_builder._dataset_reduce_fn,
'build_training_loop',
wraps=loop_builder.build_training_loop,
)
def test_client_tf_dataset_reduce_fn(self, simulation, mock_method):
def test_client_tf_dataset_reduce_fn(self, loop_implementation, mock_method):
fed_avg.build_weighted_fed_avg(
model_fn=model_examples.LinearRegression,
client_optimizer_fn=sgdm.build_sgdm(1.0),
use_experimental_simulation_loop=simulation,
loop_implementation=loop_implementation,
)
if simulation:
mock_method.assert_not_called()
else:
mock_method.assert_called()
mock_method.assert_called_once_with(loop_implementation=loop_implementation)

@mock.patch.object(fed_avg, 'build_weighted_fed_avg')
def test_build_weighted_fed_avg_called_by_unweighted_fed_avg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning import client_weight_lib
from tensorflow_federated.python.learning import loop_builder
from tensorflow_federated.python.learning.algorithms import fed_avg
from tensorflow_federated.python.learning.metrics import aggregator as metric_aggregator
from tensorflow_federated.python.learning.metrics import types
Expand All @@ -57,7 +58,7 @@ def build_scheduled_client_work(
learning_rate_fn: Callable[[int], float],
optimizer_fn: Callable[[float], TFFOrKerasOptimizer],
metrics_aggregator: types.MetricsAggregatorType,
use_experimental_simulation_loop: bool = False,
loop_implementation: loop_builder.LoopImplementation = loop_builder.LoopImplementation.DATASET_REDUCE,
) -> client_works.ClientWorkProcess:
"""Creates a `ClientWorkProcess` for federated averaging.
Expand Down Expand Up @@ -87,10 +88,8 @@ def build_scheduled_client_work(
type of
`tff.learning.models.VariableModel.report_local_unfinalized_metrics()`),
and returns a `tff.Computation` for aggregating the unfinalized metrics.
use_experimental_simulation_loop: Controls the reduce loop function for
input dataset. An experimental reduce loop is used for simulation. It is
currently necessary to set this flag to True for performant GPU
simulations.
loop_implementation: Changes the implementation of the training loop
generated. See `tff.learning.LoopImplementation` for more details.
Returns:
A `ClientWorkProcess`.
Expand Down Expand Up @@ -135,12 +134,12 @@ def build_scheduled_client_work(
elif isinstance(whimsy_optimizer, optimizer_base.Optimizer):
build_client_update_fn = functools.partial(
model_delta_client_work.build_model_delta_update_with_tff_optimizer,
use_experimental_simulation_loop=use_experimental_simulation_loop,
loop_implementation=loop_implementation,
)
else:
build_client_update_fn = functools.partial(
model_delta_client_work.build_model_delta_update_with_keras_optimizer,
use_experimental_simulation_loop=use_experimental_simulation_loop,
loop_implementation=loop_implementation,
)

@tensorflow_computation.tf_computation(weights_type, data_type, np.int32)
Expand All @@ -150,7 +149,7 @@ def client_update_computation(initial_model_weights, dataset, round_num):
client_update = build_client_update_fn(
model_fn,
weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES,
use_experimental_simulation_loop=use_experimental_simulation_loop,
loop_implementation=loop_implementation,
)
return client_update(optimizer, initial_model_weights, dataset)

Expand Down Expand Up @@ -210,7 +209,7 @@ def build_weighted_fed_avg_with_optimizer_schedule(
model_distributor: Optional[distributors.DistributionProcess] = None,
model_aggregator: Optional[factory.WeightedAggregationFactory] = None,
metrics_aggregator: Optional[types.MetricsAggregatorType] = None,
use_experimental_simulation_loop: bool = False,
loop_implementation: loop_builder.LoopImplementation = loop_builder.LoopImplementation.DATASET_REDUCE,
) -> learning_process.LearningProcess:
"""Builds a learning process for FedAvg with client optimizer scheduling.
Expand Down Expand Up @@ -296,10 +295,8 @@ def build_weighted_fed_avg_with_optimizer_schedule(
`tff.learning.models.VariableModel.report_local_unfinalized_metrics()`),
and returns a `tff.Computation` for aggregating the unfinalized metrics.
If `None`, this is set to `tff.learning.metrics.sum_then_finalize`.
use_experimental_simulation_loop: Controls the reduce loop function for
input dataset. An experimental reduce loop is used for simulation. It is
currently necessary to set this flag to True for performant GPU
simulations.
loop_implementation: Changes the implementation of the training loop
generated. See `tff.learning.LoopImplementation` for more details.
Returns:
A `LearningProcess`.
Expand Down Expand Up @@ -344,7 +341,7 @@ def initial_model_weights_fn():
client_learning_rate_fn,
client_optimizer_fn,
metrics_aggregator,
use_experimental_simulation_loop,
loop_implementation,
)
finalizer = apply_optimizer_finalizer.build_apply_optimizer_finalizer(
server_optimizer_fn, model_weights_type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,24 @@ def test_construction_of_functional_model(self):
)

@parameterized.named_parameters(
('non_simulation', False),
('simulation', True),
('dataset_reduce', loop_builder.LoopImplementation.DATASET_REDUCE),
('dataset_iterator', loop_builder.LoopImplementation.DATASET_ITERATOR),
)
@mock.patch.object(
loop_builder,
'_dataset_reduce_fn',
wraps=loop_builder._dataset_reduce_fn,
'build_training_loop',
wraps=loop_builder.build_training_loop,
)
def test_client_tf_dataset_reduce_fn(self, use_simulation, mock_reduce):
def test_client_tf_dataset_reduce_fn(self, loop_implementation, mock_reduce):
client_learning_rate_fn = lambda x: 0.5
client_optimizer_fn = tf.keras.optimizers.SGD
fed_avg_with_optimizer_schedule.build_weighted_fed_avg_with_optimizer_schedule(
model_fn=model_examples.LinearRegression,
client_learning_rate_fn=client_learning_rate_fn,
client_optimizer_fn=client_optimizer_fn,
use_experimental_simulation_loop=use_simulation,
loop_implementation=loop_implementation,
)

if use_simulation:
mock_reduce.assert_not_called()
else:
mock_reduce.assert_called()
mock_reduce.assert_called_once_with(loop_implementation=loop_implementation)

@parameterized.named_parameters([
('keras_optimizer', lambda x: tf.keras.optimizers.SGD()),
Expand Down
25 changes: 13 additions & 12 deletions tensorflow_federated/python/learning/algorithms/fed_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _build_local_evaluation(
model_fn: Callable[[], variable.VariableModel],
model_weights_type: computation_types.StructType,
batch_type: computation_types.Type,
use_experimental_simulation_loop: bool = False,
loop_implementation: loop_builder.LoopImplementation,
) -> computation_base.Computation:
"""Builds the local TFF computation for evaluation of the given model.
Expand All @@ -72,8 +72,8 @@ def _build_local_evaluation(
model_weights_type: The `tff.Type` of the model parameters that will be used
to initialize the model during evaluation.
batch_type: The type of one entry in the dataset.
use_experimental_simulation_loop: Controls the reduce loop function for
input dataset. An experimental reduce loop is used for simulation.
loop_implementation: Changes the implementation of the training loop
generated. See `tff.learning.LoopImplementation` for more details.
Returns:
A federated computation (an instance of `tff.Computation`) that accepts
Expand Down Expand Up @@ -106,9 +106,7 @@ def reduce_fn(num_examples, batch):
return num_examples + tf.cast(model_output.num_examples, tf.int64)

dataset_reduce_fn = loop_builder.build_training_loop(
loop_builder.LoopImplementation.DATASET_ITERATOR
if use_experimental_simulation_loop
else loop_builder.LoopImplementation.DATASET_REDUCE
loop_implementation=loop_implementation
)
num_examples = dataset_reduce_fn(
reduce_fn, dataset, lambda: tf.zeros([], dtype=tf.int64)
Expand Down Expand Up @@ -194,7 +192,7 @@ def _build_fed_eval_client_work(
model_fn: Callable[[], variable.VariableModel],
metrics_aggregation_process: Optional[_AggregationProcess],
model_weights_type: computation_types.StructType,
use_experimental_simulation_loop: bool = False,
loop_implementation: loop_builder.LoopImplementation,
) -> client_works.ClientWorkProcess:
"""Builds a `ClientWorkProcess` that performs model evaluation at clients."""

Expand Down Expand Up @@ -233,7 +231,10 @@ def init_fn():
return metrics_aggregation_process.initialize()

client_update_computation = _build_local_evaluation(
model_fn, model_weights_type, batch_type, use_experimental_simulation_loop
model_fn,
model_weights_type,
batch_type,
loop_implementation=loop_implementation,
)

@federated_computation.federated_computation(
Expand Down Expand Up @@ -359,7 +360,7 @@ def build_fed_eval(
metrics_aggregation_process: Optional[
aggregation_process.AggregationProcess
] = None,
use_experimental_simulation_loop: bool = False,
loop_implementation: loop_builder.LoopImplementation = loop_builder.LoopImplementation.DATASET_REDUCE,
) -> learning_process.LearningProcess:
"""Builds a learning process that performs federated evaluation.
Expand Down Expand Up @@ -416,8 +417,8 @@ def build_fed_eval(
None, the `tff.templates.AggregationProcess` created by the
`SumThenFinalizeFactory` with metric finalizers defined in the model is
used.
use_experimental_simulation_loop: Controls the reduce loop function for
input dataset. An experimental reduce loop is used for simulation.
loop_implementation: Changes the implementation of the training loop
generated. See `tff.learning.LoopImplementation` for more details.
Returns:
A `tff.learning.templates.LearningProcess` performs federated evaluation on
Expand Down Expand Up @@ -467,7 +468,7 @@ def initial_model_weights_fn():
model_fn,
metrics_aggregation_process,
model_weights_type,
use_experimental_simulation_loop,
loop_implementation=loop_implementation,
)

client_work_result_type = computation_types.FederatedType(
Expand Down
Loading

0 comments on commit e1cc0c8

Please sign in to comment.