Skip to content

Commit

Permalink
Make committed a public property of jax.Array.
Browse files Browse the repository at this point in the history
Why?

Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.

PiperOrigin-RevId: 686329828
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Oct 16, 2024
1 parent ad99ab1 commit 66c6292
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Array properties and methods
Array.choose
Array.clip
Array.compress
Array.committed
Array.conj
Array.conjugate
Array.copy
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ def device(self):
def weak_type(self):
return self.aval.weak_type

@property
def committed(self) -> bool:
return self._committed

def __str__(self):
return str(self._value)

Expand Down
27 changes: 27 additions & 0 deletions jax/_src/basearray.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,33 @@ def is_fully_replicated(self) -> bool:
def sharding(self) -> Sharding:
"""The sharding for the array."""

@property
@abc.abstractmethod
def committed(self) -> bool:
"""Whether the array is committed or not.
An array is committed when it is explicitly placed on device(s) via JAX
APIs. For example, `jax.device_put(np.arange(8), jax.devices()[0])` is
committed to device 0. While `jax.device_put(np.arange(8))` is uncommitted
and will be placed on the default device.
Computations involving some committed inputs will happen on the committed
device(s) and the result will be committed on the same device(s).
Invoking an operation on arguments that are committed to different device(s)
will raise an error.
For example:
```
a = jax.device_put(np.arange(8), jax.devices()[0])
b = jax.device_put(np.arange(8), jax.devices()[1])
a + b # Raises an error
```
See https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
for more information.
"""

@property
@abc.abstractmethod
def device(self) -> Device | Sharding:
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/basearray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ class Array(abc.ABC):
@property
def sharding(self) -> Sharding: ...
@property
def committed(self) -> bool: ...
@property
def device(self) -> Device | Sharding: ...
@property
def addressable_shards(self) -> Sequence[Shard]: ...
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,13 @@ def sharding(self):
f"The 'sharding' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")

@property
def committed(self):
raise ConcretizationTypeError(
self,
f"The 'committed' attribute is not available on {self._error_repr()}."
f"{self._origin_msg()}")

@property
def device(self):
# This attribute is part of the jax.Array API, but only defined on concrete arrays.
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/earray.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def sharding(self):
phys_sharding = self._data.sharding
return sharding_impls.logical_sharding(self.aval, phys_sharding)

@property
def committed(self):
return self._data.committed

@property
def device(self):
if isinstance(self._data.sharding, sharding_impls.SingleDeviceSharding):
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ def global_shards(self) -> list[Shard]:
def sharding(self):
return logical_sharding(self.aval, self._base_array.sharding)

@property
def committed(self):
return self._base_array.committed

def _is_scalar(self):
base_ndim = len(self._impl.key_shape)
return self._base_array.ndim == base_ndim
Expand Down

0 comments on commit 66c6292

Please sign in to comment.