Skip to content

Commit

Permalink
Update docs of callbacks
Browse files Browse the repository at this point in the history
Callback functions should not call into JAX. This information was
missing in the docs of the callbacks. This commit adds this information
to the docs.

See: #25861, #24255
  • Loading branch information
roth-jakob committed Jan 19, 2025
1 parent aed9c6f commit 4f8699c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/external-callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ In earlier versions of JAX, there was only one kind of callback available, imple

(The {func}`jax.debug.print` function you used previously is a wrapper around {func}`jax.debug.callback`).

From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.
From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow. All three of them must **not** include any calls back into JAX.

|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |
|-------------------------------------|----|----|----|----|----|----|
Expand Down
13 changes: 8 additions & 5 deletions jax/_src/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ def pure_callback(
``pure_callback`` enables calling a Python function in JIT-ed JAX functions.
The input ``callback`` will be passed JAX arrays placed on a local CPU, and
it should also return JAX arrays on CPU.
it should also return JAX arrays on CPU. The ``callback`` function must not
include any calls back into JAX.
The callback is treated as functionally pure, meaning it has no side-effects
and its output value depends only on its argument values. As a consequence, it
Expand Down Expand Up @@ -382,8 +383,9 @@ def pure_callback(
Args:
callback: function to execute on the host. The callback is assumed to be a pure
function (i.e. one without side-effects): if an impure function is passed, it
may behave in unexpected ways, particularly under transformation. The callable
will be passed PyTrees of arrays as arguments, and should return a PyTree of
may behave in unexpected ways, particularly under transformation.
Furthermore, the callback must not call into JAX. The callable will
be passed PyTrees of arrays as arguments, and should return a PyTree of
arrays that matches ``result_shape_dtypes``.
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
whose structure matches the expected output of the callback function at runtime.
Expand Down Expand Up @@ -622,14 +624,15 @@ def io_callback(
ordered: bool = False,
**kwargs: Any,
):
"""Calls an impure Python callback.
"""Calls an impure Python callback. The callback function must not include any
calls back into JAX.
For more explanation, see `External Callbacks`_.
Args:
callback: function to execute on the host. It is assumed to be an impure function.
If ``callback`` is pure, using :func:`jax.pure_callback` instead may lead to
more efficient execution.
more efficient execution. The ``callback`` must not call into JAX.
result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes,
whose structure matches the expected output of the callback function at runtime.
:class:`jax.ShapeDtypeStruct` is often used to define leaf values.
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,11 @@ def debug_callback(callback: Callable[..., None], *args: Any,
possible while revealing as much about them as possible, such as which parts
of the computation are duplicated or dropped.
Inside of the ``callback`` function there should not be a call back into JAX.
Args:
callback: A Python callable returning None.
callback: A Python callable returning None. The ``callback`` must not call
into JAX.
*args: The positional arguments to the callback.
ordered: A keyword only argument used to indicate whether or not the
staged out computation will enforce ordering of this callback w.r.t.
Expand Down

0 comments on commit 4f8699c

Please sign in to comment.