diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 4e76d7c30944..40a04ff11d2c 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -310,10 +310,10 @@ def gather(operand: ArrayLike, start_indices: ArrayLike, Wraps `XLA's Gather operator `_. - The semantics of gather are complicated, and its API might change in the - future. For most use cases, you should prefer `Numpy-style indexing - `_ - (e.g., `x[:, (1,4,7), ...]`), rather than using `gather` directly. + :func:`gather` is a low-level operator with complicated semantics, and most JAX + users will never need to call it directly. Instead, you should prefer using + `Numpy-style indexing`_, and/or :func:`jax.numpy.ndarray.at`, perhaps in combination + with :func:`jax.vmap`. Args: operand: an array from which slices should be taken @@ -340,6 +340,42 @@ def gather(operand: ArrayLike, start_indices: ArrayLike, Returns: An array containing the gather output. + + Examples: + As mentioned above, you should basically never use :func:`gather` directly, + and instead use NumPy-style indexing expressions to gather values from + arrays. + + For example, here is how you can extract values at particular indices using + straightforward indexing semantics, which will lower to XLA's Gather operator: + + >>> import jax.numpy as jnp + >>> x = jnp.array([10, 11, 12]) + >>> indices = jnp.array([0, 1, 1, 2, 2, 2]) + + >>> x[indices] + Array([10, 11, 11, 12, 12, 12], dtype=int32) + + For control over settings like ``indices_are_sorted``, ``unique_indices``, ``mode``, + and ``fill_value``, you can use the :attr:`jax.numpy.ndarray.at` syntax: + + >>> x.at[indices].get(indices_are_sorted=True, mode="promise_in_bounds") + Array([10, 11, 11, 12, 12, 12], dtype=int32) + + By comparison, here is the equivalent function call using :func:`gather` directly, + which is not something typical users should ever need to do: + + >>> from jax import lax + >>> lax.gather(x, indices[:, None], slice_sizes=(1,), + ... dimension_numbers=lax.GatherDimensionNumbers( + ... offset_dims=(), + ... collapsed_slice_dims=(0,), + ... start_index_map=(0,)), + ... indices_are_sorted=True, + ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS) + Array([10, 11, 11, 12, 12, 12], dtype=int32) + + .. _Numpy-style indexing: https://numpy.org/doc/stable/reference/arrays.indexing.html """ if mode is None: mode = GatherScatterMode.PROMISE_IN_BOUNDS @@ -737,10 +773,9 @@ def scatter( If multiple updates are performed to the same index of operand, they may be applied in any order. - The semantics of scatter are complicated, and its API might change in the - future. For most use cases, you should prefer the - :attr:`jax.numpy.ndarray.at` property on JAX arrays which uses - the familiar NumPy indexing syntax. + :func:`scatter` is a low-level operator with complicated semantics, and most + JAX users will never need to call it directly. Instead, you should prefer using + :func:`jax.numpy.ndarray.at` for more familiary NumPy-style indexing syntax. Args: operand: an array to which the scatter should be applied @@ -764,6 +799,39 @@ def scatter( Returns: An array containing the sum of `operand` and the scattered updates. + + Examples: + As mentioned above, you should basically never use :func:`scatter` directly, + and instead perform scatter-style operations using NumPy-style indexing + expressions via :attr:`jax.numpy.ndarray.at`. + + Here is and example of updating entries in an array using :attr:`jax.numpy.ndarray.at`, + which lowers to an XLA Scatter operation: + + >>> x = jnp.zeros(5) + >>> indices = jnp.array([1, 2, 4]) + >>> values = jnp.array([2.0, 3.0, 4.0]) + + >>> x.at[indices].set(values) + Array([0., 2., 3., 0., 4.], dtype=float32) + + This syntax also supports several of the optional arguments to :func:`scatter`, + for example: + + >>> x.at[indices].set(values, indices_are_sorted=True, mode='promise_in_bounds') + Array([0., 2., 3., 0., 4.], dtype=float32) + + By comparison, here is the equivalent function call using :func:`scatter` directly, + which is not something typical users should ever need to do: + + >>> lax.scatter(x, indices[:, None], values, + ... dimension_numbers=lax.ScatterDimensionNumbers( + ... update_window_dims=(), + ... inserted_window_dims=(0,), + ... scatter_dims_to_operand_dims=(0,)), + ... indices_are_sorted=True, + ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS) + Array([0., 2., 3., 0., 4.], dtype=float32) """ return scatter_p.bind( operand, scatter_indices, updates, update_jaxpr=None,