Skip to content

Commit

Permalink
Added jax.experimental.multihost_utils.live_devices API.
Browse files Browse the repository at this point in the history
This API is intended to enable fault tolerant multi-controller JAX programs.

PiperOrigin-RevId: 703178517
  • Loading branch information
mwhittaker authored and Google-ML-Automation committed Jan 21, 2025
1 parent 051861b commit df5be8b
Showing 1 changed file with 79 additions and 0 deletions.
79 changes: 79 additions & 0 deletions jax/experimental/multihost_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from jax.sharding import PartitionSpec as P
from jax._src import distributed
from jax._src.util import safe_zip
from jax._src import xla_bridge
from jax._src.lib import xla_client
import numpy as np


Expand Down Expand Up @@ -470,3 +472,80 @@ def gtl_abstract_eval(arr, *, global_mesh, pspec):
def _gtl_lowering(ctx, x, *, global_mesh, pspec):
return [x]
mlir.register_lowering(global_array_to_host_local_array_p, _gtl_lowering)


def live_devices(devices: list[xla_client.Device]) -> list[xla_client.Device]:
"""Returns the subset of the provided devices that are live and healthy.
`live_devices` is a low-level fault tolerance primitive that can be used to
implement fault tolerant multi-process JAX programs.
Barrier Semantics
It's important that every process agrees on which devices are live to avoid
the processes' behavior from diverging. For example, imagine a set of
processes trying to run an AllGather, but they all disagree on which devices
should be participating in the AllGather. This is buggy.
To ensure that every process agrees on the set of live devices, the
`live_devices` function has barrier-like semantics. Consider an invocation
`live_devices(devices)` where `devices` includes devices across a set of
processes P. The invocation acts as a barrier, waiting for every process in P
to call `live_devices(devices)`. Afterwards, `live_devices` returns the same
set of live devices `A` to all the processes in P. This ensures that every
process agrees on the set of live devices.
Note that `live_devices` does not actually act as a barrir for *every*
process in P because some processes in P might have failed. Instead, the
`live_devices` function waits only for the processes with a device in the
returned set of live devices A.
An Example
Imagine we have four processes, each with two devices:
Process A: Devices 1 and 2
Process B: Devices 3 and 4
Process C: Devices 5 and 6
Process D: Devices 7 and 8
Further imagine that process D fails and that every process calls
`live_devices(jax.devices())`. The invocation returns devices 1, 2, 3, 4, 5,
and 6. Because these devices are hosted by processes A, B, and C, the call to
`live_devices` acts as a barrier across processes A, B, and C. Process D,
which failed, is ignored.
Args:
devices: A list of devices. Note that the provided devices must include at
least one local device.
Returns:
The subset of the provided devices that are live and healthy.
Raises:
RuntimeError: If the distributed runtime was not initialized.
ValueError: If no local devices are provided.
"""
client = distributed.global_state.client
if client is None:
raise RuntimeError('Distributed JAX not initialized.')

if not devices:
# TODO(mwhittaker): Make devices optional. If it's not provided, use
# jax.devices() as a default.
raise ValueError('No devices provided.')

process_ids = {d.process_index for d in devices}
if xla_bridge.process_index() not in process_ids:
# A process can only participate in an live_devices call if it hosts some
# of the provided devices.
raise ValueError('Provided devices do not have any local devices.')

if len(process_ids) == 1:
# If the provided devices are hosted by a single process (this one), then we
# don't have to perform any distributed computation. We know our local
# devices are all live.
return devices

live_process_ids = client.get_live_nodes(list(process_ids))
return [d for d in devices if d.process_index in live_process_ids]

0 comments on commit df5be8b

Please sign in to comment.