diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml index 9a49ed2a3e61..ea87d4e29e40 100644 --- a/.github/workflows/asan.yaml +++ b/.github/workflows/asan.yaml @@ -12,7 +12,7 @@ on: branches: - main paths: - - '**/workflows/asan.yml' + - '**/workflows/asan.yaml' jobs: asan: diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml new file mode 100644 index 000000000000..4a2e2ecb7fe6 --- /dev/null +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -0,0 +1,41 @@ +name: CI - Bazel CPU tests (RBE) + +on: + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + run_tests: + if: github.event.repository.fork == false + strategy: + matrix: + runner: ["linux-x86-n2-16", "linux-arm64-t2a-16"] + + runs-on: ${{ matrix.runner }} + # TODO(b/369382309): Replace Linux Arm64 container with the ml-build container once it is available + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') }} + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "3.12" + + steps: + - uses: actions/checkout@v3 + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel CPU Tests with RBE + run: ./ci/run_bazel_test_cpu_rbe.sh \ No newline at end of file diff --git a/.github/workflows/rocm-nightly-upstream-sync.yml b/.github/workflows/rocm-nightly-upstream-sync.yml index e915ccba390d..e8cb5f480313 100644 --- a/.github/workflows/rocm-nightly-upstream-sync.yml +++ b/.github/workflows/rocm-nightly-upstream-sync.yml @@ -1,5 +1,5 @@ # Pulls the latest changes from upstream into main and opens a PR to merge -# them into rocm-main. +# them into rocm-main branch. name: ROCm Nightly Upstream Sync on: diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ab334c15904..10b1fc808970 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. `platforms` instead. * Hashing of tracers, which has been deprecated since version 0.4.30, now results in a `TypeError`. + * {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional + inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel` + on the function inputs. + * `jax.clear_backends` was removed after being deprecated in v0.4.26. * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for @@ -49,6 +53,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. declared inline via {func}`dataclasses.field`. See the function documentation for examples. +* Bug fixes + * Fixed a bug where the GPU implementations of LU and QR decomposition would + result in an indexing overflow for batch sizes close to int32 max. See + {jax-issue}`#24843` for more details. + ## jax 0.4.35 (Oct 22, 2024) * Breaking Changes @@ -79,7 +88,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The semi-public API `jax.lib.xla_client.register_custom_call_target` has been deprecated. Use the JAX FFI instead. * The semi-public APIs `jax.lib.xla_client.dtype_to_etype`, - `jax.lib.xla_client.ops`, + `jax.lib.xla_client.ops`, `jax.lib.xla_client.shape_from_pyval`, `jax.lib.xla_client.PrimitiveType`, `jax.lib.xla_client.Shape`, `jax.lib.xla_client.XlaBuilder`, and `jax.lib.xla_client.XlaComputation` have been deprecated. Use StableHLO diff --git a/ci/README.md b/ci/README.md new file mode 100644 index 000000000000..ea867df52f97 --- /dev/null +++ b/ci/README.md @@ -0,0 +1,10 @@ +# JAX continuous integration + +> [!WARNING] +> This folder is still under construction. It is part of an ongoing +> effort to improve the structure of CI and build related files within the +> JAX repo. This warning will be removed when the contents of this +> directory are stable and appropriate documentation around its usage is in +> place. + +******************************************************************************** \ No newline at end of file diff --git a/ci/envs/default.env b/ci/envs/default.env new file mode 100644 index 000000000000..528c02701acc --- /dev/null +++ b/ci/envs/default.env @@ -0,0 +1,37 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file contains all the default values for the "JAXCI_" environment +# variables used in the CI scripts. These variables are used to control the +# behavior of the CI scripts such as the Python version used, path to JAX/XLA +# repo, if to clone XLA repo, etc. + +# The path to the JAX git repository. +export JAXCI_JAX_GIT_DIR=$(pwd) + +# Controls the version of Hermetic Python to use. Use system default if not +# set. +export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-$(python3 -V | awk '{print $2}' | awk -F. '{print $1"."$2}')} + +# Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository to use a local +# copy of XLA instead of the pinned version in the WORKSPACE. When +# JAXCI_CLONE_MAIN_XLA=1, this gets set automatically. +export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-} + +# If set to 1, the builds will clone the XLA repository at HEAD and set its +# path in JAXCI_XLA_GIT_DIR. +export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0} + +# Allows overriding the XLA commit that is used. +export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} \ No newline at end of file diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh new file mode 100755 index 000000000000..6ba9f6dce239 --- /dev/null +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -0,0 +1,68 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Runs Bazel CPU tests with RBE. +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Clone XLA at HEAD if path to local XLA is not provided +if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + export JAXCI_CLONE_MAIN_XLA=1 +fi + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +# Run Bazel CPU tests with RBE. +os=$(uname -s | awk '{print tolower($0)}') +arch=$(uname -m) + +# When running on Mac or Linux Aarch64, we only build the test targets and +# not run them. These platforms do not have native RBE support so we +# RBE cross-compile them on remote Linux x86 machines. As the tests still +# need to be run on the host machine and because running the tests on a +# single machine can take a long time, we skip running them on these +# platforms. +if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ); then + echo "Building RBE CPU tests..." + bazel build --config=rbe_cross_compile_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64=0 \ + --test_output=errors \ + --color=yes \ + //tests:cpu_tests //tests:backend_independent_tests +else + echo "Running RBE CPU tests..." + bazel test --config=rbe_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64=0 \ + --test_output=errors \ + --color=yes \ + //tests:cpu_tests //tests:backend_independent_tests +fi \ No newline at end of file diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh new file mode 100644 index 000000000000..e77e84f3c07f --- /dev/null +++ b/ci/utilities/setup_build_environment.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Set up the build environment for JAX CI jobs. This script depends on the +# "JAXCI_" environment variables set or sourced in the build script. + +# Pre-emptively mark the JAX git directory as safe. This is necessary for JAX CI +# jobs running on Linux runners in GitHub Actions. Without this, git complains +# that the directory has dubious ownership and refuses to run any commands. +# Avoid running on Windows runners as git runs into issues with not being able +# to lock the config file. Other git commands seem to work on the Windows +# runners so we can skip this step for Windows. +# TODO(b/375073267): Remove this once we understand why git repositories are +# being marked as unsafe inside the self-hosted runners. +if [[ ! $(uname -s) =~ "MSYS_NT" ]]; then + git config --global --add safe.directory $JAXCI_JAX_GIT_DIR +fi + +function clone_main_xla() { + echo "Cloning XLA at HEAD to $(pwd)/xla" + git clone --depth=1 https://github.com/openxla/xla.git $(pwd)/xla + export JAXCI_XLA_GIT_DIR=$(pwd)/xla +} + +# Clone XLA at HEAD if required. +if [[ "$JAXCI_CLONE_MAIN_XLA" == 1 ]]; then + # Clone only if $(pwd)/xla does not exist to avoid failure on re-runs. + if [[ ! -d $(pwd)/xla ]]; then + clone_main_xla + else + echo "JAXCI_CLONE_MAIN_XLA set but local XLA folder already exists: $(pwd)/xla so using that instead." + # Set JAXCI_XLA_GIT_DIR if local XLA already exists + export JAXCI_XLA_GIT_DIR=$(pwd)/xla + fi +fi + +# If a XLA commit is provided, check out XLA at that commit. +if [[ ! -z "$JAXCI_XLA_COMMIT" ]]; then + # Clone XLA at HEAD if a path to local XLA is not provided. + if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + clone_main_xla + fi + pushd "$JAXCI_XLA_GIT_DIR" + + git fetch --depth=1 origin "$JAXCI_XLA_COMMIT" + echo "JAXCI_XLA_COMMIT is set. Checking out XLA at $JAXCI_XLA_COMMIT" + git checkout "$JAXCI_XLA_COMMIT" + + popd +fi + +if [[ ! -z ${JAXCI_XLA_GIT_DIR} ]]; then + echo "INFO: Overriding XLA to be read from $JAXCI_XLA_GIT_DIR instead of the" + echo "pinned version in the WORKSPACE." + echo "If you would like to revert this behavior, unset JAXCI_CLONE_MAIN_XLA" + echo "and JAXCI_XLA_COMMIT in your environment. Note that the Bazel RBE test" + echo "commands overrides the XLA repository and thus require a local copy of" + echo "XLA to run." +fi \ No newline at end of file diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index fcb7b570e493..f4b61cbcf7dc 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -679,7 +679,7 @@ class RmsNormFwdClass: NamedSharding(mesh, PartitionSpec(None, None))) invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0])) output_shardings = (arg_shardings[0], invvar_sharding) - # Sharded_impl only accepts positional arugments + # Sharded_impl only accepts positional arguments # And they should be Jax traceable variables impl = partial(RmsNormFwdClass.impl, eps=eps) @@ -739,7 +739,7 @@ class RmsNormBwdClass: output_shardings = (output_sharding, invvar_sharding, invvar_sharding) - # Sharded_impl only accepts positional arugments + # Sharded_impl only accepts positional arguments # And they should be Jax traceable variables def impl(g, invvar, x, weight): grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind( diff --git a/docs/Custom_Operation_for_GPUs.py b/docs/Custom_Operation_for_GPUs.py index 31a00c49071e..1cdf67c41a90 100644 --- a/docs/Custom_Operation_for_GPUs.py +++ b/docs/Custom_Operation_for_GPUs.py @@ -353,7 +353,7 @@ def partition(eps: float, mesh : jax.sharding.Mesh, NamedSharding(mesh, PartitionSpec(None, None))) # TODO: TE don't force anything. invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0])) output_shardings = (arg_shardings[0], invvar_sharding) - # Sharded_impl only accepts positional arugments + # Sharded_impl only accepts positional arguments # And they should be Jax traceable variables impl = partial(RmsNormFwdClass.impl, eps=eps) diff --git a/docs/_static/style.css b/docs/_static/style.css index 296912ace2c8..2c1dfcbcbf08 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -8,15 +8,15 @@ background-color: #fff; } -.getting-started { +.installation { background-color: rgba(78, 150, 253, var(--block-bg-opacity)); } -.user-guides { +.getting-started { background-color: rgba(0, 169, 154, var(--block-bg-opacity)); } -.developer-docs { +.user-guides { background-color: rgba(171, 0, 182, var(--block-bg-opacity)); } diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index ea8a86fa80f1..f1a699b5c56c 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -26,10 +26,7 @@ "In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n", "We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n", "\n", - "This tutorial comes with two supplementary files:\n", - "\n", - "* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and\n", - "* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code.\n", + "The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi).\n", "\n", "## A simple example\n", "\n", @@ -101,7 +98,7 @@ "\n", "To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).\n", "For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call).\n", - "The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here:\n", + "The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here:\n", "\n", "```c++\n", "#include \n", @@ -129,12 +126,11 @@ "// A wrapper function providing the interface between the XLA FFI call and our\n", "// library function `ComputeRmsNorm` above. This function handles the batch\n", "// dimensions by calling `ComputeRmsNorm` within a loop.\n", - "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", - " ffi::Result> y) {\n", + "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", + " ffi::ResultBuffer y) {\n", " auto [totalSize, lastDim] = GetDims(x);\n", " if (lastDim == 0) {\n", - " return ffi::Error(ffi::ErrorCode::kInvalidArgument,\n", - " \"RmsNorm input must be an array\");\n", + " return ffi::Error::InvalidArgument(\"RmsNorm input must be an array\");\n", " }\n", " for (int64_t n = 0; n < totalSize; n += lastDim) {\n", " ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));\n", @@ -149,8 +145,8 @@ " RmsNorm, RmsNormImpl,\n", " ffi::Ffi::Bind()\n", " .Attr(\"eps\")\n", - " .Arg>() // x\n", - " .Ret>() // y\n", + " .Arg>() // x\n", + " .Ret>() // y\n", ");\n", "```\n", "\n", @@ -173,8 +169,7 @@ "Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python.\n", "In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below.\n", "\n", - "To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble.\n", - "The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt)." + "To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble." ] }, { @@ -433,7 +428,7 @@ "1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n", "2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n", "\n", - "We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end.\n", + "We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.\n", "The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n", "\n", "This custom derivative rule can be wired in as follows:" @@ -508,16 +503,16 @@ "When defining our FFI wrapper for CPU, the function signature that we used was:\n", "\n", "```c++\n", - "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", - " ffi::Result> y)\n", + "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", + " ffi::ResultBuffer y)\n", "```\n", "\n", "To update this to interface with a CUDA kernel, this signature becomes:\n", "\n", "```c++\n", "ffi::Error RmsNormImpl(cudaStream_t stream, float eps,\n", - " ffi::Buffer x,\n", - " ffi::Result> y)\n", + " ffi::Buffer x,\n", + " ffi::ResultBuffer y)\n", "```\n", "\n", "And the handler definition is updated to include a `Ctx` in its binding:\n", @@ -528,8 +523,8 @@ " ffi::Ffi::Bind()\n", " .Ctx>()\n", " .Attr(\"eps\")\n", - " .Arg>() // x\n", - " .Ret>() // y\n", + " .Arg>() // x\n", + " .Ret>() // y\n", ");\n", "```\n", "\n", diff --git a/docs/ffi.md b/docs/ffi.md index 5afc8f809d4d..dbe901237ed4 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -34,10 +34,7 @@ JAX's FFI support is provided in two parts: In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases. We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below. -This tutorial comes with two supplementary files: - -* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and -* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code. +The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi). ## A simple example @@ -96,7 +93,7 @@ and, for our example, this is the function that we want to expose to JAX via the To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla). For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call). -The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here: +The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here: ```c++ #include @@ -124,12 +121,11 @@ std::pair GetDims(const ffi::Buffer &buffer) { // A wrapper function providing the interface between the XLA FFI call and our // library function `ComputeRmsNorm` above. This function handles the batch // dimensions by calling `ComputeRmsNorm` within a loop. -ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) { +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNorm input must be an array"); + return ffi::Error::InvalidArgument("RmsNorm input must be an array"); } for (int64_t n = 0; n < totalSize; n += lastDim) { ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); @@ -144,8 +140,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( RmsNorm, RmsNormImpl, ffi::Ffi::Bind() .Attr("eps") - .Arg>() // x - .Ret>() // y + .Arg>() // x + .Ret>() // y ); ``` @@ -166,7 +162,6 @@ Now that we have our minimal FFI wrapper implemented, we need to expose this fun In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below. To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble. -The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt). ```{code-cell} ipython3 :tags: [hide-output] @@ -357,7 +352,7 @@ In this case, we actually define two new FFI calls: 1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass. 2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents. -We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end. +We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end. The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes. This custom derivative rule can be wired in as follows: @@ -422,16 +417,16 @@ Since this documentation page is automatically generated on a machine without ac When defining our FFI wrapper for CPU, the function signature that we used was: ```c++ -ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y) ``` To update this to interface with a CUDA kernel, this signature becomes: ```c++ ffi::Error RmsNormImpl(cudaStream_t stream, float eps, - ffi::Buffer x, - ffi::Result> y) + ffi::Buffer x, + ffi::ResultBuffer y) ``` And the handler definition is updated to include a `Ctx` in its binding: @@ -442,8 +437,8 @@ XLA_FFI_DEFINE_HANDLER( ffi::Ffi::Bind() .Ctx>() .Attr("eps") - .Arg>() // x - .Ret>() // y + .Arg>() // x + .Ret>() // y ); ``` diff --git a/docs/ffi/rms_norm.cc b/docs/ffi/rms_norm.cc index 4dc8a890410c..467f13d44ac2 100644 --- a/docs/ffi/rms_norm.cc +++ b/docs/ffi/rms_norm.cc @@ -56,12 +56,11 @@ std::pair GetDims(const ffi::Buffer &buffer) { // A wrapper function providing the interface between the XLA FFI call and our // library function `ComputeRmsNorm` above. This function handles the batch // dimensions by calling `ComputeRmsNorm` within a loop. -ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) { +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNorm input must be an array"); + return ffi::Error::InvalidArgument("RmsNorm input must be an array"); } for (int64_t n = 0; n < totalSize; n += lastDim) { ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); @@ -75,17 +74,16 @@ ffi::Error RmsNormImpl(float eps, ffi::Buffer x, XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl, ffi::Ffi::Bind() .Attr("eps") - .Arg>() // x - .Ret>() // y + .Arg>() // x + .Ret>() // y ); -ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, - ffi::Result> y, - ffi::Result> res) { +ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y, + ffi::ResultBuffer res) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormFwd input must be an array"); + return ffi::Error::InvalidArgument("RmsNormFwd input must be an array"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), @@ -94,13 +92,12 @@ ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, return ffi::Error::Success(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - RmsNormFwd, RmsNormFwdImpl, - ffi::Ffi::Bind() - .Attr("eps") - .Arg>() // x - .Ret>() // y - .Ret>() // res +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>() // x + .Ret>() // y + .Ret>() // res ); void ComputeRmsNormBwd(int64_t size, float res, const float *x, @@ -115,14 +112,12 @@ void ComputeRmsNormBwd(int64_t size, float res, const float *x, } } -ffi::Error RmsNormBwdImpl(ffi::Buffer res, - ffi::Buffer x, - ffi::Buffer ct_y, - ffi::Result> ct_x) { +ffi::Error RmsNormBwdImpl(ffi::Buffer res, ffi::Buffer x, + ffi::Buffer ct_y, + ffi::ResultBuffer ct_x) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormBwd inputs must be arrays"); + return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]), @@ -131,11 +126,10 @@ ffi::Error RmsNormBwdImpl(ffi::Buffer res, return ffi::Error::Success(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - RmsNormBwd, RmsNormBwdImpl, - ffi::Ffi::Bind() - .Arg>() // res - .Arg>() // x - .Arg>() // ct_y - .Ret>() // ct_x +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl, + ffi::Ffi::Bind() + .Arg>() // res + .Arg>() // x + .Arg>() // ct_y + .Ret>() // ct_x ); diff --git a/docs/index.rst b/docs/index.rst index ba724f8e77ab..5f3bce5cf7da 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,12 @@ designed for high-performance numerical computing and large-scale machine learni .. grid:: 3 + .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Installation + :columns: 12 6 6 4 + :link: installation + :link-type: ref + :class-card: installation + .. grid-item-card:: :material-regular:`rocket_launch;2em` Getting started :columns: 12 6 6 4 :link: beginner-guide @@ -44,12 +50,6 @@ designed for high-performance numerical computing and large-scale machine learni :link-type: ref :class-card: user-guides - .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Developer notes - :columns: 12 6 6 4 - :link: contributor-guide - :link-type: ref - :class-card: developer-docs - If you're looking to train neural networks, use Flax_ and start with its tutorials. For an end-to-end transformer library built on JAX, see MaxText_. diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index 246d3a6cb084..37afa2f594e3 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -24,6 +24,7 @@ import jax.numpy as jnp jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1) jax.config.update("jax_persistent_cache_min_compile_time_secs", 0) +jax.config.update("jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir") @jax.jit def f(x): @@ -87,6 +88,23 @@ cc.set_cache_dir("/tmp/jax_cache") Note that both criteria need to be satisfied for a function to be cached. +### Additional caching + +XLA supports additional caching mechanism which can be enabled alongside JAX's +persistent compilation cache to further improve recompilation time. + +* `jax_persistent_cache_enable_xla_caches`: Possible values: + + * `all`: enable all XLA caching features + + * `none`: don't enable any extra XLA caching features + + * `xla_gpu_kernel_cache_file`: only enable the kernel cache + + * `xla_gpu_per_fusion_autotune_cache_dir`: (default value) only enable the + autotuning cache + + ### Google Cloud When running on Google Cloud, the compilation cache can be placed on a Google diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt index 4179f4bd9ad4..9f9090e2b7ef 100644 --- a/examples/ffi/CMakeLists.txt +++ b/examples/ffi/CMakeLists.txt @@ -12,13 +12,18 @@ message(STATUS "XLA include directory: ${XLA_DIR}") find_package(nanobind CONFIG REQUIRED) -nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc") -target_include_directories(_rms_norm PUBLIC ${XLA_DIR}) -install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) +set( + JAX_FFI_EXAMPLE_PROJECTS + "rms_norm" + "attrs" + "counter" +) -nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc") -target_include_directories(_attrs PUBLIC ${XLA_DIR}) -install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) +foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS}) + nanobind_add_module("_${PROJECT}" NB_STATIC "src/jax_ffi_example/${PROJECT}.cc") + target_include_directories("_${PROJECT}" PUBLIC ${XLA_DIR}) + install(TARGETS "_${PROJECT}" LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) +endforeach() if(JAX_FFI_EXAMPLE_ENABLE_CUDA) enable_language(CUDA) diff --git a/examples/ffi/README.md b/examples/ffi/README.md index cc7018782a25..eb730b483b76 100644 --- a/examples/ffi/README.md +++ b/examples/ffi/README.md @@ -3,7 +3,26 @@ This directory includes an example project demonstrating the use of JAX's foreign function interface (FFI). The JAX docs provide more information about this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html), -but the example in this directory explicitly demonstrates: +but the example in this directory complements that document by demonstrating +(and testing!) the full packaging workflow, and some more advanced use cases. +Within the example project, there are several example calls: -1. One way to package and distribute FFI targets, and -2. Some more advanced use cases. +1. `rms_norm`: This is the example from the tutorial on the JAX docs, and it + demonstrates the most basic use of the FFI. It also includes customization of + behavior under automatic differentiation using `jax.custom_vjp`. + +2. `counter`: This example demonstrates a common pattern for how an FFI call can + use global cache to maintain state between calls. This pattern is useful when + an FFI call requires an expensive initialization step which shouldn't be + run on every execution, or if there is other shared state that could be + reused between calls. In this simple example we just count the number of + times the call was executed. + +3. `attrs`: An example demonstrating the different ways that attributes can be + passed to the FFI. For example, we can pass arrays, variadic attributes, and + user-defined types. Full support of user-defined types isn't yet supported + by XLA, so that example will be added in the future. + +4. `cuda_e2e`: An end-to-end example demonstrating the use of the JAX FFI with + CUDA. The specifics of the kernels are not very important, but the general + structure, and packaging of the extension are useful for testing. diff --git a/examples/ffi/src/jax_ffi_example/attrs.cc b/examples/ffi/src/jax_ffi_example/attrs.cc index 2a6e8d847cf4..7ff5c98e52e1 100644 --- a/examples/ffi/src/jax_ffi_example/attrs.cc +++ b/examples/ffi/src/jax_ffi_example/attrs.cc @@ -22,7 +22,7 @@ namespace nb = nanobind; namespace ffi = xla::ffi; ffi::Error ArrayAttrImpl(ffi::Span array, - ffi::Result> res) { + ffi::ResultBufferR0 res) { int64_t total = 0; for (int32_t x : array) { total += x; @@ -37,8 +37,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl, .Ret>()); ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs, - ffi::Result> secret, - ffi::Result> count) { + ffi::ResultBufferR0 secret, + ffi::ResultBufferR0 count) { auto maybe_secret = attrs.get("secret"); if (maybe_secret.has_error()) { return maybe_secret.error(); diff --git a/examples/ffi/src/jax_ffi_example/counter.cc b/examples/ffi/src/jax_ffi_example/counter.cc new file mode 100644 index 000000000000..d7f17e730fd6 --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/counter.cc @@ -0,0 +1,53 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/ffi.h" + +namespace nb = nanobind; +namespace ffi = xla::ffi; + +ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0 out) { + static std::mutex mutex; + static auto& cache = *new std::unordered_map(); + { + const std::lock_guard lock(mutex); + auto it = cache.find(index); + if (it != cache.end()) { + out->typed_data()[0] = ++it->second; + } else { + cache.insert({index, 0}); + out->typed_data()[0] = 0; + } + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + Counter, CounterImpl, + ffi::Ffi::Bind().Attr("index").Ret>()); + +NB_MODULE(_counter, m) { + m.def("registrations", []() { + nb::dict registrations; + registrations["counter"] = nb::capsule(reinterpret_cast(Counter)); + return registrations; + }); +} diff --git a/examples/ffi/src/jax_ffi_example/counter.py b/examples/ffi/src/jax_ffi_example/counter.py new file mode 100644 index 000000000000..12c7f015bf58 --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/counter.py @@ -0,0 +1,38 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An example demonstrating how an FFI call can maintain "state" between calls + +In this case, the ``counter`` call simply accumulates the number of times it +was executed, but this pattern can also be used for more advanced use cases. +For example, this pattern is used in jaxlib for: + +1. The GPU solver linear algebra kernels which require an expensive "handler" + initialization, and +2. The ``triton_call`` function which caches the compiled triton modules after + their first use. +""" + +import jax +import jax.extend as jex + +from jax_ffi_example import _counter + +for name, target in _counter.registrations().items(): + jex.ffi.register_ffi_target(name, target) + + +def counter(index): + return jex.ffi.ffi_call( + "counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index)) diff --git a/examples/ffi/src/jax_ffi_example/cuda_e2e.cu b/examples/ffi/src/jax_ffi_example/cuda_e2e.cu index 858b5f8a888a..240adb6d6a8c 100644 --- a/examples/ffi/src/jax_ffi_example/cuda_e2e.cu +++ b/examples/ffi/src/jax_ffi_example/cuda_e2e.cu @@ -44,11 +44,9 @@ __global__ void FooFwdKernel(const float *a, const float *b, float *c, // Buffer type provides buffer dimensions, so the "n" argument here is not // strictly necessary, but it allows us to demonstrate the use of attributes // (.Attr in the FFI handler definition above). -ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer a, - ffi::Buffer b, - ffi::Result> c, - ffi::Result> b_plus_1, - size_t n) { +ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer a, + ffi::Buffer b, ffi::ResultBuffer c, + ffi::ResultBuffer b_plus_1, size_t n) { const int block_dim = 128; const int grid_dim = 1; // Note how we access regular Buffer data vs Result Buffer data: @@ -73,12 +71,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( FooFwd, FooFwdHost, ffi::Ffi::Bind() .Ctx>() // stream - .Arg>() // a - .Arg>() // b - .Ret>() // c - .Ret>() // b_plus_1 + .Arg>() // a + .Arg>() // b + .Ret>() // c + .Ret>() // b_plus_1 .Attr("n"), - {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled + {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled //----------------------------------------------------------------------------// // Backward pass // @@ -106,11 +104,11 @@ __global__ void FooBwdKernel(const float *c_grad, // incoming gradient wrt c } ffi::Error FooBwdHost(cudaStream_t stream, - ffi::Buffer c_grad, - ffi::Buffer a, - ffi::Result> b_plus_1, - ffi::Result> a_grad, - ffi::Result> b_grad, + ffi::Buffer c_grad, + ffi::Buffer a, + ffi::ResultBuffer b_plus_1, + ffi::ResultBuffer a_grad, + ffi::ResultBuffer b_grad, size_t n) { const int block_dim = 128; const int grid_dim = 1; @@ -131,10 +129,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( FooBwd, FooBwdHost, ffi::Ffi::Bind() .Ctx>() // stream - .Arg>() // c_grad - .Arg>() // a - .Arg>() // b_plus_1 - .Ret>() // a_grad - .Ret>() // b_grad + .Arg>() // c_grad + .Arg>() // a + .Arg>() // b_plus_1 + .Ret>() // a_grad + .Ret>() // b_grad .Attr("n"), - {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled + {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc index 2fb8d96c8461..455a0e557620 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.cc +++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc @@ -59,11 +59,10 @@ std::pair GetDims(const ffi::Buffer &buffer) { // library function `ComputeRmsNorm` above. This function handles the batch // dimensions by calling `ComputeRmsNorm` within a loop. ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) { + ffi::ResultBuffer y) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNorm input must be an array"); + return ffi::Error::InvalidArgument("RmsNorm input must be an array"); } for (int64_t n = 0; n < totalSize; n += lastDim) { ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); @@ -82,12 +81,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl, ); ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, - ffi::Result> y, - ffi::Result> res) { + ffi::ResultBuffer y, + ffi::ResultBuffer res) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormFwd input must be an array"); + return ffi::Error::InvalidArgument("RmsNormFwd input must be an array"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), @@ -118,11 +116,10 @@ void ComputeRmsNormBwd(int64_t size, float res, const float *x, ffi::Error RmsNormBwdImpl(ffi::Buffer res, ffi::Buffer x, ffi::Buffer ct_y, - ffi::Result> ct_x) { + ffi::ResultBuffer ct_x) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormBwd inputs must be arrays"); + return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]), diff --git a/examples/ffi/tests/counter_test.py b/examples/ffi/tests/counter_test.py new file mode 100644 index 000000000000..1e2ad38a363f --- /dev/null +++ b/examples/ffi/tests/counter_test.py @@ -0,0 +1,55 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +from jax._src import test_util as jtu + +from jax_ffi_example import counter + +jax.config.parse_flags_with_absl() + + +class CounterTests(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["cpu"]): + self.skipTest("Unsupported platform") + + def test_basic(self): + self.assertEqual(counter.counter(0), 0) + self.assertEqual(counter.counter(0), 1) + self.assertEqual(counter.counter(0), 2) + self.assertEqual(counter.counter(1), 0) + self.assertEqual(counter.counter(0), 3) + + def test_jit(self): + @jax.jit + def counter_fun(x): + return x, counter.counter(2) + + self.assertEqual(counter_fun(0)[1], 0) + self.assertEqual(counter_fun(0)[1], 1) + + # Persists across different cache hits + self.assertEqual(counter_fun(1)[1], 2) + + # Persists after the cache is cleared + counter_fun.clear_cache() + self.assertEqual(counter_fun(0)[1], 3) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/BUILD b/jax/BUILD index 71be67368f3b..0da99677dc7b 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -430,6 +430,7 @@ pytype_strict_library( ":config", ":mlir", ":monitoring", + ":path", ":profiler", ":traceback_util", ":xla_bridge", @@ -722,6 +723,7 @@ py_library( ":jax", ":mlir", "//jax/_src/lib", + "//jax/extend:ffi", "//jaxlib/mlir:arithmetic_dialect", "//jaxlib/mlir:builtin_dialect", "//jaxlib/mlir:func_dialect", diff --git a/jax/__init__.py b/jax/__init__.py index 7916ef0e3962..8ca7721da445 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -83,7 +83,6 @@ from jax._src.api import block_until_ready as block_until_ready from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401 from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies -from jax._src.api import clear_backends as _deprecated_clear_backends from jax._src.api import clear_caches as clear_caches from jax._src.custom_derivatives import closure_convert as closure_convert from jax._src.custom_derivatives import custom_gradient as custom_gradient @@ -218,16 +217,15 @@ "or jax.tree_util.tree_map (any JAX version).", _deprecated_tree_map ), - # Added Mar 18, 2024 + # Finalized Nov 12 2024; remove after Feb 12 2025 "clear_backends": ( - "jax.clear_backends is deprecated.", - _deprecated_clear_backends + "jax.clear_backends was removed in JAX v0.4.36", + None ), } import typing as _typing if _typing.TYPE_CHECKING: - from jax._src.api import clear_backends as clear_backends from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf from jax._src.tree_util import tree_flatten as tree_flatten from jax._src.tree_util import tree_leaves as tree_leaves diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 5ed0b0192a7b..fc135ac8f28c 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -716,6 +716,8 @@ def remat_vmap(axis_data, args, dims, *, jaxpr, **params): # TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) new_params = dict(eqn.params, jaxpr=new_jaxpr) if (not any(used_inputs) and not any(used_outputs) and diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 329abd6b7570..1bfce85d592c 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -68,11 +68,13 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]: else: return tuple(map(_ensure_str, x)) -@lu.transformation_with_aux -def flatten_fun(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun(f, store, in_tree, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) - ans = yield py_args, py_kwargs - yield tree_flatten(ans) + ans = f(*py_args, **py_kwargs) + ans, out_tree = tree_flatten(ans) + store.store(out_tree) + return ans def apply_flat_fun(fun, io_tree, *py_args): in_tree_expected, out_tree = io_tree @@ -82,11 +84,13 @@ def apply_flat_fun(fun, io_tree, *py_args): ans = fun(*args) return tree_unflatten(out_tree, ans) -@lu.transformation_with_aux -def flatten_fun_nokwargs(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun_nokwargs(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) - ans = yield py_args, {} - yield tree_flatten(ans) + ans = f(*py_args) + ans, out_tree = tree_flatten(ans) + store.store(out_tree) + return ans def apply_flat_fun_nokwargs(fun, io_tree, py_args): in_tree_expected, out_tree = io_tree @@ -118,17 +122,18 @@ def flattened_fun_in_tree( else: return in_tree, lambda: out_tree_store.val, has_kwargs -@lu.transformation_with_aux -def flatten_fun_nokwargs2(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun_nokwargs2(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) - pair = yield py_args, {} + pair = f(*py_args) if not isinstance(pair, (list, tuple)) or len(pair) != 2: raise TypeError("expected function with aux output to return a two-element " f"tuple, but got type {type(pair)} with value {pair!r}") ans, aux = pair ans_flat, ans_tree = tree_flatten(ans) aux_flat, aux_tree = tree_flatten(aux) - yield (ans_flat, aux_flat), (ans_tree, aux_tree) + store.store((ans_tree, aux_tree)) + return ans_flat, aux_flat class _HashableWithStrictTypeEquality: """Box object used when comparing static arguments as a jit key. @@ -277,8 +282,8 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...], return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args -@lu.transformation -def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs): +@lu.transformation2 +def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs): sentinel = object() args = [sentinel] * (len(fixed_args) + len(dyn_args)) for i, arg in zip(dyn_argnums, dyn_args): @@ -286,9 +291,7 @@ def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs): fixed_args_ = iter(fixed_args) args = [next(fixed_args_).val if x is sentinel else x for x in args] assert next(fixed_args_, sentinel) is sentinel - ans = yield args, kwargs - yield ans - + return f(*args, **kwargs) def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], kwargs: dict[str, Any]): @@ -311,11 +314,10 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...], return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs -@lu.transformation -def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): +@lu.transformation2 +def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs): kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs) - ans = yield args, kwargs - yield ans + return f(*args, **kwargs) @lru_cache(maxsize=4096) @@ -435,9 +437,9 @@ def flat_out_axes( f, out_axes = _flat_out_axes(f, tuple(leaves), treedef) return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef)) -@lu.transformation_with_aux -def _flat_out_axes(leaves, treedef, *args, **kwargs): - ans = yield args, kwargs +@lu.transformation_with_aux2 +def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs): + ans = f(*args, **kwargs) spec = tree_unflatten(treedef, leaves) try: spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None)) @@ -449,7 +451,8 @@ def _flat_out_axes(leaves, treedef, *args, **kwargs): "that the `out_axes` argument to `pmap` is a pytree prefix of the " "pmapped function's output.") raise ValueError(msg) from None - yield ans, spec_flat + store.store(spec_flat) + return ans def check_callable(fun): # In Python 3.10+, the only thing stopping us from supporting staticmethods @@ -683,11 +686,12 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames, return tuple(f'{name}{keystr(path)}' for name, x in ba.arguments.items() for path, l in generate_key_paths(x) if l is not static) -@lu.transformation_with_aux -def result_paths(*args, **kwargs): +@lu.transformation_with_aux2 +def result_paths(f, store, *args, **kwargs): "linear_util transform to get output pytree paths of pre-flattened function." - ans = yield args, kwargs - yield ans, [keystr(path) for path, _ in generate_key_paths(ans)] + ans = f(*args, **kwargs) + store.store([keystr(path) for path, _ in generate_key_paths(ans)]) + return ans def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None, result_paths: tuple[str, ...] | None = None, diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 71886b453bef..013b766b8550 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -343,7 +343,7 @@ def pure_callback( * Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method`` is deprecated and it will eventually raise ``NotImplementedError``. * ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over - the batched arugments, calling ``callback`` once for each batch element. + the batched arguments, calling ``callback`` once for each batch element. * ``vmap_method="expand_dims"`` calls ``callback`` with new axes of size ``1`` added as the leading dimension unbatched inputs. * ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 55db5d13e848..22fde8bd1cb5 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -330,11 +330,12 @@ def update_error(error, pred, code, metadata, payload, effect_type): ## Checkify transformation for plumbing functional error values. -@lu.transformation_with_aux -def _flatten_and_get_error_metadata_thunk(*invals): - error, out = yield invals, {} +@lu.transformation_with_aux2 +def _flatten_and_get_error_metadata_thunk(f, store, *invals): + error, out = f(*invals) out_vals, out_tree = jtu.tree_flatten((error, out)) - yield out_vals, (out_tree, set(error._pred.keys())) + store.store((out_tree, set(error._pred.keys()))) + return out_vals def default_checkify_rule(primitive: core.Primitive, error: Error, enabled_errors, *invals: core.Value, @@ -438,10 +439,12 @@ def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors, consts = tuple(c.x for c in hashable_consts) return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args) -@lu.transformation_with_aux -def flatten_fun_output(*args): - ans = yield args, {} - yield tree_flatten(ans) +@lu.transformation_with_aux2 +def flatten_fun_output(f, store, *args): + ans = f(*args) + ans, out_tree = tree_flatten(ans) + store.store(out_tree) + return ans def _reduce_any_error(error: Error): diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index c7665da961af..8ff52bd2f559 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -80,7 +80,8 @@ def cloud_tpu_init() -> None: os.environ.setdefault('TPU_ML_PLATFORM', 'JAX') os.environ.setdefault('TPU_ML_PLATFORM_VERSION', version.__version__) os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') - os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true" + if '--xla_tpu_use_enhanced_launch_barrier' not in os.environ['LIBTPU_INIT_ARGS']: + os.environ['LIBTPU_INIT_ARGS'] = os.environ.get('LIBTPU_INIT_ARGS','') + ' --xla_tpu_use_enhanced_launch_barrier=true' # this makes tensorstore serialization work better on TPU os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60') diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 113f7507c4b0..ebb1a2b54855 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -29,10 +29,12 @@ from jax._src import distributed from jax._src import lib from jax._src import monitoring +from jax._src import path as pathlib from jax._src import profiler from jax._src import traceback_util from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir import numpy as np @@ -241,6 +243,31 @@ def get_compile_options( debug_options.xla_detailed_logging = detailed_logging + # If persistent cache is enabled, also enable additional XLA caching features. + if compilation_cache.is_persistent_cache_enabled() and jaxlib_version > (0, 4, 35): + # compilation_cache_dir can't be None here, but the type checker is a bit + # strict. + path = pathlib.Path(config.compilation_cache_dir.value or "") + enabled_flags = config.persistent_cache_enable_xla_caches.value or "" + + if enabled_flags == "all" or "xla_gpu_kernel_cache_file" in enabled_flags: + kernel_cache_path = path / "xla_gpu_kernel_cache_file" + debug_options.xla_gpu_kernel_cache_file = str(kernel_cache_path) + # This option is required to use the kernel cache. + debug_options.xla_gpu_enable_llvm_module_compilation_parallelism = True + logger.debug("Enabling XLA kernel cache at '%s'", kernel_cache_path) + + if enabled_flags == "all" or "xla_gpu_per_fusion_autotune_cache_dir" in enabled_flags: + autotune_cache_path = path / "xla_gpu_per_fusion_autotune_cache_dir" + debug_options.xla_gpu_per_fusion_autotune_cache_dir = str(autotune_cache_path) + logger.debug("Enabling XLA autotuning cache at '%s'", autotune_cache_path) + + # Set caching mode so that only process 0 can write to the cache. + if distributed.global_state.process_id == 0: + debug_options.xla_gpu_experimental_autotune_cache_mode = xc.AutotuneCacheMode.UPDATE + else: + debug_options.xla_gpu_experimental_autotune_cache_mode = xc.AutotuneCacheMode.READ + return compile_options diff --git a/jax/_src/config.py b/jax/_src/config.py index f3edde69981f..72f394dba76f 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1369,6 +1369,15 @@ def _update_jax_memories_thread_local(val): ' filesystem being used for the cache. ' '* > 0: the actual minimum size desired; no overrides.')) +# TODO: Change default to all +persistent_cache_enable_xla_caches = optional_string_state( + name='jax_persistent_cache_enable_xla_caches', + default='xla_gpu_per_fusion_autotune_cache_dir', + help=('When the persistent cache is enabled, additional XLA caching will ' + 'also be enabled automatically. This option can be used to configure' + 'which XLA caching methods will be enabled.'), +) + compilation_cache_include_metadata_in_key = bool_state( name='jax_compilation_cache_include_metadata_in_key', default=False, @@ -1561,7 +1570,9 @@ def _update_default_device_thread_local(val): def _validate_default_device(val): - if val is not None and not isinstance(val, xla_client.Device): + if (val is not None and + not isinstance(val, xla_client.Device) and + val not in ['cpu', 'gpu', 'tpu']): # TODO(skyewm): this is a workaround for non-PJRT Device types. Remove when # all JAX backends use a single C++ device interface. if 'Device' in str(type(val)): @@ -1569,12 +1580,11 @@ def _validate_default_device(val): 'Allowing non-`xla_client.Device` default device: %s, type: %s', repr(val), type(val)) return - raise ValueError('jax.default_device must be passed a Device object (e.g. ' - f"`jax.devices('cpu')[0]`), got: {val!r}") + raise ValueError('jax.default_device must be passed either a Device object (e.g. ' + f"`jax.devices('cpu')[0]`) or a platform name string like 'cpu' or 'gpu'" + f", got: {val!r}") -# TODO(skye): default_device only accepts devices for now. Make it work with -# platform names as well (e.g. "cpu" to mean the same as jax.devices("cpu")[0]). default_device = string_or_object_state( name='jax_default_device', default=None, diff --git a/jax/_src/core.py b/jax/_src/core.py index beb755348a0f..a1fcdac65df0 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -446,8 +446,16 @@ def bind(self, *args, **params): # TODO: figure out how to handle function arguments # assert (not config.enable_checks.value or # all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args - with take_current_trace() as cur_trace: - return self.bind_with_trace(cur_trace, args, params) + + # This is equivalent to "with take_current_trace()", but the bind() code + # is called frequently and it's slightly faster to avoid using a context + # manager object. + prev_trace = trace_ctx.trace + trace_ctx.set_trace(eval_trace) + try: + return self.bind_with_trace(prev_trace, args, params) + finally: + trace_ctx.set_trace(prev_trace) def bind_with_trace(self, trace, args, params): return trace.process_primitive(self, args, params) @@ -1648,8 +1656,10 @@ def str_short(self, short_dtypes=False): self.dtype.name) dt_str = dt_str.replace('void', 'float0') if hasattr(self, 'sharding') and self.sharding is not None: - shapestr = ','.join(_get_shape_sharding_str(self.shape, self.sharding.spec)) - return f'{dt_str}[{shapestr}]' + shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) + axis_types = self.sharding.mesh.axis_types + axt = _get_axis_type_str(axis_types) if axis_types is not None else '' + return f'{dt_str}[{shapestr}]{axt}' else: shapestr = ','.join(map(str, self.shape)) return f'{dt_str}[{shapestr}]' @@ -1661,15 +1671,32 @@ def _len(self, ignored_tracer): raise TypeError("len() of unsized object") from err # same as numpy error +def _get_axis_type_str(axis_types): + from jax._src.mesh import AxisTypes # type: ignore + + out = [] + for t, axes in axis_types.items(): + a = f"({','.join(a for a in axes)})" if isinstance(axes, tuple) else axes + if t == AxisTypes.Collective: + out.append(f"C:{a}") + elif t == AxisTypes.User: + out.append(f"U:{a}") + else: + assert t == AxisTypes.Auto + out.append(f"A:{a}") + return f"{{{', '.join(out)}}}" + def _get_shape_sharding_str(shape, spec): + out = [] for s1, s2 in zip(shape, spec): if s2 is None: - yield f"{s1}" + out.append(f"{s1}") elif isinstance(s2, tuple): ss = ','.join(s for s in s2) - yield f"{s1}@({ss})" + out.append(f"{s1}@({ss})") else: - yield f"{s1}@{s2}" + out.append(f"{s1}@{s2}") + return ','.join(out) def _get_abstract_sharding(val): from jax._src.sharding_impls import NamedSharding # pytype: disable=import-error diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 375efeb712b8..69130cc1831e 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -75,13 +75,14 @@ def _zeros_like_pytree(x): # like the api_util.py function, but also grabs output avals for error checking -@lu.transformation_with_aux -def _flatten_fun_nokwargs(in_tree, *args_flat): +@lu.transformation_with_aux2 +def _flatten_fun_nokwargs(f, store, in_tree, *args_flat): py_args = tree_unflatten(in_tree, args_flat) - ans = yield py_args, {} + ans = f(*py_args) ans_flat, ans_tree = tree_flatten(ans) ans_avals = [core.get_aval(x) for x in ans_flat] - yield ans_flat, (ans_tree, ans_avals) + store.store((ans_tree, ans_avals)) + return ans_flat ### JVPs @@ -266,18 +267,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable def _add_args(f, extra_args): return _add_args_(f, tuple(Unhashable(arg) for arg in extra_args)) -@lu.transformation -def _add_args_(extra_args, *args, **kwargs): +@lu.transformation2 +def _add_args_(f, extra_args, *args, **kwargs): extra_args = tuple(arg.val for arg in extra_args) all_args = (extra_args + args) - yield (yield all_args, kwargs) + return f(*all_args, **kwargs) -@partial(lu.transformation_with_aux, use_eq_store=True) -def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): +@partial(lu.transformation_with_aux2, use_eq_store=True) +def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args): primals_in, tangents_in = split_list(args, [len(args) // 2]) py_primals = tree_unflatten(in_tree, primals_in) py_tangents = tree_unflatten(in_tree, tangents_in) - pair_out = yield (py_primals, py_tangents), {} + pair_out = f(py_primals, py_tangents) if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom JVP rule {jvp_name} for function {primal_name} " "must produce a pair (list or tuple of length two) representing " @@ -348,7 +349,8 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): if av_et != av_t) raise TypeError(msg.format('\n'.join(disagreements))) - yield primals_out + tangents_out, (out_tree, primal_avals) + store.store((out_tree, primal_avals)) + return primals_out + tangents_out class CustomJVPCallPrimitive(core.Primitive): multiple_results = True @@ -652,15 +654,15 @@ def _check_for_tracers(x): "arguments should typically not be indicated as nondiff_argnums.") raise UnexpectedTracerError(msg) -@partial(lu.transformation_with_aux, use_eq_store=True) -def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, +@partial(lu.transformation_with_aux2, use_eq_store=True) +def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, *args): if symbolic_zeros: args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])] else: args = args[::2] py_args = tree_unflatten(in_tree, args) - pair_out = yield py_args, {} + pair_out = f(*py_args) if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} " "must produce a pair (list or tuple of length two) where the first " @@ -710,16 +712,17 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - yield (*res, *primals_out), (out_tree, res_tree) + store.store((out_tree, res_tree)) + return (*res, *primals_out) -@lu.transformation -def _flatten_bwd(in_tree, in_avals, out_trees, *args): +@lu.transformation2 +def _flatten_bwd(f, in_tree, in_avals, out_trees, *args): out_tree, res_tree = out_trees() assert len(args) == res_tree.num_leaves + out_tree.num_leaves res, cts_out = split_list(args, [res_tree.num_leaves]) py_res = tree_unflatten(res_tree, res) py_cts_out = tree_unflatten(out_tree, cts_out) - py_cts_in = yield (py_res, py_cts_out), {} + py_cts_in = f(py_res, py_cts_out) if isinstance(py_cts_in, list) and len(py_cts_in) == len(treedef_children(in_tree)): py_cts_in = tuple(py_cts_in) # For each None in py_cts_in, indicating an argument for which the rule @@ -775,7 +778,7 @@ def append(x, d): f"to an input of shape/dtype {a.str_short()}.") raise ValueError(msg) results.append(ct) - yield results + return results # TODO(mattjj): remove both these exceptions to cotangent compatibility check def _temporary_dtype_exception(a, a_) -> bool: @@ -1425,11 +1428,11 @@ def fun_jaxpr_thunk(): return wrapped_fwd -@lu.transformation -def _fix_fwd_args(*args): +@lu.transformation2 +def _fix_fwd_args(f, *args): args = [(x, True) for x in args] args = [x for pair in args for x in pair] - yield (yield args, {}) + return f(*args) def _remat_opt_impl( *args, @@ -1531,6 +1534,8 @@ def _remat_opt_transpose( "remat optimization for custom_vjp does not support higher-order AD") def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): + if not any(used_outs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]]) outvars = [v for used, v in zip(used_outs, eqn.outvars) if used] if any(used_res): diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index ac0418932b83..f5b0c3fd68b1 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -339,6 +339,8 @@ def _issubclass(a: Any, b: Any) -> bool: return False +_types_for_issubdtype = (type, np.dtype, ExtendedDType) + # TODO(jakevdp): consider whether to disallow None here. We allow it # because np.issubdtype allows it (and treats it as equivalent to float64). def issubdtype(a: DTypeLike | ExtendedDType | None, @@ -360,8 +362,8 @@ def issubdtype(a: DTypeLike | ExtendedDType | None, # unhashable (e.g. custom objects with a dtype attribute). The following check is # fast and covers the majority of calls to this function within JAX library code. return _issubdtype_cached( - a if isinstance(a, (type, np.dtype, ExtendedDType)) else np.dtype(a), # type: ignore[arg-type] - b if isinstance(b, (type, np.dtype, ExtendedDType)) else np.dtype(b), # type: ignore[arg-type] + a if isinstance(a, _types_for_issubdtype) else np.dtype(a), # type: ignore[arg-type] + b if isinstance(b, _types_for_issubdtype) else np.dtype(b), # type: ignore[arg-type] ) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index d080aae759a6..99340e728545 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -68,42 +68,43 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, fun, aux = jvp_subtrace_aux(fun) return jvpfun(fun, instantiate, transform_stack), aux -@lu.transformation -def jvpfun(instantiate, transform_stack, primals, tangents): +@lu.transformation2 +def jvpfun(f, instantiate, transform_stack, primals, tangents): tag = core.TraceTag() tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) and dtype(t) == float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) with ctx: - out_primals, out_tangents = yield (tag, primals, tangents), {} + out_primals, out_tangents = f(tag, primals, tangents) if type(instantiate) is bool: instantiate = [instantiate] * len(out_tangents) out_tangents = [instantiate_zeros(t) if inst else t for t, inst in zip(out_tangents, instantiate)] - yield out_primals, out_tangents + return out_primals, out_tangents -@lu.transformation -def jvp_subtrace(tag, primals, tangents): +@lu.transformation2 +def jvp_subtrace(f, tag, primals, tangents): with core.take_current_trace() as parent_trace: trace = JVPTrace(parent_trace, tag) in_tracers = [maybe_jvp_tracer(trace, x, t) for x, t in zip(primals, tangents)] with core.set_current_trace(trace): - ans = yield in_tracers, {} + ans = f(*in_tracers) out = unzip2(map(trace.to_primal_tangent_pair, ans)) - yield out + return out -@lu.transformation_with_aux -def jvp_subtrace_aux(tag, primals, tangents): +@lu.transformation_with_aux2 +def jvp_subtrace_aux(f, store, tag, primals, tangents): with core.take_current_trace() as parent_trace: trace = JVPTrace(parent_trace, tag) with core.set_current_trace(trace): - ans, aux = yield map(partial(maybe_jvp_tracer, trace), primals, tangents), {} + ans, aux = f(*(map(partial(maybe_jvp_tracer, trace), primals, tangents))) out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) aux_primals = [x.primal if isinstance(x, JVPTracer) and x._trace.tag is tag else x for x in aux] - yield (out_primals, out_tangents), aux_primals + store.store(aux_primals) + return out_primals, out_tangents def linearize(traceable, *primals, **kwargs): has_aux = kwargs.pop('has_aux', False) @@ -262,10 +263,11 @@ def get_primitive_transpose(p): "Transpose rule (for reverse-mode differentiation) for '{}' " "not implemented".format(p)) from err -@lu.transformation_with_aux -def nonzero_tangent_outputs(*args, **kwargs): - results = (_, tangents_out) = yield args, kwargs - yield results, [type(r) is not Zero for r in tangents_out] +@lu.transformation_with_aux2 +def nonzero_tangent_outputs(f, store, *args, **kwargs): + results = (_, tangents_out) = f(*args, **kwargs) + store.store([type(r) is not Zero for r in tangents_out]) + return results class JVPTrace(Trace): @@ -543,15 +545,16 @@ def zero_jvp(primitive, primals, tangents, **params): def instantiate_zeros(tangent): return zeros_like_aval(tangent.aval) if type(tangent) is Zero else tangent -@lu.transformation_with_aux -def traceable(in_tree, *primals_and_tangents): +@lu.transformation_with_aux2 +def traceable(f, store, in_tree, *primals_and_tangents): primals, tangents = tree_unflatten(in_tree, primals_and_tangents) tangents = [Zero.from_primal_value(p) if t is None else t for p, t in zip(primals, tangents)] - primals_out, tangents_out = yield (primals, tangents), {} + primals_out, tangents_out = f(primals, tangents) tangents_out = [None if type(t) is Zero else t for t in tangents_out] out_flat, out_tree = tree_flatten((primals_out, tangents_out)) - yield out_flat, out_tree + store.store(out_tree) + return out_flat def call_transpose(primitive, params, call_jaxpr, args, ct, _): @@ -588,10 +591,11 @@ def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals): primitive_transposes[core.closed_call_p] = _closed_call_transpose -@lu.transformation_with_aux -def nonzero_outputs(*args, **kwargs): - results = yield args, kwargs - yield results, [type(r) is not Zero for r in results] +@lu.transformation_with_aux2 +def nonzero_outputs(f, store, *args, **kwargs): + results = f(*args, **kwargs) + store.store([type(r) is not Zero for r in results]) + return results def map_transpose(primitive, params, call_jaxpr, args, ct, _): all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts @@ -655,17 +659,18 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate): jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() -@lu.transformation_with_aux -def f_jvp_traceable(nonzeros, *primals_and_nztangents): +@lu.transformation_with_aux2 +def f_jvp_traceable(f, store, nonzeros, *primals_and_nztangents): num_primals = len(nonzeros) primals = list(primals_and_nztangents[:num_primals]) nonzero_tangents = iter(primals_and_nztangents[num_primals:]) tangents = [next(nonzero_tangents) if nz else Zero.from_primal_value(p) for p, nz in zip(primals, nonzeros)] - primals_out, tangents_out = yield (primals, tangents), {} + primals_out, tangents_out = f(primals, tangents) out_nonzeros = [type(t) is not Zero for t in tangents_out] nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero] - yield list(primals_out) + nonzero_tangents_out, out_nonzeros + store.store(out_nonzeros) + return list(primals_out) + nonzero_tangents_out def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 0adb582a7993..f4658ec2be29 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -327,11 +327,13 @@ def unregister_vmappable(data_type: type) -> None: def is_vmappable(x: Any) -> bool: return type(x) is Jumble or type(x) in vmappables -@lu.transformation_with_aux -def flatten_fun_for_vmap(in_tree, *args_flat): +@lu.transformation_with_aux2 +def flatten_fun_for_vmap(f, store, in_tree, *args_flat): py_args, py_kwargs = tree_unflatten(in_tree, args_flat) - ans = yield py_args, py_kwargs - yield tree_flatten(ans, is_leaf=is_vmappable) + ans = f(*py_args, **py_kwargs) + ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable) + store.store(out_tree) + return ans # Propagate ragged masking rules from invars to outvars # rule([params], [raggedness_per_invar], outvars) -> @@ -580,16 +582,16 @@ def batch(fun: lu.WrappedFun, axis_data, f = _batch_inner(fun, axis_data, out_dim_dests) return _batch_outer(f, axis_data, in_dims) -@lu.transformation -def _batch_outer(axis_data, in_dims, *in_vals): +@lu.transformation2 +def _batch_outer(f, axis_data, in_dims, *in_vals): tag = TraceTag() with source_info_util.transform_name_stack('vmap'): - outs, trace = yield (tag, in_dims, *in_vals), {} + outs, trace = f(tag, in_dims, *in_vals) with core.ensure_no_leaks(trace): del trace - yield outs + return outs -@lu.transformation -def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals): +@lu.transformation2 +def _batch_inner(f, axis_data, out_dim_dests, tag, in_dims, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) @@ -599,13 +601,13 @@ def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals): with (core.set_current_trace(trace), core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), core.add_spmd_axis_names(axis_data.spmd_name)): - outs = yield in_tracers, {} + outs = f(*in_tracers) out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)), outs, out_dim_dests) - yield out_vals, trace + return out_vals, trace # NOTE: This divides the in_axes by the tile_size and multiplies the out_axes by it. def vtile(f_flat: lu.WrappedFun, @@ -628,21 +630,21 @@ def untile_axis(out, axis: int | None): shape[axis:axis+2] = [shape[axis] * shape[axis+1]] return out.reshape(shape) - @lu.transformation - def _map_to_tile(*args_flat): + @lu.transformation2 + def _map_to_tile(f, *args_flat): sizes = (x.shape[i] for x, i in safe_zip(args_flat, in_axes_flat) if i is not None) tile_size_ = tile_size or next(sizes, None) assert tile_size_ is not None, "No mapped arguments?" - outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {} - yield map(untile_axis, outputs_flat, out_axes_flat) + outputs_flat = f(*map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat)) + return map(untile_axis, outputs_flat, out_axes_flat) axis_data = AxisData(axis_name, tile_size, None) return _map_to_tile(batch(f_flat, axis_data, in_axes_flat, out_axes_flat)) ### API for batching functions with jaxpr type inputs and outputs -@lu.transformation_with_aux -def batch_subtrace(tag, axis_data, in_dims, *in_vals): +@lu.transformation_with_aux2 +def batch_subtrace(f, store, tag, axis_data, in_dims, *in_vals): with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) with core.set_current_trace(trace): @@ -650,10 +652,11 @@ def batch_subtrace(tag, axis_data, in_dims, *in_vals): in_vals, in_dims = resolve_ragged_axes(in_vals, in_dims) in_tracers = [BatchTracer(trace, x, dim, source_info_util.current()) if dim is not None else x for x, dim in zip(in_vals, in_dims)] - outs = yield in_tracers, {} + outs = f(*in_tracers) out_vals, out_dims = unzip2(map(trace.to_batch_info, outs)) segment_lens, out_dims = indirectify_ragged_axes(out_dims) - yield (*segment_lens, *out_vals), out_dims + store.store(out_dims) + return (*segment_lens, *out_vals) def indirectify_ragged_axes(dims): if not any(type(d) is RaggedAxis for d in dims): @@ -789,8 +792,8 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest): jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, consts), out_batched() -@lu.transformation_with_aux -def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals): +@lu.transformation_with_aux2 +def _batch_jaxpr_inner(f, store, axis_data, tag, in_axes, *in_vals): with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) _, in_axes = resolve_ragged_axes(in_vals, in_axes) @@ -799,16 +802,17 @@ def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals): with (core.set_current_trace(trace), core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), core.add_spmd_axis_names(axis_data.spmd_name)): - outs = yield in_tracers, {} + outs = f(*in_tracers) out_vals, out_axes = unzip2(map(trace.to_batch_info, outs)) new_out_axes = indirectify_ragged_axes_against_inputs_outputs( out_axes, in_vals, out_vals) - yield out_vals, new_out_axes + store.store(new_out_axes) + return out_vals -@lu.transformation_with_aux -def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes, +@lu.transformation_with_aux2 +def _match_axes_jaxpr(f, store, axis_data, out_axes_dest, out_axes, trace, in_axes, *in_vals): - out_vals = yield (trace, in_axes, *in_vals), {} + out_vals = f(trace, in_axes, *in_vals) out_axes = out_axes() out_axes_dest = [(None if src is not_mapped else 0) if dst is zero_if_mapped else dst @@ -819,16 +823,16 @@ def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, trace, in_axes, out_vals = map(partial(matchaxis, axis_data.name, axis_data.size), out_axes, out_axes_dest, out_vals) out_batched = [dst is not None for dst in out_axes_dest] - yield out_vals, out_batched + store.store(out_batched) + return out_vals -@lu.transformation -def _batch_jaxpr_outer(axis_data, in_dims, *in_vals): +@lu.transformation2 +def _batch_jaxpr_outer(f, axis_data, in_dims, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int) else ax for x, ax in unsafe_zip(in_vals, in_dims)] tag = TraceTag() - out_vals = yield (tag, in_dims, *in_vals), {} - yield out_vals + return f(tag, in_dims, *in_vals) def _merge_bdims(x, y): if x == y: @@ -845,8 +849,8 @@ class ZeroIfMapped: pass ### functions for handling custom_vjp -@lu.transformation_with_aux -def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals): +@lu.transformation_with_aux2 +def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals): size = axis_data.size with core.take_current_trace() as parent_trace: trace = BatchTrace(parent_trace, tag, axis_data) @@ -855,7 +859,7 @@ def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals): if type(val) is SymbolicZero else BatchTracer(trace, val, dim) for val, dim in zip(in_vals, in_dims * 2)] with core.set_current_trace(trace): - outs = yield in_tracers, {} + outs = f(*in_tracers) # TODO(mattjj,frostig): instantiating any SymbolicZero output is easy, but can # be wasteful in the rare case it actually triggers; handle symbolically! outs = [instantiate(replace_rule_output_symbolic_zeros(x)) for x in outs] @@ -868,7 +872,8 @@ def batch_custom_jvp_subtrace(tag, axis_data, in_dims, *in_vals): out_primal_bds, out_dims, out_primals) out_tangents = map(partial(matchaxis, trace.axis_data.name, size), out_tangent_bds, out_dims, out_tangents) - yield out_primals + out_tangents, out_dims * 2 + store.store(out_dims * 2) + return out_primals + out_tangents def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests): axis_size = axis_data.size @@ -886,11 +891,11 @@ def new_bwd(*args): return bwd_.call_wrapped(*args) return new_bwd -@lu.transformation -def _match_axes_and_sum(axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals): +@lu.transformation2 +def _match_axes_and_sum(f, axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals): # this is like _match_axes, but we do reduce-sums as needed - out_vals = yield in_vals, {} - yield map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name, + out_vals = f(*in_vals) + return map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name, sum_match=True), out_dims_thunk(), out_dim_dests, out_vals) def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False): diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index ad97ef325f64..943c15b6ea49 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -360,7 +360,7 @@ def const_out_axes_thunk(): staged_out_axes, _ = partition_list(out_knowns, out_axes) staged_in_axes = (0,) * len(res) + (None,) * len(env) + (*unk_in_axes,) - # Create the input tracers for the staged-out (unkonwn-value) call. + # Create the input tracers for the staged-out (unknown-value) call. const_tracers = map(self.new_instantiated_const, res) env_tracers = map(self.to_jaxpr_tracer, env) unknown_arg_tracers = [t for t in tracers if not t.is_known()] @@ -475,18 +475,19 @@ def partition_pvals( consts = [pval.get_known() for pval in pvals if pval.is_known()] return knowns, avals, consts -@lu.transformation_with_aux +@lu.transformation_with_aux2 def partial_eval_wrapper_nounits( - in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue], + f, store, in_knowns: Sequence[bool], in_avals: Sequence[AbstractValue], *in_consts: Any): in_avals_, in_consts_ = iter(in_avals), iter(in_consts) in_pvals = [PartialVal.known(next(in_consts_)) if known else PartialVal.unknown(next(in_avals_)) for known in in_knowns] sentinel = object() assert next(in_avals_, sentinel) is next(in_consts_, sentinel) is sentinel - jaxpr, (*maybe_fwds, out_pvals, res, env) = yield (in_pvals,), {} + jaxpr, (*maybe_fwds, out_pvals, res, env) = f(in_pvals) out_knowns, out_avals, out_consts = partition_pvals(out_pvals) - yield (*out_consts, *res), (*maybe_fwds, out_knowns, out_avals, jaxpr, env) + store.store((*maybe_fwds, out_knowns, out_avals, jaxpr, env)) + return (*out_consts, *res) custom_partial_eval_rules: dict[Primitive, Callable] = {} call_partial_eval_rules: dict[Primitive, Callable] = {} @@ -574,20 +575,22 @@ def trace_to_jaxpr_nounits( return jaxpr, out_pvals, consts # TODO(mattjj): superfluous wrapper...? -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits( + f, trace: JaxprTrace, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): assert all(isinstance(pv, PartialVal) for pv in in_pvals), in_pvals - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - trace, instantiate, in_pvals) + out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] del out_tracers - yield jaxpr, (out_pvals, out_consts, env) + return jaxpr, (out_pvals, out_consts, env) -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits2( + f, tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): @@ -596,19 +599,19 @@ def trace_to_subjaxpr_nounits2( current_name_stack = source_info_util.current_name_stack() with core.take_current_trace() as parent_trace: trace = JaxprTrace(parent_trace, current_name_stack, tag) - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - trace, instantiate, in_pvals) + out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] del out_tracers - yield jaxpr, (out_pvals, out_consts, env) + return jaxpr, (out_pvals, out_consts, env) -def _trace_to_subjaxpr_nounits(trace:JaxprTrace, instantiate, in_pvals): +def _trace_to_subjaxpr_nounits(f, trace:JaxprTrace, instantiate, in_pvals): in_knowns = [pval.is_known() for pval in in_pvals] in_consts = [pval.get_known() for pval in in_pvals if pval.is_known()] in_tracers = [trace.new_arg(pval) for pval in in_pvals if not pval.is_known()] in_args = merge_lists(in_knowns, in_tracers, in_consts) with core.set_current_trace(trace): - ans = yield in_args, {} + ans = f(*in_args) assert isinstance(ans, (list, tuple)), ( f"Got unexpected return type when tracing function to jaxpr: {ans}") assert all(isinstance(x, core.Tracer) or core.valid_jaxtype(x) for x in ans), ( @@ -625,8 +628,9 @@ def _trace_to_subjaxpr_nounits(trace:JaxprTrace, instantiate, in_pvals): # The below variant implements an optimization where residuals which are also # inputs are indicated in auxiliary data rather than passed as outputs. # TODO(mattjj): update all callers to use this version, delete other version. -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits_fwd( + f, tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): @@ -635,8 +639,8 @@ def trace_to_subjaxpr_nounits_fwd( with core.take_current_trace() as parent_trace: trace = JaxprTrace(parent_trace, current_name_stack, tag) with core.set_current_trace(trace): - out_tracers, jaxpr, out_consts, env = yield from _trace_to_subjaxpr_nounits( - trace, instantiate, in_pvals) + out_tracers, jaxpr, out_consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] # Which out_consts (aka residuals) are just forwarded inputs? Check obj id. @@ -646,15 +650,16 @@ def trace_to_subjaxpr_nounits_fwd( pruned_consts = [c for c, fwd in zip(out_consts, fwds) if fwd is None] del out_tracers - yield jaxpr, (fwds, out_pvals, pruned_consts, env) + return jaxpr, (fwds, out_pvals, pruned_consts, env) # The below variant implements two optimizations: # 1. residuals that are also primal inputs are indicated in aux data rather # than passed as outputs; # 2. residuals that are also primal outputs are indicated in aux data rather # than passed as redundant outputs. -@lu.transformation +@lu.transformation2 def trace_to_subjaxpr_nounits_fwd2( + f, tag: TraceTag, instantiate: bool | Sequence[bool], in_pvals: Sequence[PartialVal]): @@ -662,8 +667,8 @@ def trace_to_subjaxpr_nounits_fwd2( current_name_stack = source_info_util.current_name_stack() with core.take_current_trace() as parent_trace: trace = JaxprTrace(parent_trace, current_name_stack, tag) - out_tracers, jaxpr, consts, env = yield from _trace_to_subjaxpr_nounits( - trace, instantiate, in_pvals) + out_tracers, jaxpr, consts, env = _trace_to_subjaxpr_nounits( + f, trace, instantiate, in_pvals) out_pvals = [t.pval for t in out_tracers] # Which consts (aka residuals) are just forwarded inputs? Check obj id. @@ -680,7 +685,7 @@ def trace_to_subjaxpr_nounits_fwd2( if f1 is None and f2 is None] del out_tracers - yield jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env) + return jaxpr, (input_fwds, output_fwds, out_pvals, pruned_consts, env) FreeVar = namedtuple('FreeVar', ['val']) @@ -1382,6 +1387,11 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool], return new_jaxpr, used_consts, used_inputs +def has_effects(eqn: JaxprEqn) -> bool: + effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)} + return bool(effs) + + @weakref_lru_cache def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...], instantiate: tuple[bool, ...] @@ -1395,21 +1405,14 @@ def write(x: Atom, b: bool) -> None: if type(x) is Var: env[x] = read(x) or b - def has_effects(eqn: JaxprEqn) -> bool: - effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)} - return bool(effs) - new_eqns = [] map(write, jaxpr.outvars, used_outputs) for eqn in jaxpr.eqns[::-1]: used_outs = map(read, eqn.outvars) - if not any(used_outs) and not has_effects(eqn): - used_ins = [False] * len(eqn.invars) - else: - rule = dce_rules.get(eqn.primitive, _default_dce_rule) - used_ins, new_eqn = rule(used_outs, eqn) - if new_eqn is not None: - new_eqns.append(new_eqn) + rule = dce_rules.get(eqn.primitive, _default_dce_rule) + used_ins, new_eqn = rule(used_outs, eqn) + if new_eqn is not None: + new_eqns.append(new_eqn) map(write, eqn.invars, used_ins) used_inputs = map(read, jaxpr.invars) used_inputs = map(op.or_, instantiate, used_inputs) @@ -1433,7 +1436,9 @@ def has_effects(eqn: JaxprEqn) -> bool: def _default_dce_rule( used_outs: list[bool], eqn: JaxprEqn - ) -> tuple[list[bool], JaxprEqn]: + ) -> tuple[list[bool], JaxprEqn | None]: + if not any(used_outs) and not has_effects(eqn): + return [False] * len(eqn.invars), None return [True] * len(eqn.invars), eqn dce_rules: dict[Primitive, DCERule] = {} @@ -1441,6 +1446,8 @@ def _default_dce_rule( def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn ) -> tuple[list[bool], JaxprEqn | None]: + if not any(used_outputs) and not has_effects(eqn): + return [False] * len(eqn.invars), None new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) new_params = dict(eqn.params, call_jaxpr=new_jaxpr) update_params = call_param_updaters.get(eqn.primitive) @@ -1454,6 +1461,7 @@ def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn [v for v, used in zip(eqn.outvars, used_outputs) if used], eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info, eqn.ctx) return used_inputs, new_eqn + dce_rules[core.call_p] = dce_jaxpr_call_rule @@ -1465,8 +1473,10 @@ def _cached_closed_call_dce(jaxpr_, used_outputs: tuple[bool, ...] return core.ClosedJaxpr(new_jaxpr, consts), used_inputs def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn - ) -> tuple[list[bool], JaxprEqn]: + ) -> tuple[list[bool], JaxprEqn | None]: # TODO(mattjj): de-duplicate with above rule? + if not any(used_outputs) and not has_effects(eqn): + return [False] * len(eqn.invars), None jaxpr_ = eqn.params['call_jaxpr'] closed_jaxpr, used_inputs = _cached_closed_call_dce(jaxpr_, tuple(used_outputs)) new_params = dict(eqn.params, call_jaxpr=closed_jaxpr) @@ -2061,10 +2071,10 @@ def transpose_jaxpr_thunk(): custom_staging_rules: dict[Primitive, Callable] = {} -@lu.transformation -def _interleave_fun(every_others, *args, **kwargs): +@lu.transformation2 +def _interleave_fun(f, every_others, *args, **kwargs): args_ = [x for pair in zip(args, every_others) for x in pair] - yield (yield (args_, kwargs)) + return f(*args_, **kwargs) # TODO: consider renaming to "lazy_thunk" def _memoize(fn): @@ -2078,18 +2088,19 @@ def memoized(*args): return out return memoized -@lu.transformation_with_aux -def _jvp_jaxpr_zeros(in_zeros, zero_avals, *primal_tangent_avals): +@lu.transformation_with_aux2 +def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals): in_primals, nz_in_tangents = split_list(primal_tangent_avals, [len(in_zeros)]) symbolic_zeros = map(ad_util.SymbolicZero, zero_avals) tangents = merge_lists(in_zeros, nz_in_tangents, symbolic_zeros) - out = yield (*in_primals, *tangents), {} + out = f(*in_primals, *tangents) n, ragged = divmod(len(out), 2) assert not ragged out_primals, out_tangents = out[:n], out[n:] out_zeros = [type(t) is ad_util.SymbolicZero for t in out_tangents] out_nz_tangents, _ = partition_list(out_zeros, out_tangents) - yield [*out_primals, *out_nz_tangents], out_zeros + store.store(out_zeros) + return [*out_primals, *out_nz_tangents] # TODO(mattjj): remove this DebugInfo and helper functions, replace with # api_util.py versions diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9a17194d46c9..6c9e54441f8e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -690,15 +690,15 @@ def find_replicas( num_global_replicas = global_axis_size * jaxpr_replicas return ReplicaInfo(jaxpr_replicas, num_local_replicas, num_global_replicas) -@lu.transformation -def _change_argument_ranks(in_axes, out_axes_thunk, *args): +@lu.transformation2 +def _change_argument_ranks(f, in_axes, out_axes_thunk, *args): args = tuple( arg if in_axis is None else jax.lax.squeeze(arg, dimensions=(in_axis,)) for in_axis, arg in zip(in_axes, args) ) - results = yield (args, {}) + results = f(*args) out_axes = out_axes_thunk() - yield tuple( + return tuple( x if axis is None else jax.lax.expand_dims(x, dimensions=(axis,)) for x, axis in zip(results, out_axes) ) @@ -1353,6 +1353,8 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval): def _pmap_dce_rule(used_outputs, eqn): # just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None axis_name = eqn.params["axis_name"] with core.extend_axis_env_nd([(axis_name, eqn.params["global_axis_size"])]): new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs) @@ -1709,7 +1711,10 @@ class DeviceAssignmentMismatchError(Exception): def _get_default_device() -> xc.Device: - return config.default_device.value or xb.local_devices()[0] + if isinstance(config.default_device.value, str): + return xb.get_backend(config.default_device.value).local_devices()[0] + else: + return config.default_device.value or xb.local_devices()[0] def _get_and_check_device_assignment( diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 6333638deae6..9e1f7e04c741 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -642,7 +642,11 @@ def _ordered_unique(xs): return list(d.keys()) def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn, - ) -> tuple[list[bool], core.JaxprEqn]: + ) -> tuple[list[bool], core.JaxprEqn | None]: + + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + closed_branches = eqn.params['branches'] branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches] diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index b5bb8658e675..d15917b8b1da 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -944,7 +944,9 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params): return scan_p.bind(*args, jaxpr=_cached_scan_pad_jaxpr(jaxpr), **params) def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn - ) -> tuple[list[bool], core.JaxprEqn]: + ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None jaxpr = eqn.params['jaxpr'] num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry'] num_xs = len(jaxpr.in_avals) - num_consts - num_carry diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9781f67152c8..c45d8f5c80b2 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1915,7 +1915,7 @@ def batch_matmul(lhs: Array, rhs: Array, def square(x: ArrayLike) -> Array: r"""Elementwise square: :math:`x^2`.""" - return integer_pow(x, 2) + return square_p.bind(x) def reciprocal(x: ArrayLike) -> Array: r"""Elementwise reciprocal: :math:`1 \over x`.""" @@ -2203,14 +2203,13 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): for op, in_aval in zip(ops, in_avals): if in_aval.sharding == out_aval.sharding or in_aval.sharding is None: out.append(op) + elif in_aval.sharding.mesh.are_all_axes_collective: + out.append(op) else: # TODO(yashkatariya, dougalm): If `in_aval.sharding` contains # CompilerShardingAxis, then specify `unspecified_dims` via # `wrap_with_sharding_op`. - if config.use_shardy_partitioner.value: - sp = in_aval.sharding._to_sdy_sharding(in_aval.ndim) - else: - sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() + sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp)) return out @@ -2227,10 +2226,9 @@ def _nary_lower_hlo(op: Callable, ctx, out = op(*args) if config.sharding_in_types.value: - if config.use_shardy_partitioner.value: - out_sp = aval_out.sharding._to_sdy_sharding(aval_out.ndim) - else: - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() + if aval_out.sharding.mesh.are_all_axes_collective: + return [out] + out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)] else: return [out] @@ -2524,6 +2522,27 @@ def _abs_jvp_rule(g, ans, x): lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2)))) mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt)) +square_p = standard_unop(_int | _float | _complex, 'square') + +def _square_complex(x): + a, b = real(x), imag(x) + # zero square(x).real is handled explicitly for abs(a)==abs(b) cases + # where for finite a, 2 * a is non-finite: + zero_re = is_finite(a) & (eq(a, b) | eq(a, -b)) + # equivalent to a**2 - b**2 but avoids overflow errors for large a + # and large b cases: + re = (a - b) * (a + b) + im = a * b * 2 + return select(zero_re, complex(_const(a, 0), im), complex(re, im)) + +def _square_lower_hlo(ctx, x): + if dtypes.issubdtype(ctx.avals_in[0].dtype, np.complexfloating): + return mlir.lower_fun(_square_complex, multiple_results=False)(ctx, x) + return [hlo.multiply(x, x)] + +ad.defjvp2(square_p, lambda g, ans, x: mul(g, mul(_const(x, 2), x))) +mlir.register_lowering(square_p, _square_lower_hlo) # TODO(pearu): use chlo.square + def _pow_dtype_rule(x, y): if (dtypes.issubdtype(x.dtype, np.inexact) and dtypes.issubdtype(y.dtype, np.integer)): @@ -2612,24 +2631,24 @@ def _integer_pow(x, *, y): def _integer_pow_lowering(ctx, x, *, y): # These cases are subsumed by the general case, but it's faster to emit these # common cases directly. - if y == 2: + if y == 1: + out = x + elif y == 2: out = hlo.multiply(x, x) elif y == 3: out = hlo.multiply(hlo.multiply(x, x), x) + elif y == -1: + out = hlo.divide(mlir.full_like_aval(ctx, 1, ctx.avals_in[0]), x) else: lowering = mlir.lower_fun(_integer_pow, multiple_results=False) - # TODO(b/217551391): emitting an out-of-line call leads to a large - # expansion when the MLIR is lowered to HLO, because the HLO lowering - # clones the callee. Consider unconditionally caching when the MLIR->HLO - # lowering doesn't expand the program. - lowering = mlir.cache_lowering(lowering) - out = lowering(ctx, x, y=y) + if builtins.abs(y) >= 3: + lowering = mlir.cache_lowering(lowering) + out, = lowering(ctx, x, y=y) if config.sharding_in_types.value: aval_out, = ctx.avals_out proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() - out = out[0] if isinstance(out, list) else out return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)] - return out if isinstance(out, list) else [out] + return [out] mlir.register_lowering(integer_pow_p, _integer_pow_lowering) diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 08f94c6e8eda..37d812dec619 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -64,6 +64,7 @@ def trans1(static_arg, *dynamic_args, **kwargs): from __future__ import annotations from collections.abc import Callable +from functools import partial from typing import Any, NamedTuple import weakref @@ -149,10 +150,11 @@ class WrappedFun: params: extra parameters to pass as keyword arguments to `f`, along with the transformed keyword arguments. """ - __slots__ = ("f", "transforms", "stores", "params", "in_type", "debug_info") + __slots__ = ("f", "f_transformed", "transforms", "stores", "params", "in_type", "debug_info") - def __init__(self, f, transforms, stores, params, in_type, debug_info): + def __init__(self, f, f_transformed, transforms, stores, params, in_type, debug_info): self.f = f + self.f_transformed = f_transformed self.transforms = transforms self.stores = stores self.params = params @@ -165,8 +167,14 @@ def __name__(self): def wrap(self, gen, gen_static_args, out_store) -> WrappedFun: """Add another transform and its store.""" - return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms, - (out_store,) + self.stores, self.params, None, None) + if out_store is None: + return WrappedFun(self.f, partial(gen, self.f_transformed, *gen_static_args), + ((gen, gen_static_args),) + self.transforms, + (out_store,) + self.stores, self.params, None, None) + else: + return WrappedFun(self.f, partial(gen, self.f_transformed, out_store, *gen_static_args), + ((gen, gen_static_args),) + self.transforms, + (out_store,) + self.stores, self.params, None, None) def populate_stores(self, stores): """Copy the values from the `stores` into `self.stores`.""" @@ -175,47 +183,8 @@ def populate_stores(self, stores): self_store.store(other_store.val) def call_wrapped(self, *args, **kwargs): - """Calls the underlying function, applying the transforms. - - The positional `args` and keyword `kwargs` are passed to the first - transformation generator. - """ - stack = [] - for (gen, gen_static_args), out_store in zip(self.transforms, self.stores): - gen = gen(*(gen_static_args + tuple(args)), **kwargs) - args, kwargs = next(gen) - stack.append((gen, out_store)) - gen = gen_static_args = out_store = None - - try: - ans = self.f(*args, **dict(self.params, **kwargs)) - except: - # Some transformations yield from inside context managers, so we have to - # interrupt them before reraising the exception. Otherwise they will only - # get garbage-collected at some later time, running their cleanup tasks - # only after this exception is handled, which can corrupt the global - # state. - while stack: - stack.pop()[0].close() - raise - - args = kwargs = None - while stack: - gen, out_store = stack.pop() - try: - ans = gen.send(ans) - except: - # As above does for the first half of the transformation, exceptions - # raised in the second half of the transformation also require us to - # clean up references here. - while stack: - stack.pop()[0].close() - raise - if out_store is not None: - ans, side = ans - out_store.store(side) - - return ans + """Calls the transformed function""" + return self.f_transformed(*args, **kwargs) def __repr__(self): def transform_to_str(x): @@ -234,7 +203,7 @@ def __eq__(self, other): self.debug_info == other.debug_info) @curry -def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: +def transformation2(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: """Adds one more transformation to a WrappedFun. Args: @@ -244,8 +213,28 @@ def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: """ return fun.wrap(gen, gen_static_args, None) +# Backwards compat only. TODO: deprecate +@curry +def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: + def gen2(f, *args, **kwargs): + gen_inst = gen(*args, **kwargs) + args_, kwargs_ = next(gen_inst) + return gen_inst.send(f(*args_, **kwargs_)) + return transformation2(gen2, fun, *gen_static_args)() + +# Backwards compat only. TODO: deprecate +@curry +def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: + def gen2(f, store, *args, **kwargs): + gen_inst = gen(*args, **kwargs) + args_, kwargs_ = next(gen_inst) + ans, aux = gen_inst.send(f(*args_, **kwargs_)) + store.store(aux) + return ans + return transformation_with_aux2(gen2, fun, *gen_static_args)() + @curry -def transformation_with_aux( +def transformation_with_aux2( gen, fun: WrappedFun, *gen_static_args, use_eq_store: bool = False ) -> tuple[WrappedFun, Callable[[], Any]]: """Adds one more transformation with auxiliary output to a WrappedFun.""" @@ -261,8 +250,9 @@ def fun_name(f): def wrap_init(f, params=None) -> WrappedFun: """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" + params_dict = {} if params is None else params params = () if params is None else tuple(sorted(params.items())) - return WrappedFun(f, (), (), params, None, None) + return WrappedFun(f, partial(f, **params_dict), (), (), params, None, None) def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun: @@ -270,7 +260,7 @@ def annotate(f: WrappedFun, in_type: core.InputType | None) -> WrappedFun: if in_type is None: return f _check_input_type(in_type) - return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type, f.debug_info) + return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, in_type, f.debug_info) def _check_input_type(in_type: core.InputType) -> None: # Check that in_type is syntactically well-formed @@ -317,7 +307,7 @@ def add_debug_info(f: WrappedFun, debug_info: TracingDebugInfo | None assert f.debug_info is None if debug_info is None: return f - return WrappedFun(f.f, f.transforms, f.stores, f.params, f.in_type, debug_info) + return WrappedFun(f.f, f.f_transformed, f.transforms, f.stores, f.params, f.in_type, debug_info) def cache(call: Callable, *, explain: Callable | None = None): @@ -357,9 +347,9 @@ def _evict_function(f): cache_clearing_funs.add(memoized_fun.cache_clear) return memoized_fun -@transformation -def hashable_partial(*args): - yield (yield args, {}) +@transformation2 +def hashable_partial(f, *args): + return f(*args) def merge_linear_aux(aux1, aux2): diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 43791f2e5f72..082c443fade4 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -18,6 +18,7 @@ import collections from collections.abc import Hashable, Sequence import contextlib +import enum import functools import math import threading @@ -101,6 +102,12 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names) +class AxisTypes(enum.Enum): + Auto = enum.auto() + User = enum.auto() + Collective = enum.auto() + + _mesh_object_dict = {} # type: ignore @@ -157,9 +164,11 @@ class Mesh(contextlib.ContextDecorator): devices: np.ndarray axis_names: tuple[MeshAxisName, ...] + axis_types: dict[AxisTypes, str | tuple[str, ...]] | None def __new__(cls, devices: np.ndarray | Sequence[xc.Device], - axis_names: str | Sequence[MeshAxisName]): + axis_names: str | Sequence[MeshAxisName], + axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None): if not isinstance(devices, np.ndarray): devices = np.array(devices) if isinstance(axis_names, str): @@ -175,7 +184,10 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], f"devices.ndim == {devices.ndim} and " f"len(axis_names) == {len(axis_names)}.") - key = (axis_names, devices.shape, tuple(devices.flat)) + # TODO(yashkatariya): If axis_types is None, set all axes to AUTO. + axis_types_tuple = (None if axis_types is None else + tuple(axis_types.items())) + key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple) val = _mesh_object_dict.get(key, None) if val is not None: return val @@ -184,11 +196,13 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], self.devices = devices.copy() self.devices.flags.writeable = False self.axis_names = axis_names + self.axis_types = axis_types + self._axis_types_tuple = axis_types_tuple _mesh_object_dict[key] = self return self def __reduce__(self): - return (type(self), (self.devices, self.axis_names)) + return (type(self), (self.devices, self.axis_names, self.axis_types)) def __eq__(self, other): if not isinstance(other, Mesh): @@ -199,12 +213,14 @@ def __eq__(self, other): return True return (self.axis_names == other.axis_names and self.devices.shape == other.devices.shape and + self._axis_types_tuple == other._axis_types_tuple and self._internal_device_list == other._internal_device_list) def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash( - (self.axis_names, self._internal_device_list, self.devices.shape)) + (self.axis_names, self._internal_device_list, self.devices.shape, + self._axis_types_tuple)) return self._hash def __setattr__(self, name, value): @@ -301,7 +317,8 @@ def __str__(self): def _repr(self): if self.empty: return "Mesh(device_ids=[], axis_names=())" - return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r})" + atr = '' if self.axis_types is None else f", axis_types={self.axis_types}" + return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r}{atr})" def __repr__(self): return self._repr @@ -313,7 +330,7 @@ def local_devices(self): @functools.cached_property def abstract_mesh(self): - return AbstractMesh(self.shape_tuple) + return AbstractMesh(self.shape_tuple, self.axis_types) EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ())) @@ -338,25 +355,32 @@ class AbstractMesh: details. """ - def __init__(self, shape_tuple: tuple[tuple[str, int], ...]): + def __init__(self, shape_tuple: tuple[tuple[str, int], ...], + axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None): self.shape_tuple = shape_tuple + self.axis_types = axis_types if self.shape_tuple: self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple)) else: self._axis_names, self._axis_sizes = (), () + # TODO(yashkatariya): If axis_types is None, set all axes to AUTO. + self._axis_types_tuple = (None if axis_types is None else + tuple(axis_types.items())) def __hash__(self): - return hash(self.shape_tuple) + return hash((self.shape_tuple, self._axis_types_tuple)) def __eq__(self, other): if not isinstance(other, AbstractMesh): return False if id(self) == id(other): return True - return self.shape_tuple == other.shape_tuple + return (self.shape_tuple == other.shape_tuple and + self._axis_types_tuple == other._axis_types_tuple) def __repr__(self): - return f"AbstractMesh({self.shape_tuple})" + atr = '' if self.axis_types is None else f", axis_types={self.axis_types}" + return f"AbstractMesh({self.shape_tuple}{atr})" @property def axis_names(self): @@ -382,6 +406,12 @@ def _internal_device_list(self): def empty(self): return self.size == 0 + @functools.cached_property + def are_all_axes_collective(self) -> bool: + if self.axis_types is None: + return False + return all(t == AxisTypes.Collective for t in self.axis_types.keys()) + @property def devices(self): _raise_value_error("devices") diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b90004e19932..88ddc85a0a40 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4048,15 +4048,37 @@ def _check_no_padding(axis_padding: tuple[Any, Any], mode: str): def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array) -> Array: nd = ndim(array) - constant_values = broadcast_to(constant_values, (nd, 2)) constant_values = lax_internal._convert_element_type( constant_values, array.dtype, dtypes.is_weakly_typed(array)) + constant_values_nd = ndim(constant_values) + + if constant_values_nd == 0: + widths = [(low, high, 0) for (low, high) in pad_width] + return lax.pad(array, constant_values, widths) + + if constant_values_nd == 1: + if constant_values.shape[-1] == 1: + widths = [(low, high, 0) for (low, high) in pad_width] + return lax.pad(array, squeeze(constant_values), widths) + elif constant_values.shape[-1] == 2: + widths = [(low, 0, 0) for (low, _) in pad_width] + array = lax.pad(array, constant_values[0], widths) + widths = [(0, high, 0) for (_, high) in pad_width] + return lax.pad(array, constant_values[1], widths) + else: + raise ValueError("jnp.pad: constant_values has unsupported shape " + f"{constant_values.shape}. If the shape is 1D or 2D, the " + "last dimension must be of size 1 or 2.") + + constant_values = broadcast_to(constant_values, (nd, 2)) for i in range(nd): widths = [(0, 0, 0)] * nd - widths[i] = (pad_width[i][0], 0, 0) - array = lax.pad(array, constant_values[i, 0], widths) - widths[i] = (0, pad_width[i][1], 0) - array = lax.pad(array, constant_values[i, 1], widths) + if pad_width[i][0] != 0: + widths[i] = (pad_width[i][0], 0, 0) + array = lax.pad(array, constant_values[i, 0], widths) + if pad_width[i][1] != 0: + widths[i] = (0, pad_width[i][1], 0) + array = lax.pad(array, constant_values[i, 1], widths) return array diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index be1e55675079..08f11d0cb6ad 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -192,6 +192,11 @@ def _cast_to_bool(operand: ArrayLike) -> Array: def _cast_to_numeric(operand: ArrayLike) -> Array: return promote_dtypes_numeric(operand)[0] +def _require_integer(operand: ArrayLike) -> Array: + arr = lax_internal.asarray(operand) + if not dtypes.isdtype(arr, ("bool", "integral")): + raise ValueError(f"integer argument required; got dtype={arr.dtype}") + return arr def _ensure_optional_axes(x: Axis) -> Axis: def force(x): @@ -652,6 +657,63 @@ def any(a: ArrayLike, axis: Axis = None, out: None = None, return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + arr = lax_internal.asarray(a) + init_val = np.array(-1, dtype=dtype or arr.dtype) + return _reduction(arr, name="reduce_bitwise_and", np_fun=None, op=lax.bitwise_and, init_val=init_val, preproc=_require_integer, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, name="reduce_bitwise_or", np_fun=None, op=lax.bitwise_or, init_val=0, preproc=_require_integer, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, name="reduce_bitwise_xor", np_fun=None, op=lax.bitwise_xor, init_val=0, preproc=_require_integer, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, name="reduce_logical_and", np_fun=None, op=lax.bitwise_and, init_val=True, preproc=_cast_to_bool, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, name="reduce_logical_or", np_fun=None, op=lax.bitwise_or, init_val=False, preproc=_cast_to_bool, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True) +def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + return _reduction(a, name="reduce_logical_xor", np_fun=None, op=lax.bitwise_xor, init_val=False, preproc=_cast_to_bool, + axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + + def amin(a: ArrayLike, axis: Axis = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 8692c30a3e17..93e116fa4b6a 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -31,7 +31,7 @@ from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax from jax._src.lax import other as lax_other -from jax._src.typing import Array, ArrayLike, DTypeLike +from jax._src.typing import Array, ArrayLike from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, @@ -57,6 +57,24 @@ def _to_bool(x: Array) -> Array: return x if x.dtype == bool else lax.ne(x, _lax_const(x, 0)) +def unary_ufunc(func: Callable[[ArrayLike], Array]) -> ufunc: + """An internal helper function for defining unary ufuncs.""" + func_jit = jit(func, inline=True) + return ufunc(func_jit, name=func.__name__, nin=1, nout=1, call=func_jit) + + +def binary_ufunc(identity: Any, reduce: Callable[..., Any] | None = None, + accumulate: Callable[..., Any] | None = None, + at: Callable[..., Any] | None = None, + reduceat: Callable[..., Any] | None = None) -> Callable[[Callable[[ArrayLike, ArrayLike], Array]], ufunc]: + """An internal helper function for defining binary ufuncs.""" + def decorator(func: Callable[[ArrayLike, ArrayLike], Array]) -> ufunc: + func_jit = jit(func, inline=True) + return ufunc(func_jit, name=func.__name__, nin=2, nout=1, call=func_jit, + identity=identity, reduce=reduce, accumulate=accumulate, at=at, reduceat=reduceat) + return decorator + + @partial(jit, inline=True) def fabs(x: ArrayLike, /) -> Array: """Compute the element-wise absolute values of the real-valued input. @@ -160,8 +178,8 @@ def invert(x: ArrayLike, /) -> Array: return lax.bitwise_not(*promote_args('invert', x)) -@partial(jit, inline=True) -def _negative(x: ArrayLike, /) -> Array: +@unary_ufunc +def negative(x: ArrayLike, /) -> Array: """Return element-wise negative values of the input. JAX implementation of :obj:`numpy.negative`. @@ -1126,8 +1144,16 @@ def cbrt(x: ArrayLike, /) -> Array: """ return lax.cbrt(*promote_args_inexact('cbrt', x)) -@partial(jit, inline=True) -def _add(x: ArrayLike, y: ArrayLike, /) -> Array: +def _add_at(a: Array, indices: Any, b: ArrayLike) -> Array: + """Implementation of jnp.add.at.""" + if a.dtype == bool: + a = a.astype('int32') + b = lax.convert_element_type(b, bool).astype('int32') + return a.at[indices].add(b).astype(bool) + return a.at[indices].add(b) + +@binary_ufunc(identity=0, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at) +def add(x: ArrayLike, y: ArrayLike, /) -> Array: """Add two arrays element-wise. JAX implementation of :obj:`numpy.add`. This is a universal function, @@ -1156,8 +1182,17 @@ def _add(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("add", x, y) return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) -@partial(jit, inline=True) -def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array: +def _multiply_at(a: Array, indices: Any, b: ArrayLike) -> Array: + """Implementation of jnp.multiply.at.""" + if a.dtype == bool: + a = a.astype('int32') + b = lax.convert_element_type(b, bool).astype('int32') + return a.at[indices].mul(b).astype(bool) + else: + return a.at[indices].mul(b) + +@binary_ufunc(identity=1, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at) +def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: """Multiply two arrays element-wise. JAX implementation of :obj:`numpy.multiply`. This is a universal function, @@ -1186,8 +1221,8 @@ def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array: x, y = promote_args("multiply", x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) -@partial(jit, inline=True) -def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: +@binary_ufunc(identity=-1, reduce=reductions._reduce_bitwise_and) +def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise AND operation elementwise. JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function, @@ -1215,8 +1250,8 @@ def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_and(*promote_args("bitwise_and", x, y)) -@partial(jit, inline=True) -def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: +@binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_or) +def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise OR operation elementwise. JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function, @@ -1244,8 +1279,8 @@ def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_or(*promote_args("bitwise_or", x, y)) -@partial(jit, inline=True) -def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: +@binary_ufunc(identity=0, reduce=reductions._reduce_bitwise_xor) +def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the bitwise XOR operation elementwise. JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function, @@ -1433,8 +1468,12 @@ def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.ne(*promote_args("not_equal", x, y)) -@partial(jit, inline=True) -def _subtract(x: ArrayLike, y: ArrayLike, /) -> Array: +def _subtract_at(a: Array, indices: Any, b: ArrayLike) -> Array: + """Implementation of jnp.subtract.at.""" + return a.at[indices].subtract(b) + +@binary_ufunc(identity=None, at=_subtract_at) +def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: """Subtract two arrays element-wise. JAX implementation of :obj:`numpy.subtract`. This is a universal function, @@ -1754,8 +1793,8 @@ def spacing(x: ArrayLike, /) -> Array: # Logical ops -@partial(jit, inline=True) -def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: +@binary_ufunc(identity=True, reduce=reductions._reduce_logical_and) +def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the logical AND operation elementwise. JAX implementation of :obj:`numpy.logical_and`. This is a universal function, @@ -1774,8 +1813,9 @@ def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) -@partial(jit, inline=True) -def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: + +@binary_ufunc(identity=False, reduce=reductions._reduce_logical_or) +def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the logical OR operation elementwise. JAX implementation of :obj:`numpy.logical_or`. This is a universal function, @@ -1794,8 +1834,9 @@ def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: """ return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) -@partial(jit, inline=True) -def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + +@binary_ufunc(identity=False, reduce=reductions._reduce_logical_xor) +def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: """Compute the logical XOR operation elementwise. JAX implementation of :obj:`numpy.logical_xor`. This is a universal function, @@ -3048,7 +3089,7 @@ def square(x: ArrayLike, /) -> Array: """ check_arraylike("square", x) x, = promote_dtypes_numeric(x) - return lax.integer_pow(x, 2) + return lax.square(x) @partial(jit, inline=True) @@ -3653,57 +3694,3 @@ def _sinc_maclaurin(k, x): def _sinc_maclaurin_jvp(k, primals, tangents): (x,), (t,) = primals, tangents return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t - - -def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, - out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None): - if initial is not None: - raise ValueError("initial argument not supported in jnp.logical_and.reduce()") - result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where) - return result if dtype is None else result.astype(dtype) - - -def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, - out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, - where: ArrayLike | None = None): - if initial is not None: - raise ValueError("initial argument not supported in jnp.logical_or.reduce()") - result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where) - return result if dtype is None else result.astype(dtype) - -def _add_at(a: Array, indices: Any, b: ArrayLike): - if a.dtype == bool: - a = a.astype('int32') - b = lax.convert_element_type(b, bool).astype('int32') - return a.at[indices].add(b).astype(bool) - return a.at[indices].add(b) - -def _subtract_at(a: Array, indices: Any, b: ArrayLike): - return a.at[indices].subtract(b) - -def _multiply_at(a: Array, indices: Any, b: ArrayLike): - if a.dtype == bool: - a = a.astype('int32') - b = lax.convert_element_type(b, bool).astype('int32') - return a.at[indices].mul(b).astype(bool) - else: - return a.at[indices].mul(b) - -# Generate ufunc interfaces for several common binary functions. -# We start with binary ufuncs that have well-defined identities.' -# TODO(jakevdp): wrap more ufuncs. Possibly define a decorator for convenience? -# TODO(jakevdp): optimize some implementations. -# - define add.at/multiply.at in terms of scatter_add/scatter_mul -# - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod -# - define all monoidal reductions in terms of lax.reduce -add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at) -multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at) -bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1, call=_bitwise_and) -bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0, call=_bitwise_or) -bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0, call=_bitwise_xor) -logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, call=_logical_and, reduce=_logical_and_reduce) -logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce) -logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor) -negative = ufunc(_negative, name="negative", nin=1, nout=1, call=_negative) -subtract = ufunc(_subtract, name="subtract", nin=2, nout=1, call=_subtract, at=_subtract_at) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 489aae59dcd2..3dbb410be29f 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -502,8 +502,13 @@ def err_details(): ) else: assert rank == 1 - # TODO(necula): test this for bool. What should it do? - tiling_size = 128 * (32 // lax_internal._bit_width(bm.array_shape_dtype.dtype)) + # bools get a bitwidth of 32 due to how mosaic handles them + if bm.array_shape_dtype.dtype == jnp.bool_: + bitwidth = 32 + else: + bitwidth = lax_internal._bit_width(bm.array_shape_dtype.dtype) + packing = 32 // bitwidth + tiling_size = 128 * packing evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0) if not evenly_divisible: raise ValueError( @@ -2079,6 +2084,15 @@ def _sqrt_lowering_rule(ctx: LoweringRuleContext, x): lowering_rules[lax.sqrt_p] = _sqrt_lowering_rule +def _square_lowering_rule(ctx: LoweringRuleContext, x): + if jnp.issubdtype(ctx.avals_in[0].dtype, jnp.integer): + return arith.muli(x, x) + return arith.mulf(x, x) + + +lowering_rules[lax.square_p] = _square_lowering_rule + + def _exp_lowering_rule(ctx: LoweringRuleContext, x): return math.exp(x) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 6f98c83fdfd8..ad418e2b936d 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -107,7 +107,10 @@ pytype_strict_library( ":core", ":primitives", "//jax", + "//jax:core", + "//jax:mosaic_gpu", "//jax:pallas", + "//jax:partial_eval", "//jax:util", "//jax/_src/pallas", ], diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index ba343cd923c3..6d30cdb0d4a3 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -264,6 +264,7 @@ def scratch_view( class LoweringRuleContext: module_ctx: ModuleContext launch_ctx: mgpu.LaunchContext + predicate: ir.Value avals_in: Sequence[jax_core.ShapedArray] avals_out: Sequence[jax_core.ShapedArray] @@ -878,6 +879,7 @@ def write_env(var: jax_core.Var, val): rule_ctx = LoweringRuleContext( module_ctx, launch_ctx, + predicate=mgpu.single_thread_predicate(per_block=False), avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars], avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars], ) @@ -1120,6 +1122,12 @@ def _convert_element_type_lowering_rule( ) +mosaic_lowering_rules.update({ + lax.neg_p: lambda ctx, x: -x, + lax.not_p: lambda ctx, x: ~x, +}) + + def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out) return impl(x, y) @@ -1160,6 +1168,11 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): return x * x return NotImplementedError +@register_lowering_rule(lax.square_p) +def _square_lowering_rule(ctx: LoweringRuleContext, x): + [x_aval] = ctx.avals_in + x = _ensure_fa(x, x_aval.dtype) + return x * x @register_lowering_rule(lax.rsqrt_p) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): @@ -1571,4 +1584,4 @@ def _as_index(v: object) -> ir.Value: case mgpu.FragmentedArray(layout=mgpu.WGSplatFragLayout()): return _as_index(v.registers.item()) case _: - raise ValueError(f"Unsupported index: {v}") + raise ValueError(f"Unsupported index: {v} of type {type(v)}") diff --git a/jax/_src/pallas/mosaic_gpu/pipeline.py b/jax/_src/pallas/mosaic_gpu/pipeline.py index 21267b50a007..91e1e1c45429 100644 --- a/jax/_src/pallas/mosaic_gpu/pipeline.py +++ b/jax/_src/pallas/mosaic_gpu/pipeline.py @@ -16,7 +16,7 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Callable, Sequence import dataclasses import functools import itertools as it @@ -25,7 +25,10 @@ import jax from jax import lax +from jax._src import core +from jax._src import linear_util as lu from jax._src import util +from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core from jax._src.pallas.mosaic_gpu import primitives as gpu_primitives @@ -37,17 +40,19 @@ zip = util.safe_zip +@jax.tree_util.register_dataclass @dataclasses.dataclass(frozen=True) class BufferedRef: - spec: pallas_core.BlockSpec + spec: pallas_core.BlockSpec = dataclasses.field(metadata={"static": True}) + is_index_invariant: bool = dataclasses.field(metadata={"static": True}) gmem_ref: pallas_core.AbstractMemoryRef smem_ref: pallas_core.AbstractMemoryRef # [num_slots, *spec.block_shape] - def compute_gmem_slice(self, grid_indices) -> tuple[Any, ...]: + def compute_gmem_slice(self, grid_indices) -> tuple[pl.Slice, ...]: index_map = self.spec.index_map assert index_map is not None return tuple( - pl.ds(idx * size, size) + pl.Slice(idx * size, size) # type: ignore[arg-type] for idx, size in zip( index_map(*grid_indices), self.spec.block_shape # type: ignore[arg-type] ) @@ -61,16 +66,31 @@ def copy_in(self, slot, grid_indices, barrier_ref): barrier=barrier_ref.at[slot], ) - def copy_out(self, slot, grid_indices): + def copy_out(self, slot, grid_indices, predicate=None): gmem_slices = self.compute_gmem_slice(grid_indices) gpu_primitives.copy_smem_to_gmem( - self.smem_ref.at[slot], self.gmem_ref.at[gmem_slices] # pytype: disable=unsupported-operands + self.smem_ref.at[slot], + self.gmem_ref.at[gmem_slices], # pytype: disable=unsupported-operands + predicate=predicate, ) -jax.tree_util.register_dataclass( - BufferedRef, data_fields=["gmem_ref", "smem_ref"], meta_fields=["spec"] -) +def _uses_arguments( + index_map: Callable[..., Any], num_args: int +) -> Sequence[bool]: + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(index_map), (core.ShapedArray((), jnp.int32),) * num_args + ) + _, used_inputs = pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars)) + return used_inputs + + +def _is_index_invariant( + spec: pallas_core.BlockSpec, grid: pallas_core.StaticGrid +) -> bool: + index_map = spec.index_map + assert index_map is not None + return not any(_uses_arguments(index_map, len(grid))) def _inc_grid_by_1( @@ -85,6 +105,25 @@ def _inc_grid_by_1( return tuple(reversed(next_indices)) +# ``pl.Slice`` uses a different pytree encoding, depending on whether the +# start/size are static or dynamic. This leads to pytree structure mismatch +# in the pipeline body. So, we define a different ``Slice`` class below. + + +@dataclasses.dataclass(frozen=True) +class _Slice: + start: int | jax.Array + size: int | jax.Array + + def __eq__(self, other: _Slice) -> jax.Array: # type: ignore + return lax.bitwise_and(self.start == other.start, self.size == other.size) + + +jax.tree_util.register_dataclass( + _Slice, data_fields=["start", "size"], meta_fields=[] +) + + def emit_pipeline( body, *, @@ -102,6 +141,16 @@ def emit_pipeline( max_concurrent_steps = num_steps def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): + for gmem_ref, spec in zip(gmem_refs, it.chain(in_specs, out_specs)): + if any( + spec.block_shape[-idx] * grid[-idx] != gmem_ref.shape[-idx] # type: ignore + for idx in range(1, len(grid) + 1) + ): + raise NotImplementedError( + f"Cannot emit a pipeline over the {grid=} for {gmem_ref} with block" + f" shape {spec.block_shape}." + ) + in_gmem_refs, out_gmem_refs = util.split_list(gmem_refs, [len(in_specs)]) in_smem_refs, out_smem_refs = util.split_list( map( @@ -132,13 +181,18 @@ def pipeline(*gmem_refs: pallas_core.AbstractMemoryRef): def scoped_pipeline( *, in_gmem_refs, out_gmem_refs, in_smem_refs, out_smem_refs, barrier_ref ): - - in_brefs: Sequence[BufferedRef] = map( - BufferedRef, in_specs, in_gmem_refs, in_smem_refs - ) - out_brefs: Sequence[BufferedRef] = map( - BufferedRef, out_specs, out_gmem_refs, out_smem_refs - ) + in_brefs: Sequence[BufferedRef] = [ + BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref) + for spec, gmem_ref, smem_ref in zip( + in_specs, in_gmem_refs, in_smem_refs + ) + ] + out_brefs: Sequence[BufferedRef] = [ + BufferedRef(spec, _is_index_invariant(spec, grid), gmem_ref, smem_ref) + for spec, gmem_ref, smem_ref in zip( + out_specs, out_gmem_refs, out_smem_refs + ) + ] for step, indices in enumerate( it.islice(it.product(*map(range, grid)), max_concurrent_steps) @@ -147,10 +201,11 @@ def scoped_pipeline( def loop_body(step, carry): slot = step % max_concurrent_steps - indices, fetch_indices = carry + indices, fetch_indices, last_store_slices = carry - # Wait for the current GMEM->SMEM copy to complete. - gpu_primitives.barrier_wait(barrier_ref.at[slot]) + if in_specs: + # Wait for the current GMEM->SMEM copy to complete. + gpu_primitives.barrier_wait(barrier_ref.at[slot]) # Wait for the previous output SMEM->GMEM copy to complete. gpu_primitives.wait_smem_to_gmem(max_concurrent_steps - 1) @@ -159,9 +214,34 @@ def loop_body(step, carry): *(bref.smem_ref.at[slot] for bref in it.chain(in_brefs, out_brefs)) ) + if not all(bref.is_index_invariant for bref in out_brefs): + gpu_primitives.commit_smem() + # Copy the output from SMEM to GMEM. - gpu_primitives.commit_smem() - map(lambda bref: bref.copy_out(slot, indices), out_brefs) + new_store_slices = last_store_slices[:] + for idx, bref in enumerate(out_brefs): + if bref.is_index_invariant: + assert last_store_slices[idx] is None + continue + assert last_store_slices[idx] is not None + new_store_slices[idx] = tuple( + _Slice(s.start, s.size) for s in bref.compute_gmem_slice(indices) + ) + are_same_slices = map( + lambda old, new: old == new, + last_store_slices[idx], + new_store_slices[idx], + ) + slices_changed = ~functools.reduce(lax.bitwise_and, are_same_slices) + is_last_step = step == num_steps - 1 + # TODO(apaszke,slebedev): This still diverges significantly from the + # TPU semantics in that it will move on to the next SMEM output slice + # even if it's not storing the previous one. + bref.copy_out( + slot, + indices, + predicate=lax.bitwise_or(slices_changed, is_last_step), + ) fetch_step = step + max_concurrent_steps fetch_slot = slot # (x + y) % y == x % y @@ -174,13 +254,34 @@ def loop_body(step, carry): lambda: [None] * len(in_brefs), ) - return _inc_grid_by_1(indices, grid), _inc_grid_by_1(fetch_indices, grid) + return ( + _inc_grid_by_1(indices, grid), + _inc_grid_by_1(fetch_indices, grid), + new_store_slices, + ) indices = (jnp.asarray(0, dtype=lax.dtype(0)),) * len(grid) fetch_indices = indices for _ in range(max_concurrent_steps): fetch_indices = _inc_grid_by_1(fetch_indices, grid) - lax.fori_loop(0, num_steps, loop_body, (indices, fetch_indices)) + last_store_slices = [ + None + if bref.is_index_invariant + else (_Slice(-1, -1),) * len(bref.spec.block_shape) + for bref in out_brefs + ] + last_indices, _, _ = lax.fori_loop( + 0, num_steps, loop_body, (indices, fetch_indices, last_store_slices) + ) + + # Outputs invariant to the sequential axis are never written from inside the + # loop. This is the only place where we store them. + if all(bref.is_index_invariant for bref in out_brefs): + gpu_primitives.commit_smem() + last_slot = (num_steps - 1) % max_concurrent_steps + for bref in out_brefs: + if bref.is_index_invariant: + bref.copy_out(last_slot, last_indices, predicate=None) # Finalize the pipeline. gpu_primitives.wait_smem_to_gmem(0) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index 1ced213394ff..5fc4ed5e7afc 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -26,6 +26,7 @@ from jax._src import tree_util from jax._src import util from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import arith as arith_dialect from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import core as gpu_core @@ -34,6 +35,7 @@ from jax._src.state import indexing from jax._src.state import primitives as state_primitives import jax.experimental.mosaic.gpu as mgpu +import jax.numpy as jnp WARPGROUP_SIZE = 128 @@ -54,19 +56,31 @@ def _copy_smem_to_gmem_lowering( ctx: lowering.LoweringRuleContext, src, dst, - *flat_transforms, + *flat_args, src_transforms_treedef, dst_transforms_treedef, + has_user_predicate, ): + predicate = ctx.predicate + if has_user_predicate: + flat_args, user_predicate = flat_args[:-1], flat_args[-1] + predicate = arith_dialect.andi( + predicate, lowering._ensure_ir_value(user_predicate, jnp.bool) + ) flat_src_transforms, flat_dst_transforms = util.split_list( - flat_transforms, + flat_args, [src_transforms_treedef.num_leaves], ) src_transforms = src_transforms_treedef.unflatten(flat_src_transforms) dst_transforms = dst_transforms_treedef.unflatten(flat_dst_transforms) src, src_transforms = lowering._handle_indexing(src, src_transforms) copy_params = _extract_gmem_copy_params(dst_transforms) | _extract_smem_copy_params(src_transforms) - ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, **copy_params) + ctx.launch_ctx.async_copy( + src_ref=src, + dst_ref=dst, + predicate=predicate, + **copy_params, + ) return () @@ -98,10 +112,18 @@ def _extract_smem_copy_params(transforms): def copy_smem_to_gmem( - src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef + src: pallas_core.AbstractMemoryRef, + dst: pallas_core.AbstractMemoryRef, + predicate: jax.Array | None = None, ) -> None: """Asynchronously copies a SMEM reference to a GMEM reference. + Args: + src: The SMEM reference to copy from. + dst: The GMEM reference to copy to. + predicate: A boolean indicating whether the copy should be performed. If + ``None``, the copy is always performed. + See also: :func:`jax.experimental.mosaic.gpu.wait_smem_to_gmem` :func:`jax.experimental.mosaic.gpu.commit_smem` @@ -127,8 +149,10 @@ def copy_smem_to_gmem( dst, *flat_src_transforms, *flat_dst_transforms, + *[] if predicate is None else [predicate], src_transforms_treedef=src_transforms_treedef, dst_transforms_treedef=dst_transforms_treedef, + has_user_predicate=predicate is not None, ) return None diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index c7bd7dd7178f..d77ca86c152a 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -824,14 +824,13 @@ def debug_print_lowering_rule(ctx, *args, **params): # because they should appear as atomic JAX values to the users. # TODO(apaszke): This can be deleted once we make transforms in Mosaic GPU # inferred by the compiler. -@lu.transformation -def wrap_with_transforms(transforms, *args): +@lu.transformation2 +def wrap_with_transforms(f, transforms, *args): new_args = tuple( state_types.TransformedRef(a, t) if t else a for a, t in zip(args, transforms) ) - res = yield new_args, {} - yield res + return f(*new_args) run_scoped_p = jax_core.Primitive("run_scoped") diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 1a0400ebf0db..fa49f3b7cbbf 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -780,6 +780,7 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)), ], ), + lax.square_p: lambda ctx, x: _mul(x, x), lax.pow_p: _make_dispatch_table( "pow", cuda=[ @@ -2612,3 +2613,24 @@ def _dtype_to_ir_type(dtype: jnp.dtype) -> ir.Type: # All integer types in Triton are signless. return ir.IntegerType.get_signless(dtype.itemsize * 8) return mlir.dtype_to_ir_type(dtype) + + +@register_lowering(lax.bitcast_convert_type_p) +def _bitcast_convert_type_lowering_rule( + ctx: LoweringRuleContext, operand: ir.Value, *, new_dtype +) -> ir.Value: + # TODO(petebu) Handle case where src and dst types have different bitwidths + src_elem_type = _element_type(operand.type) + dst_elem_type = _element_type(_dtype_to_ir_type(new_dtype)) + assert isinstance(src_elem_type, (ir.IntegerType, ir.FloatType)) + assert isinstance(dst_elem_type, (ir.IntegerType, ir.FloatType)) + if src_elem_type.width != dst_elem_type.width: + raise NotImplementedError( + f"cannot cast {operand} to {new_dtype} because of different widths" + ) + if ir.RankedTensorType.isinstance(operand.type): + shape = ir.RankedTensorType(operand.type).shape + result_type = ir.RankedTensorType.get(shape, dst_elem_type) + else: + result_type = dst_elem_type + return tt_dialect.bitcast(result_type, operand) diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 18e7d18d931d..f9bc2b60cee9 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + class _UnconstrainedPartitionSingleton: def __repr__(self): @@ -48,3 +50,21 @@ def __repr__(self): def __reduce__(self): return (PartitionSpec, tuple(self)) + + def _normalized_spec(self, ndim: int) -> PartitionSpec: + out = [] # type: ignore + for p in self: + if p is None: + out.append(None) + elif p == self.UNCONSTRAINED: + out.append(p) + elif isinstance(p, (list, tuple)): + if len(p) == 1: + out.append(p[0]) + else: + out.append(tuple(p)) + else: + out.append(p) + if len(out) < ndim: + out.extend([None] * (ndim - len(out))) + return PartitionSpec(*out) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 6ab8c90811a6..f1844c7ba13b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1608,15 +1608,22 @@ def _resolve_and_lower( lowering_parameters=lowering_parameters, pgle_profiler=pgle_profiler) +_pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore + def _pjit_call_impl_python( *args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, compiler_options_kvs): pgle_compile_options, pgle_profiler = {}, None if config.enable_pgle.value and config.pgle_profiling_runs.value > 0: - pgle_profiler = profiler.PGLEProfiler( - config.pgle_profiling_runs.value, - config.pgle_aggregation_percentile.value) + compilation_target_key = jaxpr + pgle_profiler = _pgle_profiler_dict.get(compilation_target_key) + if pgle_profiler is None: + pgle_profiler = profiler.PGLEProfiler( + config.pgle_profiling_runs.value, + config.pgle_aggregation_percentile.value) + _pgle_profiler_dict[compilation_target_key] = pgle_profiler + # The method below will return FDO profile when module was profiled # config.jax_pgle_profiling_runs amount of times, otherwise the result will # be None. @@ -2319,6 +2326,10 @@ def _dce_jaxpr_pjit( def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: + + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None + dced_jaxpr, used_inputs = _dce_jaxpr_pjit( eqn.params['jaxpr'], tuple(used_outputs)) diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index d014e5ceb24e..1c5eba988e6a 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -2004,7 +2004,7 @@ def hessenberg(a: ArrayLike, *, calc_q: bool = False, overwrite_a: bool = False, def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: - r"""Construct a Toeplitz matrix + r"""Construct a Toeplitz matrix. JAX implementation of :func:`scipy.linalg.toeplitz`. @@ -2023,13 +2023,13 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: Notice this implies that :math:`r_0` is ignored. Args: - c: array specifying the first column. Will be flattened - if not 1-dimensional. - r: (optional) array specifying the first row. If not specified, defaults - to ``conj(c)``. Will be flattened if not 1-dimensional. + c: array of shape ``(..., N)`` specifying the first column. + r: (optional) array of shape ``(..., M)`` specifying the first row. Leading + dimensions must be broadcast-compatible with those of ``c``. If not specified, + ``r`` defaults to ``conj(c)``. Returns: - toeplitz matrix of shape ``(c.size, r.size)``. + A Toeplitz matrix of shape ``(... N, M)``. Examples: Specifying ``c`` only: @@ -2059,32 +2059,40 @@ def toeplitz(c: ArrayLike, r: ArrayLike | None = None) -> Array: [1.+2.j, 2.+1.j, 1.+0.j]], dtype=complex64) >>> print("M is Hermitian:", jnp.all(M == M.conj().T)) M is Hermitian: True + + For N-dimensional ``c`` and/or ``r``, the result is a batch of Toeplitz matrices: + + >>> c = jnp.array([[1, 2, 3], [4, 5, 6]]) + >>> jax.scipy.linalg.toeplitz(c) + Array([[[1, 2, 3], + [2, 1, 2], + [3, 2, 1]], + + [[4, 5, 6], + [5, 4, 5], + [6, 5, 4]]], dtype=int32) """ if r is None: check_arraylike("toeplitz", c) r = jnp.conjugate(jnp.asarray(c)) else: check_arraylike("toeplitz", c, r) + return _toeplitz(jnp.atleast_1d(jnp.asarray(c)), jnp.atleast_1d(jnp.asarray(r))) - c_arr = jnp.asarray(c).flatten() - r_arr = jnp.asarray(r).flatten() - - ncols, = c_arr.shape - nrows, = r_arr.shape - +@partial(jnp.vectorize, signature="(m),(n)->(m,n)") +def _toeplitz(c: Array, r: Array) -> Array: + ncols, = c.shape + nrows, = r.shape if ncols == 0 or nrows == 0: - return jnp.empty((ncols, nrows), - dtype=jnp.promote_types(c_arr.dtype, r_arr.dtype)) - + return jnp.empty((ncols, nrows), dtype=jnp.promote_types(c.dtype, r.dtype)) nelems = ncols + nrows - 1 - elems = jnp.concatenate((c_arr[::-1], r_arr[1:])) + elems = jnp.concatenate((c[::-1], r[1:])) patches = lax.conv_general_dilated_patches( elems.reshape((1, nelems, 1)), (nrows,), (1,), 'VALID', dimension_numbers=('NTC', 'IOT', 'NTC'), precision=lax.Precision.HIGHEST)[0] return jnp.flip(patches, axis=0) - @partial(jit, static_argnames=("n",)) def hilbert(n: int) -> Array: r"""Create a Hilbert matrix of order n. diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index fa65bbe9328d..9b847f15d86a 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -361,19 +361,7 @@ def with_memory_kind(self, kind: str) -> NamedSharding: return NamedSharding(self.mesh, self.spec, memory_kind=kind) def _normalized_spec(self, ndim: int) -> PartitionSpec: - out = [] # type: ignore - for p in self._parsed_pspec: - if p is None: - raise ValueError("UNCONSTRAINED is not supported yet.") - if not p: - out.append(None) - elif isinstance(p, tuple) and len(p) == 1: - out.append(p[0]) - else: - out.append(p) - if len(out) < ndim: - out.extend([None] * (ndim - len(out))) - return PartitionSpec(*out) + return self.spec._normalized_spec(ndim) def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 23b255ef1750..28148761c8a4 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -90,6 +90,13 @@ help="Mock number of JAX processes in GPU client. Value zero turns " "off mocking.", ) +_MOCK_GPU_TOPOLOGY = config.string_flag( + name="jax_mock_gpu_topology", + default="", + help='Mock multi-host GPU topology in GPU client. The value should ' + 'be of the form " x x ' + '". Empty string turns off mocking.', +) _CPU_ENABLE_GLOO_COLLECTIVES = config.bool_flag( name="jax_cpu_enable_gloo_collectives", @@ -425,6 +432,14 @@ def _version_check(name: str, f'following issues with CUDA components:\n' f'{join_str.join(errors)}') +def _get_num_nodes_from_gpu_topology(topology: str) -> int: + try: + slices_str, hosts_per_slice_str, _ = topology.split("x", 2) + return int(slices_str) * int(hosts_per_slice_str) + except (IndexError, ValueError): + raise ValueError('Mock topology must be of the form ' + '" x x ' + '".') def make_gpu_client( *, platform_name: str, visible_devices_flag: config.Flag[str] @@ -434,12 +449,14 @@ def make_gpu_client( if visible_devices != "all": allowed_devices = {int(x) for x in visible_devices.split(",")} - use_mock_gpu_client = _MOCK_NUM_GPU_PROCESSES.value > 0 - num_nodes = ( - _MOCK_NUM_GPU_PROCESSES.value - if use_mock_gpu_client - else distributed.global_state.num_processes - ) + mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None + mock_num_gpu_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if + mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) + + use_mock_gpu_client = mock_num_gpu_processes > 0 + num_nodes = (mock_num_gpu_processes if use_mock_gpu_client + else distributed.global_state.num_processes) + if platform_name == "cuda": if not os.getenv("JAX_SKIP_CUDA_CONSTRAINTS_CHECK"): _check_cuda_versions() @@ -634,10 +651,14 @@ def _options_from_jax_configs(plugin_name): visible_devices = CUDA_VISIBLE_DEVICES.value if visible_devices != 'all': options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - mock_gpu_processes = _MOCK_NUM_GPU_PROCESSES.value - options['enable_mock_nccl'] = mock_gpu_processes > 0 - if options['enable_mock_nccl']: - options['num_nodes'] = mock_gpu_processes + mock_gpu_topology = _MOCK_GPU_TOPOLOGY.value or None + mock_num_processes = (_get_num_nodes_from_gpu_topology(mock_gpu_topology) if + mock_gpu_topology else _MOCK_NUM_GPU_PROCESSES.value) + options['enable_mock_nccl'] = mock_num_processes > 0 + if mock_num_processes > 0: + options['num_nodes'] = mock_num_processes + if mock_gpu_topology: + options['mock_gpu_topology'] = mock_gpu_topology return options diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index a25d93a35c51..b4adbadfa6c5 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -97,34 +97,34 @@ def jvp(f, primals, tangents, attr_tangents): out_tangents = tree_unflatten(out_tree(), out_tangents_flat) return out_primals, out_tangents, tangent_attrs_out -@lu.transformation -def _set_attrs(attrs, attr_vals, *args): +@lu.transformation2 +def _set_attrs(f, attrs, attr_vals, *args): for (o, a), x in zip(attrs, attr_vals): jax_setattr(o, a, x) - yield (yield args, {}) + return f(*args) def _jvp(fun: lu.WrappedFun): return jvpfun2(jvp_subtrace2(fun)) -@lu.transformation -def jvpfun2(primals, tangents): +@lu.transformation2 +def jvpfun2(f, primals, tangents): tag = core.TraceTag() tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) and dtype(t) == float0 else t for t in tangents] ctx = source_info_util.transform_name_stack('jvp') with ctx: - out_primals, out_tangents, tangent_attrs_out = yield (tag, primals, tangents), {} - yield out_primals, out_tangents, tangent_attrs_out + out_primals, out_tangents, tangent_attrs_out = f(tag, primals, tangents) + return out_primals, out_tangents, tangent_attrs_out -@lu.transformation -def jvp_subtrace2(tag, primals, tangents): +@lu.transformation2 +def jvp_subtrace2(f, tag, primals, tangents): with core.take_current_trace() as parent_trace: trace = ad.JVPTrace(parent_trace, tag) tag.attrs_tracked = [] # attrs written to in_tracers = [ad.JVPTracer(trace, x, t) if type(t) is not ad.Zero else x for x, t in zip(primals, tangents)] with core.set_current_trace(trace): - ans = yield in_tracers, {} + ans = f(*in_tracers) out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans)) tangent_attrs_out = [] for (obj, name) in tag.attrs_tracked: @@ -133,7 +133,7 @@ def jvp_subtrace2(tag, primals, tangents): if type(tangent) is not ad.Zero: tangent_attrs_out.append((obj, name, tangent)) del tag.attrs_tracked - yield out_primals, out_tangents, tangent_attrs_out + return out_primals, out_tangents, tangent_attrs_out def _setattr_jvp(trace, obj, attr, maybe_tracer): primal, tangent = trace.to_primal_tangent_pair(maybe_tracer) @@ -175,11 +175,12 @@ def _linearize(traceable: lu.WrappedFun, *primals): return (out_primals_consts, [*out_tangents_pvals, *out_tangent_attr_pvals], jaxpr, consts, attrs()) -@lu.transformation_with_aux -def _split_attrs(*args, **kwargs): - primals, tangents, tangent_attrs = yield args, kwargs +@lu.transformation_with_aux2 +def _split_attrs(f, store, *args, **kwargs): + primals, tangents, tangent_attrs = f(*args, **kwargs) attrs, tangent_attr_vals = unzip2(((o, a), t) for o, a, t in tangent_attrs) - yield (primals, tangents, tangent_attr_vals), attrs + store.store(attrs) + return primals, tangents, tangent_attr_vals def _lin_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs): in_tree, out_tree = io_tree diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index c6d920918074..c41eda693d7f 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1040,20 +1040,20 @@ def impl_multiple_results_jax(*args_jax): return wrapped_tf -@lu.transformation -def _interpret_subtrace(in_avals: Sequence[core.ShapedArray], +@lu.transformation2 +def _interpret_subtrace(f, in_avals: Sequence[core.ShapedArray], *in_vals: TfVal): trace = TensorFlowTrace() in_tracers = tuple( TensorFlowTracer(trace, val, aval) for val, aval in zip(in_vals, in_avals)) with core.set_current_trace(trace): - outs = yield in_tracers, {} # type: Sequence[TfVal] + outs = f(*in_tracers) out_tracers: Iterable[TensorFlowTracer] = ( map(trace.to_tf_tracer, outs)) out_vals_with_avals: Sequence[tuple[TfVal, core.ShapedArray]] = ( tuple((t.val, t.aval) for t in out_tracers)) - yield out_vals_with_avals + return out_vals_with_avals def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args_tf: TfVal, @@ -1726,6 +1726,7 @@ def _atan2(y, x, **kwargs): tf_impl[lax.asinh_p] = tf.math.asinh tf_impl[lax.sqrt_p] = tf.math.sqrt +tf_impl[lax.square_p] = tf.math.square tf_impl[lax.rsqrt_p] = tf.math.rsqrt def _cbrt(x): diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index e59084041306..8993d044cb3b 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -979,8 +979,8 @@ def caller_jax(x): self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf)) else: graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def()) - if "my_test_function_jax/pjit__multiply_/Mul" not in graph_def: - self.assertIn("my_test_function_jax/jit__multiply_/Mul", graph_def) + if "my_test_function_jax/pjit_multiply_/Mul" not in graph_def: + self.assertIn("my_test_function_jax/jit_multiply_/Mul", graph_def) def test_bfloat16_constant(self): # Re: https://github.com/jax-ml/jax/issues/3942 diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 827e4d01b390..2681ad1a2a7b 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -141,40 +141,43 @@ def jet(fun, primals, series): if not treedef_is_leaf(treedef): raise ValueError(f"term {j} for argument {i} is not an array") - @lu.transformation_with_aux - def flatten_fun_output(*args): - ans = yield args, {} - yield tree_flatten(ans) + @lu.transformation_with_aux2 + def flatten_fun_output(f, store, *args): + ans = f(*args) + ans, tree = tree_flatten(ans) + store.store(tree) + return ans f, out_tree = flatten_fun_output(lu.wrap_init(fun)) out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series) return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms) -@lu.transformation -def jet_fun(order, primals, series): +@lu.transformation2 +def jet_fun(f, order, primals, series): tag = core.TraceTag() - out_primals, out_terms = yield (tag, order, primals, series), {} + out_primals, out_terms = f(tag, order, primals, series) out_terms = [[jnp.zeros_like(p)] * order if s is zero_series else s for p, s in zip(out_primals, out_terms)] - yield out_primals, out_terms + return out_primals, out_terms -@lu.transformation -def jet_subtrace(tag, order, primals, series): +@lu.transformation2 +def jet_subtrace(f, tag, order, primals, series): with core.take_current_trace() as parent_trace: trace = JetTrace(tag, parent_trace, order) in_tracers = map(partial(JetTracer, trace), primals, series) with core.set_current_trace(trace): - ans = yield in_tracers, {} + ans = f(*in_tracers) out_primals, out_terms = unzip2(map(trace.to_primal_terms_pair, ans)) - yield out_primals, out_terms + return out_primals, out_terms -@lu.transformation_with_aux -def traceable(in_tree_def, *primals_and_series): +@lu.transformation_with_aux2 +def traceable(f, store, in_tree_def, *primals_and_series): primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series) - primals_out, series_out = yield (primals_in, series_in), {} + primals_out, series_out = f(primals_in, series_in) out_flat, out_tree_def = tree_flatten((primals_out, series_out)) - yield out_flat, out_tree_def + store.store(out_tree_def) + return out_flat class JetTracer(core.Tracer): @@ -405,6 +408,7 @@ def def_comp(prim, comp): def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) def_comp(lax.log1p_p, lambda x: lax.log(1 + x)) def_comp(lax.sqrt_p, lambda x: x ** 0.5) +def_comp(lax.square_p, lambda x: x * x) def_comp(lax.rsqrt_p, lambda x: x ** -0.5) def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1))) def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1))) diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index e4949b325507..337581c54b86 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -14,16 +14,15 @@ # ============================================================================== import contextlib -import ctypes -import functools import itertools import json import math +from typing import Callable, ParamSpec, TypeVar import warnings import jax -from jax._src.interpreters import mlir from jax._src.lib import xla_client +from jax.extend import ffi import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith @@ -34,72 +33,71 @@ from .utils import * # noqa: F403 - try: from jax._src.lib import mosaic_gpu as mosaic_gpu_lib - - xla_client.register_custom_call_target( - "mosaic_gpu_record_event", - mosaic_gpu_lib._mosaic_gpu_ext._record_event_capsule(), - platform="CUDA", - ) except ImportError: pass +else: + for name, handler in mosaic_gpu_lib._mosaic_gpu_ext.registrations(): + xla_client.register_custom_call_target( + name, handler, platform="CUDA", api_version=1 + ) # ruff: noqa: F405 # mypy: ignore-errors +T = TypeVar("T") +P = ParamSpec("P") -record_event_p = jax.core.Primitive("record_event") -record_event_p.multiple_results = True - -@record_event_p.def_abstract_eval -def _record_event_abstract_eval(*args, event): - del event # Unused. - return args - -@functools.partial(mlir.register_lowering, record_event_p, platform="cuda") -def _record_event_lowering_rule(ctx, *args, event): - ptr_bytes = ctypes.cast(event, ctypes.c_void_p).value.to_bytes( - 8, byteorder="little" - ) # pytype: disable=attribute-error - op = mlir.custom_call( - "mosaic_gpu_record_event", - result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands=args, - backend_config=ptr_bytes, - operand_output_aliases={i: i for i in range(len(args))}, - ) - return op.results - -def _record_event(args, event): +def _event_record(args, *, copy_before): flat_args, treedef = jax.tree.flatten(args) - return jax.tree.unflatten( - treedef, record_event_p.bind(*flat_args, event=event) - ) - -def measure(f, *args, **kwargs): - # TODO(apaszke): Raise if this is called under jit. - start_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create() - end_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create() - try: - - @jax.jit - def run(*args, **kwargs): - flat_args, treedef = jax.tree.flatten((args, kwargs)) - flat_args = _record_event(flat_args, start_event) - args, kwargs = jax.tree.unflatten(treedef, flat_args) - return _record_event(f(*args, **kwargs), end_event) - - jax.block_until_ready(run(*args, **kwargs)) # Warmup. - results = jax.block_until_ready(run(*args, **kwargs)) - elapsed = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_elapsed( - start_event, end_event + event, *flat_outs = ffi.ffi_call( + "mgpu_event_record", + result_shape_dtypes=(jax.core.ShapedArray((), jnp.uint64), *flat_args), + input_output_aliases={i: i + 1 for i in range(len(flat_args))}, + )(*flat_args, copy_before=copy_before) + return event, treedef.unflatten(flat_outs) + + +def _event_elapsed(start_event, end_event): + return ffi.ffi_call( + "mgpu_event_elapsed", + result_shape_dtypes=jax.core.ShapedArray((), jnp.float32), + )(start_event, end_event) + + +def measure( + f: Callable[P, T], *args: P.args, **kwargs: P.kwargs +) -> tuple[T, float]: + """Measures the time it takes to execute the function on the GPU. + + Args: + f: The function to measure. It must accept at least one argument and return + at least one output to be measurable. + *args: The arguments to pass to ``f``. + **kwargs: The keyword arguments to pass to ``f``. + + Returns: + The return value of ``f`` and the elapsed time in milliseconds. + """ + if not (args or kwargs): + # We require at least one argument and at least one output to ensure + # that there is a data dependency between `_event_record` calls in + # the resulting HLO program. + raise ValueError("Can only measure functions with arguments") + + @jax.jit + def run(*args, **kwargs): + start_event, (args, kwargs) = _event_record( + (args, kwargs), copy_before=True ) - finally: - mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(start_event) - mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(end_event) - return results, elapsed + end_event, outs = _event_record(f(*args, **kwargs), copy_before=False) + if jax.tree.structure(outs).num_leaves == 0: + raise ValueError("Can only measure functions with at least one output") + return outs, _event_elapsed(start_event, end_event) + + outs, elapsed = run(*args, **kwargs) + return outs, float(elapsed) class ProfilerSpec: diff --git a/jax/experimental/ode.py b/jax/experimental/ode.py index b8e3daee48c8..987e461a39b2 100644 --- a/jax/experimental/ode.py +++ b/jax/experimental/ode.py @@ -47,12 +47,12 @@ def ravel_first_arg(f, unravel): return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped -@lu.transformation -def ravel_first_arg_(unravel, y_flat, *args): +@lu.transformation2 +def ravel_first_arg_(f, unravel, y_flat, *args): y = unravel(y_flat) - ans = yield (y,) + args, {} + ans = f(y, *args) ans_flat, _ = ravel_pytree(ans) - yield ans_flat + return ans_flat def interp_fit_dopri(y0, y1, k, dt): # Fit a polynomial to the results of a Runge-Kutta step. diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 7ddd3805b5d0..9391d7ddf546 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -46,7 +46,7 @@ from jax._src import traceback_util from jax._src import util from jax._src.core import Tracer -from jax._src.mesh import AbstractMesh, Mesh +from jax._src.mesh import AbstractMesh, Mesh, AxisTypes from jax._src.api import _shared_code_pmap, _prepare_pmap from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, @@ -528,17 +528,30 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue raise NotImplementedError(f"Unsupported aval type: {type(aval)}") def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue - ) -> core.AbstractValue: + ) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) - return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape))) + new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + if config.sharding_in_types.value: + new_mesh = AbstractMesh( + mesh.shape_tuple, {AxisTypes.Collective: mesh.axis_names}) + new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim)) + else: + new_sharding = None + return aval.update(shape=new_shape, sharding=new_sharding) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array def _unshard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue,) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) - return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape))) + new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + if config.sharding_in_types.value: + spec = _names_to_pspec(names)._normalized_spec(aval.ndim) + new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec) + else: + new_sharding = None + return aval.update(shape=new_shape, sharding=new_sharding) core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array # Type-checking @@ -1274,6 +1287,32 @@ def _scan_rewrite(mesh, in_rep, *args, jaxpr, num_consts, num_carry, **params): *args, jaxpr=jaxpr_, num_consts=num_consts, num_carry=num_carry, **params) return out_vals, out_rep +@register_check(control_flow.conditionals.cond_p) +def _cond_rule(mesh, *in_rep, branches): + _, *args_rep = in_rep + true_out_rep = _check_rep(mesh, branches[0].jaxpr, args_rep) + false_out_rep = _check_rep(mesh, branches[1].jaxpr, args_rep) + if not true_out_rep == false_out_rep: + raise Exception("The true and false branches of cond produced mismatched " + f"replication types {true_out_rep} and {false_out_rep}. " + "Please open an issue at " + "https://github.com/jax-ml/jax/issues, and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + return true_out_rep + +@register_rewrite(control_flow.conditionals.cond_p) +def _cond_rewrite(mesh, in_rep, *args, branches): + pred_rep, *args_rep = in_rep + _, true_out_rep = _replication_rewrite_nomatch(mesh, branches[0], args_rep) + _, false_out_rep = _replication_rewrite_nomatch(mesh, branches[1], args_rep) + out_rep = map(op.and_, true_out_rep, false_out_rep) + out_rep = map(partial(op.and_, pred_rep), out_rep) + branches_ = ( + _replication_rewrite_match(mesh, branches[0], args_rep, out_rep), + _replication_rewrite_match(mesh, branches[1], args_rep, out_rep), + ) + out_vals = control_flow.conditionals.cond_p.bind(*args, branches=branches_) + return out_vals, out_rep @register_rewrite(core.closed_call_p) def _closed_call_rewrite(mesh, in_rep, *args, call_jaxpr, **kwargs): @@ -1479,15 +1518,15 @@ def known_out_names(): return pe.merge_lists(out_knowns, out_tracers, out_consts) pe.JaxprTrace.process_shard_map = _shard_map_partial_eval -@lu.transformation -def _promote_scalar_residuals(*args, **kwargs): - jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = yield args, kwargs +@lu.transformation2 +def _promote_scalar_residuals(f, *args, **kwargs): + jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) = f(*args, **kwargs) which = [f1 is None and f2 is None and not v.aval.shape for f1, f2, v in zip(in_fwds, out_fwds, jaxpr.constvars)] jaxpr = _promote_scalar_residuals_jaxpr(jaxpr, which) out_consts = [jax.lax.broadcast(x, (1,)) if not getattr(x, 'shape', ()) else x for x in out_consts] - yield jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) + return jaxpr, (in_fwds, out_fwds, out_pvals, out_consts, env) def _promote_scalar_residuals_jaxpr(jaxpr, which): @lu.wrap_init @@ -1660,6 +1699,8 @@ def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]: # TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule? def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn ) -> tuple[list[bool], core.JaxprEqn | None]: + if not any(used_outputs) and not pe.has_effects(eqn): + return [False] * len(eqn.invars), None mesh = eqn.params["mesh"] with core.extend_axis_env_nd(mesh.shape.items()): jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs) @@ -1726,13 +1767,13 @@ def _cached_shard_map(flat_fun, mesh, in_axes_flat, out_axes_thunk, axis_name): check_rep=False, auto=frozenset()), in_specs, out_specs) -@lu.transformation -def _handle_reshapes(in_axes, out_axes_thunk, *args, **kwargs): +@lu.transformation2 +def _handle_reshapes(f, in_axes, out_axes_thunk, *args, **kwargs): args = tree_map(lambda x, ax: x if ax is None else jnp.squeeze(x, axis=ax), list(args), list(in_axes)) - out = yield args, {} - yield tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax), - list(out), list(out_axes_thunk())) + out = f(*args) + return tree_map(lambda x, ax: x if ax is None else jnp.expand_dims(x, axis=ax), + list(out), list(out_axes_thunk())) def _axis_to_spec(axis_name, ax): if isinstance(ax, int): @@ -1853,27 +1894,28 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps) return _match_rep(fun, mesh, out_reps_src, out_reps_dst) -@lu.transformation_with_aux -def _efficient_transpose_rewrite_nomatch(mesh, in_reps, *args): +@lu.transformation_with_aux2 +def _efficient_transpose_rewrite_nomatch(f, store, mesh, in_reps, *args): with core.take_current_trace() as parent: tag = core.TraceTag() t = RewriteTrace(parent_trace = parent, tag = tag, mesh=mesh) in_tracers = map(partial(RewriteTracer, t), in_reps, args) with core.set_current_trace(t): - ans = yield in_tracers, {} + ans = f(*in_tracers) out_vals, out_reps = unzip2(map(t.to_val_rep_pair, ans)) del t, in_tracers, ans - yield out_vals, out_reps + store.store(out_reps) + return out_vals -@lu.transformation -def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args): - outs = yield args, {} +@lu.transformation2 +def _match_rep(f, mesh, out_reps_src_, out_reps_dst_, *args): + outs = f(*args) out_reps_src = out_reps_src_() if callable(out_reps_src_) else out_reps_src_ out_reps_dst = out_reps_dst_() if callable(out_reps_dst_) else out_reps_dst_ _check_reps2(mesh, out_reps_dst, out_reps_src) outs = [pbroadcast(x, tuple(n for n in src if n not in dst)) if src - dst else x for x, src, dst in zip(outs, out_reps_src, out_reps_dst)] - yield outs + return outs # TODO(mattjj): caching def _replication_rewrite_match( @@ -1899,16 +1941,17 @@ def _replication_rewrite_nomatch( jaxpr_, _, consts, () = pe.trace_to_jaxpr_dynamic(f, jaxpr.in_avals) return core.ClosedJaxpr(jaxpr_, consts), out_rep() -@lu.transformation_with_aux -def _rewrite_subtrace(tag, mesh, in_reps, *in_vals): +@lu.transformation_with_aux2 +def _rewrite_subtrace(f, store, tag, mesh, in_reps, *in_vals): with core.take_current_trace() as parent_trace: assert len(in_reps) == len(in_vals), (len(in_reps), len(in_vals)) t = RewriteTrace(parent_trace, tag, mesh) in_tracers = map(partial(RewriteTracer, t), in_reps, in_vals) with core.set_current_trace(t): - outs = yield in_tracers, {} - ans = unzip2(map(t.to_val_rep_pair, outs)) - yield ans + outs = f(*in_tracers) + out_vals, out_reps = unzip2(map(t.to_val_rep_pair, outs)) + store.store(out_reps) + return out_vals def _rewrite_bwd(bwd, mesh, in_reps, reps_dst): def new_bwd(*args): diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 7c5a966500f7..c83d9a667888 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -97,6 +97,7 @@ lax.sin_p, lax.sinh_p, lax.sqrt_p, + lax.square_p, lax.tan_p, lax.tanh_p, lax.convert_element_type_p, @@ -340,16 +341,17 @@ def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zero with core.set_current_trace(self): return fun.call_wrapped(*tracers) -@lu.transformation_with_aux -def sparsify_subtrace(tag, spenv, spvalues, *bufs): +@lu.transformation_with_aux2 +def sparsify_subtrace(f, store, tag, spenv, spvalues, *bufs): with core.take_current_trace() as parent: trace = SparseTrace(parent, tag, spenv) with core.set_current_trace(trace): in_tracers = [SparseTracer(trace, spvalue=spvalue) for spvalue in spvalues] - outs = yield in_tracers, {} + outs = f(*in_tracers) out_traces = [trace.to_sparse_tracer(out) for out in outs] buffers = spenv._buffers - yield buffers, [out._spvalue for out in out_traces] + store.store([out._spvalue for out in out_traces]) + return buffers def sparsify_fun(wrapped_fun, args: list[ArrayOrSparse]): tag = core.TraceTag() diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index feb70b5171be..02f0657cc371 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -127,6 +127,7 @@ sinh_p as sinh_p, sort_p as sort_p, sqrt_p as sqrt_p, + square_p as square_p, squeeze_p as squeeze_p, sub_p as sub_p, tan_p as tan_p, diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py index 74c52dddbae8..8b80d033fa5c 100644 --- a/jax/extend/linear_util.py +++ b/jax/extend/linear_util.py @@ -22,5 +22,7 @@ merge_linear_aux as merge_linear_aux, transformation as transformation, transformation_with_aux as transformation_with_aux, + transformation2 as transformation2, + transformation_with_aux2 as transformation_with_aux2, wrap_init as wrap_init, ) diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 1aa3ebc67b06..dca438996229 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -63,6 +63,7 @@ debug_info_final as debug_info_final, def_trivial_padding as def_trivial_padding, forwarding_rules as forwarding_rules, + has_effects as has_effects, infer_lambda_input_type as infer_lambda_input_type, instantiate_const_at as instantiate_const_at, make_jaxpr_effects as make_jaxpr_effects, diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index d2fb6a9bae3c..d569ed641138 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -206,6 +206,7 @@ sqrt as sqrt, sqrt_p as sqrt_p, square as square, + square_p as square_p, squeeze as squeeze, squeeze_p as squeeze_p, stop_gradient as stop_gradient, diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 3affd31e51d6..b312bca7a7d3 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -640,7 +640,7 @@ def TPU_SemaphoreSignalOp : TPU_Op<"sem_signal", [AttrSizedOperandSegments]> { Optional:$core_id // For megacore ); let assemblyFormat = [{ - $semaphore `,` $amount (`,` $device_id^)? (`,` $core_id^)? attr-dict `:` type($semaphore) + $semaphore `,` $amount (`device_id` $device_id^)? (`core_id` $core_id^)? attr-dict `:` type($semaphore) }]; let hasVerifier = 1; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 8cb01ee67ad4..8792503f4636 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -2664,7 +2664,6 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, const auto bitwidth = res_ty.getElementTypeBitWidth(); const int packing = res_layout->packing(); - SmallVector out_idx; vreg.Each([&](absl::Span idx, Value *v) { out_idx.assign(idx.begin(), idx.end()); @@ -2674,17 +2673,29 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, const VectorType vmask_ty = getNativeVregOrVmaskType( builder.getI1Type(), bitwidth, ctx.target_shape); if (tiling_dim.value() == 0) { // sublane - mask = builder.create( - op.getLoc(), vmask_ty, - ArrayRef{boundIdxConst(0), boundIdxConst(0)}, - ArrayRef{boundIdxConst(operand_offset * packing), - boundIdxConst(layout->tiling()[1])}); + if (operand_offset % packing != 0) { + // Packed case, degenerate where we have a half or quarter + // sublane. + // TODO(mvoz): We can probably always use the + // CreateSubelementMaskOp if (1) optimize it on TPUv4 and (2) Add + // support for unpacked types in some of the invariants in + // lower_to_llo. + mask = builder.create( + op.getLoc(), vmask_ty, 0, operand_offset, packing); + } else { + auto sublane_offset = operand_offset / packing; + mask = builder.create( + op.getLoc(), vmask_ty, + ArrayRef{boundIdxConst(0), boundIdxConst(0)}, + ArrayRef{boundIdxConst(sublane_offset), + boundIdxConst(layout->tiling()[1])}); + } } else { // lane mask = builder.create( op.getLoc(), vmask_ty, ArrayRef{boundIdxConst(0), boundIdxConst(0)}, ArrayRef{boundIdxConst(layout->tiling()[0]), - boundIdxConst(operand_offset * packing)}); + boundIdxConst(operand_offset)}); } // Blend the current value with the existing value in the output. *v = builder.create(op.getLoc(), mask, @@ -4712,6 +4723,11 @@ FailureOr> disassemble( TPU_ASSERT_LOC(val.getLoc(), def_layout.has_value()); TPU_ASSERT_LOC(val.getLoc(), def_layout->generalizes(layout, vty.getShape(), target_shape)); + auto layout_product = + xla::Product(layout.tileArrayShape(vty.getShape(), target_shape)); + auto def_layout_product = + xla::Product(def_layout->tileArrayShape(vty.getShape(), target_shape)); + TPU_ASSERT_LOC(val.getLoc(), layout_product == def_layout_product); // TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of // having `tileArrayShape` and `tileArrayImplicitShape`. SmallVector layout_shape = @@ -6313,11 +6329,50 @@ FailureOr> relayout(RewriteContext &ctx, if (src.generalizes(dst, vty.getShape(), target_shape)) { // A value with a replicated offset might use fewer vregs than a value with // a non-zero offset. - if (xla::Product(src.tileArrayShape(vty.getShape(), target_shape)) != - xla::Product(dst.tileArrayShape(vty.getShape(), target_shape))) { - return emitError(v.getLoc(), - "Not implemented: source layout is more general, but " - "vreg count changes"); + auto src_product = + xla::Product(src.tileArrayShape(vty.getShape(), target_shape)); + auto dst_product = + xla::Product(dst.tileArrayShape(vty.getShape(), target_shape)); + if (src_product != dst_product) { + TPU_ASSERT_LOC(v.getLoc(), dst_product > src_product); + auto src_offsets = src.offsets(); + + TPU_ASSERT_LOC(v.getLoc(), src_offsets != dst.offsets()); + TPU_ASSERT_LOC(v.getLoc(), src.bitwidth() == dst.bitwidth()); + + if (src.implicit_dim() != dst.implicit_dim()) { + return emitError(v.getLoc(), + "Not implemented: Source layout is more general, but " + "vreg count changes and implicit dims are mismatched"); + } + + if (src.tiling() != dst.tiling()) { + return emitError(v.getLoc(), + "Not implemented: Source layout is more general, but " + "vreg count changes and tiling are mismatched"); + } + + // This case is moving from a replicated to a non replicated layout. + // As such, we need to make a new destination shape that is the + // materialization of the src shape with replication. + FAILUREOR_ASSIGN_OR_RETURN(auto src_vregs, + disassemble(builder, src, v, target_shape, + /*use_implicit_shape=*/true)); + auto dst_vregs_shape = dst.tileArrayShape(vty.getShape(), target_shape); + xla::Array dst_vregs(dst_vregs_shape); + dst_vregs.Each([&](const absl::Span idx, Value *vreg) { + SmallVector local_idx(idx.begin(), idx.end()); + if (!src_offsets[0].has_value()) { + local_idx[local_idx.size() - 2] = 0; + } + if (!src_offsets[1].has_value()) { + local_idx[local_idx.size() - 1] = 0; + } + *vreg = src_vregs(local_idx); + }); + return assemble(builder, vty, dst, std::move(dst_vregs), target_shape, + /*use_implicit_shape=*/true) + .getResult(); } src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape)); return assemble(builder, vty, dst, std::move(src_tiles), target_shape, @@ -6400,8 +6455,6 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { if (vector_operand == nullptr) { continue; } - auto vty = vector_operand.getType(); - // The operand should always be an Operation (and not a BlockArgument) // since we expect the FuncOp to have only memrefs and semaphores as // arguments. @@ -6416,9 +6469,6 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { getOutLayouts(*def_op, ctx.target_shape)); const Layout lo = def_layouts[res_idx]; TPU_ASSERT_OP(lo.has_value()); - if (lo->generalizes(*li, vty.getShape(), ctx.target_shape)) { - continue; - } OpBuilder builder(&op); FAILUREOR_ASSIGN_OR_RETURN( Value new_v, relayout(ctx, builder, vector_operand, /*src=*/*lo, diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index 3f6050f31dab..fd68c9e6c95e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -92,6 +92,9 @@ LogicalResult semaphore_signal_rule(Operation* op, int version) { EnqueueDMAOp>::getOperandSegmentSizeAttr(), mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0})); } else if (op->getNumOperands() == 3) { // Remote signal. + // Hardcoding that one optional value is device_id, not core_id. This + // could misinterpret sem_signals where core_id is specified, but + // device_id isn't. op->setAttr(OpTrait::AttrSizedOperandSegments< EnqueueDMAOp>::getOperandSegmentSizeAttr(), mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0})); diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 2fb8f0103e65..1f78782a0891 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -185,9 +185,11 @@ pybind_extension( deps = [ "//jaxlib:kernel_nanobind_helpers", "//jaxlib/cuda:cuda_vendor", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/strings", "@nanobind", - "@xla//xla/service:custom_call_status", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", "@xla//xla/tsl/cuda:cudart", ], ) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 922d13d213f5..608270239882 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -13,19 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include #include #include "nanobind/nanobind.h" +#include "absl/cleanup/cleanup.h" #include "absl/strings/str_cat.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" -#include "xla/service/custom_call_status.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" namespace jax::cuda { namespace { +namespace ffi = xla::ffi; +namespace nb = nanobind; + static std::string ToString(CUresult result) { const char* error_name; if (cuGetErrorName(result, &error_name)) { @@ -38,45 +43,88 @@ static std::string ToString(CUresult result) { return absl::StrCat(error_name, ": ", error_string); } -void EventRecordCall(void* stream, void** buffers, char* opaque, - size_t opaque_len, XlaCustomCallStatus* status) { - auto* event = reinterpret_cast(opaque); - if (auto res = gpuEventRecord(**event, reinterpret_cast(stream)); - res) { - auto message = absl::StrCat("Failed to record event: ", ToString(res)); - XlaCustomCallStatusSetFailure(status, message.c_str(), message.size()); - } +// Ensure it is safe to store gpuEvent_t in a uint64_t buffer. +static_assert(sizeof(gpuEvent_t) <= sizeof(uint64_t)); + +static const auto* kEventRecord = + ffi::Ffi::Bind() + .Ctx>() + .Attr("copy_before") + .RemainingArgs() + .Ret>() // event + .RemainingRets() + .To([](gpuStream_t stream, bool copy_before, + auto remaining_args, auto ret, auto remaining_rets) { + static auto* event = new gpuEvent_t; + if (auto res = gpuEventCreate(event, GPU_EVENT_DEFAULT); + res) { + return ffi::Error::Internal( + absl::StrCat("Failed to create event: ", ToString(res))); + } + auto do_copy = [&]() { + gpuMemcpyAsync(ret->untyped_data(), event, + sizeof(gpuEvent_t), gpuMemcpyHostToDevice, stream); + }; + if (copy_before) { + do_copy(); + } + if (auto res = gpuEventRecord(*event, stream); res) { + return ffi::Error::Internal( + absl::StrCat("Failed to record event: ", ToString(res))); + } + if (!copy_before) { + do_copy(); + } + return ffi::Error::Success(); + }) + .release(); + +XLA_FFI_Error* EventRecord(XLA_FFI_CallFrame* call_frame) { + return kEventRecord->Call(call_frame); +} + +static const auto* kEventElapsed = + ffi::Ffi::Bind() + .Ctx>() + .Arg>() // start_event + .Arg>() // end_event + .Ret>() // elapsed_ms + .To([](gpuStream_t stream, auto start, auto end, auto out) { + gpuStreamSynchronize(stream); + auto start_event = std::make_unique(); + auto end_event = std::make_unique(); + absl::MakeCleanup([&]() { + gpuEventDestroy(*start_event); + gpuEventDestroy(*end_event); + }); + gpuMemcpy(start_event.get(), start.untyped_data(), sizeof(gpuEvent_t), + gpuMemcpyDeviceToHost); + gpuMemcpy(end_event.get(), end.untyped_data(), sizeof(gpuEvent_t), + gpuMemcpyDeviceToHost); + float elapsed; + if (auto res = + gpuEventElapsedTime(&elapsed, *start_event, *end_event); + res) { + return ffi::Error::Internal(absl::StrCat( + "Failed to get elapsed time between events: ", ToString(res))); + } + gpuMemcpy(out->untyped_data(), &elapsed, sizeof(float), + gpuMemcpyHostToDevice); + return ffi::Error::Success(); + }) + .release(); + +XLA_FFI_Error* EventElapsed(XLA_FFI_CallFrame* call_frame) { + return kEventElapsed->Call(call_frame); } NB_MODULE(_mosaic_gpu_ext, m) { - m.def("_gpu_event_create", []() { - gpuEvent_t* event = new gpuEvent_t(); - if (auto res = gpuEventCreate(event, GPU_EVENT_DEFAULT); res) { - throw std::runtime_error( - absl::StrCat("Failed to create event: ", ToString(res))); - } - return reinterpret_cast(event); - }); - m.def("_gpu_event_destroy", [](uintptr_t event) { - if (auto res = gpuEventDestroy(*reinterpret_cast(event)); - res) { - throw std::runtime_error( - absl::StrCat("Failed to destroy event: ", ToString(res))); - } - }); - m.def("_gpu_event_elapsed", [](uintptr_t start_event, uintptr_t end_event) { - float elapsed_ms = -1; - if (auto res = gpuEventElapsedTime( - &elapsed_ms, *reinterpret_cast(start_event), - *reinterpret_cast(end_event)); - res) { - throw std::runtime_error(absl::StrCat( - "Failed to get elapsed time between events: ", ToString(res))); - } - return elapsed_ms; + m.def("registrations", []() { + return nb::make_tuple( + nb::make_tuple("mgpu_event_record", EncapsulateFunction(EventRecord)), + nb::make_tuple("mgpu_event_elapsed", EncapsulateFunction(EventElapsed)) + ); }); - m.def("_record_event_capsule", - []() { return EncapsulateFunction(EventRecordCall); }); m.def("_sync_all_devices", []() { int devices = 0; if (cudaGetDeviceCount(&devices) != gpuSuccess) { diff --git a/tests/BUILD b/tests/BUILD index dc81c408c4ce..c80f63e6d7d6 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -321,6 +321,21 @@ jax_multiplatform_test( ], ) +jax_multiplatform_test( + name = "mock_gpu_topology_test", + srcs = ["mock_gpu_topology_test.py"], + enable_backends = ["gpu"], + enable_configs = [ + "gpu_h100", + ], + tags = [ + "config-cuda-only", + ], + deps = [ + "//jax:experimental", + ], +) + jax_multiplatform_test( name = "array_test", srcs = ["array_test.py"], @@ -1523,6 +1538,7 @@ jax_multiplatform_test( srcs = ["cudnn_fusion_test.py"], enable_backends = [], enable_configs = [ + "gpu_a100", "gpu_h100", ], tags = ["multiaccelerator"], diff --git a/tests/api_test.py b/tests/api_test.py index 8ab5d90f6e07..49cd33ee464c 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -287,13 +287,15 @@ def test_jit_default_device(self, module): self.assertEqual(f(sticky).devices(), system_default_devices) self.assertEqual(f(1).devices(), system_default_devices) - # TODO(skye): make this work! def test_jit_default_platform(self): - with self.assertRaisesWithLiteralMatch( - ValueError, "jax.default_device must be passed a Device object " - "(e.g. `jax.devices('cpu')[0]`), got: 'cpu'"): with jax.default_device("cpu"): - jax.jit(lambda x: x + 1)(1) + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, "cpu") + self.assertEqual(result.device, jax.local_devices(backend="cpu")[0]) + + result = jax.jit(lambda x: x + 1)(1) + self.assertEqual(result.device.platform, jax.default_backend()) + self.assertEqual(result.device, jax.local_devices()[0]) def test_complex_support(self): self.assertEqual(jit(lambda x: x + 1)(1 + 1j), 2 + 1j) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index 40c2181a9e3c..d10558afbe16 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -40,6 +40,8 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.compilation_cache_interface import CacheInterface +from jax._src.lib import xla_client as xc +from jax._src.lib import version as jaxlib_version from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P import numpy as np @@ -535,6 +537,42 @@ def test_backend_serialization_deserialization(self): self.assertEqual( executable.fingerprint, deserialized_executable.fingerprint) + def test_persistent_cache_enable_xla_caches(self): + if jaxlib_version <= (0, 4, 35): + self.skipTest("Test requires AutotuneCacheMode bindings") + with config.compilation_cache_dir("jax-cache"): + with config.persistent_cache_enable_xla_caches("none"): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, False) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + with config.persistent_cache_enable_xla_caches("all"): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "jax-cache/xla_gpu_kernel_cache_file") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, True) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "jax-cache/xla_gpu_per_fusion_autotune_cache_dir") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + with config.persistent_cache_enable_xla_caches("xla_gpu_kernel_cache_file"): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "jax-cache/xla_gpu_kernel_cache_file") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, True) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + with config.persistent_cache_enable_xla_caches("xla_gpu_per_fusion_autotune_cache_dir"): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, False) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "jax-cache/xla_gpu_per_fusion_autotune_cache_dir") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) @jtu.with_config( jax_enable_compilation_cache=False, @@ -570,5 +608,17 @@ def test_tasks_disable_cache_metric(self): "/jax/compilation_cache/task_disabled_cache"] self.assertEqual(count_after_second_use, count_after_first_use) + def test_persistent_cache_enable_xla_caches_disabled(self): + if jaxlib_version <= (0, 4, 35): + self.skipTest("Test requires AutotuneCacheMode bindings") + with config.enable_compilation_cache(False): + compile_options = compiler.get_compile_options( + num_replicas=1, num_partitions=1 + ) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_kernel_cache_file, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism, False) + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir, "") + self.assertEqual(compile_options.executable_build_options.debug_options.xla_gpu_experimental_autotune_cache_mode, xc.AutotuneCacheMode.UPDATE) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/cudnn_fusion_test.py b/tests/cudnn_fusion_test.py index 151cb72be8dc..7dc0571bc172 100644 --- a/tests/cudnn_fusion_test.py +++ b/tests/cudnn_fusion_test.py @@ -15,6 +15,7 @@ from absl.testing import absltest, parameterized from unittest import SkipTest from jax._src import test_util as jtu +from jax._src.lib import cuda_versions import jax import jax.numpy as jnp from jax._src.cudnn import cudnn_fusion @@ -26,8 +27,9 @@ class CudnnFusionTest(jtu.JaxTestCase): def setUp(self): if (not jtu.test_device_matches(["cuda"]) or - not jtu.is_cuda_compute_capability_at_least("9.0")): - self.skipTest("Only works on >= sm90 GPUs") + not jtu.is_cuda_compute_capability_at_least("8.0") or + cuda_versions.cudnn_get_version() < 90110): + self.skipTest("Only works on >= sm80 GPUs with cuDNN 9.1.1+") super().setUp() @parameterized.parameters(["", "pmap"]) diff --git a/tests/lax_test.py b/tests/lax_test.py index 17132996c429..14f453b38e7c 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -4362,12 +4362,6 @@ def regions_with_inaccuracies_keep(*to_keep): elif name == 'sign': regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4') - elif name == 'square': - if is_cuda: - regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.real', 'pinf.real', 'ninfj.real', 'pinfj.real') - if is_cpu: - regions_with_inaccuracies_keep('ninf.real', 'pinf.real', 'q1.real', 'q2.real', 'q3.real', 'q4.real') - elif name == 'log': regions_with_inaccuracies_keep('q1.real', 'q2.real', 'q3.real', 'q4.real', 'ninf.imag', 'pinf.imag', 'ninfj.imag', 'pinfj.imag') @@ -4411,7 +4405,7 @@ def regions_with_inaccuracies_keep(*to_keep): regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag') elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p', 'tan', - 'arcsinh', 'arcsin', 'arccosh', 'arctan', 'arctanh'}: + 'arcsinh', 'arcsin', 'arccosh', 'arctan', 'arctanh', 'square'}: regions_with_inaccuracies.clear() else: assert 0 # unreachable diff --git a/tests/layout_test.py b/tests/layout_test.py index 31f3d71d0537..afddab916723 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -655,6 +655,44 @@ def f(x): f(sparecore_arr) + def test_sparsecore_and_host_compute(self): + if not ( + jax.devices()[0].device_kind == 'TPU v5' + or jtu.is_device_tpu_at_least(6) + ): + self.skipTest('Does not have a sparsecore present') + shape = (128, 128) + inp = jnp.arange(math.prod(shape)).reshape(shape) + s = SingleDeviceSharding(jax.devices()[0]) + + sparse_dll = DLL(major_to_minor=(0, 1), _tiling=((8,),)) + sparse_layout = Layout(sparse_dll, s) + sparecore_arr = jax.device_put(inp, sparse_layout) + + host_dll = DLL(major_to_minor=(0, 1), _tiling=((1,),)) + host_layout = Layout(host_dll, s) + host_arr = jax.device_put(inp, host_layout) + + @compute_on('tpu_sparsecore') + @jax.jit + def sparsecore_compute(x): + return x * x + + @compute_on('device_host') + @jax.jit + def host_compute(x): + return x + x + + @partial( + jax.jit, + in_shardings=(sparse_layout, host_layout), + out_shardings=(sparse_layout, host_layout), + ) + def f(x, y): + return sparsecore_compute(x), host_compute(y) + + f(sparecore_arr, host_arr) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 5ace4b5ecf18..d3fe8f476722 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -53,6 +53,22 @@ def _is_required_cuda_version_satisfied(cuda_version): else: return int(version.split()[-1]) >= cuda_version + +def osp_linalg_toeplitz(c: np.ndarray, r: np.ndarray | None = None) -> np.ndarray: + """scipy.linalg.toeplitz with v1.17+ batching semantics.""" + if scipy_version >= (1, 17, 0): + return scipy.linalg.toeplitz(c, r) + elif r is None: + c = np.atleast_1d(c) + return np.vectorize( + scipy.linalg.toeplitz, signature="(m)->(m,m)", otypes=(c.dtype,))(c) + else: + c = np.atleast_1d(c) + r = np.atleast_1d(r) + return np.vectorize( + scipy.linalg.toeplitz, signature="(m),(n)->(m,n)", otypes=(np.result_type(c, r),))(c, r) + + class NumpyLinalgTest(jtu.JaxTestCase): @jtu.sample_product( @@ -1990,11 +2006,11 @@ def testSqrtmEdgeCase(self, diag, expected, dtype): self.assertAllClose(root, expected, check_dtypes=False) @jtu.sample_product( - cshape=[(), (4,), (8,), (3, 7), (0, 5, 1)], + cshape=[(), (4,), (8,), (4, 7), (2, 1, 5)], cdtype=float_types + complex_types, - rshape=[(), (3,), (7,), (2, 1, 4), (19, 0)], + rshape=[(), (3,), (7,), (4, 4), (2, 4, 0)], rdtype=float_types + complex_types + int_types) - def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype): + def testToeplitzConstruction(self, rshape, rdtype, cshape, cdtype): if ((rdtype in [np.float64, np.complex128] or cdtype in [np.float64, np.complex128]) and not config.enable_x64.value): @@ -2007,10 +2023,11 @@ def testToeplitzConstrcution(self, rshape, rdtype, cshape, cdtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(cshape, cdtype), rng(rshape, rdtype)] - with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]): - self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz), - jsp.linalg.toeplitz, args_maker) - self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) + with jax.numpy_rank_promotion("allow"): + with jtu.strict_promotion_if_dtypes_match([rdtype, cdtype]): + self._CheckAgainstNumpy(jtu.promote_like_jnp(osp_linalg_toeplitz), + jsp.linalg.toeplitz, args_maker) + self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) @jtu.sample_product( shape=[(), (3,), (1, 4), (1, 5, 9), (11, 0, 13)], @@ -2028,8 +2045,7 @@ def testToeplitzSymmetricConstruction(self, shape, dtype): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(jtu.promote_like_jnp(osp.linalg.toeplitz), - jsp.linalg.toeplitz, args_maker) + self._CheckAgainstNumpy(osp_linalg_toeplitz, jsp.linalg.toeplitz, args_maker) self._CompileAndCheck(jsp.linalg.toeplitz, args_maker) def testToeplitzConstructionWithKnownCases(self): diff --git a/tests/mock_gpu_topology_test.py b/tests/mock_gpu_topology_test.py new file mode 100644 index 000000000000..44ec4e2f9529 --- /dev/null +++ b/tests/mock_gpu_topology_test.py @@ -0,0 +1,60 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import jax +from jax._src import test_util as jtu +import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P + +jax.config.parse_flags_with_absl() + +NUM_SLICES = 2 +NUM_HOSTS_PER_SLICE = 4 + + +@jtu.with_config( + jax_mock_gpu_topology=f"{NUM_SLICES}x{NUM_HOSTS_PER_SLICE}x1", + jax_cuda_visible_devices="0") +class MockGPUTopologyTest(jtu.JaxTestCase): + + def setUp(self): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Mocking devices only works on the GPU backend.") + super().setUp() + + @jtu.skip_under_pytest("Test must run in an isolated process") + def testMockDeviceCount(self): + self.assertEqual(jax.device_count(), NUM_SLICES * NUM_HOSTS_PER_SLICE) + + @jtu.skip_under_pytest("Test must run in an isolated process") + def testMockWithSharding(self): + mesh = jax.sharding.Mesh(jax.devices(), ('x',)) + f = jax.jit(jnp.sum, + in_shardings=NamedSharding(mesh, P('x')), + out_shardings=NamedSharding(mesh, P())) + + f_lowered = f.lower(jnp.arange(16)) + hlo = f_lowered.compiler_ir() + + mocked_count = NUM_SLICES * NUM_HOSTS_PER_SLICE + self.assertIn(f'num_partitions = {mocked_count}', str(hlo)) + self.assertIn( + f'sharding = "{{devices=[{mocked_count}]<=[{mocked_count}]}}"', + str(hlo) + ) + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index cbbe8da54972..83202937503d 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -1146,6 +1146,37 @@ def kernel_body(x_smem, o_smem): ) np.testing.assert_array_equal(kernel_fn(x), x + 1.0) + def test_emit_with_grid_invariant_output(self): + num_steps = 4 + + def kernel(x_gmem, o_gmem): + plgpu.emit_pipeline( + kernel_body, + in_specs=[pl.BlockSpec((32, 16), lambda i: (0, i))], + out_specs=[pl.BlockSpec((32, 16), lambda i: (0, 0))], + grid=(num_steps,), + max_concurrent_steps=2, + )(x_gmem, o_gmem) + + def kernel_body(x_smem, o_smem): + o_smem[...] = x_smem[...] + 1.0 + + x = jnp.arange(32 * num_steps * 16) + x = x.reshape(-1, num_steps * 16).astype(jnp.float32) + kernel_fn = pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=plgpu.GMEM)], + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), + ) + y = jnp.empty_like(x) + for i in range(num_steps): + i_slice = slice(16 * i, 16 * (i + 1)) + y = y.at[:, :16].set(x[:, i_slice] + 1) + # We only compare the elements in the first 16 columns, because the rest + # are never written to. + np.testing.assert_array_equal(kernel_fn(x)[:, :16], y[:, :16]) + def test_emit_with_parallel_grid(self): self.skipTest("Enable once we support multiple levels of indexing") diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index b8a42ecf1835..41670137c39f 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -1927,6 +1927,47 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, np.triu(x, k=k)) + @parameterized.parameters( + (jnp.float16, jnp.float16), # Noop + (jnp.int16, jnp.float16), + (jnp.int16, jnp.bfloat16), + (jnp.float32, jnp.int32), + ) + def test_bitcast_convert_type(self, in_dtype, out_dtype): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + + m, n = 4, 4 + out_shape = jax.ShapeDtypeStruct((m, n), out_dtype) + grid = () + + @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) + def convert(x_ref, y_ref): + y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_shape) + + x = jnp.arange(m * n, dtype=in_dtype).reshape((m, n)) + y = convert(x) + y_ref = jax.lax.bitcast_convert_type(x, out_dtype) + np.testing.assert_array_equal(y, y_ref) + + def test_bitcast_convert_type_scalar(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Not implemented on TPU") + + x = jnp.int32(42) + out_dtype = jnp.float32 + out_shape = jax.ShapeDtypeStruct(x.shape, out_dtype) + grid = () + + @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) + def convert(x_ref, y_ref): + y_ref[...] = jax.lax.bitcast_convert_type(x_ref[...], out_dtype) + + y = convert(x) + y_ref = jax.lax.bitcast_convert_type(x, out_dtype) + np.testing.assert_array_equal(y, y_ref) + + class OpsInterpretTest(OpsTest): INTERPRET = True diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 5f0c28541b62..fa574df18f29 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -12,50 +12,80 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import ExitStack from functools import partial import glob import logging import math import os +import shutil import tempfile -import unittest from absl.testing import absltest import jax +from jax._src import api +from jax._src import compilation_cache as cc from jax._src import config -from jax._src import profiler -from jax._src import pjit from jax._src import monitoring +from jax._src import pjit +from jax._src import profiler from jax._src import test_util as jtu -from jax._src import api from jax.experimental import profiler as exp_profiler -import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec -from jax._src import compilation_cache as cc -import numpy as np - from jax.experimental.serialize_executable import ( deserialize_and_load, serialize, ) +import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec +import numpy as np jax.config.parse_flags_with_absl() -dump_dir = tempfile.TemporaryDirectory().name -os.environ['XLA_FLAGS'] = ( - f'--xla_dump_to={dump_dir}' - ' --xla_gpu_experimental_dump_fdo_profiles=true' - ' --xla_gpu_enable_latency_hiding_scheduler=true' -) @jtu.pytest_mark_if_available('multiaccelerator') class PgleTest(jtu.JaxTestCase): + _dump_exit_stack: ExitStack | None = None + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._dump_exit_stack = ExitStack() + + cls.dump_dir = cls._dump_exit_stack.enter_context(tempfile.TemporaryDirectory()) + if 'XLA_FLAGS' in os.environ: + cls.old_xla_flags = os.environ['XLA_FLAGS'] + else: + cls.old_xla_flags = None + + os.environ['XLA_FLAGS'] = ( + f'--xla_dump_to={cls.dump_dir}' + ' --xla_gpu_experimental_dump_fdo_profiles=true' + ' --xla_gpu_enable_latency_hiding_scheduler=true' + # TODO(patrios): Remove this flag once b/376647494 is fixed. + ' --xla_gpu_graph_level=0' + ) + if cls.old_xla_flags: + os.environ['XLA_FLAGS'] += ' ' + cls.old_xla_flags + + @classmethod + def tearDownClass(cls): + if cls.old_xla_flags: + os.environ['XLA_FLAGS'] = cls.old_xla_flags + cls._dump_exit_stack.close() + super().tearDownClass() + def setUp(self): super().setUp() cc.set_cache_dir(None) cc.reset_cache() def tearDown(self): + # Cleanup dump directory + for file in os.listdir(self.dump_dir): + file_path = os.path.join(self.dump_dir, file) + if os.path.isfile(file_path): + os.remove(file_path) + cc.set_cache_dir(None) super().tearDown() @@ -87,7 +117,6 @@ def f(x, y): self.assertIsNotNone(fdo_profile) self.assertIn(b'custom', fdo_profile) - @unittest.skip("Test failing in CI") def testPGLEProfilerGetFDOProfileLarge(self): mesh = jtu.create_mesh((2,), ('x',)) its = 500 @@ -106,14 +135,10 @@ def f(x): shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) - with config.pgle_profiling_runs(0): - f_lowered = f.lower(x) - f_compiled = f_lowered.compile() - pgle_profiler = profiler.PGLEProfiler(1, 90) with config.enable_pgle(False): with profiler.PGLEProfiler.trace(pgle_profiler): - f_compiled(x) + f(x) fdo_profile = pgle_profiler.consume_fdo_profile() self.assertEqual(fdo_profile.count(b'custom'), its) @@ -177,7 +202,6 @@ def f(x): self.assertArraysEqual(compiled(x), expected) self.assertEqual(cache_miss_count[0], 0) - @unittest.skip("Test failing in CI") def testAutoPgleWithPersistentCache(self): its = 50 mesh = jtu.create_mesh((2,), ('x',)) @@ -196,8 +220,6 @@ def f(x): shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) - profilers_dict = ( - pjit._most_recent_pjit_call_executable.weak_pgle_profiler_dict) with (config.enable_compilation_cache(True), config.enable_pgle(True), config.raise_persistent_cache_errors(True), @@ -206,11 +228,12 @@ def f(x): config.persistent_cache_min_compile_time_secs(0), config.pgle_profiling_runs(2), tempfile.TemporaryDirectory() as cache_dir): + cc.reset_cache() cc.set_cache_dir(cache_dir) # Run 1: Module should be compiled without FDO with jtu.count_cached_compilation_cache_miss() as cache_miss_count: f(x) - self.assertEqual(cache_miss_count[0], 1) + self.assertGreater(cache_miss_count[0], 0) # Non-pgle profiled version of module should be saved non_pgle_profiled_files = os.listdir(cache_dir) @@ -221,26 +244,24 @@ def f(x): f(x) self.assertEqual(cache_miss_count[0], 0) - module_before_pgle = os.listdir(dump_dir) - print(module_before_pgle) + module_before_pgle = os.listdir(self.dump_dir) self.assertNotEmpty(module_before_pgle) # Run 3: Module should be compiled with FDO and stored to persistent cache with jtu.count_cached_compilation_cache_miss() as cache_miss_count: - # Add xla_dump_to to env flags f(x) - self.assertEqual(cache_miss_count[0], 1) + self.assertGreater(cache_miss_count[0], 0) # Check if FDO profile file of the biggest module is not empty module_after_pgle = [ x - for x in os.listdir(dump_dir) + for x in os.listdir(self.dump_dir) if x not in module_before_pgle ] self.assertNotEmpty(module_after_pgle) biggest_module_after_pgle = max( module_after_pgle, key=lambda x: os.path.getsize( - os.path.join(dump_dir, x) + os.path.join(self.dump_dir, x) ), ) base_module_name = '.'.join(biggest_module_after_pgle.split('.')[0:1]) @@ -251,10 +272,10 @@ def f(x): '.fdo_profile' ): self.assertGreater( - os.path.getsize(os.path.join(dump_dir, module)), 0 + os.path.getsize(os.path.join(self.dump_dir, module)), 0 ) - for pgle_profiler in profilers_dict.values(): + for pgle_profiler in pjit._pgle_profiler_dict.values(): self.assertTrue(pgle_profiler.is_enabled()) self.assertTrue(pgle_profiler.is_fdo_consumed()) @@ -266,10 +287,14 @@ def f(x): # Removing non-pgle profiled module from cache to check that later pgle # profiled version will be used. for non_pgle_file in non_pgle_profiled_files: - os.remove(os.path.join(cache_dir, non_pgle_file)) + path = os.path.join(cache_dir, non_pgle_file) + if os.path.isfile(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) api.clear_caches() - profilers_dict.clear() + pjit._pgle_profiler_dict.clear() # Run 4: Persistent compilation cache should be hit PGLE profiler should # be disabled @@ -283,7 +308,7 @@ def check_if_cache_hit(event): f(x) monitoring._unregister_event_listener_by_callback(check_if_cache_hit) - self.assertEqual(cache_hit, 1) + self.assertGreater(cache_hit, 0) def testPassingFDOProfile(self): mesh = jtu.create_mesh((2,), ('x',)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index a9760d02fc0f..8a63bbe39099 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -17,6 +17,7 @@ import re from functools import partial import logging +import json import math import textwrap import threading @@ -59,6 +60,7 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -3825,6 +3827,16 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P())) self.assertEqual(out.sharding.memory_kind, 'device') + @unittest.skipIf(xla_extension_version < 297, "Requires jaxlib 0.4.36+") + def test_jit_static_argnames_non_interned(self): + def do_nothing(foobar: int): + return foobar + + argname = "foobar" + # Has the side effect of ensuring argname is not interned. + argname = str(json.loads(json.dumps(argname))) + jax.jit(do_nothing, static_argnames=[argname])(foobar=2) # doesn't crash + def test_most_recent_executable_outer_inner_cache(self): x = np.zeros((20, 20), dtype=jnp.float64) @@ -5189,6 +5201,30 @@ def f(x): self.assertArraysEqual(out, np_inp) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + def test_shard_map_full_manual(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + + def g(x, y): + self.assertTrue(x.sharding.mesh.are_all_axes_collective) + self.assertTrue(y.sharding.mesh.are_all_axes_collective) + return x * y + + @jax.jit + def f(x, y): + z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec), + out_specs=P('x', 'y'))(x, y) + self.assertEqual(z.sharding.spec, P('x', 'y')) + out = z * 2 + self.assertEqual(out.sharding.spec, P('x', 'y')) + return out + + out = f(arr, arr2) + self.assertArraysEqual(out, (np_inp * np_inp) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 48850c8da66a..df24315ce110 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -993,6 +993,63 @@ def body(c, _): shard_map(g, mesh, in_specs=(P(None), P('x'), P(('x', 'y'))), out_specs=[P(None), P(None), P(('x', 'y'))])(x, x, x) + def test_cond_rep_rule(self): + mesh = jtu.create_mesh((2, 2,), ('x', 'y')) + x = jnp.arange(4) + + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return x + 1 + return jax.lax.cond(True, true_fn, false_fun, x, y) + + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) + + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return y + return jax.lax.cond(True, true_fn, false_fun, x, y) + + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return x + 1 + return jax.lax.cond(jnp.any(x > 0), true_fn, false_fun, x, y) + + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(None))(x, x) + + def f(x, y): + def true_fn(x, y): + return x + def false_fun(x, y): + return x + 1 + return jax.lax.cond(jnp.any(y > 0), true_fn, false_fun, x, y) + + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P(('x', 'y')))(x, x) + with self.assertRaisesRegex(ValueError, "require replication"): + shard_map(f, mesh, in_specs=(P('x'), P('y')), out_specs=P('x'))(x, x) + + # https://github.com/jax-ml/jax/issues/24418 + def f(a): + c = jax.lax.cond(jnp.any(a), lambda: 1, lambda: 0) + return jnp.reshape(c, a.shape) + + mesh = jtu.create_mesh((2,), ('x',)) + a = jnp.array([True, False]) + shard_map(f, mesh, in_specs=P('x'), out_specs=P('x'))(a) + def test_eager_custom_jvp_basic(self): @jax.custom_jvp def foo(x): diff --git a/tests/util_test.py b/tests/util_test.py index 5f07d2f50880..5e99fff4b347 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -42,8 +42,8 @@ def f(*args, **kwargs): assert not kwargs return tuple(a * factor for a in args) - @lu.transformation_with_aux - def kw_to_positional(factor, *args, **kwargs): + @lu.transformation_with_aux2 + def kw_to_positional(f, store, factor, *args, **kwargs): """A transformation with auxiliary output. Turns all keyword parameters into positional ones. @@ -55,12 +55,12 @@ def kw_to_positional(factor, *args, **kwargs): kwargs_keys = kwargs.keys() new_args = tuple(kwargs[k] for k in kwargs_keys) new_kwargs = dict(factor=factor) - results = yield args + new_args, new_kwargs # Yield transformed (args, kwargs) + results = f(*(args + new_args), **new_kwargs) # Yield transformed (args, kwargs) # Assume results correspond 1:1 to the args + new_args assert len(results) == len(args) + len(new_args) aux_output = len(new_args) - yield (results[0:len(args)], - dict(zip(kwargs_keys, results[len(args):]))), aux_output + store.store(aux_output) + return (results[0:len(args)], dict(zip(kwargs_keys, results[len(args):]))) wf = lu.wrap_init(f) # Wraps `f` as a `WrappedFun`. wf, out_thunk = kw_to_positional(wf, 2) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f74c74077198..fdb6b1607816 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "e93a258e4494231626c7d3b6a6447e746ea72f9c" -XLA_SHA256 = "99f3a6b06230becf013f00009afeee4c89f52818e7a4a1ea4851157dc853830e" +XLA_COMMIT = "2a7890387f812c17fb5f17eec961ee52ac3e059d" +XLA_SHA256 = "cfe1eebc643355f55e6422451cbd750ac6a7f096ed8d6a0605238e4d8ce6d0d1" def repo(): tf_http_archive(