From 5c84062b9461509c2ebed5dcc581f15ec9eb9e6e Mon Sep 17 00:00:00 2001 From: Yu Xiao Date: Mon, 29 Jul 2024 10:32:47 -0700 Subject: [PATCH] Fix the version collisions of evaluation state manager. PiperOrigin-RevId: 657248426 --- RELEASE.md | 2 + .../python/learning/programs/BUILD | 1 + .../programs/evaluation_program_logic.py | 58 ++++++++++++++++--- .../programs/evaluation_program_logic_test.py | 49 +++++++++++++++- 4 files changed, 101 insertions(+), 9 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 91d89f5ef7..875aa7a4dc 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -25,6 +25,8 @@ and this project adheres to ### Changed +* Fixed a bug in `tff.learning.programs.EvaluationManager` that raised an + error when the version IDs of two state-saving operations were the same. * Fixed a bug in `tff.jax.computation` that raised an error when the computation had unused arguments. * Fixed a bug when using `tff.backends.xla` execution stack that raised errors diff --git a/tensorflow_federated/python/learning/programs/BUILD b/tensorflow_federated/python/learning/programs/BUILD index c893648df7..c74de42e98 100644 --- a/tensorflow_federated/python/learning/programs/BUILD +++ b/tensorflow_federated/python/learning/programs/BUILD @@ -32,6 +32,7 @@ py_library( "//tensorflow_federated/python/program:data_source", "//tensorflow_federated/python/program:federated_context", "//tensorflow_federated/python/program:file_program_state_manager", + "//tensorflow_federated/python/program:program_state_manager", "//tensorflow_federated/python/program:release_manager", "//tensorflow_federated/python/program:value_reference", ], diff --git a/tensorflow_federated/python/learning/programs/evaluation_program_logic.py b/tensorflow_federated/python/learning/programs/evaluation_program_logic.py index ddb859bca2..ed369a0fea 100644 --- a/tensorflow_federated/python/learning/programs/evaluation_program_logic.py +++ b/tensorflow_federated/python/learning/programs/evaluation_program_logic.py @@ -74,6 +74,7 @@ from tensorflow_federated.python.program import data_source as data_source_lib from tensorflow_federated.python.program import federated_context from tensorflow_federated.python.program import file_program_state_manager +from tensorflow_federated.python.program import program_state_manager from tensorflow_federated.python.program import release_manager from tensorflow_federated.python.program import value_reference @@ -105,6 +106,51 @@ def _pop_value( _EVAL_NAME_PATTERN = 'evaluation_of_train_round_{round_num:05d}' +class AutoVersionAdvanceingStateManager: + """A file state manager that automatically advances the version number.""" + + def __init__( + self, + state_manager: file_program_state_manager.FileProgramStateManager, + ): + """Initializes the AutoVersionAdvanceingStateManager. + + Args: + state_manager: The file state manager to use for saving and loading state. + """ + self._state_manager = state_manager + self._next_version = 0 + self._lock = asyncio.Lock() # Lock for concurrency safety. + + async def load_latest( + self, structure: program_state_manager.ProgramStateStructure + ) -> program_state_manager.ProgramStateStructure: + """Returns the latest program state. + + Args: + structure: The structure of the saved program state for the given + `version` used to support serialization and deserialization of + user-defined classes in the structure. + """ + async with self._lock: + state, version = await self._state_manager.load_latest(structure) + self._next_version = version + 1 + return state + + async def save( + self, + program_state: program_state_manager.ProgramStateStructure, + ) -> None: + """Saves `program_state` and automatically advances the version number. + + Args: + program_state: A `tff.program.ProgramStateStructure` to save. + """ + async with self._lock: + await self._state_manager.save(program_state, version=self._next_version) + self._next_version += 1 + + class EvaluationManager: """A manager for facilitating multiple in-progress evaluations. @@ -184,8 +230,9 @@ def __init__( self._create_evaluation_process_fn = create_process_fn self._cohort_size = cohort_size self._duration = duration - self._state_manager = create_state_manager_fn(_EVAL_MANAGER_KEY) - self._next_version = 0 + self._state_manager = AutoVersionAdvanceingStateManager( + create_state_manager_fn(_EVAL_MANAGER_KEY) + ) self._evaluating_training_checkpoints = np.zeros([0], np.int32) self._evaluation_start_timestamp_seconds = np.zeros([0], np.int32) self._pending_tasks: set[asyncio.Task] = set() @@ -262,7 +309,7 @@ def _finalize_task(self, task: asyncio.Task): async def resume_from_previous_state(self) -> None: """Load the most recent state and restart in-progress evaluations.""" - loaded_state, loaded_version = await self._state_manager.load_latest(( + loaded_state = await self._state_manager.load_latest(( self._evaluating_training_checkpoints, self._evaluation_start_timestamp_seconds, )) @@ -273,7 +320,6 @@ async def resume_from_previous_state(self) -> None: self._evaluating_training_checkpoints, self._evaluation_start_timestamp_seconds, ) = loaded_state - self._next_version = loaded_version + 1 train_round_nums = self._evaluating_training_checkpoints.tolist() _logging.info( 'Resuming previous evaluations found for training rounds: %s', @@ -403,9 +449,7 @@ async def start_evaluation( self._evaluating_training_checkpoints, self._evaluation_start_timestamp_seconds, ), - version=self._next_version, ) - self._next_version += 1 async def record_evaluations_finished(self, train_round: int) -> None: """Removes evaluation for `train_round` from the internal state manager. @@ -438,9 +482,7 @@ async def record_evaluations_finished(self, train_round: int) -> None: self._evaluating_training_checkpoints, self._evaluation_start_timestamp_seconds, ), - version=self._next_version, ) - self._next_version += 1 def extract_and_rewrap_metrics( diff --git a/tensorflow_federated/python/learning/programs/evaluation_program_logic_test.py b/tensorflow_federated/python/learning/programs/evaluation_program_logic_test.py index 1463e7076e..ea4fedfac1 100644 --- a/tensorflow_federated/python/learning/programs/evaluation_program_logic_test.py +++ b/tensorflow_federated/python/learning/programs/evaluation_program_logic_test.py @@ -398,7 +398,7 @@ async def test_record_finished_evaluations_removes_from_state(self): datetime.datetime(2022, 10, 28, 5, 15).timestamp(), datetime.datetime(2022, 10, 28, 5, 25).timestamp(), ]).astype(np.int32) - manager._next_version = 1 + manager._state_manager._next_version = 1 await manager.record_evaluations_finished(5) # Only train_round 15 should be saved after 5 finishes. self.assertSequenceEqual( @@ -419,6 +419,53 @@ async def test_record_finished_evaluations_removes_from_state(self): await manager.wait_for_evaluations_to_finish() self.assertEmpty(manager._pending_tasks) + async def test_record_two_evaluations_finished_removes_from_state(self): + mock_data_source = mock.create_autospec( + data_source.FederatedDataSource, instance=True, spec_set=True + ) + mock_metrics_manager = mock.create_autospec( + release_manager.ReleaseManager, instance=True, spec_set=True + ) + # Create a state manager with two inflight evaluations. + mock_meta_eval_manager = mock.create_autospec( + file_program_state_manager.FileProgramStateManager, + instance=True, + spec_set=True, + ) + mock_create_state_manager = mock.Mock(side_effect=[mock_meta_eval_manager]) + mock_create_process_fn = mock.Mock() + manager = evaluation_program_logic.EvaluationManager( + data_source=mock_data_source, + aggregated_metrics_manager=mock_metrics_manager, + create_state_manager_fn=mock_create_state_manager, + create_process_fn=mock_create_process_fn, + cohort_size=10, + duration=datetime.timedelta(milliseconds=10), + ) + # Directly set the state, avoid starting asyncio.Task for the resumed evals. + manager._evaluating_training_checkpoints = np.asarray([5, 15]).astype( + np.int32 + ) + manager._evaluation_start_timestamp_seconds = np.asarray([ + datetime.datetime(2022, 10, 28, 5, 15).timestamp(), + datetime.datetime(2022, 10, 28, 5, 25).timestamp(), + ]).astype(np.int32) + manager._state_manager._next_version = 1 + task1 = asyncio.create_task(manager.record_evaluations_finished(5)) + task2 = asyncio.create_task(manager.record_evaluations_finished(15)) + finished, _ = await asyncio.wait([task1, task2]) + self.assertLen(finished, 2) + # Assert that all evaluations were removed from the state manager without + # version conflicts. + self.assertSequenceEqual( + mock_meta_eval_manager.save.call_args_list, + [ + mock.call((_NumpyMatcher([15]), mock.ANY), version=1), + mock.call((_NumpyMatcher([]), _NumpyMatcher([])), version=2), + ], + ) + self.assertEmpty(manager._pending_tasks) + async def test_resume_previous_evaluations(self): mock_data_source = mock.create_autospec( data_source.FederatedDataSource, instance=True, spec_set=True