diff --git a/CHANGELOG.md b/CHANGELOG.md index 02dbc259d9e0..074ce841805d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * `jax.lib.xla_client.Device` is deprecated; use `jax.Device` instead. * `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use `jax.errors.JaxRuntimeError` instead. + * The default behavior of {func}`jax.pure_callback` and + {func}`jax.extend.ffi.ffi_call` under `vmap` has been deprecated and so has + the `vectorized` parameter to those functions. The `vmap_method` parameter + should be used instead for better defined behavior. See the discussion in + {jax-issue}`#23881` for more details. * Deletion: * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index a8cd5219d4b5..8b5d5ea6c907 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -303,9 +303,9 @@ " # type (which corresponds to numpy's `float32` type), and it must be a\n", " # static parameter (i.e. not a JAX array).\n", " eps=np.float32(eps),\n", - " # The `vectorized` parameter controls this function's behavior under `vmap`\n", + " # The `vmap_method` parameter controls this function's behavior under `vmap`\n", " # as discussed below.\n", - " vectorized=True,\n", + " vmap_method=\"broadcast_fullrank\",\n", " )\n", "\n", "\n", @@ -325,7 +325,7 @@ "Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n", "Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n", "\n", - "The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n", + "The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n", "\n", "```{tip}\n", "If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.\n", @@ -336,19 +336,29 @@ "(ffi-call-vmap)=\n", "### Batching with `vmap`\n", "\n", - "All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.\n", - "By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n", - "This default implementation is general purpose, but it doesn't parallelize very well.\n", - "But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.\n", + "{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n", + "The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.\n", "\n", - "The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.\n", + "The simplest `vmap_method` is `\"sequential\"`.\n", + "In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n", + "This implementation is general purpose, but it doesn't parallelize very well.\n", + "Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"broadcast\"` or `\"broadcast_fullrank\"` methods can be used to expose a better implementation.\n", + "\n", + "In this case, since we only have one input argument, `\"broadcast\"` and `\"broadcast_fullrank\"` actually have the same behavior.\n", + "The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.\n", "Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n", "\n", "```python\n", "ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])\n", "```\n", "\n", - "Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:" + "```{tip}\n", + "Note that things get a bit more complicated when we have multiple input arguments.\n", + "For simplicity, we will use the `\"broadcast_fullrank\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"broadcast\"` method.\n", + "The documentation for {func}`~jax.pure_callback` includes some examples of this\n", + "```\n", + "\n", + "Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_fullrank\"` out of the box:" ] }, { @@ -380,7 +390,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:" + "Using `vmap_method=\"sequential\"`, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:" ] }, { @@ -389,24 +399,24 @@ "metadata": {}, "outputs": [], "source": [ - "def rms_norm_not_vectorized(x, eps=1e-5):\n", + "def rms_norm_sequential(x, eps=1e-5):\n", " return jex.ffi.ffi_call(\n", " \"rms_norm\",\n", " jax.ShapeDtypeStruct(x.shape, x.dtype),\n", " x,\n", " eps=np.float32(eps),\n", - " vectorized=False, # This is the default behavior\n", + " vmap_method=\"sequential\",\n", " )\n", "\n", "\n", - "jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)" + "jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)." + "If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)." ] }, { @@ -454,7 +464,7 @@ " ),\n", " x,\n", " eps=np.float32(eps),\n", - " vectorized=True,\n", + " vmap_method=\"broadcast_fullrank\",\n", " )\n", " return y, (res, x)\n", "\n", @@ -471,7 +481,7 @@ " res,\n", " x,\n", " ct,\n", - " vectorized=True,\n", + " vmap_method=\"broadcast_fullrank\",\n", " ),\n", " )\n", "\n", @@ -561,7 +571,7 @@ " out_type,\n", " x,\n", " eps=np.float32(eps),\n", - " vectorized=True,\n", + " vmap_method=\"broadcast_fullrank\",\n", " )\n", "\n", " return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n", diff --git a/docs/ffi.md b/docs/ffi.md index cc3863ed99b2..b3d1dcf46364 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -264,9 +264,9 @@ def rms_norm(x, eps=1e-5): # type (which corresponds to numpy's `float32` type), and it must be a # static parameter (i.e. not a JAX array). eps=np.float32(eps), - # The `vectorized` parameter controls this function's behavior under `vmap` + # The `vmap_method` parameter controls this function's behavior under `vmap` # as discussed below. - vectorized=True, + vmap_method="broadcast_fullrank", ) @@ -282,7 +282,7 @@ It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_cal Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`. Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments. -The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next. +The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next. ```{tip} If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`. @@ -293,19 +293,29 @@ One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support so (ffi-call-vmap)= ### Batching with `vmap` -All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient. -By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body. -This default implementation is general purpose, but it doesn't parallelize very well. -But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation. +{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter. +The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`. -The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes. +The simplest `vmap_method` is `"sequential"`. +In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body. +This implementation is general purpose, but it doesn't parallelize very well. +Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"broadcast"` or `"broadcast_fullrank"` methods can be used to expose a better implementation. + +In this case, since we only have one input argument, `"broadcast"` and `"broadcast_fullrank"` actually have the same behavior. +The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions. Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly: ```python ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs]) ``` -Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box: +```{tip} +Note that things get a bit more complicated when we have multiple input arguments. +For simplicity, we will use the `"broadcast_fullrank"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"broadcast"` method. +The documentation for {func}`~jax.pure_callback` includes some examples of this +``` + +Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_fullrank"` out of the box: ```{code-cell} ipython3 np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5) @@ -317,23 +327,23 @@ We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms jax.make_jaxpr(jax.vmap(rms_norm))(x) ``` -If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body: +Using `vmap_method="sequential"`, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body: ```{code-cell} ipython3 -def rms_norm_not_vectorized(x, eps=1e-5): +def rms_norm_sequential(x, eps=1e-5): return jex.ffi.ffi_call( "rms_norm", jax.ShapeDtypeStruct(x.shape, x.dtype), x, eps=np.float32(eps), - vectorized=False, # This is the default behavior + vmap_method="sequential", ) -jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x) +jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x) ``` -If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues). +If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues). +++ @@ -372,7 +382,7 @@ def rms_norm_fwd(x, eps=1e-5): ), x, eps=np.float32(eps), - vectorized=True, + vmap_method="broadcast_fullrank", ) return y, (res, x) @@ -389,7 +399,7 @@ def rms_norm_bwd(eps, res, ct): res, x, ct, - vectorized=True, + vmap_method="broadcast_fullrank", ), ) @@ -469,7 +479,7 @@ def rms_norm_cross_platform(x, eps=1e-5): out_type, x, eps=np.float32(eps), - vectorized=True, + vmap_method="broadcast_fullrank", ) return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda")) diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.py b/examples/ffi/src/jax_ffi_example/rms_norm.py index 4e0ed1d195b4..d063f1cf319c 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.py +++ b/examples/ffi/src/jax_ffi_example/rms_norm.py @@ -60,8 +60,7 @@ def rms_norm(x, eps=1e-5): # type (which corresponds to numpy's `float32` type), and it must be a # static parameter (i.e. not a JAX array). eps=np.float32(eps), - # The `vectorized` parameter controls this function's behavior under `vmap`. - vectorized=True, + vmap_method="broadcast_fullrank", ) @@ -74,7 +73,7 @@ def rms_norm_fwd(x, eps=1e-5): ), x, eps=np.float32(eps), - vectorized=True, + vmap_method="broadcast_fullrank", ) return y, (res, x) @@ -91,7 +90,7 @@ def rms_norm_bwd(eps, res, ct): res, x, ct, - vectorized=True, + vmap_method="broadcast_fullrank", ), ) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 3a18dcdfa2ac..5fcd2c4c4260 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -22,6 +22,7 @@ import jax from jax._src import core +from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import effects @@ -31,13 +32,18 @@ from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.lax import lax from jax._src.lax.control_flow.loops import map as lax_map from jax._src.lib import xla_client as xc from jax._src.sharding_impls import SingleDeviceSharding +from jax._src.typing import DeprecatedArg import numpy as np logger = logging.getLogger(__name__) +# TODO(dfm): Remove after 6 months. +# Added Oct 1, 2024 +deprecations.register("jax-callback-vectorized") # `pure_callback_p` is the main primitive for staging out Python pure callbacks. pure_callback_p = core.Primitive("pure_callback") @@ -45,6 +51,7 @@ dispatch.prim_requires_devices_during_lowering.add(pure_callback_p) map, unsafe_map = util.safe_map, map +zip, unsafe_zip = util.safe_zip, zip @dataclasses.dataclass(frozen=True) @@ -69,9 +76,10 @@ def pure_callback_impl( result_avals, callback: _FlatCallback, sharding: SingleDeviceSharding | None, - vectorized: bool, + vectorized: bool | DeprecatedArg, + vmap_method: str | None, ): - del sharding, vectorized, result_avals + del sharding, vectorized, vmap_method, result_avals try: cpu_device, *_ = jax.local_devices(backend="cpu") except RuntimeError as e: @@ -99,9 +107,10 @@ def pure_callback_abstract_eval( callback: _FlatCallback, result_avals, sharding: SingleDeviceSharding | None, - vectorized: bool, + vectorized: bool | DeprecatedArg, + vmap_method: str | None, ): - del avals, callback, sharding, vectorized + del avals, callback, sharding, vectorized, vmap_method return result_avals @@ -129,25 +138,51 @@ def callback_batching_rule( args, dims, *, - vectorized: bool, + vectorized: bool | None | DeprecatedArg, + vmap_method: str | None, result_avals: Sequence[core.ShapedArray], **kwargs: Any, ): - axis_size = next(a.shape[d] for a, d in zip(args, dims) - if d is not batching.not_mapped) + if isinstance(vectorized, DeprecatedArg) and vmap_method is None: + deprecations.warn( + "jax-callback-vectorized", + f"The default behavior of {prim.name} under vmap will soon " + "change. Currently, the default behavior is to generate a sequential " + "vmap (i.e. a loop), but in the future the default will be to raise " + "an error. To keep the current default, set vmap_method='sequential'.", + stacklevel=6) + vmap_method = "sequential" + + axis_size, = {a.shape[d] for a, d in zip(args, dims) + if d is not batching.not_mapped} new_args = [arg if dim is batching.not_mapped else batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)] - if vectorized: - result_avals = tuple( - core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore - for aval in result_avals) + batched_result_avals = tuple( + core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) + for aval in result_avals) + if vmap_method == "legacy_vectorized": + # This method is kept to support the behavior that was previously exposed + # when using `vectorized=True`. outvals = prim.bind( *new_args, vectorized=vectorized, - result_avals=result_avals, + vmap_method=vmap_method, + result_avals=batched_result_avals, **kwargs, ) - else: + elif vmap_method == "broadcast" or vmap_method == "broadcast_fullrank": + size = axis_size if vmap_method == "broadcast_fullrank" else 1 + bcast_args = [ + lax.broadcast(x, (size,)) if d is batching.not_mapped else x + for x, d in zip(new_args, dims)] + outvals = prim.bind( + *bcast_args, + vectorized=vectorized, + vmap_method=vmap_method, + result_avals=batched_result_avals, + **kwargs, + ) + elif vmap_method == "sequential": is_batched = [d is not batching.not_mapped for d in dims] unbatched_args, batched_args = util.partition_list(is_batched, new_args) def _batch_fun(batched_args): @@ -156,9 +191,15 @@ def _batch_fun(batched_args): *merged_args, result_avals=result_avals, vectorized=vectorized, + vmap_method=vmap_method, **kwargs, ) outvals = lax_map(_batch_fun, batched_args) + else: + raise NotImplementedError( + f"vmap is only supported for the {prim.name} primitive when vmap_method " + "is one of 'sequential', 'broadcast', 'broadcast_fullrank', or " + "'legacy_vectorized'.") return tuple(outvals), (0,) * len(outvals) @@ -261,7 +302,8 @@ def pure_callback( result_shape_dtypes: Any, *args: Any, sharding: SingleDeviceSharding | None = None, - vectorized: bool = False, + vectorized: bool | None | DeprecatedArg = DeprecatedArg(), + vmap_method: str | None = None, **kwargs: Any, ): """Calls a pure Python callback. Works under :func:`jit`/:func:`~vmap`/etc. @@ -279,17 +321,25 @@ def pure_callback( `jit`-decorated function has no data dependence on its value. Pure callbacks may also be reordered if data-dependence allows. - When `vmap`-ed the behavior will depend on the value of the - ``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback - is assumed to obey - ``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``. - Therefore, the callback will be called directly on batched inputs (where the - batch axes are the leading dimensions). Additionally, the callbacks should - return outputs that have corresponding leading batch axes. If not vectorized - ``callback`` will be mapped sequentially across the batched axis. - For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free - to set ``vectorized=True`` because the ``np.matmul`` function handles - arbitrary leading batch dimensions. + When `vmap`-ed the behavior will depend on the value of the ``vmap_method``. + + * Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method`` + is deprecated and it will eventually raise ``NotImplementedError``. + * ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over + the batched arugments, calling ``callback`` once for each batch element. + * ``vmap_method="broadcast"`` calls ``callback`` with new axes of size ``1`` + added as the leading dimension unbatched inputs. + * ``vmap_method="broadcast_fullrank"`` behaves like ``broadcast``, but the + inputs are tiled to the expected batched shape. + + If necessary, the legacy behavior provided by the deprecated + ``vectorized=True`` argument can be recovered using + ``vmap_method="legacy_vectorized"``. + + The current default behavior is to use ``vmap_method="sequential"`` when + not specified, but this behavior is deprecated, and in the future, the + default will be to raise a ``NotImplementedError`` unless ``vmap_method`` is + explicitly specified. Args: callback: function to execute on the host. The callback is assumed to be a pure @@ -303,8 +353,8 @@ def pure_callback( *args: arguments to be passed to the callback function sharding: optional sharding that specifies the device from which the callback should be invoked. - vectorized: boolean specifying whether the callback function can operate in a - vectorized manner. + vmap_method: string specifying how the callback transforms under + :func:`~jax.vmap` as described above. **kwargs: keyword arguments to be passed to the callback function Returns: @@ -316,8 +366,62 @@ def pure_callback( - :func:`jax.debug.callback`: callback designed for general-purpose debugging. - :func:`jax.debug.print`: callback designed for printing. + Examples: + The behavior of ``pure_callback`` under :func:`~jax.vmap` is controlled by + the ``vmap_method`` argument as described above. It is useful to consider + some explicit examples that demonstrate the semantics. For example, + consider the following function: + + >>> def callback(x, y): + ... print(jnp.shape(x), jnp.shape(y)) + ... return x + y + + >>> def fun(x, y, *, vmap_method): + ... shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y)) + ... dtype = jnp.result_type(x, y) + ... out_type = jax.ShapeDtypeStruct(shape, dtype) + ... return jax.pure_callback(callback, out_type, x, y, + ... vmap_method=vmap_method) + + Calling this with ``vmap_method="broadcast"`` adds a new axis of size ``1`` + to ``y``: + + >>> from functools import partial + >>> x = jnp.arange(4) + >>> y = 1.0 + >>> jax.vmap(partial(fun, vmap_method="broadcast"), in_axes=(0, None))(x, y) + (4,) (1,) + Array([1., 2., 3., 4.], dtype=float32) + + Whereas, ``vmap_method="broadcast_fullrank"`` adds an axis of size ``4`` to + ``y``: + + >>> jax.vmap(partial(fun, vmap_method="broadcast_fullrank"), + ... in_axes=(0, None))(x, y) + (4,) (4,) + Array([1., 2., 3., 4.], dtype=float32) + .. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html """ + if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: + deprecations.warn( + "jax-callback-vectorized", + "The vectorized argument of jax.pure_callback is deprecated and setting " + "it will soon raise an error. To avoid an error in the future, and to " + "suppress this warning, please use the vmap_method argument instead.", + stacklevel=2) + if vmap_method is not None: + raise ValueError( + "the vectorized and vmap_method arguments of jax.pure_callback cannot " + "be used together. Please use the vmap_method argument.") + vmap_method = "legacy_vectorized" if vectorized else "sequential" + allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank", + "legacy_vectorized", None] + if vmap_method not in allowed_vmap_methods: + raise ValueError( + f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, " + f"but got: {vmap_method}") + flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) tree_util.tree_map(_check_shape_dtype, result_shape_dtypes) result_avals = tree_util.tree_map( @@ -329,6 +433,7 @@ def pure_callback( result_avals=tuple(flat_result_avals), sharding=sharding, vectorized=vectorized, + vmap_method=vmap_method, ) return tree_util.tree_unflatten(out_tree, out_flat) diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index d63c74d74098..6bbba0cbd88e 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -23,6 +23,7 @@ import numpy as np from jax._src import core +from jax._src import deprecations from jax._src import dispatch from jax._src import effects from jax._src import util @@ -34,7 +35,8 @@ from jax._src.lib import jaxlib from jax._src.lib import xla_client from jax._src.lib.mlir import ir -from jax._src.typing import Array, ArrayLike, DuckTypedArray, Shape +from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray, + Shape) map, unsafe_map = util.safe_map, map FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None @@ -199,23 +201,22 @@ def ffi_call( target_name: str, result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], *args: ArrayLike, - vectorized: bool = False, has_side_effect: bool = False, + vmap_method: str | None = None, + vectorized: bool | DeprecatedArg = DeprecatedArg(), **kwargs: Any, ) -> Array | list[Array]: """Call a foreign function interface (FFI) target. Like :func:`~jax.pure_callback`, the behavior of ``ffi_call`` under - :func:`~jax.vmap` depends on the value of ``vectorized``. When ``vectorized`` - is ``True``, the FFI target is assumed to satisfy: ``ffi_call(xs) == - jnp.stack([ffi_call(x) for x in xs])``. In other words, calling the FFI target - with an extra leading dimension should return the same result as calling it - within a loop and stacking along the zeroth axis. Therefore, the FFI target - will be called directly on batched inputs (where the batch axes are the - leading dimensions). Additionally, the callbacks should return outputs that - have corresponding leading batch axes. If ``vectorized`` is ``False`` (the - default behavior), transforming this ``ffi_call`` under :func:`~jax.vmap` will - result in a :func:`~jax.lax.scan` with the ``ffi_call`` in the body. + :func:`~jax.vmap` depends on the value of ``vmap_method``. See the + :func:`~jax.pure_callback` documenation for more details about the allowed + values and examples of their behavior. + + The current default behavior is to use ``vmap_method="sequential"`` when + not specified, but this behavior is deprecated, and in the future, the + default will be to raise a ``NotImplementedError`` unless ``vmap_method`` is + explicitly specified. Args: target_name: the name of the XLA FFI custom call target that was registered @@ -226,11 +227,11 @@ def ffi_call( used to define the elements of ``result_shape_dtypes``. ``jax.core.abstract_token`` may be used to represent a token-typed output. *args: the arguments passed to the custom call. - vectorized: boolean specifying whether the FFI call can operate in a - vectorized manner, as described above. has_side_effect: boolean specifying whether the custom call has side effects. When ``True``, the FFI call will be executed even when the outputs are not used. + vmap_method: string specifying how the FFI call transforms under + :func:`~jax.vmap` as described above. **kwargs: keyword arguments that are passed as named attributes to the custom call using XLA's FFI interface. @@ -238,6 +239,25 @@ def ffi_call( One or more :class:`~jax.Array` objects whose shapes and dtypes match ``result_shape_dtypes``. """ + if not isinstance(vectorized, DeprecatedArg) and not vectorized is None: + deprecations.warn( + "jax-callback-vectorized", + "The vectorized argument of ffi_call is deprecated and setting " + "it will soon raise an error. To avoid an error in the future, and to " + "suppress this warning, please use the vmap_method argument instead.", + stacklevel=2) + if vmap_method is not None: + raise ValueError( + "the vectorized and vmap_method arguments of ffi_call cannot " + "be used together. Please use the vmap_method argument.") + vmap_method = "legacy_vectorized" if vectorized else "sequential" + allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank", + "legacy_vectorized", None] + if vmap_method not in allowed_vmap_methods: + raise ValueError( + f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, " + f"but got: {vmap_method}") + if isinstance(result_shape_dtypes, Sequence): multiple_results = True result_avals = _result_avals(result_shape_dtypes) @@ -248,6 +268,7 @@ def ffi_call( *args, result_avals=result_avals, vectorized=vectorized, + vmap_method=vmap_method, target_name=target_name, has_side_effect=has_side_effect, **_wrap_kwargs_hashable(kwargs), @@ -342,11 +363,12 @@ def ffi_call_abstract_eval( *avals_in, result_avals: tuple[core.AbstractValue, ...], target_name: str, - vectorized: bool, + vectorized: bool | DeprecatedArg, + vmap_method: str | None, has_side_effect: bool, **kwargs: Any, ): - del avals_in, target_name, vectorized, kwargs + del avals_in, target_name, vectorized, vmap_method, kwargs effects = {_FfiEffect} if has_side_effect else core.no_effects return result_avals, effects @@ -370,11 +392,12 @@ def ffi_call_lowering( *operands: ir.Value, result_avals: tuple[core.AbstractValue, ...], target_name: str, - vectorized: bool, + vectorized: bool | DeprecatedArg, + vmap_method: str | None, has_side_effect: bool, **kwargs: Any, ) -> Sequence[ir.Value]: - del result_avals, vectorized + del result_avals, vectorized, vmap_method rule = ffi_lowering(target_name, has_side_effect=has_side_effect) return rule(ctx, *operands, **_unwrap_kwargs_hashable(kwargs)) diff --git a/tests/extend_test.py b/tests/extend_test.py index e56b0936c7bb..805ad937bc02 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -245,10 +245,11 @@ def testFfiCall(self, shape, dtype): @jtu.sample_product( shape=[(1,), (4,), (5,)], dtype=(np.int32,), - vectorized=(False, True), + vmap_method=("broadcast", "broadcast_fullrank", "sequential", + "legacy_vectorized"), ) @jtu.run_on_devices("gpu") - def testFfiCallBatching(self, shape, dtype, vectorized): + def testFfiCallBatching(self, shape, dtype, vmap_method): shape = (10,) + shape pivots_size = shape[-1] permutation_size = 2 * pivots_size @@ -256,15 +257,29 @@ def testFfiCallBatching(self, shape, dtype, vectorized): pivots = jnp.broadcast_to(pivots, shape) expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size) actual = jax.vmap(lambda x: ffi_call_lu_pivots_to_permutation( - x, permutation_size, vectorized=vectorized))(pivots) + x, permutation_size, vmap_method=vmap_method))(pivots) self.assertArraysEqual(actual, expected) + @jtu.run_on_devices("gpu") + def testVectorizedDeprecation(self): + pivots_size = 4 + shape = (10, pivots_size) + permutation_size = 2 * pivots_size + pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, + dtype=np.int32) + pivots = jnp.broadcast_to(pivots, shape) + with self.assertWarns(DeprecationWarning): + ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True) + with self.assertWarns(DeprecationWarning): + jax.vmap( + lambda x: ffi_call_lu_pivots_to_permutation(x, permutation_size))(pivots) + # TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation` # custom call target because that's the only one in jaxlib that uses the # new FFI interface. Once more are available, consider using something that # can be run on multiple platforms. -def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True): +def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, **kwargs): return jex.ffi.ffi_call( "cu_lu_pivots_to_permutation", jax.ShapeDtypeStruct( @@ -272,7 +287,7 @@ def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True) dtype=pivots.dtype, ), pivots, - vectorized=vectorized, + **kwargs, ) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 21cb31f693ef..605a8c84389c 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -1724,7 +1724,8 @@ def fun(x): "batching rules are implemented only for id_tap, not for call"): jax.vmap(fun)(np.ones((2, 3))) else: - jax.vmap(fun)(np.ones((2, 3))) + with jtu.ignore_warning(category=DeprecationWarning): + jax.vmap(fun)(np.ones((2, 3))) def test_call_error_bad_result_shape(self): with self.assertRaisesRegex( diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index d9887cf7b482..389a4181ebd9 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -635,13 +635,13 @@ def test_can_vmap_pure_callback(self): @jax.jit @jax.vmap def f(x): - return jax.pure_callback(np.sin, x, x) + return jax.pure_callback(np.sin, x, x, vmap_method="sequential") out = f(jnp.arange(4.)) np.testing.assert_allclose(out, np.sin(np.arange(4.))) @jax.jit def g(x): - return jax.pure_callback(np.sin, x, x) + return jax.pure_callback(np.sin, x, x, vmap_method="sequential") out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2))) np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T) @@ -649,7 +649,8 @@ def g(x): @functools.partial(jax.vmap, in_axes=(0, None)) def h(x, y): out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype)) - return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y) + return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y, + vmap_method="sequential") out = h(jnp.arange(4.), 4.) self.assertArraysAllClose(out, np.sin(np.arange(4.)) + 4., rtol=1E-7, check_dtypes=False) @@ -658,7 +659,8 @@ def h(x, y): @functools.partial(jax.vmap) def h(x, y): out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype)) - return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y) + return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y, + vmap_method="sequential") out = h(jnp.arange(4.), jnp.arange(10., 14.)) self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10., 14.), rtol=1E-7, check_dtypes=False) @@ -667,7 +669,8 @@ def h(x, y): @functools.partial(jax.vmap, in_axes=1, out_axes=1) def h(x, y): out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype)) - return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y) + return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y, + vmap_method="sequential") out = h(jnp.arange(4.)[None], jnp.arange(10., 14.)[None]) self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10., 14.)[None], @@ -682,7 +685,7 @@ def cb(x): @jax.jit @jax.vmap def f(x): - return jax.pure_callback(cb, x, x) + return jax.pure_callback(cb, x, x, vmap_method="sequential") np.testing.assert_allclose(f(jnp.arange(4.)), np.sin(np.arange(4.))) @@ -693,7 +696,7 @@ def cb2(x): @jax.jit @jax.vmap def g(x): - return jax.pure_callback(cb2, x, x, vectorized=True) + return jax.pure_callback(cb2, x, x, vmap_method="broadcast") np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.))) @@ -701,7 +704,7 @@ def g(x): @functools.partial(jax.vmap, in_axes=(0, None)) def h(x, y): return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y, - vectorized=True) + vmap_method="broadcast") out = h(jnp.arange(4.), 4.) np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.) @@ -709,7 +712,7 @@ def h(x, y): @functools.partial(jax.vmap, in_axes=(1, None), out_axes=1) def h(x, y): return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y, - vectorized=True) + vmap_method="legacy_vectorized") out = h(jnp.arange(4.)[None], 4.) np.testing.assert_allclose(out, np.sin(np.arange(4.)[None]) + 4.) @@ -722,7 +725,7 @@ def cb(x): @jax.jit @jax.vmap def f(x): - return jax.pure_callback(cb, x, x, vectorized=True) + return jax.pure_callback(cb, x, x, vmap_method="broadcast") with self.assertRaises(RuntimeError): f(jnp.arange(4.)) @@ -981,6 +984,52 @@ def f(x): out = jax.pure_callback(f, jax.ShapeDtypeStruct(x.shape, x.dtype), x) np.testing.assert_allclose(out, 2 * jnp.log(x + 1)) + def test_vmap_method_raise(self): + @jax.vmap + def f(x): + # Setting vectorized to None disables the current default behavior of + # falling back on sequential. + return jax.pure_callback(np.sin, x, x, vectorized=None) + + with self.assertRaisesRegex(NotImplementedError, "vmap is only supported"): + f(jnp.arange(4.)) + + def test_deprecated_vectorized(self): + def f(x, **kwargs): + return jax.pure_callback(np.sin, x, x, **kwargs) + + with self.assertWarnsRegex(DeprecationWarning, "The default behavior"): + jax.vmap(f)(jnp.arange(4.0)) + + with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"): + f(jnp.arange(4.0), vectorized=True) + + with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"): + f(jnp.arange(4.0), vectorized=False) + + def test_vmap_method_broadcast(self): + def callback(x, y): + self.assertTupleEqual(x.shape, (4,)) + self.assertTupleEqual(y.shape, (1,)) + return x + y + + def f(x, y): + return jax.pure_callback(callback, x, x, y, vmap_method="broadcast") + + jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error + + def test_vmap_method_broadcast_fullrank(self): + def callback(x, y): + self.assertTupleEqual(x.shape, (4,)) + self.assertTupleEqual(y.shape, (4,)) + return x + y + + def f(x, y): + return jax.pure_callback(callback, x, x, y, + vmap_method="broadcast_fullrank") + + jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error + class IOCallbackTest(jtu.JaxTestCase):