Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update docs of callbacks #25982

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/external-callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ In earlier versions of JAX, there was only one kind of callback available, imple

(The {func}`jax.debug.print` function you used previously is a wrapper around {func}`jax.debug.callback`).

From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.
From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow. All three of them must **not** include any calls back into JAX.

|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |
|-------------------------------------|----|----|----|----|----|----|
Expand Down
13 changes: 8 additions & 5 deletions jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def pure_callback(

``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
The input ``callback`` will be passed JAX arrays placed on a local CPU, and
it should also return JAX arrays on CPU.
it should also return JAX arrays on CPU. The ``callback`` function must not
include any calls back into JAX.

The callback is treated as functionally pure, meaning it has no side-effects
and its output value depends only on its argument values. As a consequence, it
Expand Down Expand Up @@ -382,8 +383,9 @@ def pure_callback(
Args:
callback: function to execute on the host. The callback is assumed to be a pure
function (i.e. one without side-effects): if an impure function is passed, it
may behave in unexpected ways, particularly under transformation. The callable
will be passed PyTrees of arrays as arguments, and should return a PyTree of
may behave in unexpected ways, particularly under transformation.
Furthermore, the callback must not call into JAX. The callable will
be passed PyTrees of arrays as arguments, and should return a PyTree of
arrays that matches ``result_shape_dtypes``.
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
whose structure matches the expected output of the callback function at runtime.
Expand Down Expand Up @@ -622,14 +624,15 @@ def io_callback(
ordered: bool = False,
**kwargs: Any,
):
"""Calls an impure Python callback.
"""Calls an impure Python callback. The callback function must not include any
calls back into JAX.

For more explanation, see `External Callbacks`_.

Args:
callback: function to execute on the host. It is assumed to be an impure function.
If ``callback`` is pure, using :func:`jax.pure_callback` instead may lead to
more efficient execution.
more efficient execution. The ``callback`` must not call into JAX.
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
whose structure matches the expected output of the callback function at runtime.
:class:`jax.ShapeDtypeStruct` is often used to define leaf values.
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,11 @@ def debug_callback(callback: Callable[..., None], *args: Any,
possible while revealing as much about them as possible, such as which parts
of the computation are duplicated or dropped.
Inside of the ``callback`` function there should not be a call back into JAX.
Args:
callback: A Python callable returning None.
callback: A Python callable returning None. The ``callback`` must not call
into JAX.
*args: The positional arguments to the callback.
ordered: A keyword only argument used to indicate whether or not the
staged out computation will enforce ordering of this callback w.r.t.
Expand Down
Loading