Skip to content

Commit

Permalink
Merge pull request jax-ml#24312 from jakevdp:gather-doc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686372450
  • Loading branch information
Google-ML-Automation committed Oct 16, 2024
2 parents 66c6292 + 284ca8b commit 56eea2b
Showing 1 changed file with 76 additions and 8 deletions.
84 changes: 76 additions & 8 deletions jax/_src/lax/slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ def gather(operand: ArrayLike, start_indices: ArrayLike,
Wraps `XLA's Gather operator
<https://www.tensorflow.org/xla/operation_semantics#gather>`_.
The semantics of gather are complicated, and its API might change in the
future. For most use cases, you should prefer `Numpy-style indexing
<https://numpy.org/doc/stable/reference/arrays.indexing.html>`_
(e.g., `x[:, (1,4,7), ...]`), rather than using `gather` directly.
:func:`gather` is a low-level operator with complicated semantics, and most JAX
users will never need to call it directly. Instead, you should prefer using
`Numpy-style indexing`_, and/or :func:`jax.numpy.ndarray.at`, perhaps in combination
with :func:`jax.vmap`.
Args:
operand: an array from which slices should be taken
Expand All @@ -340,6 +340,42 @@ def gather(operand: ArrayLike, start_indices: ArrayLike,
Returns:
An array containing the gather output.
Examples:
As mentioned above, you should basically never use :func:`gather` directly,
and instead use NumPy-style indexing expressions to gather values from
arrays.
For example, here is how you can extract values at particular indices using
straightforward indexing semantics, which will lower to XLA's Gather operator:
>>> import jax.numpy as jnp
>>> x = jnp.array([10, 11, 12])
>>> indices = jnp.array([0, 1, 1, 2, 2, 2])
>>> x[indices]
Array([10, 11, 11, 12, 12, 12], dtype=int32)
For control over settings like ``indices_are_sorted``, ``unique_indices``, ``mode``,
and ``fill_value``, you can use the :attr:`jax.numpy.ndarray.at` syntax:
>>> x.at[indices].get(indices_are_sorted=True, mode="promise_in_bounds")
Array([10, 11, 11, 12, 12, 12], dtype=int32)
By comparison, here is the equivalent function call using :func:`gather` directly,
which is not something typical users should ever need to do:
>>> from jax import lax
>>> lax.gather(x, indices[:, None], slice_sizes=(1,),
... dimension_numbers=lax.GatherDimensionNumbers(
... offset_dims=(),
... collapsed_slice_dims=(0,),
... start_index_map=(0,)),
... indices_are_sorted=True,
... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS)
Array([10, 11, 11, 12, 12, 12], dtype=int32)
.. _Numpy-style indexing: https://numpy.org/doc/stable/reference/arrays.indexing.html
"""
if mode is None:
mode = GatherScatterMode.PROMISE_IN_BOUNDS
Expand Down Expand Up @@ -737,10 +773,9 @@ def scatter(
If multiple updates are performed to the same index of operand, they may be
applied in any order.
The semantics of scatter are complicated, and its API might change in the
future. For most use cases, you should prefer the
:attr:`jax.numpy.ndarray.at` property on JAX arrays which uses
the familiar NumPy indexing syntax.
:func:`scatter` is a low-level operator with complicated semantics, and most
JAX users will never need to call it directly. Instead, you should prefer using
:func:`jax.numpy.ndarray.at` for more familiary NumPy-style indexing syntax.
Args:
operand: an array to which the scatter should be applied
Expand All @@ -764,6 +799,39 @@ def scatter(
Returns:
An array containing the sum of `operand` and the scattered updates.
Examples:
As mentioned above, you should basically never use :func:`scatter` directly,
and instead perform scatter-style operations using NumPy-style indexing
expressions via :attr:`jax.numpy.ndarray.at`.
Here is and example of updating entries in an array using :attr:`jax.numpy.ndarray.at`,
which lowers to an XLA Scatter operation:
>>> x = jnp.zeros(5)
>>> indices = jnp.array([1, 2, 4])
>>> values = jnp.array([2.0, 3.0, 4.0])
>>> x.at[indices].set(values)
Array([0., 2., 3., 0., 4.], dtype=float32)
This syntax also supports several of the optional arguments to :func:`scatter`,
for example:
>>> x.at[indices].set(values, indices_are_sorted=True, mode='promise_in_bounds')
Array([0., 2., 3., 0., 4.], dtype=float32)
By comparison, here is the equivalent function call using :func:`scatter` directly,
which is not something typical users should ever need to do:
>>> lax.scatter(x, indices[:, None], values,
... dimension_numbers=lax.ScatterDimensionNumbers(
... update_window_dims=(),
... inserted_window_dims=(0,),
... scatter_dims_to_operand_dims=(0,)),
... indices_are_sorted=True,
... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS)
Array([0., 2., 3., 0., 4.], dtype=float32)
"""
return scatter_p.bind(
operand, scatter_indices, updates, update_jaxpr=None,
Expand Down

0 comments on commit 56eea2b

Please sign in to comment.