Skip to content

Commit

Permalink
Fix the version collisions of evaluation state manager.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657248426
  • Loading branch information
xiaoyux11 authored and copybara-github committed Jul 29, 2024
1 parent eb7fc72 commit 5c84062
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 9 deletions.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ and this project adheres to

### Changed

* Fixed a bug in `tff.learning.programs.EvaluationManager` that raised an
error when the version IDs of two state-saving operations were the same.
* Fixed a bug in `tff.jax.computation` that raised an error when the
computation had unused arguments.
* Fixed a bug when using `tff.backends.xla` execution stack that raised errors
Expand Down
1 change: 1 addition & 0 deletions tensorflow_federated/python/learning/programs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ py_library(
"//tensorflow_federated/python/program:data_source",
"//tensorflow_federated/python/program:federated_context",
"//tensorflow_federated/python/program:file_program_state_manager",
"//tensorflow_federated/python/program:program_state_manager",
"//tensorflow_federated/python/program:release_manager",
"//tensorflow_federated/python/program:value_reference",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from tensorflow_federated.python.program import data_source as data_source_lib
from tensorflow_federated.python.program import federated_context
from tensorflow_federated.python.program import file_program_state_manager
from tensorflow_federated.python.program import program_state_manager
from tensorflow_federated.python.program import release_manager
from tensorflow_federated.python.program import value_reference

Expand Down Expand Up @@ -105,6 +106,51 @@ def _pop_value(
_EVAL_NAME_PATTERN = 'evaluation_of_train_round_{round_num:05d}'


class AutoVersionAdvanceingStateManager:
"""A file state manager that automatically advances the version number."""

def __init__(
self,
state_manager: file_program_state_manager.FileProgramStateManager,
):
"""Initializes the AutoVersionAdvanceingStateManager.
Args:
state_manager: The file state manager to use for saving and loading state.
"""
self._state_manager = state_manager
self._next_version = 0
self._lock = asyncio.Lock() # Lock for concurrency safety.

async def load_latest(
self, structure: program_state_manager.ProgramStateStructure
) -> program_state_manager.ProgramStateStructure:
"""Returns the latest program state.
Args:
structure: The structure of the saved program state for the given
`version` used to support serialization and deserialization of
user-defined classes in the structure.
"""
async with self._lock:
state, version = await self._state_manager.load_latest(structure)
self._next_version = version + 1
return state

async def save(
self,
program_state: program_state_manager.ProgramStateStructure,
) -> None:
"""Saves `program_state` and automatically advances the version number.
Args:
program_state: A `tff.program.ProgramStateStructure` to save.
"""
async with self._lock:
await self._state_manager.save(program_state, version=self._next_version)
self._next_version += 1


class EvaluationManager:
"""A manager for facilitating multiple in-progress evaluations.
Expand Down Expand Up @@ -184,8 +230,9 @@ def __init__(
self._create_evaluation_process_fn = create_process_fn
self._cohort_size = cohort_size
self._duration = duration
self._state_manager = create_state_manager_fn(_EVAL_MANAGER_KEY)
self._next_version = 0
self._state_manager = AutoVersionAdvanceingStateManager(
create_state_manager_fn(_EVAL_MANAGER_KEY)
)
self._evaluating_training_checkpoints = np.zeros([0], np.int32)
self._evaluation_start_timestamp_seconds = np.zeros([0], np.int32)
self._pending_tasks: set[asyncio.Task] = set()
Expand Down Expand Up @@ -262,7 +309,7 @@ def _finalize_task(self, task: asyncio.Task):

async def resume_from_previous_state(self) -> None:
"""Load the most recent state and restart in-progress evaluations."""
loaded_state, loaded_version = await self._state_manager.load_latest((
loaded_state = await self._state_manager.load_latest((
self._evaluating_training_checkpoints,
self._evaluation_start_timestamp_seconds,
))
Expand All @@ -273,7 +320,6 @@ async def resume_from_previous_state(self) -> None:
self._evaluating_training_checkpoints,
self._evaluation_start_timestamp_seconds,
) = loaded_state
self._next_version = loaded_version + 1
train_round_nums = self._evaluating_training_checkpoints.tolist()
_logging.info(
'Resuming previous evaluations found for training rounds: %s',
Expand Down Expand Up @@ -403,9 +449,7 @@ async def start_evaluation(
self._evaluating_training_checkpoints,
self._evaluation_start_timestamp_seconds,
),
version=self._next_version,
)
self._next_version += 1

async def record_evaluations_finished(self, train_round: int) -> None:
"""Removes evaluation for `train_round` from the internal state manager.
Expand Down Expand Up @@ -438,9 +482,7 @@ async def record_evaluations_finished(self, train_round: int) -> None:
self._evaluating_training_checkpoints,
self._evaluation_start_timestamp_seconds,
),
version=self._next_version,
)
self._next_version += 1


def extract_and_rewrap_metrics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ async def test_record_finished_evaluations_removes_from_state(self):
datetime.datetime(2022, 10, 28, 5, 15).timestamp(),
datetime.datetime(2022, 10, 28, 5, 25).timestamp(),
]).astype(np.int32)
manager._next_version = 1
manager._state_manager._next_version = 1
await manager.record_evaluations_finished(5)
# Only train_round 15 should be saved after 5 finishes.
self.assertSequenceEqual(
Expand All @@ -419,6 +419,53 @@ async def test_record_finished_evaluations_removes_from_state(self):
await manager.wait_for_evaluations_to_finish()
self.assertEmpty(manager._pending_tasks)

async def test_record_two_evaluations_finished_removes_from_state(self):
mock_data_source = mock.create_autospec(
data_source.FederatedDataSource, instance=True, spec_set=True
)
mock_metrics_manager = mock.create_autospec(
release_manager.ReleaseManager, instance=True, spec_set=True
)
# Create a state manager with two inflight evaluations.
mock_meta_eval_manager = mock.create_autospec(
file_program_state_manager.FileProgramStateManager,
instance=True,
spec_set=True,
)
mock_create_state_manager = mock.Mock(side_effect=[mock_meta_eval_manager])
mock_create_process_fn = mock.Mock()
manager = evaluation_program_logic.EvaluationManager(
data_source=mock_data_source,
aggregated_metrics_manager=mock_metrics_manager,
create_state_manager_fn=mock_create_state_manager,
create_process_fn=mock_create_process_fn,
cohort_size=10,
duration=datetime.timedelta(milliseconds=10),
)
# Directly set the state, avoid starting asyncio.Task for the resumed evals.
manager._evaluating_training_checkpoints = np.asarray([5, 15]).astype(
np.int32
)
manager._evaluation_start_timestamp_seconds = np.asarray([
datetime.datetime(2022, 10, 28, 5, 15).timestamp(),
datetime.datetime(2022, 10, 28, 5, 25).timestamp(),
]).astype(np.int32)
manager._state_manager._next_version = 1
task1 = asyncio.create_task(manager.record_evaluations_finished(5))
task2 = asyncio.create_task(manager.record_evaluations_finished(15))
finished, _ = await asyncio.wait([task1, task2])
self.assertLen(finished, 2)
# Assert that all evaluations were removed from the state manager without
# version conflicts.
self.assertSequenceEqual(
mock_meta_eval_manager.save.call_args_list,
[
mock.call((_NumpyMatcher([15]), mock.ANY), version=1),
mock.call((_NumpyMatcher([]), _NumpyMatcher([])), version=2),
],
)
self.assertEmpty(manager._pending_tasks)

async def test_resume_previous_evaluations(self):
mock_data_source = mock.create_autospec(
data_source.FederatedDataSource, instance=True, spec_set=True
Expand Down

0 comments on commit 5c84062

Please sign in to comment.