From 4f8699c8a1b7f5f2ec017f6ffd226e52e3e4654d Mon Sep 17 00:00:00 2001 From: Jakob Roth Date: Sun, 19 Jan 2025 11:26:25 +0100 Subject: [PATCH] Update docs of callbacks Callback functions should not call into JAX. This information was missing in the docs of the callbacks. This commit adds this information to the docs. See: #25861, #24255 --- docs/external-callbacks.md | 2 +- jax/_src/callback.py | 13 ++++++++----- jax/_src/debugging.py | 5 ++++- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/docs/external-callbacks.md b/docs/external-callbacks.md index c404f320fca7..6a64c23e188b 100644 --- a/docs/external-callbacks.md +++ b/docs/external-callbacks.md @@ -71,7 +71,7 @@ In earlier versions of JAX, there was only one kind of callback available, imple (The {func}`jax.debug.print` function you used previously is a wrapper around {func}`jax.debug.callback`). -From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow. +From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow. All three of them must **not** include any calls back into JAX. |callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution | |-------------------------------------|----|----|----|----|----|----| diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 098273d27240..73c9ec8f231c 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -350,7 +350,8 @@ def pure_callback( ``pure_callback`` enables calling a Python function in JIT-ed JAX functions. The input ``callback`` will be passed JAX arrays placed on a local CPU, and - it should also return JAX arrays on CPU. + it should also return JAX arrays on CPU. The ``callback`` function must not + include any calls back into JAX. The callback is treated as functionally pure, meaning it has no side-effects and its output value depends only on its argument values. As a consequence, it @@ -382,8 +383,9 @@ def pure_callback( Args: callback: function to execute on the host. The callback is assumed to be a pure function (i.e. one without side-effects): if an impure function is passed, it - may behave in unexpected ways, particularly under transformation. The callable - will be passed PyTrees of arrays as arguments, and should return a PyTree of + may behave in unexpected ways, particularly under transformation. + Furthermore, the callback must not call into JAX. The callable will + be passed PyTrees of arrays as arguments, and should return a PyTree of arrays that matches ``result_shape_dtypes``. result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes, whose structure matches the expected output of the callback function at runtime. @@ -622,14 +624,15 @@ def io_callback( ordered: bool = False, **kwargs: Any, ): - """Calls an impure Python callback. + """Calls an impure Python callback. The callback function must not include any + calls back into JAX. For more explanation, see `External Callbacks`_. Args: callback: function to execute on the host. It is assumed to be an impure function. If ``callback`` is pure, using :func:`jax.pure_callback` instead may lead to - more efficient execution. + more efficient execution. The ``callback`` must not call into JAX. result_shape_dtypes: pytree whose leaves have ``shape`` and ``dtype`` attributes, whose structure matches the expected output of the callback function at runtime. :class:`jax.ShapeDtypeStruct` is often used to define leaf values. diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 7685ac2bf38e..fab9ab296e36 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -249,8 +249,11 @@ def debug_callback(callback: Callable[..., None], *args: Any, possible while revealing as much about them as possible, such as which parts of the computation are duplicated or dropped. + Inside of the ``callback`` function there should not be a call back into JAX. + Args: - callback: A Python callable returning None. + callback: A Python callable returning None. The ``callback`` must not call + into JAX. *args: The positional arguments to the callback. ordered: A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this callback w.r.t.