diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 560b5c5cc2f4..541f74b9c7c7 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -34,6 +34,8 @@ from jax.sharding import PartitionSpec as P from jax._src import distributed from jax._src.util import safe_zip +from jax._src import xla_bridge +from jax._src.lib import xla_client import numpy as np @@ -470,3 +472,80 @@ def gtl_abstract_eval(arr, *, global_mesh, pspec): def _gtl_lowering(ctx, x, *, global_mesh, pspec): return [x] mlir.register_lowering(global_array_to_host_local_array_p, _gtl_lowering) + + +def live_devices(devices: list[xla_client.Device]) -> list[xla_client.Device]: + """Returns the subset of the provided devices that are live and healthy. + + `live_devices` is a low-level fault tolerance primitive that can be used to + implement fault tolerant multi-process JAX programs. + + Barrier Semantics + + It's important that every process agrees on which devices are live to avoid + the processes' behavior from diverging. For example, imagine a set of + processes trying to run an AllGather, but they all disagree on which devices + should be participating in the AllGather. This is buggy. + + To ensure that every process agrees on the set of live devices, the + `live_devices` function has barrier-like semantics. Consider an invocation + `live_devices(devices)` where `devices` includes devices across a set of + processes P. The invocation acts as a barrier, waiting for every process in P + to call `live_devices(devices)`. Afterwards, `live_devices` returns the same + set of live devices `A` to all the processes in P. This ensures that every + process agrees on the set of live devices. + + Note that `live_devices` does not actually act as a barrir for *every* + process in P because some processes in P might have failed. Instead, the + `live_devices` function waits only for the processes with a device in the + returned set of live devices A. + + An Example + + Imagine we have four processes, each with two devices: + + Process A: Devices 1 and 2 + Process B: Devices 3 and 4 + Process C: Devices 5 and 6 + Process D: Devices 7 and 8 + + Further imagine that process D fails and that every process calls + `live_devices(jax.devices())`. The invocation returns devices 1, 2, 3, 4, 5, + and 6. Because these devices are hosted by processes A, B, and C, the call to + `live_devices` acts as a barrier across processes A, B, and C. Process D, + which failed, is ignored. + + Args: + devices: A list of devices. Note that the provided devices must include at + least one local device. + + Returns: + The subset of the provided devices that are live and healthy. + + Raises: + RuntimeError: If the distributed runtime was not initialized. + ValueError: If no local devices are provided. + """ + client = distributed.global_state.client + if client is None: + raise RuntimeError('Distributed JAX not initialized.') + + if not devices: + # TODO(mwhittaker): Make devices optional. If it's not provided, use + # jax.devices() as a default. + raise ValueError('No devices provided.') + + process_ids = {d.process_index for d in devices} + if xla_bridge.process_index() not in process_ids: + # A process can only participate in an live_devices call if it hosts some + # of the provided devices. + raise ValueError('Provided devices do not have any local devices.') + + if len(process_ids) == 1: + # If the provided devices are hosted by a single process (this one), then we + # don't have to perform any distributed computation. We know our local + # devices are all live. + return devices + + live_process_ids = client.get_live_nodes(list(process_ids)) + return [d for d in devices if d.process_index in live_process_ids]